mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: Enforce default embedding model in create_dataset / update_dataset (#8486)
### What problem does this PR solve? Previous: - Defaulted to hardcoded model 'BAAI/bge-large-zh-v1.5@BAAI' - Did not respect user-configured default embedding_model Now: - Correctly prioritizes user-configured default embedding_model Other: - Make embedding_model optional in CreateDatasetReq with proper None handling - Add default embedding model fallback in dataset update when empty - Enhance validation utils to handle None values and string normalization - Update SDK default embedding model to None to match API changes - Adjust related test cases to reflect new validation rules ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -347,6 +347,8 @@ def update(tenant_id, dataset_id):
|
|||||||
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
|
||||||
|
|
||||||
if "embd_id" in req:
|
if "embd_id" in req:
|
||||||
|
if not req["embd_id"]:
|
||||||
|
req["embd_id"] = kb.embd_id
|
||||||
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
|
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
|
||||||
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
|
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
|
||||||
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
||||||
|
|||||||
@ -380,7 +380,7 @@ class CreateDatasetReq(Base):
|
|||||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
|
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
|
||||||
avatar: str | None = Field(default=None, max_length=65535)
|
avatar: str | None = Field(default=None, max_length=65535)
|
||||||
description: str | None = Field(default=None, max_length=65535)
|
description: str | None = Field(default=None, max_length=65535)
|
||||||
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
|
embedding_model: str | None = Field(default=None, max_length=255, serialization_alias="embd_id")
|
||||||
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
|
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
|
||||||
chunk_method: ChunkMethodEnum = Field(default=ChunkMethodEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
|
chunk_method: ChunkMethodEnum = Field(default=ChunkMethodEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
|
||||||
parser_config: ParserConfig | None = Field(default=None)
|
parser_config: ParserConfig | None = Field(default=None)
|
||||||
@ -435,9 +435,16 @@ class CreateDatasetReq(Base):
|
|||||||
else:
|
else:
|
||||||
raise PydanticCustomError("format_invalid", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>")
|
raise PydanticCustomError("format_invalid", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>")
|
||||||
|
|
||||||
|
@field_validator("embedding_model", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_embedding_model(cls, v: Any) -> Any:
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.strip()
|
||||||
|
return v
|
||||||
|
|
||||||
@field_validator("embedding_model", mode="after")
|
@field_validator("embedding_model", mode="after")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_embedding_model(cls, v: str) -> str:
|
def validate_embedding_model(cls, v: str | None) -> str | None:
|
||||||
"""
|
"""
|
||||||
Validates embedding model identifier format compliance.
|
Validates embedding model identifier format compliance.
|
||||||
|
|
||||||
@ -464,16 +471,17 @@ class CreateDatasetReq(Base):
|
|||||||
Invalid: "@openai" (empty model_name)
|
Invalid: "@openai" (empty model_name)
|
||||||
Invalid: "text-embedding-3-large@" (empty provider)
|
Invalid: "text-embedding-3-large@" (empty provider)
|
||||||
"""
|
"""
|
||||||
if "@" not in v:
|
if isinstance(v, str):
|
||||||
raise PydanticCustomError("format_invalid", "Embedding model identifier must follow <model_name>@<provider> format")
|
if "@" not in v:
|
||||||
|
raise PydanticCustomError("format_invalid", "Embedding model identifier must follow <model_name>@<provider> format")
|
||||||
|
|
||||||
components = v.split("@", 1)
|
components = v.split("@", 1)
|
||||||
if len(components) != 2 or not all(components):
|
if len(components) != 2 or not all(components):
|
||||||
raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings")
|
raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings")
|
||||||
|
|
||||||
model_name, provider = components
|
model_name, provider = components
|
||||||
if not model_name.strip() or not provider.strip():
|
if not model_name.strip() or not provider.strip():
|
||||||
raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings")
|
raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("permission", mode="before")
|
@field_validator("permission", mode="before")
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class RAGFlow:
|
|||||||
name: str,
|
name: str,
|
||||||
avatar: Optional[str] = None,
|
avatar: Optional[str] = None,
|
||||||
description: Optional[str] = None,
|
description: Optional[str] = None,
|
||||||
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
|
embedding_model: Optional[str] = None,
|
||||||
permission: str = "me",
|
permission: str = "me",
|
||||||
chunk_method: str = "naive",
|
chunk_method: str = "naive",
|
||||||
parser_config: Optional[DataSet.ParserConfig] = None,
|
parser_config: Optional[DataSet.ParserConfig] = None,
|
||||||
|
|||||||
@ -260,19 +260,21 @@ class TestDatasetCreate:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name, embedding_model",
|
"name, embedding_model",
|
||||||
[
|
[
|
||||||
|
("empty", ""),
|
||||||
|
("space", " "),
|
||||||
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
||||||
("missing_model_name", "@BAAI"),
|
("missing_model_name", "@BAAI"),
|
||||||
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
||||||
("whitespace_only_model_name", " @BAAI"),
|
("whitespace_only_model_name", " @BAAI"),
|
||||||
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
||||||
],
|
],
|
||||||
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||||
)
|
)
|
||||||
def test_embedding_model_format(self, HttpApiAuth, name, embedding_model):
|
def test_embedding_model_format(self, HttpApiAuth, name, embedding_model):
|
||||||
payload = {"name": name, "embedding_model": embedding_model}
|
payload = {"name": name, "embedding_model": embedding_model}
|
||||||
res = create_dataset(HttpApiAuth, payload)
|
res = create_dataset(HttpApiAuth, payload)
|
||||||
assert res["code"] == 101, res
|
assert res["code"] == 101, res
|
||||||
if name == "missing_at":
|
if name in ["empty", "space", "missing_at"]:
|
||||||
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
|
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
|
||||||
else:
|
else:
|
||||||
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
||||||
@ -288,8 +290,8 @@ class TestDatasetCreate:
|
|||||||
def test_embedding_model_none(self, HttpApiAuth):
|
def test_embedding_model_none(self, HttpApiAuth):
|
||||||
payload = {"name": "embedding_model_none", "embedding_model": None}
|
payload = {"name": "embedding_model_none", "embedding_model": None}
|
||||||
res = create_dataset(HttpApiAuth, payload)
|
res = create_dataset(HttpApiAuth, payload)
|
||||||
assert res["code"] == 101, res
|
assert res["code"] == 0, res
|
||||||
assert "Input should be a valid string" in res["message"], res
|
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
|
||||||
|
|
||||||
@pytest.mark.p1
|
@pytest.mark.p1
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -300,20 +300,22 @@ class TestDatasetUpdate:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name, embedding_model",
|
"name, embedding_model",
|
||||||
[
|
[
|
||||||
|
("empty", ""),
|
||||||
|
("space", " "),
|
||||||
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
||||||
("missing_model_name", "@BAAI"),
|
("missing_model_name", "@BAAI"),
|
||||||
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
||||||
("whitespace_only_model_name", " @BAAI"),
|
("whitespace_only_model_name", " @BAAI"),
|
||||||
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
||||||
],
|
],
|
||||||
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||||
)
|
)
|
||||||
def test_embedding_model_format(self, HttpApiAuth, add_dataset_func, name, embedding_model):
|
def test_embedding_model_format(self, HttpApiAuth, add_dataset_func, name, embedding_model):
|
||||||
dataset_id = add_dataset_func
|
dataset_id = add_dataset_func
|
||||||
payload = {"name": name, "embedding_model": embedding_model}
|
payload = {"name": name, "embedding_model": embedding_model}
|
||||||
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||||
assert res["code"] == 101, res
|
assert res["code"] == 101, res
|
||||||
if name == "missing_at":
|
if name in ["empty", "space", "missing_at"]:
|
||||||
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
|
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
|
||||||
else:
|
else:
|
||||||
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
assert "Both model_name and provider must be non-empty strings" in res["message"], res
|
||||||
@ -323,8 +325,11 @@ class TestDatasetUpdate:
|
|||||||
dataset_id = add_dataset_func
|
dataset_id = add_dataset_func
|
||||||
payload = {"embedding_model": None}
|
payload = {"embedding_model": None}
|
||||||
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||||
assert res["code"] == 101, res
|
assert res["code"] == 0, res
|
||||||
assert "Input should be a valid string" in res["message"], res
|
|
||||||
|
res = list_datasets(HttpApiAuth)
|
||||||
|
assert res["code"] == 0, res
|
||||||
|
assert res["data"][0]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
|
||||||
|
|
||||||
@pytest.mark.p1
|
@pytest.mark.p1
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -217,19 +217,21 @@ class TestDatasetCreate:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name, embedding_model",
|
"name, embedding_model",
|
||||||
[
|
[
|
||||||
|
("empty", ""),
|
||||||
|
("space", " "),
|
||||||
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
||||||
("missing_model_name", "@BAAI"),
|
("missing_model_name", "@BAAI"),
|
||||||
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
||||||
("whitespace_only_model_name", " @BAAI"),
|
("whitespace_only_model_name", " @BAAI"),
|
||||||
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
||||||
],
|
],
|
||||||
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||||
)
|
)
|
||||||
def test_embedding_model_format(self, client, name, embedding_model):
|
def test_embedding_model_format(self, client, name, embedding_model):
|
||||||
payload = {"name": name, "embedding_model": embedding_model}
|
payload = {"name": name, "embedding_model": embedding_model}
|
||||||
with pytest.raises(Exception) as excinfo:
|
with pytest.raises(Exception) as excinfo:
|
||||||
client.create_dataset(**payload)
|
client.create_dataset(**payload)
|
||||||
if name == "missing_at":
|
if name in ["empty", "space", "missing_at"]:
|
||||||
assert "Embedding model identifier must follow <model_name>@<provider> format" in str(excinfo.value), str(excinfo.value)
|
assert "Embedding model identifier must follow <model_name>@<provider> format" in str(excinfo.value), str(excinfo.value)
|
||||||
else:
|
else:
|
||||||
assert "Both model_name and provider must be non-empty strings" in str(excinfo.value), str(excinfo.value)
|
assert "Both model_name and provider must be non-empty strings" in str(excinfo.value), str(excinfo.value)
|
||||||
@ -243,9 +245,8 @@ class TestDatasetCreate:
|
|||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
def test_embedding_model_none(self, client):
|
def test_embedding_model_none(self, client):
|
||||||
payload = {"name": "embedding_model_none", "embedding_model": None}
|
payload = {"name": "embedding_model_none", "embedding_model": None}
|
||||||
with pytest.raises(Exception) as excinfo:
|
dataset = client.create_dataset(**payload)
|
||||||
client.create_dataset(**payload)
|
assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset)
|
||||||
assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value)
|
|
||||||
|
|
||||||
@pytest.mark.p1
|
@pytest.mark.p1
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -207,30 +207,34 @@ class TestDatasetUpdate:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name, embedding_model",
|
"name, embedding_model",
|
||||||
[
|
[
|
||||||
|
("empty", ""),
|
||||||
|
("space", " "),
|
||||||
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
|
||||||
("missing_model_name", "@BAAI"),
|
("missing_model_name", "@BAAI"),
|
||||||
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
|
||||||
("whitespace_only_model_name", " @BAAI"),
|
("whitespace_only_model_name", " @BAAI"),
|
||||||
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
|
||||||
],
|
],
|
||||||
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
ids=["empty", "space", "missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
|
||||||
)
|
)
|
||||||
def test_embedding_model_format(self, add_dataset_func, name, embedding_model):
|
def test_embedding_model_format(self, add_dataset_func, name, embedding_model):
|
||||||
dataset = add_dataset_func
|
dataset = add_dataset_func
|
||||||
with pytest.raises(Exception) as excinfo:
|
with pytest.raises(Exception) as excinfo:
|
||||||
dataset.update({"name": name, "embedding_model": embedding_model})
|
dataset.update({"name": name, "embedding_model": embedding_model})
|
||||||
error_msg = str(excinfo.value)
|
error_msg = str(excinfo.value)
|
||||||
if name == "missing_at":
|
if name in ["empty", "space", "missing_at"]:
|
||||||
assert "Embedding model identifier must follow <model_name>@<provider> format" in error_msg, error_msg
|
assert "Embedding model identifier must follow <model_name>@<provider> format" in error_msg, error_msg
|
||||||
else:
|
else:
|
||||||
assert "Both model_name and provider must be non-empty strings" in error_msg, error_msg
|
assert "Both model_name and provider must be non-empty strings" in error_msg, error_msg
|
||||||
|
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
def test_embedding_model_none(self, add_dataset_func):
|
def test_embedding_model_none(self, client, add_dataset_func):
|
||||||
dataset = add_dataset_func
|
dataset = add_dataset_func
|
||||||
with pytest.raises(Exception) as excinfo:
|
dataset.update({"embedding_model": None})
|
||||||
dataset.update({"embedding_model": None})
|
assert dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(dataset)
|
||||||
assert "Input should be a valid string" in str(excinfo.value), str(excinfo.value)
|
|
||||||
|
retrieved_dataset = client.get_dataset(name=dataset.name)
|
||||||
|
assert retrieved_dataset.embedding_model == "BAAI/bge-large-zh-v1.5@BAAI", str(retrieved_dataset)
|
||||||
|
|
||||||
@pytest.mark.p1
|
@pytest.mark.p1
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
Reference in New Issue
Block a user