Refa: http API create dataset and test cases (#7393)

### What problem does this PR solve?

This PR introduces Pydantic-based validation for the create dataset HTTP
API, improving code clarity and robustness. Key changes include:
1. Pydantic Validation
2. ​​Error Handling
3. Test Updates
4. Documentation

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Documentation Update
- [x] Refactoring
This commit is contained in:
liu an
2025-04-29 16:53:57 +08:00
committed by GitHub
parent c88e4b3fc0
commit 78380fa181
11 changed files with 1239 additions and 812 deletions

View File

@ -14,24 +14,35 @@
# limitations under the License.
#
import logging
from flask import request
from api.db import StatusEnum, FileSource
from peewee import OperationalError
from pydantic import ValidationError
from api import settings
from api.db import FileSource, StatusEnum
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService, LLMService
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.user_service import TenantService
from api import settings
from api.utils import get_uuid
from api.utils.api_utils import (
check_duplicate_ids,
dataset_readonly_fields,
get_error_argument_result,
get_error_data_result,
get_parser_config,
get_result,
token_required,
get_error_data_result,
valid,
get_parser_config, valid_parser_config, dataset_readonly_fields,check_duplicate_ids
valid_parser_config,
)
from api.utils.validation_utils import CreateDatasetReq, format_validation_error_message
@manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -62,16 +73,28 @@ def create(tenant_id):
name:
type: string
description: Name of the dataset.
avatar:
type: string
description: Base64 encoding of the avatar.
description:
type: string
description: Description of the dataset.
embedding_model:
type: string
description: Embedding model Name.
permission:
type: string
enum: ['me', 'team']
description: Dataset permission.
chunk_method:
type: string
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
"presentation", "picture", "one", "email", "tag"
enum: ["naive", "book", "email", "laws", "manual", "one", "paper",
"picture", "presentation", "qa", "table", "tag"
]
description: Chunking method.
pagerank:
type: integer
description: Set page rank.
parser_config:
type: object
description: Parser configuration.
@ -84,106 +107,87 @@ def create(tenant_id):
data:
type: object
"""
req = request.json
for k in req.keys():
if dataset_readonly_fields(k):
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"'{k}' is readonly.")
e, t = TenantService.get_by_id(tenant_id)
permission = req.get("permission")
chunk_method = req.get("chunk_method")
parser_config = req.get("parser_config")
valid_parser_config(parser_config)
valid_permission = ["me", "team"]
valid_chunk_method = [
"naive",
"manual",
"qa",
"table",
"paper",
"book",
"laws",
"presentation",
"picture",
"one",
"email",
"tag"
]
check_validation = valid(
permission,
valid_permission,
chunk_method,
valid_chunk_method,
)
if check_validation:
return check_validation
req["parser_config"] = get_parser_config(chunk_method, parser_config)
if "tenant_id" in req:
return get_error_data_result(message="`tenant_id` must not be provided")
if "chunk_count" in req or "document_count" in req:
return get_error_data_result(
message="`chunk_count` or `document_count` must not be provided"
)
if "name" not in req:
return get_error_data_result(message="`name` is not empty!")
req_i = request.json
if not isinstance(req_i, dict):
return get_error_argument_result(f"Invalid request payload: expected object, got {type(req_i).__name__}")
try:
req_v = CreateDatasetReq(**req_i)
except ValidationError as e:
return get_error_argument_result(format_validation_error_message(e))
# Field name transformations during model dump:
# | Original | Dump Output |
# |----------------|-------------|
# | embedding_model| embd_id |
# | chunk_method | parser_id |
req = req_v.model_dump(by_alias=True)
try:
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_argument_result(message=f"Dataset name '{req['name']}' already exists")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
req["parser_config"] = get_parser_config(req["parser_id"], req["parser_config"])
req["id"] = get_uuid()
req["name"] = req["name"].strip()
if req["name"] == "":
return get_error_data_result(message="`name` is not empty string!")
if len(req["name"]) >= 128:
return get_error_data_result(
message="Dataset name should not be longer than 128 characters."
)
if KnowledgebaseService.query(
name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value
):
return get_error_data_result(
message="Duplicated dataset name in creating dataset."
)
req["tenant_id"] = tenant_id
req["created_by"] = tenant_id
if not req.get("embedding_model"):
req["embedding_model"] = t.embd_id
try:
ok, t = TenantService.get_by_id(tenant_id)
if not ok:
return get_error_data_result(message="Tenant not found")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
if not req.get("embd_id"):
req["embd_id"] = t.embd_id
else:
valid_embedding_models = [
"BAAI/bge-large-zh-v1.5",
"maidalun1020/bce-embedding-base_v1",
builtin_embedding_models = [
"BAAI/bge-large-zh-v1.5@BAAI",
"maidalun1020/bce-embedding-base_v1@Youdao",
]
embd_model = LLMService.query(
llm_name=req["embedding_model"], model_type="embedding"
)
if embd_model:
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),):
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
if not embd_model:
embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))
if not embd_model:
return get_error_data_result(
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
)
is_builtin_model = req["embd_id"] in builtin_embedding_models
try:
# model name must be model_name@model_factory
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["embd_id"])
is_tenant_model = TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="embedding")
is_supported_model = LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")
if not (is_supported_model and (is_builtin_model or is_tenant_model)):
return get_error_argument_result(f"The embedding_model '{req['embd_id']}' is not supported")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
try:
if not KnowledgebaseService.save(**req):
return get_error_data_result(message="Database operation failed")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
try:
ok, k = KnowledgebaseService.get_by_id(req["id"])
if not ok:
return get_error_data_result(message="Dataset created failed")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")
response_data = {}
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
"parser_id": "chunk_method",
"embd_id": "embedding_model",
}
mapped_keys = {
new_key: req[old_key]
for new_key, old_key in key_mapping.items()
if old_key in req
}
req.update(mapped_keys)
flds = list(req.keys())
for f in flds:
if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]:
del req[f]
if not KnowledgebaseService.save(**req):
return get_error_data_result(message="Create dataset error.(Database error)")
renamed_data = {}
e, k = KnowledgebaseService.get_by_id(req["id"])
for key, value in k.to_dict().items():
new_key = key_mapping.get(key, key)
renamed_data[new_key] = value
return get_result(data=renamed_data)
response_data[new_key] = value
return get_result(data=response_data)
@manager.route("/datasets", methods=["DELETE"]) # noqa: F821
@ -254,29 +258,28 @@ def delete(tenant_id):
]
)
File2DocumentService.delete_by_document_id(doc.id)
FileService.filter_delete(
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
if not KnowledgebaseService.delete_by_id(id):
errors.append(f"Delete dataset error for {id}")
continue
success_count += 1
if errors:
if success_count > 0:
return get_result(
data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} datasets with {len(errors)} errors"
)
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} datasets with {len(errors)} errors")
else:
return get_error_data_result(message="; ".join(errors))
if duplicate_messages:
if success_count > 0:
return get_result(message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages},)
return get_result(
message=f"Partially deleted {success_count} datasets with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages},
)
else:
return get_error_data_result(message=";".join(duplicate_messages))
return get_result(code=settings.RetCode.SUCCESS)
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, dataset_id):
"""
@ -333,7 +336,7 @@ def update(tenant_id, dataset_id):
if dataset_readonly_fields(k):
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"'{k}' is readonly.")
e, t = TenantService.get_by_id(tenant_id)
invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status","token_num","update_date","update_time"}
invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status", "token_num", "update_date", "update_time"}
if any(key in req for key in invalid_keys):
return get_error_data_result(message="The input parameters are invalid.")
permission = req.get("permission")
@ -341,20 +344,7 @@ def update(tenant_id, dataset_id):
parser_config = req.get("parser_config")
valid_parser_config(parser_config)
valid_permission = ["me", "team"]
valid_chunk_method = [
"naive",
"manual",
"qa",
"table",
"paper",
"book",
"laws",
"presentation",
"picture",
"one",
"email",
"tag"
]
valid_chunk_method = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "email", "tag"]
check_validation = valid(
permission,
valid_permission,
@ -381,18 +371,14 @@ def update(tenant_id, dataset_id):
req.pop("document_count")
if req.get("chunk_method"):
if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id:
return get_error_data_result(
message="If `chunk_count` is not 0, `chunk_method` is not changeable."
)
return get_error_data_result(message="If `chunk_count` is not 0, `chunk_method` is not changeable.")
req["parser_id"] = req.pop("chunk_method")
if req["parser_id"] != kb.parser_id:
if not req.get("parser_config"):
req["parser_config"] = get_parser_config(chunk_method, parser_config)
if "embedding_model" in req:
if kb.chunk_num != 0 and req["embedding_model"] != kb.embd_id:
return get_error_data_result(
message="If `chunk_count` is not 0, `embedding_model` is not changeable."
)
return get_error_data_result(message="If `chunk_count` is not 0, `embedding_model` is not changeable.")
if not req.get("embedding_model"):
return get_error_data_result("`embedding_model` can't be empty")
valid_embedding_models = [
@ -409,38 +395,26 @@ def update(tenant_id, dataset_id):
"text-embedding-v3",
"maidalun1020/bce-embedding-base_v1",
]
embd_model = LLMService.query(
llm_name=req["embedding_model"], model_type="embedding"
)
embd_model = LLMService.query(llm_name=req["embedding_model"], model_type="embedding")
if embd_model:
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),):
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(
tenant_id=tenant_id,
model_type="embedding",
llm_name=req.get("embedding_model"),
):
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
if not embd_model:
embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))
embd_model = TenantLLMService.query(tenant_id=tenant_id, model_type="embedding", llm_name=req.get("embedding_model"))
if not embd_model:
return get_error_data_result(
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
)
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
req["embd_id"] = req.pop("embedding_model")
if "name" in req:
req["name"] = req["name"].strip()
if len(req["name"]) >= 128:
return get_error_data_result(
message="Dataset name should not be longer than 128 characters."
)
if (
req["name"].lower() != kb.name.lower()
and len(
KnowledgebaseService.query(
name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value
)
)
> 0
):
return get_error_data_result(
message="Duplicated dataset name in updating dataset."
)
return get_error_data_result(message="Dataset name should not be longer than 128 characters.")
if req["name"].lower() != kb.name.lower() and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
return get_error_data_result(message="Duplicated dataset name in updating dataset.")
flds = list(req.keys())
for f in flds:
if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]:
@ -511,11 +485,11 @@ def list_datasets(tenant_id):
id = request.args.get("id")
name = request.args.get("name")
if id:
kbs = KnowledgebaseService.get_kb_by_id(id,tenant_id)
kbs = KnowledgebaseService.get_kb_by_id(id, tenant_id)
if not kbs:
return get_error_data_result(f"You don't own the dataset {id}")
if name:
kbs = KnowledgebaseService.get_kb_by_name(name,tenant_id)
kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id)
if not kbs:
return get_error_data_result(f"You don't own the dataset {name}")
page_number = int(request.args.get("page", 1))

View File

@ -322,6 +322,10 @@ def get_error_data_result(
return jsonify(response)
def get_error_argument_result(message="Invalid arguments"):
return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
def generate_confirmation_token(tenant_id):
serializer = URLSafeTimedSerializer(tenant_id)
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
@ -368,46 +372,34 @@ def get_parser_config(chunk_method, parser_config):
return parser_config
def get_data_openai(id=None,
created=None,
model=None,
prompt_tokens= 0,
completion_tokens=0,
content = None,
finish_reason= None,
object="chat.completion",
param=None,
def get_data_openai(
id=None,
created=None,
model=None,
prompt_tokens=0,
completion_tokens=0,
content=None,
finish_reason=None,
object="chat.completion",
param=None,
):
total_tokens= prompt_tokens + completion_tokens
total_tokens = prompt_tokens + completion_tokens
return {
"id":f"{id}",
"id": f"{id}",
"object": object,
"created": int(time.time()) if created else None,
"model": model,
"param":param,
"param": param,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"completion_tokens_details": {
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
"completion_tokens_details": {"reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0},
},
"choices": [
{
"message": {
"role": "assistant",
"content": content
},
"logprobs": None,
"finish_reason": finish_reason,
"index": 0
}
]
}
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": finish_reason, "index": 0}],
}
def valid_parser_config(parser_config):
if not parser_config:
return

View File

@ -0,0 +1,162 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from enum import auto
from typing import Annotated, List, Optional
from pydantic import BaseModel, Field, StringConstraints, ValidationError, field_validator
from strenum import StrEnum
def format_validation_error_message(e: ValidationError):
error_messages = []
for error in e.errors():
field = ".".join(map(str, error["loc"]))
msg = error["msg"]
input_val = error["input"]
input_str = str(input_val)
if len(input_str) > 128:
input_str = input_str[:125] + "..."
error_msg = f"Field: <{field}> - Message: <{msg}> - Value: <{input_str}>"
error_messages.append(error_msg)
return "\n".join(error_messages)
class PermissionEnum(StrEnum):
me = auto()
team = auto()
class ChunkMethodnEnum(StrEnum):
naive = auto()
book = auto()
email = auto()
laws = auto()
manual = auto()
one = auto()
paper = auto()
picture = auto()
presentation = auto()
qa = auto()
table = auto()
tag = auto()
class GraphragMethodEnum(StrEnum):
light = auto()
general = auto()
class Base(BaseModel):
class Config:
extra = "forbid"
json_schema_extra = {"charset": "utf8mb4", "collation": "utf8mb4_0900_ai_ci"}
class RaptorConfig(Base):
use_raptor: bool = Field(default=False)
prompt: Annotated[
str,
StringConstraints(strip_whitespace=True, min_length=1),
Field(
default="Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize."
),
]
max_token: int = Field(default=256, ge=1, le=2048)
threshold: float = Field(default=0.1, ge=0.0, le=1.0)
max_cluster: int = Field(default=64, ge=1, le=1024)
random_seed: int = Field(default=0, ge=0, le=10_000)
class GraphragConfig(Base):
use_graphrag: bool = Field(default=False)
entity_types: List[str] = Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])
method: GraphragMethodEnum = Field(default=GraphragMethodEnum.light)
community: bool = Field(default=False)
resolution: bool = Field(default=False)
class ParserConfig(Base):
auto_keywords: int = Field(default=0, ge=0, le=32)
auto_questions: int = Field(default=0, ge=0, le=10)
chunk_token_num: int = Field(default=128, ge=1, le=2048)
delimiter: str = Field(default=r"\n!?;。;!?", min_length=1)
graphrag: Optional[GraphragConfig] = None
html4excel: bool = False
layout_recognize: str = "DeepDOC"
raptor: Optional[RaptorConfig] = None
tag_kb_ids: List[str] = Field(default_factory=list)
topn_tags: int = Field(default=1, ge=1, le=10)
filename_embd_weight: Optional[float] = Field(default=None, ge=0.0, le=1.0)
task_page_size: Optional[int] = Field(default=None, ge=1, le=10_000)
pages: Optional[List[List[int]]] = None
class CreateDatasetReq(Base):
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=128), Field(...)]
avatar: Optional[str] = Field(default=None, max_length=65535)
description: Optional[str] = Field(default=None, max_length=65535)
embedding_model: Annotated[Optional[str], StringConstraints(strip_whitespace=True, max_length=255), Field(default=None, serialization_alias="embd_id")]
permission: Annotated[PermissionEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=16), Field(default=PermissionEnum.me)]
chunk_method: Annotated[ChunkMethodnEnum, StringConstraints(strip_whitespace=True, min_length=1, max_length=32), Field(default=ChunkMethodnEnum.naive, serialization_alias="parser_id")]
pagerank: int = Field(default=0, ge=0, le=100)
parser_config: Optional[ParserConfig] = Field(default=None)
@field_validator("avatar")
@classmethod
def validate_avatar_base64(cls, v: str) -> str:
if v is None:
return v
if "," in v:
prefix, _ = v.split(",", 1)
if not prefix.startswith("data:"):
raise ValueError("Invalid MIME prefix format. Must start with 'data:'")
mime_type = prefix[5:].split(";")[0]
supported_mime_types = ["image/jpeg", "image/png"]
if mime_type not in supported_mime_types:
raise ValueError(f"Unsupported MIME type. Allowed: {supported_mime_types}")
return v
else:
raise ValueError("Missing MIME prefix. Expected format: data:<mime>;base64,<data>")
@field_validator("embedding_model", mode="after")
@classmethod
def validate_embedding_model(cls, v: str) -> str:
if "@" not in v:
raise ValueError("Embedding model must be xxx@yyy")
return v
@field_validator("permission", mode="before")
@classmethod
def permission_auto_lowercase(cls, v: str) -> str:
if isinstance(v, str):
return v.lower()
return v
@field_validator("parser_config", mode="after")
@classmethod
def validate_parser_config_json_length(cls, v: Optional[ParserConfig]) -> Optional[ParserConfig]:
if v is not None:
json_str = v.model_dump_json()
if len(json_str) > 65535:
raise ValueError("Parser config have at most 65535 characters")
return v