Fix: Enforce default embedding model in create_dataset / update_dataset (#8486)

### What problem does this PR solve?

Previous:
- Defaulted to hardcoded model 'BAAI/bge-large-zh-v1.5@BAAI'
- Did not respect user-configured default embedding_model

Now:
- Correctly prioritizes user-configured default embedding_model

Other:
- Make embedding_model optional in CreateDatasetReq with proper None
handling
- Add default embedding model fallback in dataset update when empty
- Enhance validation utils to handle None values and string
normalization
- Update SDK default embedding model to None to match API changes
- Adjust related test cases to reflect new validation rules

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Liu An
2025-06-25 16:41:32 +08:00
committed by GitHub
parent 340354b79c
commit dac5bcdf17
7 changed files with 52 additions and 30 deletions

View File

@ -347,6 +347,8 @@ def update(tenant_id, dataset_id):
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
if "embd_id" in req:
if not req["embd_id"]:
req["embd_id"] = kb.embd_id
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)

View File

@ -380,7 +380,7 @@ class CreateDatasetReq(Base):
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
avatar: str | None = Field(default=None, max_length=65535)
description: str | None = Field(default=None, max_length=65535)
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
embedding_model: str | None = Field(default=None, max_length=255, serialization_alias="embd_id")
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
chunk_method: ChunkMethodEnum = Field(default=ChunkMethodEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
parser_config: ParserConfig | None = Field(default=None)
@ -435,9 +435,16 @@ class CreateDatasetReq(Base):
else:
raise PydanticCustomError("format_invalid", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>")
@field_validator("embedding_model", mode="before")
@classmethod
def normalize_embedding_model(cls, v: Any) -> Any:
if isinstance(v, str):
return v.strip()
return v
@field_validator("embedding_model", mode="after")
@classmethod
def validate_embedding_model(cls, v: str) -> str:
def validate_embedding_model(cls, v: str | None) -> str | None:
"""
Validates embedding model identifier format compliance.
@ -464,16 +471,17 @@ class CreateDatasetReq(Base):
Invalid: "@openai" (empty model_name)
Invalid: "text-embedding-3-large@" (empty provider)
"""
if "@" not in v:
raise PydanticCustomError("format_invalid", "Embedding model identifier must follow <model_name>@<provider> format")
if isinstance(v, str):
if "@" not in v:
raise PydanticCustomError("format_invalid", "Embedding model identifier must follow <model_name>@<provider> format")
components = v.split("@", 1)
if len(components) != 2 or not all(components):
raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings")
components = v.split("@", 1)
if len(components) != 2 or not all(components):
raise PydanticCustomError("format_invalid", "Both model_name and provider must be non-empty strings")
model_name, provider = components
if not model_name.strip() or not provider.strip():
raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings")
model_name, provider = components
if not model_name.strip() or not provider.strip():
raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings")
return v
@field_validator("permission", mode="before")