mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: HTTP API update dataset / test cases / docs (#7564)
### What problem does this PR solve? This PR introduces Pydantic-based validation for the update dataset HTTP API, improving code clarity and robustness. Key changes include: 1. Pydantic Validation 2. Error Handling 3. Test Updates 4. Documentation Updates 5. fix bug: #5915 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Documentation Update - [x] Refactoring
This commit is contained in:
@ -19,6 +19,7 @@ import logging
|
||||
import random
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from hmac import HMAC
|
||||
from io import BytesIO
|
||||
@ -333,22 +334,6 @@ def generate_confirmation_token(tenant_id):
|
||||
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
|
||||
|
||||
|
||||
def valid(permission, valid_permission, chunk_method, valid_chunk_method):
|
||||
if valid_parameter(permission, valid_permission):
|
||||
return valid_parameter(permission, valid_permission)
|
||||
if valid_parameter(chunk_method, valid_chunk_method):
|
||||
return valid_parameter(chunk_method, valid_chunk_method)
|
||||
|
||||
|
||||
def valid_parameter(parameter, valid_values):
|
||||
if parameter and parameter not in valid_values:
|
||||
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
|
||||
|
||||
|
||||
def dataset_readonly_fields(field_name):
|
||||
return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
|
||||
|
||||
|
||||
def get_parser_config(chunk_method, parser_config):
|
||||
if parser_config:
|
||||
return parser_config
|
||||
@ -402,43 +387,6 @@ def get_data_openai(
|
||||
}
|
||||
|
||||
|
||||
def valid_parser_config(parser_config):
|
||||
if not parser_config:
|
||||
return
|
||||
scopes = set(
|
||||
[
|
||||
"chunk_token_num",
|
||||
"delimiter",
|
||||
"raptor",
|
||||
"graphrag",
|
||||
"layout_recognize",
|
||||
"task_page_size",
|
||||
"pages",
|
||||
"html4excel",
|
||||
"auto_keywords",
|
||||
"auto_questions",
|
||||
"tag_kb_ids",
|
||||
"topn_tags",
|
||||
"filename_embd_weight",
|
||||
]
|
||||
)
|
||||
for k in parser_config.keys():
|
||||
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
|
||||
|
||||
assert isinstance(parser_config.get("chunk_token_num", 1), int), "chunk_token_num should be int"
|
||||
assert 1 <= parser_config.get("chunk_token_num", 1) < 100000000, "chunk_token_num should be in range from 1 to 100000000"
|
||||
assert isinstance(parser_config.get("task_page_size", 1), int), "task_page_size should be int"
|
||||
assert 1 <= parser_config.get("task_page_size", 1) < 100000000, "task_page_size should be in range from 1 to 100000000"
|
||||
assert isinstance(parser_config.get("auto_keywords", 1), int), "auto_keywords should be int"
|
||||
assert 0 <= parser_config.get("auto_keywords", 0) < 32, "auto_keywords should be in range from 0 to 32"
|
||||
assert isinstance(parser_config.get("auto_questions", 1), int), "auto_questions should be int"
|
||||
assert 0 <= parser_config.get("auto_questions", 0) < 10, "auto_questions should be in range from 0 to 10"
|
||||
assert isinstance(parser_config.get("topn_tags", 1), int), "topn_tags should be int"
|
||||
assert 0 <= parser_config.get("topn_tags", 0) < 10, "topn_tags should be in range from 0 to 10"
|
||||
assert isinstance(parser_config.get("html4excel", False), bool), "html4excel should be True or False"
|
||||
assert isinstance(parser_config.get("delimiter", ""), str), "delimiter should be str"
|
||||
|
||||
|
||||
def check_duplicate_ids(ids, id_type="item"):
|
||||
"""
|
||||
Check for duplicate IDs in a list and return unique IDs and error messages.
|
||||
@ -469,7 +417,8 @@ def check_duplicate_ids(ids, id_type="item"):
|
||||
|
||||
|
||||
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
|
||||
"""Verifies availability of an embedding model for a specific tenant.
|
||||
"""
|
||||
Verifies availability of an embedding model for a specific tenant.
|
||||
|
||||
Implements a four-stage validation process:
|
||||
1. Model identifier parsing and validation
|
||||
@ -518,3 +467,50 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R
|
||||
return False, get_error_data_result(message="Database operation failed")
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def deep_merge(default: dict, custom: dict) -> dict:
|
||||
"""
|
||||
Recursively merges two dictionaries with priority given to `custom` values.
|
||||
|
||||
Creates a deep copy of the `default` dictionary and iteratively merges nested
|
||||
dictionaries using a stack-based approach. Non-dict values in `custom` will
|
||||
completely override corresponding entries in `default`.
|
||||
|
||||
Args:
|
||||
default (dict): Base dictionary containing default values.
|
||||
custom (dict): Dictionary containing overriding values.
|
||||
|
||||
Returns:
|
||||
dict: New merged dictionary combining values from both inputs.
|
||||
|
||||
Example:
|
||||
>>> from copy import deepcopy
|
||||
>>> default = {"a": 1, "nested": {"x": 10, "y": 20}}
|
||||
>>> custom = {"b": 2, "nested": {"y": 99, "z": 30}}
|
||||
>>> deep_merge(default, custom)
|
||||
{'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}}
|
||||
|
||||
>>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"})
|
||||
{'config': 'manual'}
|
||||
|
||||
Notes:
|
||||
1. Merge priority is always given to `custom` values at all nesting levels
|
||||
2. Non-dict values (e.g. list, str) in `custom` will replace entire values
|
||||
in `default`, even if the original value was a dictionary
|
||||
3. Time complexity: O(N) where N is total key-value pairs in `custom`
|
||||
4. Recommended for configuration merging and nested data updates
|
||||
"""
|
||||
merged = deepcopy(default)
|
||||
stack = [(merged, custom)]
|
||||
|
||||
while stack:
|
||||
base_dict, override_dict = stack.pop()
|
||||
|
||||
for key, val in override_dict.items():
|
||||
if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict):
|
||||
stack.append((base_dict[key], val))
|
||||
else:
|
||||
base_dict[key] = val
|
||||
|
||||
return merged
|
||||
|
||||
@ -13,21 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import uuid
|
||||
from enum import auto
|
||||
from typing import Annotated, Any
|
||||
|
||||
from flask import Request
|
||||
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
|
||||
from pydantic import UUID1, BaseModel, Field, StringConstraints, ValidationError, field_serializer, field_validator
|
||||
from strenum import StrEnum
|
||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
|
||||
|
||||
def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||
def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""
|
||||
Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||
|
||||
Implements a robust four-stage validation process:
|
||||
Implements a four-stage validation process:
|
||||
1. Content-Type verification (must be application/json)
|
||||
2. JSON syntax validation
|
||||
3. Payload structure type checking
|
||||
@ -35,6 +37,10 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
|
||||
Args:
|
||||
request (Request): Flask request object containing HTTP payload
|
||||
validator (type[BaseModel]): Pydantic model class for data validation
|
||||
extras (dict[str, Any] | None): Additional fields to merge into payload
|
||||
before validation. These fields will be removed from the final output
|
||||
exclude_unset (bool): Whether to exclude fields that have not been explicitly set
|
||||
|
||||
Returns:
|
||||
tuple[Dict[str, Any] | None, str | None]:
|
||||
@ -46,26 +52,26 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
- Diagnostic error message on failure
|
||||
|
||||
Raises:
|
||||
UnsupportedMediaType: When Content-Type ≠ application/json
|
||||
UnsupportedMediaType: When Content-Type header is not application/json
|
||||
BadRequest: For structural JSON syntax errors
|
||||
ValidationError: When payload violates Pydantic schema rules
|
||||
|
||||
Examples:
|
||||
Successful validation:
|
||||
```python
|
||||
# Input: {"name": "Dataset1", "format": "csv"}
|
||||
# Returns: ({"name": "Dataset1", "format": "csv"}, None)
|
||||
```
|
||||
>>> validate_and_parse_json_request(valid_request, DatasetSchema)
|
||||
({"name": "Dataset1", "format": "csv"}, None)
|
||||
|
||||
Invalid Content-Type:
|
||||
```python
|
||||
# Returns: (None, "Unsupported content type: Expected application/json, got text/xml")
|
||||
```
|
||||
>>> validate_and_parse_json_request(xml_request, DatasetSchema)
|
||||
(None, "Unsupported content type: Expected application/json, got text/xml")
|
||||
|
||||
Malformed JSON:
|
||||
```python
|
||||
# Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
|
||||
```
|
||||
>>> validate_and_parse_json_request(bad_json_request, DatasetSchema)
|
||||
(None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
|
||||
|
||||
Notes:
|
||||
1. Validation Priority:
|
||||
- Content-Type verification precedes JSON parsing
|
||||
- Structural validation occurs before schema validation
|
||||
2. Extra fields added via `extras` parameter are automatically removed
|
||||
from the final output after validation
|
||||
"""
|
||||
try:
|
||||
payload = request.get_json() or {}
|
||||
@ -78,17 +84,25 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel]
|
||||
return None, f"Invalid request payload: expected object, got {type(payload).__name__}"
|
||||
|
||||
try:
|
||||
if extras is not None:
|
||||
payload.update(extras)
|
||||
validated_request = validator(**payload)
|
||||
except ValidationError as e:
|
||||
return None, format_validation_error_message(e)
|
||||
|
||||
parsed_payload = validated_request.model_dump(by_alias=True)
|
||||
parsed_payload = validated_request.model_dump(by_alias=True, exclude_unset=exclude_unset)
|
||||
|
||||
if extras is not None:
|
||||
for key in list(parsed_payload.keys()):
|
||||
if key in extras:
|
||||
del parsed_payload[key]
|
||||
|
||||
return parsed_payload, None
|
||||
|
||||
|
||||
def format_validation_error_message(e: ValidationError) -> str:
|
||||
"""Formats validation errors into a standardized string format.
|
||||
"""
|
||||
Formats validation errors into a standardized string format.
|
||||
|
||||
Processes pydantic ValidationError objects to create human-readable error messages
|
||||
containing field locations, error descriptions, and input values.
|
||||
@ -155,7 +169,6 @@ class GraphragMethodEnum(StrEnum):
|
||||
class Base(BaseModel):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
json_schema_extra = {"charset": "utf8mb4", "collation": "utf8mb4_0900_ai_ci"}
|
||||
|
||||
|
||||
class RaptorConfig(Base):
|
||||
@ -201,16 +214,17 @@ class CreateDatasetReq(Base):
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
|
||||
avatar: str | None = Field(default=None, max_length=65535)
|
||||
description: str | None = Field(default=None, max_length=65535)
|
||||
embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
|
||||
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
|
||||
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
|
||||
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
|
||||
pagerank: int = Field(default=0, ge=0, le=100)
|
||||
parser_config: ParserConfig | None = Field(default=None)
|
||||
parser_config: ParserConfig = Field(default_factory=dict)
|
||||
|
||||
@field_validator("avatar")
|
||||
@classmethod
|
||||
def validate_avatar_base64(cls, v: str) -> str:
|
||||
"""Validates Base64-encoded avatar string format and MIME type compliance.
|
||||
def validate_avatar_base64(cls, v: str | None) -> str | None:
|
||||
"""
|
||||
Validates Base64-encoded avatar string format and MIME type compliance.
|
||||
|
||||
Implements a three-stage validation workflow:
|
||||
1. MIME prefix existence check
|
||||
@ -259,7 +273,8 @@ class CreateDatasetReq(Base):
|
||||
@field_validator("embedding_model", mode="after")
|
||||
@classmethod
|
||||
def validate_embedding_model(cls, v: str) -> str:
|
||||
"""Validates embedding model identifier format compliance.
|
||||
"""
|
||||
Validates embedding model identifier format compliance.
|
||||
|
||||
Validation pipeline:
|
||||
1. Structural format verification
|
||||
@ -298,11 +313,12 @@ class CreateDatasetReq(Base):
|
||||
|
||||
@field_validator("permission", mode="before")
|
||||
@classmethod
|
||||
def permission_auto_lowercase(cls, v: str) -> str:
|
||||
"""Normalize permission input to lowercase for consistent PermissionEnum matching.
|
||||
def permission_auto_lowercase(cls, v: Any) -> Any:
|
||||
"""
|
||||
Normalize permission input to lowercase for consistent PermissionEnum matching.
|
||||
|
||||
Args:
|
||||
v (str): Raw input value for the permission field
|
||||
v (Any): Raw input value for the permission field
|
||||
|
||||
Returns:
|
||||
Lowercase string if input is string type, otherwise returns original value
|
||||
@ -316,13 +332,13 @@ class CreateDatasetReq(Base):
|
||||
|
||||
@field_validator("parser_config", mode="after")
|
||||
@classmethod
|
||||
def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None:
|
||||
"""Validates serialized JSON length constraints for parser configuration.
|
||||
def validate_parser_config_json_length(cls, v: ParserConfig) -> ParserConfig:
|
||||
"""
|
||||
Validates serialized JSON length constraints for parser configuration.
|
||||
|
||||
Implements a three-stage validation workflow:
|
||||
1. Null check - bypass validation for empty configurations
|
||||
2. Model serialization - convert Pydantic model to JSON string
|
||||
3. Size verification - enforce maximum allowed payload size
|
||||
Implements a two-stage validation workflow:
|
||||
1. Model serialization - convert Pydantic model to JSON string
|
||||
2. Size verification - enforce maximum allowed payload size
|
||||
|
||||
Args:
|
||||
v (ParserConfig | None): Raw parser configuration object
|
||||
@ -333,9 +349,15 @@ class CreateDatasetReq(Base):
|
||||
Raises:
|
||||
ValueError: When serialized JSON exceeds 65,535 characters
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if (json_str := v.model_dump_json()) and len(json_str) > 65535:
|
||||
raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateDatasetReq(CreateDatasetReq):
|
||||
dataset_id: UUID1 = Field(...)
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
|
||||
|
||||
@field_serializer("dataset_id")
|
||||
def serialize_uuid_to_hex(self, v: uuid.UUID) -> str:
|
||||
return v.hex
|
||||
|
||||
Reference in New Issue
Block a user