Fix: Enforce default embedding model in create_dataset / update_dataset (#8486)

### What problem does this PR solve?

Previous:
- Defaulted to hardcoded model 'BAAI/bge-large-zh-v1.5@BAAI'
- Did not respect user-configured default embedding_model

Now:
- Correctly prioritizes user-configured default embedding_model

Other:
- Make embedding_model optional in CreateDatasetReq with proper None
handling
- Add default embedding model fallback in dataset update when empty
- Enhance validation utils to handle None values and string
normalization
- Update SDK default embedding model to None to match API changes
- Adjust related test cases to reflect new validation rules

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Liu An
2025-06-25 16:41:32 +08:00
committed by GitHub
parent 340354b79c
commit dac5bcdf17
7 changed files with 52 additions and 30 deletions

View File

@ -260,19 +260,21 @@ class TestDatasetCreate:
@pytest.mark.parametrize(
"name, embedding_model",
[
("empty", ""),
("space", " "),
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_format(self, HttpApiAuth, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 101, res
if name == "missing_at":
if name in ["empty", "space", "missing_at"]:
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@ -288,8 +290,8 @@ class TestDatasetCreate:
def test_embedding_model_none(self, HttpApiAuth):
payload = {"name": "embedding_model_none", "embedding_model": None}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 101, res
assert "Input should be a valid string" in res["message"], res
assert res["code"] == 0, res
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
@pytest.mark.p1
@pytest.mark.parametrize(

View File

@ -300,20 +300,22 @@ class TestDatasetUpdate:
@pytest.mark.parametrize(
"name, embedding_model",
[
("empty", ""),
("space", " "),
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_format(self, HttpApiAuth, add_dataset_func, name, embedding_model):
dataset_id = add_dataset_func
payload = {"name": name, "embedding_model": embedding_model}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 101, res
if name == "missing_at":
if name in ["empty", "space", "missing_at"]:
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@ -323,8 +325,11 @@ class TestDatasetUpdate:
dataset_id = add_dataset_func
payload = {"embedding_model": None}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid string" in res["message"], res
assert res["code"] == 0, res
res = list_datasets(HttpApiAuth)
assert res["code"] == 0, res
assert res["data"][0]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
@pytest.mark.p1
@pytest.mark.parametrize(

View File

@ -217,19 +217,21 @@ class TestDatasetCreate:
@pytest.mark.parametrize(
"name, embedding_model",
[
("empty", ""),
("space", " "),
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_format(self, client, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
if name == "missing_at":
if name in ["empty", "space", "missing_at"]:
assert "Embedding model identifier must follow <model_name>@<provider> format" in str(excinfo.value), str(excinfo.value)
else:
assert "Both model_name and provider must be non-empty strings" in str(excinfo.value), str(excinfo.value)
@ -243,9 +245,8 @@ class TestDatasetCreate:
@pytest.mark.p2
def test_embedding_model_none(self, client):
payload = {"name": "embedding_model_none", "embedding_model": None}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value)
dataset = client.create_dataset(**payload)
assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset)
@pytest.mark.p1
@pytest.mark.parametrize(

View File

@ -207,30 +207,34 @@ class TestDatasetUpdate:
@pytest.mark.parametrize(
"name, embedding_model",
[
("empty", ""),
("space", " "),
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_format(self, add_dataset_func, name, embedding_model):
dataset = add_dataset_func
with pytest.raises(Exception) as excinfo:
dataset.update({"name": name, "embedding_model": embedding_model})
error_msg = str(excinfo.value)
if name == "missing_at":
if name in ["empty", "space", "missing_at"]:
assert "Embedding model identifier must follow <model_name>@<provider> format" in error_msg, error_msg
else:
assert "Both model_name and provider must be non-empty strings" in error_msg, error_msg
@pytest.mark.p2
def test_embedding_model_none(self, add_dataset_func):
def test_embedding_model_none(self, client, add_dataset_func):
dataset = add_dataset_func
with pytest.raises(Exception) as excinfo:
dataset.update({"embedding_model": None})
assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value)
dataset.update({"embedding_model": None})
assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset)
retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(retrieved_dataset)
@pytest.mark.p1
@pytest.mark.parametrize(