mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
refa: Optimize create dataset validation (#7451)
### What problem does this PR solve? Optimize dataset validation and add function docs ### Type of change - [x] Refactoring
This commit is contained in:
@ -36,11 +36,13 @@ from flask import (
|
||||
request as flask_request,
|
||||
)
|
||||
from itsdangerous import URLSafeTimedSerializer
|
||||
from peewee import OperationalError
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from api import settings
|
||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||
|
||||
requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
|
||||
@ -464,3 +466,55 @@ def check_duplicate_ids(ids, id_type="item"):
|
||||
|
||||
# Return unique IDs and error messages
|
||||
return list(set(ids)), duplicate_messages
|
||||
|
||||
|
||||
def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
|
||||
"""Verifies availability of an embedding model for a specific tenant.
|
||||
|
||||
Implements a four-stage validation process:
|
||||
1. Model identifier parsing and validation
|
||||
2. System support verification
|
||||
3. Tenant authorization check
|
||||
4. Database operation error handling
|
||||
|
||||
Args:
|
||||
embd_id (str): Unique identifier for the embedding model in format "model_name@factory"
|
||||
tenant_id (str): Tenant identifier for access control
|
||||
|
||||
Returns:
|
||||
tuple[bool, Response | None]:
|
||||
- First element (bool):
|
||||
- True: Model is available and authorized
|
||||
- False: Validation failed
|
||||
- Second element contains:
|
||||
- None on success
|
||||
- Error detail dict on failure
|
||||
|
||||
Raises:
|
||||
ValueError: When model identifier format is invalid
|
||||
OperationalError: When database connection fails (auto-handled)
|
||||
|
||||
Examples:
|
||||
>>> verify_embedding_availability("text-embedding@openai", "tenant_123")
|
||||
(True, None)
|
||||
|
||||
>>> verify_embedding_availability("invalid_model", "tenant_123")
|
||||
(False, {'code': 101, 'message': "Unsupported model: <invalid_model>"})
|
||||
"""
|
||||
try:
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id)
|
||||
if not LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"):
|
||||
return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
|
||||
|
||||
# Tongyi-Qianwen is added to TenantLLM by default, but remains unusable with empty api_key
|
||||
tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
|
||||
is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
|
||||
|
||||
is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
|
||||
if not (is_builtin_model or is_tenant_model):
|
||||
return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return False, get_error_data_result(message="Database operation failed")
|
||||
|
||||
return True, None
|
||||
|
||||
@ -14,13 +14,102 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from enum import auto
|
||||
from typing import Annotated, List, Optional
|
||||
from typing import Annotated, Any
|
||||
|
||||
from flask import Request
|
||||
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
|
||||
from strenum import StrEnum
|
||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
|
||||
|
||||
def validate_and_parse_json_request(request: Request, validator: type[BaseModel]) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Validates and parses JSON requests through a multi-stage validation pipeline.
|
||||
|
||||
Implements a robust four-stage validation process:
|
||||
1. Content-Type verification (must be application/json)
|
||||
2. JSON syntax validation
|
||||
3. Payload structure type checking
|
||||
4. Pydantic model validation with error formatting
|
||||
|
||||
Args:
|
||||
request (Request): Flask request object containing HTTP payload
|
||||
|
||||
Returns:
|
||||
tuple[Dict[str, Any] | None, str | None]:
|
||||
- First element:
|
||||
- Validated dictionary on success
|
||||
- None on validation failure
|
||||
- Second element:
|
||||
- None on success
|
||||
- Diagnostic error message on failure
|
||||
|
||||
Raises:
|
||||
UnsupportedMediaType: When Content-Type ≠ application/json
|
||||
BadRequest: For structural JSON syntax errors
|
||||
ValidationError: When payload violates Pydantic schema rules
|
||||
|
||||
Examples:
|
||||
Successful validation:
|
||||
```python
|
||||
# Input: {"name": "Dataset1", "format": "csv"}
|
||||
# Returns: ({"name": "Dataset1", "format": "csv"}, None)
|
||||
```
|
||||
|
||||
Invalid Content-Type:
|
||||
```python
|
||||
# Returns: (None, "Unsupported content type: Expected application/json, got text/xml")
|
||||
```
|
||||
|
||||
Malformed JSON:
|
||||
```python
|
||||
# Returns: (None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding")
|
||||
```
|
||||
"""
|
||||
try:
|
||||
payload = request.get_json() or {}
|
||||
except UnsupportedMediaType:
|
||||
return None, f"Unsupported content type: Expected application/json, got {request.content_type}"
|
||||
except BadRequest:
|
||||
return None, "Malformed JSON syntax: Missing commas/brackets or invalid encoding"
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return None, f"Invalid request payload: expected object, got {type(payload).__name__}"
|
||||
|
||||
try:
|
||||
validated_request = validator(**payload)
|
||||
except ValidationError as e:
|
||||
return None, format_validation_error_message(e)
|
||||
|
||||
parsed_payload = validated_request.model_dump(by_alias=True)
|
||||
|
||||
return parsed_payload, None
|
||||
|
||||
|
||||
def format_validation_error_message(e: ValidationError) -> str:
|
||||
"""Formats validation errors into a standardized string format.
|
||||
|
||||
Processes pydantic ValidationError objects to create human-readable error messages
|
||||
containing field locations, error descriptions, and input values.
|
||||
|
||||
Args:
|
||||
e (ValidationError): The validation error instance containing error details
|
||||
|
||||
Returns:
|
||||
str: Formatted error messages joined by newlines. Each line contains:
|
||||
- Field path (dot-separated)
|
||||
- Error message
|
||||
- Truncated input value (max 128 chars)
|
||||
|
||||
Example:
|
||||
>>> try:
|
||||
... UserModel(name=123, email="invalid")
|
||||
... except ValidationError as e:
|
||||
... print(format_validation_error_message(e))
|
||||
Field: <name> - Message: <Input should be a valid string> - Value: <123>
|
||||
Field: <email> - Message: <value is not a valid email address> - Value: <invalid>
|
||||
"""
|
||||
error_messages = []
|
||||
|
||||
for error in e.errors():
|
||||
@ -86,7 +175,7 @@ class RaptorConfig(Base):
|
||||
|
||||
class GraphragConfig(Base):
|
||||
use_graphrag: bool = Field(default=False)
|
||||
entity_types: List[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
|
||||
entity_types: list[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
|
||||
method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light)
|
||||
community: bool = Field(default=False)
|
||||
resolution: bool = Field(default=False)
|
||||
@ -97,30 +186,59 @@ class ParserConfig(Base):
|
||||
auto_questions: int = Field(default=0, ge=0, le=10)
|
||||
chunk_token_num: int = Field(default=128, ge=1, le=2048)
|
||||
delimiter: str = Field(default=r"\n", min_length=1)
|
||||
graphrag: Optional[GraphragConfig] = None
|
||||
graphrag: GraphragConfig | None = None
|
||||
html4excel: bool = False
|
||||
layout_recognize: str = "DeepDOC"
|
||||
raptor: Optional[RaptorConfig] = None
|
||||
tag_kb_ids: List[str] = Field(default_factory=list)
|
||||
raptor: RaptorConfig | None = None
|
||||
tag_kb_ids: list[str] = Field(default_factory=list)
|
||||
topn_tags: int = Field(default=1, ge=1, le=10)
|
||||
filename_embd_weight: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
task_page_size: Optional[int] = Field(default=None, ge=1)
|
||||
pages: Optional[List[List[int]]] = None
|
||||
filename_embd_weight: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
task_page_size: int | None = Field(default=None, ge=1)
|
||||
pages: list[list[int]] | None = None
|
||||
|
||||
|
||||
class CreateDatasetReq(Base):
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=128), Field(...)]
|
||||
avatar: Optional[str] = Field(default=None, max_length=65535)
|
||||
description: Optional[str] = Field(default=None, max_length=65535)
|
||||
embedding_model: Annotated[Optional[str], StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
|
||||
avatar: str | None = Field(default=None, max_length=65535)
|
||||
description: str | None = Field(default=None, max_length=65535)
|
||||
embedding_model: Annotated[str | None, StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
|
||||
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
|
||||
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
|
||||
pagerank: int = Field(default=0, ge=0, le=100)
|
||||
parser_config: Optional[ParserConfig] = Field(default=None)
|
||||
parser_config: ParserConfig | None = Field(default=None)
|
||||
|
||||
@field_validator("avatar")
|
||||
@classmethod
|
||||
def validate_avatar_base64(cls, v: str) -> str:
|
||||
"""Validates Base64-encoded avatar string format and MIME type compliance.
|
||||
|
||||
Implements a three-stage validation workflow:
|
||||
1. MIME prefix existence check
|
||||
2. MIME type format validation
|
||||
3. Supported type verification
|
||||
|
||||
Args:
|
||||
v (str): Raw avatar field value
|
||||
|
||||
Returns:
|
||||
str: Validated Base64 string
|
||||
|
||||
Raises:
|
||||
ValueError: For structural errors in these cases:
|
||||
- Missing MIME prefix header
|
||||
- Invalid MIME prefix format
|
||||
- Unsupported image MIME type
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Valid case
|
||||
CreateDatasetReq(avatar="...")
|
||||
|
||||
# Invalid cases
|
||||
CreateDatasetReq(avatar="image/jpeg;base64,...") # Missing 'data:' prefix
|
||||
CreateDatasetReq(avatar="data:video/mp4;base64,...") # Unsupported MIME type
|
||||
```
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
@ -141,22 +259,83 @@ class CreateDatasetReq(Base):
|
||||
@field_validator("embedding_model", mode="after")
|
||||
@classmethod
|
||||
def validate_embedding_model(cls, v: str) -> str:
|
||||
"""Validates embedding model identifier format compliance.
|
||||
|
||||
Validation pipeline:
|
||||
1. Structural format verification
|
||||
2. Component non-empty check
|
||||
3. Value normalization
|
||||
|
||||
Args:
|
||||
v (str): Raw model identifier
|
||||
|
||||
Returns:
|
||||
str: Validated <model_name>@<provider> format
|
||||
|
||||
Raises:
|
||||
ValueError: For these violations:
|
||||
- Missing @ separator
|
||||
- Empty model_name/provider
|
||||
- Invalid component structure
|
||||
|
||||
Examples:
|
||||
Valid: "text-embedding-3-large@openai"
|
||||
Invalid: "invalid_model" (no @)
|
||||
Invalid: "@openai" (empty model_name)
|
||||
Invalid: "text-embedding-3-large@" (empty provider)
|
||||
"""
|
||||
if "@" not in v:
|
||||
raise ValueError("Embedding model must be xxx@yyy")
|
||||
raise ValueError("Embedding model identifier must follow <model_name>@<provider> format")
|
||||
|
||||
components = v.split("@", 1)
|
||||
if len(components) != 2 or not all(components):
|
||||
raise ValueError("Both model_name and provider must be non-empty strings")
|
||||
|
||||
model_name, provider = components
|
||||
if not model_name.strip() or not provider.strip():
|
||||
raise ValueError("Model name and provider cannot be whitespace-only strings")
|
||||
return v
|
||||
|
||||
@field_validator("permission", mode="before")
|
||||
@classmethod
|
||||
def permission_auto_lowercase(cls, v: str) -> str:
|
||||
if isinstance(v, str):
|
||||
return v.lower()
|
||||
return v
|
||||
"""Normalize permission input to lowercase for consistent PermissionEnum matching.
|
||||
|
||||
Args:
|
||||
v (str): Raw input value for the permission field
|
||||
|
||||
Returns:
|
||||
Lowercase string if input is string type, otherwise returns original value
|
||||
|
||||
Behavior:
|
||||
- Converts string inputs to lowercase (e.g., "ME" → "me")
|
||||
- Non-string values pass through unchanged
|
||||
- Works in validation pre-processing stage (before enum conversion)
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parser_config", mode="after")
|
||||
@classmethod
|
||||
def validate_parser_config_json_length(cls, v: Optional[ParserConfig]) -> Optional[ParserConfig]:
|
||||
if v is not None:
|
||||
json_str = v.model_dump_json()
|
||||
if len(json_str) > 65535:
|
||||
raise ValueError("Parser config have at most 65535 characters")
|
||||
def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserConfig | None:
|
||||
"""Validates serialized JSON length constraints for parser configuration.
|
||||
|
||||
Implements a three-stage validation workflow:
|
||||
1. Null check - bypass validation for empty configurations
|
||||
2. Model serialization - convert Pydantic model to JSON string
|
||||
3. Size verification - enforce maximum allowed payload size
|
||||
|
||||
Args:
|
||||
v (ParserConfig | None): Raw parser configuration object
|
||||
|
||||
Returns:
|
||||
ParserConfig | None: Validated configuration object
|
||||
|
||||
Raises:
|
||||
ValueError: When serialized JSON exceeds 65,535 characters
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if (json_str := v.model_dump_json()) and len(json_str) > 65535:
|
||||
raise ValueError(f"Parser config exceeds size limit (max 65,535 characters). Current size: {len(json_str):,}")
|
||||
return v
|
||||
|
||||
Reference in New Issue
Block a user