mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: validation utils to use Pydantic v2 style models (#9037)
### What problem does this PR solve? - Update BaseModel to use model_config instead of Config class - Replace StrEnum with Literal types for method fields - Convert Field declarations to Annotated style ### Type of change - [x] Refactoring
This commit is contained in:
@ -14,14 +14,19 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from collections import Counter
|
||||
from enum import auto
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Request
|
||||
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
StringConstraints,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from strenum import StrEnum
|
||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
@ -307,38 +312,12 @@ def validate_uuid1_hex(v: Any) -> str:
|
||||
raise PydanticCustomError("invalid_UUID1_format", "Invalid UUID1 format")
|
||||
|
||||
|
||||
class PermissionEnum(StrEnum):
|
||||
me = auto()
|
||||
team = auto()
|
||||
|
||||
|
||||
class ChunkMethodEnum(StrEnum):
|
||||
naive = auto()
|
||||
book = auto()
|
||||
email = auto()
|
||||
laws = auto()
|
||||
manual = auto()
|
||||
one = auto()
|
||||
paper = auto()
|
||||
picture = auto()
|
||||
presentation = auto()
|
||||
qa = auto()
|
||||
table = auto()
|
||||
tag = auto()
|
||||
|
||||
|
||||
class GraphragMethodEnum(StrEnum):
|
||||
light = auto()
|
||||
general = auto()
|
||||
|
||||
|
||||
class Base(BaseModel):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
model_config = ConfigDict(extra="forbid", strict=True)
|
||||
|
||||
|
||||
class RaptorConfig(Base):
|
||||
use_raptor: bool = Field(default=False)
|
||||
use_raptor: Annotated[bool, Field(default=False)]
|
||||
prompt: Annotated[
|
||||
str,
|
||||
StringConstraints(strip_whitespace=True, min_length=1),
|
||||
@ -346,46 +325,49 @@ class RaptorConfig(Base):
|
||||
default="Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize."
|
||||
),
|
||||
]
|
||||
max_token: int = Field(default=256, ge=1, le=2048)
|
||||
threshold: float = Field(default=0.1, ge=0.0, le=1.0)
|
||||
max_cluster: int = Field(default=64, ge=1, le=1024)
|
||||
random_seed: int = Field(default=0, ge=0)
|
||||
max_token: Annotated[int, Field(default=256, ge=1, le=2048)]
|
||||
threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)]
|
||||
max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)]
|
||||
random_seed: Annotated[int, Field(default=0, ge=0)]
|
||||
|
||||
|
||||
class GraphragConfig(Base):
|
||||
use_graphrag: bool = Field(default=False)
|
||||
entity_types: list[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
|
||||
method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light)
|
||||
community: bool = Field(default=False)
|
||||
resolution: bool = Field(default=False)
|
||||
use_graphrag: Annotated[bool, Field(default=False)]
|
||||
entity_types: Annotated[list[str], Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])]
|
||||
method: Annotated[Literal["light", "general"], Field(default="light")]
|
||||
community: Annotated[bool, Field(default=False)]
|
||||
resolution: Annotated[bool, Field(default=False)]
|
||||
|
||||
|
||||
class ParserConfig(Base):
|
||||
auto_keywords: int = Field(default=0, ge=0, le=32)
|
||||
auto_questions: int = Field(default=0, ge=0, le=10)
|
||||
chunk_token_num: int = Field(default=512, ge=1, le=2048)
|
||||
delimiter: str = Field(default=r"\n", min_length=1)
|
||||
graphrag: GraphragConfig = Field(default_factory=lambda: GraphragConfig(use_graphrag=False))
|
||||
html4excel: bool = False
|
||||
layout_recognize: str = "DeepDOC"
|
||||
raptor: RaptorConfig = Field(default_factory=lambda: RaptorConfig(use_raptor=False))
|
||||
tag_kb_ids: list[str] = Field(default_factory=list)
|
||||
topn_tags: int = Field(default=1, ge=1, le=10)
|
||||
filename_embd_weight: float | None = Field(default=0.1, ge=0.0, le=1.0)
|
||||
task_page_size: int | None = Field(default=None, ge=1)
|
||||
pages: list[list[int]] | None = None
|
||||
auto_keywords: Annotated[int, Field(default=0, ge=0, le=32)]
|
||||
auto_questions: Annotated[int, Field(default=0, ge=0, le=10)]
|
||||
chunk_token_num: Annotated[int, Field(default=512, ge=1, le=2048)]
|
||||
delimiter: Annotated[str, Field(default=r"\n", min_length=1)]
|
||||
graphrag: Annotated[GraphragConfig, Field(default_factory=lambda: GraphragConfig(use_graphrag=False))]
|
||||
html4excel: Annotated[bool, Field(default=False)]
|
||||
layout_recognize: Annotated[str, Field(default="DeepDOC")]
|
||||
raptor: Annotated[RaptorConfig, Field(default_factory=lambda: RaptorConfig(use_raptor=False))]
|
||||
tag_kb_ids: Annotated[list[str], Field(default_factory=list)]
|
||||
topn_tags: Annotated[int, Field(default=1, ge=1, le=10)]
|
||||
filename_embd_weight: Annotated[float | None, Field(default=0.1, ge=0.0, le=1.0)]
|
||||
task_page_size: Annotated[int | None, Field(default=None, ge=1)]
|
||||
pages: Annotated[list[list[int]] | None, Field(default=None)]
|
||||
|
||||
|
||||
class CreateDatasetReq(Base):
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)]
|
||||
avatar: str | None = Field(default=None, max_length=65535)
|
||||
description: str | None = Field(default=None, max_length=65535)
|
||||
embedding_model: str | None = Field(default=None, max_length=255, serialization_alias="embd_id")
|
||||
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
|
||||
chunk_method: ChunkMethodEnum = Field(default=ChunkMethodEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
|
||||
parser_config: ParserConfig | None = Field(default=None)
|
||||
avatar: Annotated[str | None, Field(default=None, max_length=65535)]
|
||||
description: Annotated[str | None, Field(default=None, max_length=65535)]
|
||||
embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")]
|
||||
permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)]
|
||||
chunk_method: Annotated[
|
||||
Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
|
||||
Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"),
|
||||
]
|
||||
parser_config: Annotated[ParserConfig | None, Field(default=None)]
|
||||
|
||||
@field_validator("avatar")
|
||||
@field_validator("avatar", mode="after")
|
||||
@classmethod
|
||||
def validate_avatar_base64(cls, v: str | None) -> str | None:
|
||||
"""
|
||||
@ -438,6 +420,7 @@ class CreateDatasetReq(Base):
|
||||
@field_validator("embedding_model", mode="before")
|
||||
@classmethod
|
||||
def normalize_embedding_model(cls, v: Any) -> Any:
|
||||
"""Normalize embedding model string by stripping whitespace"""
|
||||
if isinstance(v, str):
|
||||
return v.strip()
|
||||
return v
|
||||
@ -484,10 +467,10 @@ class CreateDatasetReq(Base):
|
||||
raise PydanticCustomError("format_invalid", "Model name and provider cannot be whitespace-only strings")
|
||||
return v
|
||||
|
||||
@field_validator("permission", mode="before")
|
||||
@classmethod
|
||||
def normalize_permission(cls, v: Any) -> Any:
|
||||
return normalize_str(v)
|
||||
# @field_validator("permission", mode="before")
|
||||
# @classmethod
|
||||
# def normalize_permission(cls, v: Any) -> Any:
|
||||
# return normalize_str(v)
|
||||
|
||||
@field_validator("parser_config", mode="before")
|
||||
@classmethod
|
||||
@ -544,9 +527,9 @@ class CreateDatasetReq(Base):
|
||||
|
||||
|
||||
class UpdateDatasetReq(CreateDatasetReq):
|
||||
dataset_id: str = Field(...)
|
||||
dataset_id: Annotated[str, Field(...)]
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
|
||||
pagerank: int = Field(default=0, ge=0, le=100)
|
||||
pagerank: Annotated[int, Field(default=0, ge=0, le=100)]
|
||||
|
||||
@field_validator("dataset_id", mode="before")
|
||||
@classmethod
|
||||
@ -555,7 +538,7 @@ class UpdateDatasetReq(CreateDatasetReq):
|
||||
|
||||
|
||||
class DeleteReq(Base):
|
||||
ids: list[str] | None = Field(...)
|
||||
ids: Annotated[list[str] | None, Field(...)]
|
||||
|
||||
@field_validator("ids", mode="after")
|
||||
@classmethod
|
||||
@ -634,28 +617,20 @@ class DeleteReq(Base):
|
||||
class DeleteDatasetReq(DeleteReq): ...
|
||||
|
||||
|
||||
class OrderByEnum(StrEnum):
|
||||
create_time = auto()
|
||||
update_time = auto()
|
||||
class BaseListReq(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class BaseListReq(Base):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=30, ge=1)
|
||||
orderby: OrderByEnum = Field(default=OrderByEnum.create_time)
|
||||
desc: bool = Field(default=True)
|
||||
id: Annotated[str | None, Field(default=None)]
|
||||
name: Annotated[str | None, Field(default=None)]
|
||||
page: Annotated[int, Field(default=1, ge=1)]
|
||||
page_size: Annotated[int, Field(default=30, ge=1)]
|
||||
orderby: Annotated[Literal["create_time", "update_time"], Field(default="create_time")]
|
||||
desc: Annotated[bool, Field(default=True)]
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def validate_id(cls, v: Any) -> str:
|
||||
return validate_uuid1_hex(v)
|
||||
|
||||
@field_validator("orderby", mode="before")
|
||||
@classmethod
|
||||
def normalize_orderby(cls, v: Any) -> Any:
|
||||
return normalize_str(v)
|
||||
|
||||
|
||||
class ListDatasetReq(BaseListReq): ...
|
||||
|
||||
Reference in New Issue
Block a user