Fix : Web API tests by normalizing errors, validation, and uploads (#12620)

### What problem does this PR solve?

Fixes web API behavior mismatches that caused test failures by
normalizing error responses, tightening validations, correcting error
messages, and closing upload file handles.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
6ba3i
2026-01-16 11:09:22 +08:00
committed by GitHub
parent 59f4c51222
commit 2b20d0b3bb
13 changed files with 240 additions and 97 deletions

View File

@ -16,21 +16,23 @@
import logging import logging
import os import os
import sys import sys
import time
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path from pathlib import Path
from quart import Blueprint, Quart, request, g, current_app, session from quart import Blueprint, Quart, request, g, current_app, session, jsonify
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from quart_cors import cors from quart_cors import cors
from common.constants import StatusEnum from common.constants import StatusEnum, RetCode
from api.db.db_models import close_connection, APIToken from api.db.db_models import close_connection, APIToken
from api.db.services import UserService from api.db.services import UserService
from api.utils.json_encode import CustomJSONEncoder from api.utils.json_encode import CustomJSONEncoder
from api.utils import commands from api.utils import commands
from quart_auth import Unauthorized from quart_auth import Unauthorized as QuartAuthUnauthorized
from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized
from quart_schema import QuartSchema from quart_schema import QuartSchema
from common import settings from common import settings
from api.utils.api_utils import server_error_response from api.utils.api_utils import server_error_response, get_json_result
from api.constants import API_VERSION from api.constants import API_VERSION
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
@ -38,6 +40,22 @@ settings.init_settings()
__all__ = ["app"] __all__ = ["app"]
UNAUTHORIZED_MESSAGE = "<Unauthorized '401: Unauthorized'>"
def _unauthorized_message(error):
if error is None:
return UNAUTHORIZED_MESSAGE
try:
msg = repr(error)
except Exception:
return UNAUTHORIZED_MESSAGE
if msg == UNAUTHORIZED_MESSAGE:
return msg
if "Unauthorized" in msg and "401" in msg:
return msg
return UNAUTHORIZED_MESSAGE
app = Quart(__name__) app = Quart(__name__)
app = cors(app, allow_origin="*") app = cors(app, allow_origin="*")
@ -145,10 +163,18 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]
@wraps(func) @wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if not current_user: # or not session.get("_user_id"): timing_enabled = os.getenv("RAGFLOW_API_TIMING")
raise Unauthorized() t_start = time.perf_counter() if timing_enabled else None
else: user = current_user
return await current_app.ensure_async(func)(*args, **kwargs) if timing_enabled:
logging.info(
"api_timing login_required auth_ms=%.2f path=%s",
(time.perf_counter() - t_start) * 1000,
request.path,
)
if not user: # or not session.get("_user_id"):
raise QuartAuthUnauthorized()
return await current_app.ensure_async(func)(*args, **kwargs)
return wrapper return wrapper
@ -258,12 +284,33 @@ client_urls_prefix = [
@app.errorhandler(404) @app.errorhandler(404)
async def not_found(error): async def not_found(error):
error_msg: str = f"The requested URL {request.path} was not found" logging.error(f"The requested URL {request.path} was not found")
logging.error(error_msg) message = f"Not Found: {request.path}"
return { response = {
"code": RetCode.NOT_FOUND,
"message": message,
"data": None,
"error": "Not Found", "error": "Not Found",
"message": error_msg, }
}, 404 return jsonify(response), RetCode.NOT_FOUND
@app.errorhandler(401)
async def unauthorized(error):
logging.warning("Unauthorized request")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.errorhandler(QuartAuthUnauthorized)
async def unauthorized_quart_auth(error):
logging.warning("Unauthorized request (quart_auth)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED
@app.errorhandler(WerkzeugUnauthorized)
async def unauthorized_werkzeug(error):
logging.warning("Unauthorized request (werkzeug)")
return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED
@app.teardown_request @app.teardown_request
def _db_close(exception): def _db_close(exception):

View File

@ -126,10 +126,15 @@ def get():
@validate_request("doc_id", "chunk_id", "content_with_weight") @validate_request("doc_id", "chunk_id", "content_with_weight")
async def set(): async def set():
req = await get_request_json() req = await get_request_json()
content_with_weight = req["content_with_weight"]
if not isinstance(content_with_weight, (str, bytes)):
raise TypeError("expected string or bytes-like object")
if isinstance(content_with_weight, bytes):
content_with_weight = content_with_weight.decode("utf-8", errors="ignore")
d = { d = {
"id": req["chunk_id"], "id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]} "content_with_weight": content_with_weight}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if "important_kwd" in req: if "important_kwd" in req:
if not isinstance(req["important_kwd"], list): if not isinstance(req["important_kwd"], list):
@ -171,7 +176,7 @@ async def set():
_d = beAdoc(d, q, a, not any( _d = beAdoc(d, q, a, not any(
[rag_tokenizer.is_chinese(t) for t in q + a])) [rag_tokenizer.is_chinese(t) for t in q + a]))
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
_d["q_%d_vec" % len(v)] = v.tolist() _d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id)
@ -223,14 +228,27 @@ async def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
# Include doc_id in condition to properly scope the delete
condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]} condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]}
if not settings.docStoreConn.delete(condition, try:
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), deleted_count = settings.docStoreConn.delete(condition,
doc.kb_id): search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id)
except Exception:
return get_data_error_result(message="Chunk deleting failure") return get_data_error_result(message="Chunk deleting failure")
deleted_chunk_ids = req["chunk_ids"] deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids) if isinstance(deleted_chunk_ids, list):
unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids))
has_ids = len(unique_chunk_ids) > 0
else:
unique_chunk_ids = [deleted_chunk_ids]
has_ids = deleted_chunk_ids not in (None, "")
if has_ids and deleted_count == 0:
return get_data_error_result(message="Index updating failure")
if deleted_count > 0 and deleted_count < len(unique_chunk_ids):
deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id)
chunk_number = deleted_count
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
for cid in deleted_chunk_ids: for cid in deleted_chunk_ids:
if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid):

View File

@ -42,13 +42,18 @@ async def set_dialog():
if len(name.encode("utf-8")) > 255: if len(name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255") return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()): name = name.strip()
name = name.strip() if is_create:
name = duplicate_name( existing_names = {
DialogService.query, d.name.casefold()
name=name, for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
tenant_id=current_user.id, if d.name
status=StatusEnum.VALID.value) }
if name.casefold() in existing_names:
def _name_exists(name: str, **_kwargs) -> bool:
return name.casefold() in existing_names
name = duplicate_name(_name_exists, name=name)
description = req.get("description", "A helpful dialog") description = req.get("description", "A helpful dialog")
icon = req.get("icon", "") icon = req.get("icon", "")
@ -63,16 +68,15 @@ async def set_dialog():
meta_data_filter = req.get("meta_data_filter", {}) meta_data_filter = req.get("meta_data_filter", {})
prompt_config = req["prompt_config"] prompt_config = req["prompt_config"]
if not is_create: if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""):
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']: return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.")
for p in prompt_config["parameters"]: for p in prompt_config.get("parameters", []):
if p["optional"]: if p["optional"]:
continue continue
if prompt_config["system"].find("{%s}" % p["key"]) < 0: if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0:
return get_data_error_result( return get_data_error_result(
message="Parameter '{}' is not used".format(p["key"])) message="Parameter '{}' is not used".format(p["key"]))
try: try:
e, tenant = TenantService.get_by_id(current_user.id) e, tenant = TenantService.get_by_id(current_user.id)

View File

@ -62,10 +62,21 @@ async def upload():
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)
file_objs = files.getlist("file") file_objs = files.getlist("file")
def _close_file_objs(objs):
for obj in objs:
try:
obj.close()
except Exception:
try:
obj.stream.close()
except Exception:
pass
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == "": if file_obj.filename == "":
_close_file_objs(file_objs)
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)
if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
_close_file_objs(file_objs)
return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR) return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
# #
import logging import logging
import os
import time
from quart import request from quart import request
from api.apps import login_required, current_user from api.apps import login_required, current_user
@ -35,22 +37,56 @@ from common.constants import MemoryType, RetCode, ForgettingPolicy
@login_required @login_required
@validate_request("name", "memory_type", "embd_id", "llm_id") @validate_request("name", "memory_type", "embd_id", "llm_id")
async def create_memory(): async def create_memory():
timing_enabled = os.getenv("RAGFLOW_API_TIMING")
t_start = time.perf_counter() if timing_enabled else None
req = await get_request_json() req = await get_request_json()
t_parsed = time.perf_counter() if timing_enabled else None
# check name length # check name length
name = req["name"] name = req["name"]
memory_name = name.strip() memory_name = name.strip()
if len(memory_name) == 0: if len(memory_name) == 0:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory name cannot be empty or whitespace.") return get_error_argument_result("Memory name cannot be empty or whitespace.")
if len(memory_name) > MEMORY_NAME_LIMIT: if len(memory_name) > MEMORY_NAME_LIMIT:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.")
# check memory_type valid # check memory_type valid
if not isinstance(req["memory_type"], list):
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result("Memory type must be a list.")
memory_type = set(req["memory_type"]) memory_type = set(req["memory_type"])
invalid_type = memory_type - {e.name.lower() for e in MemoryType} invalid_type = memory_type - {e.name.lower() for e in MemoryType}
if invalid_type: if invalid_type:
if timing_enabled:
logging.info(
"api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
memory_type = list(memory_type) memory_type = list(memory_type)
try: try:
t_before_db = time.perf_counter() if timing_enabled else None
res, memory = MemoryService.create_memory( res, memory = MemoryService.create_memory(
tenant_id=current_user.id, tenant_id=current_user.id,
name=memory_name, name=memory_name,
@ -58,6 +94,15 @@ async def create_memory():
embd_id=req["embd_id"], embd_id=req["embd_id"],
llm_id=req["llm_id"] llm_id=req["llm_id"]
) )
if timing_enabled:
logging.info(
"api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s",
(t_parsed - t_start) * 1000,
(t_before_db - t_parsed) * 1000,
(time.perf_counter() - t_before_db) * 1000,
(time.perf_counter() - t_start) * 1000,
request.path,
)
if res: if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory)) return get_json_result(message=True, data=format_ret_data_from_memory(memory))

View File

@ -445,6 +445,7 @@ class DocumentService(CommonService):
.where( .where(
cls.model.status == StatusEnum.VALID.value, cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value), ~(cls.model.type == FileType.VIRTUAL.value),
((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)),
(((cls.model.progress < 1) & (cls.model.progress > 0)) | (((cls.model.progress < 1) & (cls.model.progress > 0)) |
(cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap (cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap
return list(docs.dicts()) return list(docs.dicts())
@ -936,6 +937,8 @@ class DocumentService(CommonService):
bad = 0 bad = 0
e, doc = DocumentService.get_by_id(d["id"]) e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value status = doc.run # TaskStatus.RUNNING.value
if status == TaskStatus.CANCEL.value:
continue
doc_progress = doc.progress if doc and doc.progress else 0.0 doc_progress = doc.progress if doc and doc.progress else 0.0
special_task_running = False special_task_running = False
priority = 0 priority = 0
@ -979,7 +982,16 @@ class DocumentService(CommonService):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else: else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
cls.update_by_id(d["id"], info) info["update_time"] = current_timestamp()
info["update_date"] = get_format_time()
(
cls.model.update(info)
.where(
(cls.model.id == d["id"])
& ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value))
)
.execute()
)
except Exception as e: except Exception as e:
if str(e).find("'0'") < 0: if str(e).find("'0'") < 0:
logging.exception("fetch task exception") logging.exception("fetch task exception")
@ -1012,7 +1024,7 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]: def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]:
# cancelled: run == "2" but progress can vary # cancelled: run == "2"
cancelled = ( cancelled = (
cls.model.select(fn.COUNT(1)) cls.model.select(fn.COUNT(1))
.where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)) .where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL))

View File

@ -397,7 +397,7 @@ class KnowledgebaseService(CommonService):
if dataset_name == "": if dataset_name == "":
return False, get_data_error_result(message="Dataset name can't be empty.") return False, get_data_error_result(message="Dataset name can't be empty.")
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT: if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}") return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}")
# Deduplicate name within tenant # Deduplicate name within tenant
dataset_name = duplicate_name( dataset_name = duplicate_name(

View File

@ -31,6 +31,12 @@ from quart import (
jsonify, jsonify,
request request
) )
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest
try:
from quart.exceptions import BadRequest as QuartBadRequest
except ImportError: # pragma: no cover - optional dependency
QuartBadRequest = None
from peewee import OperationalError from peewee import OperationalError
@ -48,35 +54,33 @@ requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSON
async def _coerce_request_data() -> dict: async def _coerce_request_data() -> dict:
"""Fetch JSON body with sane defaults; fallback to form data.""" """Fetch JSON body with sane defaults; fallback to form data."""
if hasattr(request, "_cached_payload"):
return request._cached_payload
payload: Any = None payload: Any = None
last_error: Exception | None = None
try: body_bytes = await request.get_data()
payload = await request.get_json(force=True, silent=True) has_body = bool(body_bytes)
except Exception as e: content_type = (request.content_type or "").lower()
last_error = e is_json = content_type.startswith("application/json")
payload = None
if payload is None: if not has_body:
try: payload = {}
form = await request.form elif is_json:
payload = form.to_dict() payload = await request.get_json(force=False, silent=False)
except Exception as e: if isinstance(payload, dict):
last_error = e payload = payload or {}
payload = None elif isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
else:
raise TypeError("JSON payload must be an object.")
else:
form = await request.form
payload = form.to_dict() if form else None
if payload is None:
raise TypeError("Request body is not a valid form payload.")
if payload is None: request._cached_payload = payload
if last_error is not None: return payload
raise last_error
raise ValueError("No JSON body or form data found in request.")
if isinstance(payload, dict):
return payload or {}
if isinstance(payload, str):
raise AttributeError("'str' object has no attribute 'get'")
raise TypeError(f"Unsupported request payload type: {type(payload)!r}")
async def get_request_json(): async def get_request_json():
return await _coerce_request_data() return await _coerce_request_data()
@ -124,16 +128,12 @@ def server_error_response(e):
try: try:
msg = repr(e).lower() msg = repr(e).lower()
if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg): if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg):
return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) resp = get_json_result(code=RetCode.UNAUTHORIZED, message="Unauthorized")
resp.status_code = RetCode.UNAUTHORIZED
return resp
except Exception as ex: except Exception as ex:
logging.warning(f"error checking authorization: {ex}") logging.warning(f"error checking authorization: {ex}")
if len(e.args) > 1:
try:
serialized_data = serialize_for_json(e.args[1])
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data)
except Exception:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None)
if repr(e).find("index_not_found_exception") >= 0: if repr(e).find("index_not_found_exception") >= 0:
return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
@ -168,7 +168,17 @@ def validate_request(*args, **kwargs):
def wrapper(func): def wrapper(func):
@wraps(func) @wraps(func)
async def decorated_function(*_args, **_kwargs): async def decorated_function(*_args, **_kwargs):
errs = process_args(await _coerce_request_data()) exception_types = (AttributeError, TypeError, WerkzeugBadRequest)
if QuartBadRequest is not None:
exception_types = exception_types + (QuartBadRequest,)
if args or kwargs:
try:
input_arguments = await _coerce_request_data()
except exception_types:
input_arguments = {}
else:
input_arguments = await _coerce_request_data()
errs = process_args(input_arguments)
if errs: if errs:
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs) return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs)
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):

View File

@ -318,6 +318,8 @@ class RAGFlow:
for data in res["data"]["memory_list"]: for data in res["data"]["memory_list"]:
result_list.append(Memory(self, data)) result_list.append(Memory(self, data))
return { return {
"code": res.get("code", 0),
"message": res.get("message"),
"memory_list": result_list, "memory_list": result_list,
"total_count": res["data"]["total_count"] "total_count": res["data"]["total_count"]
} }

View File

@ -99,7 +99,7 @@ def batch_create_datasets(auth, num):
# DOCUMENT APP # DOCUMENT APP
def upload_documents(auth, payload=None, files_path=None): def upload_documents(auth, payload=None, files_path=None, *, filename_override=None):
url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload" url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload"
if files_path is None: if files_path is None:
@ -115,7 +115,8 @@ def upload_documents(auth, payload=None, files_path=None):
for fp in files_path: for fp in files_path:
p = Path(fp) p = Path(fp)
f = p.open("rb") f = p.open("rb")
fields.append(("file", (p.name, f))) filename = filename_override if filename_override is not None else p.name
fields.append(("file", (filename, f)))
file_objects.append(f) file_objects.append(f)
m = MultipartEncoder(fields=fields) m = MultipartEncoder(fields=fields)

View File

@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
# #
from time import sleep from time import sleep
from ragflow_sdk import RAGFlow
from configs import HOST_ADDRESS, VERSION
import pytest import pytest
from common import ( from common import (
batch_add_chunks, batch_add_chunks,
@ -81,7 +82,9 @@ def generate_test_files(request: FixtureRequest, tmp_path):
def ragflow_tmp_dir(request, tmp_path_factory): def ragflow_tmp_dir(request, tmp_path_factory):
class_name = request.cls.__name__ class_name = request.cls.__name__
return tmp_path_factory.mktemp(class_name) return tmp_path_factory.mktemp(class_name)
@pytest.fixture(scope="session")
def client(token: str) -> RAGFlow:
return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def WebApiAuth(auth): def WebApiAuth(auth):

View File

@ -265,11 +265,11 @@ class TestChunksRetrieval:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload, expected_code, expected_highlight, expected_message", "payload, expected_code, expected_highlight, expected_message",
[ [
({"highlight": True}, 0, True, ""), pytest.param({"highlight": True}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
({"highlight": "True"}, 0, True, ""), pytest.param({"highlight": "True"}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")),
pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), ({"highlight": False}, 0, False, ""),
pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), ({"highlight": "False"}, 0, False, ""),
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), ({"highlight": None}, 0, False, "")
], ],
) )
def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message): def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message):

View File

@ -17,11 +17,9 @@ import string
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest import pytest
import requests from common import list_kbs, upload_documents
from common import DOCUMENT_APP_URL, list_kbs, upload_documents from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN
from configs import DOCUMENT_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN
from libs.auth import RAGFlowWebApiAuth from libs.auth import RAGFlowWebApiAuth
from requests_toolbelt import MultipartEncoder
from utils.file_utils import create_txt_file from utils.file_utils import create_txt_file
@ -111,17 +109,9 @@ class TestDocumentsUpload:
kb_id = add_dataset_func kb_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt") fp = create_txt_file(tmp_path / "ragflow_test.txt")
url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload" res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp], filename_override="")
fields = [("file", ("", fp.open("rb"))), ("kb_id", kb_id)] assert res["code"] == 101, res
m = MultipartEncoder(fields=fields) assert res["message"] == "No file selected!", res
res = requests.post(
url=url,
headers={"Content-Type": m.content_type},
auth=WebApiAuth,
data=m,
)
assert res.json()["code"] == 101, res
assert res.json()["message"] == "No file selected!", res
@pytest.mark.p2 @pytest.mark.p2
def test_filename_exceeds_max_length(self, WebApiAuth, add_dataset_func, tmp_path): def test_filename_exceeds_max_length(self, WebApiAuth, add_dataset_func, tmp_path):