diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index e3675b8cd..0be206ec0 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -347,6 +347,8 @@ def update(tenant_id, dataset_id): return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") if "embd_id" in req: + if not req["embd_id"]: + req["embd_id"] = kb.embd_id if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") ok, err = verify_embedding_availability(req["embd_id"], tenant_id) diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 3c53a6332..d87d8945d 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -380,7 +380,7 @@ class CreateDatasetReq(Base): name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] avatar: str | None = Field(default=None, max_length=65535) description: str | None = Field(default=None, max_length=65535) - embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] + embedding_model: str | None = Field(default=None, max_length=255, serialization_alias="embd_id") permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) chunk_method: ChunkMethodEnum = Field(default=ChunkMethodEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") parser_config: ParserConfig | None = Field(default=None) @@ -435,9 +435,16 @@ class CreateDatasetReq(Base): else: raise PydanticCustomError("format_invalid", "Missing MIME prefix. Expected format: data:;base64,") + @field_validator("embedding_model", mode="before") + @classmethod + def normalize_embedding_model(cls, v: Any) -> Any: + if isinstance(v, str): + return v.strip() + return v + @field_validator("embedding_model", mode="after") @classmethod - def validate_embedding_model(cls, v: str) -> str: + def validate_embedding_model(cls, v: str | None) -> str | None: """ Validates embedding model identifier format compliance. @@ -464,16 +471,17 @@ class CreateDatasetReq(Base): Invalid: "@openai" (empty model_name) Invalid: "text-embedding-3-large@" (empty provider) """ - if "@" not in v: - raise PydanticCustomError("format_invalid", "Embedding model identifier must follow @ format") + if isinstance(v, str): + if "@" not in v: + raise PydanticCustomError("format_invalid", "Embedding model identifier must follow @ format") - components = v.split("@", 1) - if len(components) != 2 or not all(components): - raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings") + components = v.split("@", 1) + if len(components) != 2 or not all(components): + raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings") - model_name, provider = components - if not model_name.strip() or not provider.strip(): - raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings") + model_name, provider = components + if not model_name.strip() or not provider.strip(): + raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings") return v @field_validator("permission", mode="before") diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index 5b65d6201..95020fda3 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -53,7 +53,7 @@ class RAGFlow: name: str, avatar: Optional[str] = None, description: Optional[str] = None, - embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", + embedding_model: Optional[str] = None, permission: str = "me", chunk_method: str = "naive", parser_config: Optional[DataSet.ParserConfig] = None, diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py index b3b3f9b8a..22772ad68 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -260,19 +260,21 @@ class TestDatasetCreate: @pytest.mark.parametrize( "name, embedding_model", [ + ("empty", ""), + ("space", " "), ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), ("missing_model_name", "@BAAI"), ("missing_provider", "BAAI/bge-large-zh-v1.5@"), ("whitespace_only_model_name", " @BAAI"), ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), ], - ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], ) def test_embedding_model_format(self, HttpApiAuth, name, embedding_model): payload = {"name": name, "embedding_model": embedding_model} res = create_dataset(HttpApiAuth, payload) assert res["code"] == 101, res - if name == "missing_at": + if name in ["empty", "space", "missing_at"]: assert "Embedding model identifier must follow @ format" in res["message"], res else: assert "Both model_name and provider must be non-empty strings" in res["message"], res @@ -288,8 +290,8 @@ class TestDatasetCreate: def test_embedding_model_none(self, HttpApiAuth): payload = {"name": "embedding_model_none", "embedding_model": None} res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 101, res - assert "Input should be a valid string" in res["message"], res + assert res["code"] == 0, res + assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res @pytest.mark.p1 @pytest.mark.parametrize( diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py index 152788000..36d55795f 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py @@ -300,20 +300,22 @@ class TestDatasetUpdate: @pytest.mark.parametrize( "name, embedding_model", [ + ("empty", ""), + ("space", " "), ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), ("missing_model_name", "@BAAI"), ("missing_provider", "BAAI/bge-large-zh-v1.5@"), ("whitespace_only_model_name", " @BAAI"), ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), ], - ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], ) def test_embedding_model_format(self, HttpApiAuth, add_dataset_func, name, embedding_model): dataset_id = add_dataset_func payload = {"name": name, "embedding_model": embedding_model} res = update_dataset(HttpApiAuth, dataset_id, payload) assert res["code"] == 101, res - if name == "missing_at": + if name in ["empty", "space", "missing_at"]: assert "Embedding model identifier must follow @ format" in res["message"], res else: assert "Both model_name and provider must be non-empty strings" in res["message"], res @@ -323,8 +325,11 @@ class TestDatasetUpdate: dataset_id = add_dataset_func payload = {"embedding_model": None} res = update_dataset(HttpApiAuth, dataset_id, payload) - assert res["code"] == 101, res - assert "Input should be a valid string" in res["message"], res + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth) + assert res["code"] == 0, res + assert res["data"][0]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res @pytest.mark.p1 @pytest.mark.parametrize( diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py index 4ba269648..ffeaaf103 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -217,19 +217,21 @@ class TestDatasetCreate: @pytest.mark.parametrize( "name, embedding_model", [ + ("empty", ""), + ("space", " "), ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), ("missing_model_name", "@BAAI"), ("missing_provider", "BAAI/bge-large-zh-v1.5@"), ("whitespace_only_model_name", " @BAAI"), ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), ], - ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], ) def test_embedding_model_format(self, client, name, embedding_model): payload = {"name": name, "embedding_model": embedding_model} with pytest.raises(Exception) as excinfo: client.create_dataset(**payload) - if name == "missing_at": + if name in ["empty", "space", "missing_at"]: assert "Embedding model identifier must follow @ format" in str(excinfo.value), str(excinfo.value) else: assert "Both model_name and provider must be non-empty strings" in str(excinfo.value), str(excinfo.value) @@ -243,9 +245,8 @@ class TestDatasetCreate: @pytest.mark.p2 def test_embedding_model_none(self, client): payload = {"name": "embedding_model_none", "embedding_model": None} - with pytest.raises(Exception) as excinfo: - client.create_dataset(**payload) - assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value) + dataset = client.create_dataset(**payload) + assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset) @pytest.mark.p1 @pytest.mark.parametrize( diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py index f4a0a9163..94c4ddb37 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py @@ -207,30 +207,34 @@ class TestDatasetUpdate: @pytest.mark.parametrize( "name, embedding_model", [ + ("empty", ""), + ("space", " "), ("missing_at", "BAAI/bge-large-zh-v1.5BAAI"), ("missing_model_name", "@BAAI"), ("missing_provider", "BAAI/bge-large-zh-v1.5@"), ("whitespace_only_model_name", " @BAAI"), ("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "), ], - ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], + ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"], ) def test_embedding_model_format(self, add_dataset_func, name, embedding_model): dataset = add_dataset_func with pytest.raises(Exception) as excinfo: dataset.update({"name": name, "embedding_model": embedding_model}) error_msg = str(excinfo.value) - if name == "missing_at": + if name in ["empty", "space", "missing_at"]: assert "Embedding model identifier must follow @ format" in error_msg, error_msg else: assert "Both model_name and provider must be non-empty strings" in error_msg, error_msg @pytest.mark.p2 - def test_embedding_model_none(self, add_dataset_func): + def test_embedding_model_none(self, client, add_dataset_func): dataset = add_dataset_func - with pytest.raises(Exception) as excinfo: - dataset.update({"embedding_model": None}) - assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value) + dataset.update({"embedding_model": None}) + assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(retrieved_dataset) @pytest.mark.p1 @pytest.mark.parametrize(