mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
refa: Optimize create dataset validation (#7451)
### What problem does this PR solve? Optimize dataset validation and add function docs ### Type of change - [x] Refactoring
This commit is contained in:
@ -19,7 +19,6 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from peewee import OperationalError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api import settings
|
||||
from api.db import FileSource, StatusEnum
|
||||
@ -41,8 +40,9 @@ from api.utils.api_utils import (
|
||||
token_required,
|
||||
valid,
|
||||
valid_parser_config,
|
||||
verify_embedding_availability,
|
||||
)
|
||||
from api.utils.validation_utils import CreateDatasetReq, format_validation_error_message
|
||||
from api.utils.validation_utils import CreateDatasetReq, validate_and_parse_json_request
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@ -107,21 +107,14 @@ def create(tenant_id):
|
||||
data:
|
||||
type: object
|
||||
"""
|
||||
req_i = request.json
|
||||
if not isinstance(req_i, dict):
|
||||
return get_error_argument_result(f"Invalid request payload: expected object, got {type(req_i).__name__}")
|
||||
|
||||
try:
|
||||
req_v = CreateDatasetReq(**req_i)
|
||||
except ValidationError as e:
|
||||
return get_error_argument_result(format_validation_error_message(e))
|
||||
|
||||
# Field name transformations during model dump:
|
||||
# | Original | Dump Output |
|
||||
# |----------------|-------------|
|
||||
# | embedding_model| embd_id |
|
||||
# | chunk_method | parser_id |
|
||||
req = req_v.model_dump(by_alias=True)
|
||||
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
try:
|
||||
if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
@ -146,21 +139,9 @@ def create(tenant_id):
|
||||
if not req.get("embd_id"):
|
||||
req["embd_id"] = t.embd_id
|
||||
else:
|
||||
builtin_embedding_models = [
|
||||
"BAAI/bge-large-zh-v1.5@BAAI",
|
||||
"maidalun1020/bce-embedding-base_v1@Youdao",
|
||||
]
|
||||
is_builtin_model = req["embd_id"] in builtin_embedding_models
|
||||
try:
|
||||
# model name must be model_name@model_factory
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["embd_id"])
|
||||
is_tenant_model = TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="embedding")
|
||||
is_supported_model = LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")
|
||||
if not (is_supported_model and (is_builtin_model or is_tenant_model)):
|
||||
return get_error_argument_result(f"The embedding_model '{req['embd_id']}' is not supported")
|
||||
except OperationalError as e:
|
||||
logging.exception(e)
|
||||
return get_error_data_result(message="Database operation failed")
|
||||
ok, err = verify_embedding_availability(req["embd_id"], tenant_id)
|
||||
if not ok:
|
||||
return err
|
||||
|
||||
try:
|
||||
if not KnowledgebaseService.save(**req):
|
||||
|
||||
Reference in New Issue
Block a user