mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: Enforce default embedding model in create_dataset / update_dataset (#8486)
### What problem does this PR solve? Previous: - Defaulted to hardcoded model 'BAAI/bge-large-zh-v1.5@BAAI' - Did not respect user-configured default embedding_model Now: - Correctly prioritizes user-configured default embedding_model Other: - Make embedding_model optional in CreateDatasetReq with proper None handling - Add default embedding model fallback in dataset update when empty - Enhance validation utils to handle None values and string normalization - Update SDK default embedding model to None to match API changes - Adjust related test cases to reflect new validation rules ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -260,19 +260,21 @@ class TestDatasetCreate:
|
||||
@pytest.mark.parametrize(
|
||||
"name, embedding_model",
|
||||
[
|
||||
("empty", ""),
|
||||
("space", " "),
|
||||
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
||||
("missing_model_name", "@BAAI"),
|
||||
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
||||
("whitespace_only_model_name", " @BAAI"),
|
||||
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
||||
],
|
||||
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||
)
|
||||
def test_embedding_model_format(self, HttpApiAuth, name, embedding_model):
|
||||
payload = {"name": name, "embedding_model": embedding_model}
|
||||
res = create_dataset(HttpApiAuth, payload)
|
||||
assert res["code"] == 101, res
|
||||
if name == "missing_at":
|
||||
if name in ["empty", "space", "missing_at"]:
|
||||
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
|
||||
else:
|
||||
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
||||
@ -288,8 +290,8 @@ class TestDatasetCreate:
|
||||
def test_embedding_model_none(self, HttpApiAuth):
|
||||
payload = {"name": "embedding_model_none", "embedding_model": None}
|
||||
res = create_dataset(HttpApiAuth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid string" in res["message"], res
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -300,20 +300,22 @@ class TestDatasetUpdate:
|
||||
@pytest.mark.parametrize(
|
||||
"name, embedding_model",
|
||||
[
|
||||
("empty", ""),
|
||||
("space", " "),
|
||||
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
||||
("missing_model_name", "@BAAI"),
|
||||
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
||||
("whitespace_only_model_name", " @BAAI"),
|
||||
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
||||
],
|
||||
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||
)
|
||||
def test_embedding_model_format(self, HttpApiAuth, add_dataset_func, name, embedding_model):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"name": name, "embedding_model": embedding_model}
|
||||
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
if name == "missing_at":
|
||||
if name in ["empty", "space", "missing_at"]:
|
||||
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
|
||||
else:
|
||||
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
||||
@ -323,8 +325,11 @@ class TestDatasetUpdate:
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"embedding_model": None}
|
||||
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid string" in res["message"], res
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(HttpApiAuth)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user