Refa: HTTP API update dataset / test cases / docs (#7564)

### What problem does this PR solve?

This PR introduces Pydantic-based validation for the update dataset HTTP
API, improving code clarity and robustness. Key changes include:
1. Pydantic Validation
2. ​​Error Handling
3. Test Updates
4. Documentation Updates
5. fix bug: #5915

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Documentation Update
- [x] Refactoring
This commit is contained in:
liu an
2025-05-09 19:17:08 +08:00
committed by GitHub
parent 31718581b5
commit 35e36cb945
12 changed files with 1283 additions and 552 deletions

View File

@ -27,22 +27,19 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.user_service import TenantService
from api.utils import get_uuid
from api.utils.api_utils import (
check_duplicate_ids,
dataset_readonly_fields,
deep_merge,
get_error_argument_result,
get_error_data_result,
get_parser_config,
get_result,
token_required,
valid,
valid_parser_config,
verify_embedding_availability,
)
from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request
from api.utils.validation_utils import CreateDatasetReq, UpdateDatasetReq, validate_and_parse_json_request
@manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -117,8 +114,8 @@ def create(tenant_id):
return get_error_argument_result(err)
try:
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_argument_result(message=f"Dataset name '{req['name']}' already exists")
if KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
@ -145,7 +142,7 @@ def create(tenant_id):
try:
if not KnowledgebaseService.save(**req):
return get_error_data_result(message="Database operation failed")
return get_error_data_result(message="Create dataset error.(Database error)")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
@ -205,6 +202,7 @@ def delete(tenant_id):
schema:
type: object
"""
errors = []
success_count = 0
req = request.json
@ -291,16 +289,28 @@ def update(tenant_id, dataset_id):
name:
type: string
description: New name of the dataset.
avatar:
type: string
description: Updated base64 encoding of the avatar.
description:
type: string
description: Updated description of the dataset.
embedding_model:
type: string
description: Updated embedding model Name.
permission:
type: string
enum: ['me', 'team']
description: Updated permission.
description: Updated dataset permission.
chunk_method:
type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
"presentation", "picture", "one", "email", "tag"
enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
"picture", "presentation", "qa", "table", "tag"
]
description: Updated chunking method.
pagerank:
type: integer
description: Updated page rank.
parser_config:
type: object
description: Updated parser configuration.
@ -310,98 +320,56 @@ def update(tenant_id, dataset_id):
schema:
type: object
"""
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
return get_error_data_result(message="You don't own the dataset")
req = request.json
for k in req.keys():
if dataset_readonly_fields(k):
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"'{k}' is readonly.")
e, t = TenantService.get_by_id(tenant_id)
invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status", "token_num", "update_date", "update_time"}
if any(key in req for key in invalid_keys):
return get_error_data_result(message="The input parameters are invalid.")
permission = req.get("permission")
chunk_method = req.get("chunk_method")
parser_config = req.get("parser_config")
valid_parser_config(parser_config)
valid_permission = ["me", "team"]
valid_chunk_method = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email", "tag"]
check_validation = valid(
permission,
valid_permission,
chunk_method,
valid_chunk_method,
)
if check_validation:
return check_validation
if "tenant_id" in req:
if req["tenant_id"] != tenant_id:
return get_error_data_result(message="Can't change `tenant_id`.")
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if "parser_config" in req:
temp_dict = kb.parser_config
temp_dict.update(req["parser_config"])
req["parser_config"] = temp_dict
if "chunk_count" in req:
if req["chunk_count"] != kb.chunk_num:
return get_error_data_result(message="Can't change `chunk_count`.")
req.pop("chunk_count")
if "document_count" in req:
if req["document_count"] != kb.doc_num:
return get_error_data_result(message="Can't change `document_count`.")
req.pop("document_count")
if req.get("chunk_method"):
if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id:
return get_error_data_result(message="If `chunk_count` is not 0, `chunk_method` is not changeable.")
req["parser_id"] = req.pop("chunk_method")
if req["parser_id"] != kb.parser_id:
if not req.get("parser_config"):
req["parser_config"] = get_parser_config(chunk_method, parser_config)
if "embedding_model" in req:
if kb.chunk_num != 0 and req["embedding_model"] != kb.embd_id:
return get_error_data_result(message="If `chunk_count` is not 0, `embedding_model` is not changeable.")
if not req.get("embedding_model"):
return get_error_data_result("`embedding_model` can't be empty")
valid_embedding_models = [
"BAAI/bge-large-zh-v1.5",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-small-zh-v1.5",
"jinaai/jina-embeddings-v2-base-en",
"jinaai/jina-embeddings-v2-small-en",
"nomic-ai/nomic-embed-text-v1.5",
"sentence-transformers/all-MiniLM-L6-v2",
"text-embedding-v2",
"text-embedding-v3",
"maidalun1020/bce-embedding-base_v1",
]
embd_model = LLMService.query(llm_name=req["embedding_model"], model_type="embedding")
if embd_model:
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(
tenant_id=tenant_id,
model_type="embedding",
llm_name=req.get("embedding_model"),
):
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
if not embd_model:
embd_model = TenantLLMService.query(tenant_id=tenant_id, model_type="embedding", llm_name=req.get("embedding_model"))
# Field name transformations during model dump:
# | Original | Dump Output |
# |----------------|-------------|
# | embedding_model| embd_id |
# | chunk_method | parser_id |
extras = {"dataset_id": dataset_id}
req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True)
if err is not None:
return get_error_argument_result(err)
if not req:
return get_error_argument_result(message="No properties were modified")
try:
kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id)
if kb is None:
return get_error_data_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if req.get("parser_config"):
req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"])
if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id and req.get("parser_config") is None:
req["parser_config"] = get_parser_config(chunk_method, None)
if "name" in req and req["name"].lower() != kb.name.lower():
try:
exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)
if exists:
return get_error_data_result(message=f"Dataset name '{req['name']}' already exists")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if "embd_id" in req:
if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id:
return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}")
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
if not ok:
return err
try:
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if not embd_model:
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
req["embd_id"] = req.pop("embedding_model")
if "name" in req:
req["name"] = req["name"].strip()
if len(req["name"]) >= 128:
return get_error_data_result(message="Dataset name should not be longer than 128 characters.")
if req["name"].lower() != kb.name.lower() and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
return get_error_data_result(message="Duplicated dataset name in updating dataset.")
flds = list(req.keys())
for f in flds:
if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]:
del req[f]
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")
return get_result(code=settings.RetCode.SUCCESS)

View File

@ -19,6 +19,7 @@ import logging
import random
import time
from base64 import b64encode
from copy import deepcopy
from functools import wraps
from hmac import HMAC
from io import BytesIO
@ -333,22 +334,6 @@ def generate_confirmation_token(tenant_id):
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
def valid(permission, valid_permission, chunk_method, valid_chunk_method):
if valid_parameter(permission, valid_permission):
return valid_parameter(permission, valid_permission)
if valid_parameter(chunk_method, valid_chunk_method):
return valid_parameter(chunk_method, valid_chunk_method)
def valid_parameter(parameter, valid_values):
if parameter and parameter not in valid_values:
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
def dataset_readonly_fields(field_name):
return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
def get_parser_config(chunk_method, parser_config):
if parser_config:
return parser_config
@ -402,43 +387,6 @@ def get_data_openai(
}
def valid_parser_config(parser_config):
if not parser_config:
return
scopes = set(
[
"chunk_token_num",
"delimiter",
"raptor",
"graphrag",
"layout_recognize",
"task_page_size",
"pages",
"html4excel",
"auto_keywords",
"auto_questions",
"tag_kb_ids",
"topn_tags",
"filename_embd_weight",
]
)
for k in parser_config.keys():
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
assert isinstance(parser_config.get("chunk_token_num", 1), int), "chunk_token_num should be int"
assert 1 <= parser_config.get("chunk_token_num", 1) < 100000000, "chunk_token_num should be in range from 1 to 100000000"
assert isinstance(parser_config.get("task_page_size", 1), int), "task_page_size should be int"
assert 1 <= parser_config.get("task_page_size", 1) < 100000000, "task_page_size should be in range from 1 to 100000000"
assert isinstance(parser_config.get("auto_keywords", 1), int), "auto_keywords should be int"
assert 0 <= parser_config.get("auto_keywords", 0) < 32, "auto_keywords should be in range from 0 to 32"
assert isinstance(parser_config.get("auto_questions", 1), int), "auto_questions should be int"
assert 0 <= parser_config.get("auto_questions", 0) < 10, "auto_questions should be in range from 0 to 10"
assert isinstance(parser_config.get("topn_tags", 1), int), "topn_tags should be int"
assert 0 <= parser_config.get("topn_tags", 0) < 10, "topn_tags should be in range from 0 to 10"
assert isinstance(parser_config.get("html4excel", False), bool), "html4excel should be True or False"
assert isinstance(parser_config.get("delimiter", ""), str), "delimiter should be str"
def check_duplicate_ids(ids, id_type="item"):
"""
Check for duplicate IDs in a list and return unique IDs and error messages.
@ -469,7 +417,8 @@ def check_duplicate_ids(ids, id_type="item"):
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
"""Verifies availability of an embedding model for a specific tenant.
"""
Verifies availability of an embedding model for a specific tenant.
Implements a four-stage validation process:
1. Model identifier parsing and validation
@ -518,3 +467,50 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R
return False, get_error_data_result(message="Database operation failed")
return True, None
def deep_merge(default: dict, custom: dict) -> dict:
"""
Recursively merges two dictionaries with priority given to `custom` values.
Creates a deep copy of the `default` dictionary and iteratively merges nested
dictionaries using a stack-based approach. Non-dict values in `custom` will
completely override corresponding entries in `default`.
Args:
default (dict): Base dictionary containing default values.
custom (dict): Dictionary containing overriding values.
Returns:
dict: New merged dictionary combining values from both inputs.
Example:
>>> from copy import deepcopy
>>> default = {"a": 1, "nested": {"x": 10, "y": 20}}
>>> custom = {"b": 2, "nested": {"y": 99, "z": 30}}
>>> deep_merge(default, custom)
{'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}}
>>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"})
{'config': 'manual'}
Notes:
1. Merge priority is always given to `custom` values at all nesting levels
2. Non-dict values (e.g. list, str) in `custom` will replace entire values
in `default`, even if the original value was a dictionary
3. Time complexity: O(N) where N is total key-value pairs in `custom`
4. Recommended for configuration merging and nested data updates
"""
merged = deepcopy(default)
stack = [(merged, custom)]
while stack:
base_dict, override_dict = stack.pop()
for key, val in override_dict.items():
if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict):
stack.append((base_dict[key], val))
else:
base_dict[key] = val
return merged

View File

@ -13,21 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from enum import auto
from typing import Annotated, Any
from flask import Request
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator
from strenum import StrEnum
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.constants import DATASET_NAME_LIMIT
def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]:
"""Validates and parses JSON requests through a multi-stage validation pipeline.
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
"""
Validates and parses JSON requests through a multi-stage validation pipeline.
Implements a robust four-stage validation process:
Implements a four-stage validation process:
1. Content-Type verification (must be application/json)
2. JSON syntax validation
3. Payload structure type checking
@ -35,6 +37,10 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
Args:
request (Request): Flask request object containing HTTP payload
validator (type[BaseModel]): Pydantic model class for data validation
extras (dict[str, Any] | None): Additional fields to merge into payload
before validation. These fields will be removed from the final output
exclude_unset (bool): Whether to exclude fields that have not been explicitly set
Returns:
tuple[Dict[str, Any] | None, str | None]:
@ -46,26 +52,26 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
- Diagnostic error message on failure
Raises:
UnsupportedMediaType: When Content-Type application/json
UnsupportedMediaType: When Content-Type header is not application/json
BadRequest: For structural JSON syntax errors
ValidationError: When payload violates Pydantic schema rules
Examples:
Successful validation:
```python
# Input: {"name": "Dataset1", "format": "csv"}
# Returns: ({"name": "Dataset1", "format": "csv"}, None)
```
>>> validate_and_parse_json_request(valid_request, DatasetSchema)
({"name": "Dataset1", "format": "csv"}, None)
Invalid Content-Type:
```python
# Returns: (None, "Unsupported content type: Expected application/json, got text/xml")
```
>>> validate_and_parse_json_request(xml_request, DatasetSchema)
(None, "Unsupported content type: Expected application/json, got text/xml")
Malformed JSON:
```python
# Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
```
>>> validate_and_parse_json_request(bad_json_request, DatasetSchema)
(None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
Notes:
1. Validation Priority:
- Content-Type verification precedes JSON parsing
- Structural validation occurs before schema validation
2. Extra fields added via `extras` parameter are automatically removed
from the final output after validation
"""
try:
payload = request.get_json() or {}
@ -78,17 +84,25 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
return None, f"Invalid request payload: expected object, got {type(payload).__name__}"
try:
if extras is not None:
payload.update(extras)
validated_request = validator(**payload)
except ValidationError as e:
return None, format_validation_error_message(e)
parsed_payload = validated_request.model_dump(by_alias=True)
parsed_payload = validated_request.model_dump(by_alias=True, exclude_unset=exclude_unset)
if extras is not None:
for key in list(parsed_payload.keys()):
if key in extras:
del parsed_payload[key]
return parsed_payload, None
def format_validation_error_message(e: ValidationError) -> str:
"""Formats validation errors into a standardized string format.
"""
Formats validation errors into a standardized string format.
Processes pydantic ValidationError objects to create human-readable error messages
containing field locations, error descriptions, and input values.
@ -155,7 +169,6 @@ class GraphragMethodEnum(StrEnum):
class Base(BaseModel):
class Config:
extra = "forbid"
json_schema_extra = {"charset": "utf8mb4", "collation": "utf8mb4_0900_ai_ci"}
class RaptorConfig(Base):
@ -201,16 +214,17 @@ class CreateDatasetReq(Base):
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
avatar: str | None = Field(default=None, max_length=65535)
description: str | None = Field(default=None, max_length=65535)
embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
pagerank: int = Field(default=0, ge=0, le=100)
parser_config: ParserConfig | None = Field(default=None)
parser_config: ParserConfig = Field(default_factory=dict)
@field_validator("avatar")
@classmethod
def validate_avatar_base64(cls, v: str) -> str:
"""Validates Base64-encoded avatar string format and MIME type compliance.
def validate_avatar_base64(cls, v: str | None) -> str | None:
"""
Validates Base64-encoded avatar string format and MIME type compliance.
Implements a three-stage validation workflow:
1. MIME prefix existence check
@ -259,7 +273,8 @@ class CreateDatasetReq(Base):
@field_validator("embedding_model", mode="after")
@classmethod
def validate_embedding_model(cls, v: str) -> str:
"""Validates embedding model identifier format compliance.
"""
Validates embedding model identifier format compliance.
Validation pipeline:
1. Structural format verification
@ -298,11 +313,12 @@ class CreateDatasetReq(Base):
@field_validator("permission", mode="before")
@classmethod
def permission_auto_lowercase(cls, v: str) -> str:
"""Normalize permission input to lowercase for consistent PermissionEnum matching.
def permission_auto_lowercase(cls, v: Any) -> Any:
"""
Normalize permission input to lowercase for consistent PermissionEnum matching.
Args:
v (str): Raw input value for the permission field
v (Any): Raw input value for the permission field
Returns:
Lowercase string if input is string type, otherwise returns original value
@ -316,13 +332,13 @@ class CreateDatasetReq(Base):
@field_validator("parser_config", mode="after")
@classmethod
def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None:
"""Validates serialized JSON length constraints for parser configuration.
def validate_parser_config_json_length(cls, v: ParserConfig) -> ParserConfig:
"""
Validates serialized JSON length constraints for parser configuration.
Implements a three-stage validation workflow:
1. Null check - bypass validation for empty configurations
2. Model serialization - convert Pydantic model to JSON string
3. Size verification - enforce maximum allowed payload size
Implements a two-stage validation workflow:
1. Model serialization - convert Pydantic model to JSON string
2. Size verification - enforce maximum allowed payload size
Args:
v (ParserConfig | None): Raw parser configuration object
@ -333,9 +349,15 @@ class CreateDatasetReq(Base):
Raises:
ValueError: When serialized JSON exceeds 65,535 characters
"""
if v is None:
return v
if (json_str := v.model_dump_json()) and len(json_str) > 65535:
raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}")
return v
class UpdateDatasetReq(CreateDatasetReq):
dataset_id: UUID1 = Field(...)
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
@field_serializer("dataset_id")
def serialize_uuid_to_hex(self, v: uuid.UUID) -> str:
return v.hex