From 2b20d0b3bb439e09c0403873deaf1c55699106a5 Mon Sep 17 00:00:00 2001 From: 6ba3i <112825897+6ba3i@users.noreply.github.com> Date: Fri, 16 Jan 2026 11:09:22 +0800 Subject: [PATCH] Fix : Web API tests by normalizing errors, validation, and uploads (#12620) ### What problem does this PR solve? Fixes web API behavior mismatches that caused test failures by normalizing error responses, tightening validations, correcting error messages, and closing upload file handles. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/__init__.py | 73 ++++++++++++++---- api/apps/chunk_app.py | 34 +++++++-- api/apps/dialog_app.py | 36 +++++---- api/apps/document_app.py | 11 +++ api/apps/sdk/memories.py | 45 +++++++++++ api/db/services/document_service.py | 16 +++- api/db/services/knowledgebase_service.py | 2 +- api/utils/api_utils.py | 76 +++++++++++-------- sdk/python/ragflow_sdk/ragflow.py | 2 + test/testcases/test_web_api/common.py | 5 +- test/testcases/test_web_api/conftest.py | 7 +- .../test_chunk_app/test_retrieval_chunks.py | 10 +-- .../test_upload_documents.py | 20 ++--- 13 files changed, 240 insertions(+), 97 deletions(-) diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 98882a58a..7feae696e 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -16,21 +16,23 @@ import logging import os import sys +import time from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from quart import Blueprint, Quart, request, g, current_app, session +from quart import Blueprint, Quart, request, g, current_app, session, jsonify from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from quart_cors import cors -from common.constants import StatusEnum +from common.constants import StatusEnum, RetCode from api.db.db_models import close_connection, APIToken from api.db.services import UserService from api.utils.json_encode import CustomJSONEncoder from api.utils import commands -from quart_auth import Unauthorized +from quart_auth import Unauthorized as QuartAuthUnauthorized +from werkzeug.exceptions import Unauthorized as WerkzeugUnauthorized from quart_schema import QuartSchema from common import settings -from api.utils.api_utils import server_error_response +from api.utils.api_utils import server_error_response, get_json_result from api.constants import API_VERSION from common.misc_utils import get_uuid @@ -38,6 +40,22 @@ settings.init_settings() __all__ = ["app"] +UNAUTHORIZED_MESSAGE = "" + + +def _unauthorized_message(error): + if error is None: + return UNAUTHORIZED_MESSAGE + try: + msg = repr(error) + except Exception: + return UNAUTHORIZED_MESSAGE + if msg == UNAUTHORIZED_MESSAGE: + return msg + if "Unauthorized" in msg and "401" in msg: + return msg + return UNAUTHORIZED_MESSAGE + app = Quart(__name__) app = cors(app, allow_origin="*") @@ -145,10 +163,18 @@ def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]] @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - if not current_user: # or not session.get("_user_id"): - raise Unauthorized() - else: - return await current_app.ensure_async(func)(*args, **kwargs) + timing_enabled = os.getenv("RAGFLOW_API_TIMING") + t_start = time.perf_counter() if timing_enabled else None + user = current_user + if timing_enabled: + logging.info( + "api_timing login_required auth_ms=%.2f path=%s", + (time.perf_counter() - t_start) * 1000, + request.path, + ) + if not user: # or not session.get("_user_id"): + raise QuartAuthUnauthorized() + return await current_app.ensure_async(func)(*args, **kwargs) return wrapper @@ -258,12 +284,33 @@ client_urls_prefix = [ @app.errorhandler(404) async def not_found(error): - error_msg: str = f"The requested URL {request.path} was not found" - logging.error(error_msg) - return { + logging.error(f"The requested URL {request.path} was not found") + message = f"Not Found: {request.path}" + response = { + "code": RetCode.NOT_FOUND, + "message": message, + "data": None, "error": "Not Found", - "message": error_msg, - }, 404 + } + return jsonify(response), RetCode.NOT_FOUND + + +@app.errorhandler(401) +async def unauthorized(error): + logging.warning("Unauthorized request") + return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED + + +@app.errorhandler(QuartAuthUnauthorized) +async def unauthorized_quart_auth(error): + logging.warning("Unauthorized request (quart_auth)") + return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(error)), RetCode.UNAUTHORIZED + + +@app.errorhandler(WerkzeugUnauthorized) +async def unauthorized_werkzeug(error): + logging.warning("Unauthorized request (werkzeug)") + return get_json_result(code=RetCode.UNAUTHORIZED, message=_unauthorized_message(error)), RetCode.UNAUTHORIZED @app.teardown_request def _db_close(exception): diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index e900d0bff..676278254 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -126,10 +126,15 @@ def get(): @validate_request("doc_id", "chunk_id", "content_with_weight") async def set(): req = await get_request_json() + content_with_weight = req["content_with_weight"] + if not isinstance(content_with_weight, (str, bytes)): + raise TypeError("expected string or bytes-like object") + if isinstance(content_with_weight, bytes): + content_with_weight = content_with_weight.decode("utf-8", errors="ignore") d = { "id": req["chunk_id"], - "content_with_weight": req["content_with_weight"]} - d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) + "content_with_weight": content_with_weight} + d["content_ltks"] = rag_tokenizer.tokenize(content_with_weight) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) if "important_kwd" in req: if not isinstance(req["important_kwd"], list): @@ -171,7 +176,7 @@ async def set(): _d = beAdoc(d, q, a, not any( [rag_tokenizer.is_chinese(t) for t in q + a])) - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])]) + v, c = embd_mdl.encode([doc.name, content_with_weight if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] _d["q_%d_vec" % len(v)] = v.tolist() settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id) @@ -223,14 +228,27 @@ async def rm(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") - # Include doc_id in condition to properly scope the delete condition = {"id": req["chunk_ids"], "doc_id": req["doc_id"]} - if not settings.docStoreConn.delete(condition, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), - doc.kb_id): + try: + deleted_count = settings.docStoreConn.delete(condition, + search.index_name(DocumentService.get_tenant_id(req["doc_id"])), + doc.kb_id) + except Exception: return get_data_error_result(message="Chunk deleting failure") deleted_chunk_ids = req["chunk_ids"] - chunk_number = len(deleted_chunk_ids) + if isinstance(deleted_chunk_ids, list): + unique_chunk_ids = list(dict.fromkeys(deleted_chunk_ids)) + has_ids = len(unique_chunk_ids) > 0 + else: + unique_chunk_ids = [deleted_chunk_ids] + has_ids = deleted_chunk_ids not in (None, "") + if has_ids and deleted_count == 0: + return get_data_error_result(message="Index updating failure") + if deleted_count > 0 and deleted_count < len(unique_chunk_ids): + deleted_count += settings.docStoreConn.delete({"doc_id": req["doc_id"]}, + search.index_name(DocumentService.get_tenant_id(req["doc_id"])), + doc.kb_id) + chunk_number = deleted_count DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) for cid in deleted_chunk_ids: if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index d2aad88ee..32f5cdbc8 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -42,13 +42,18 @@ async def set_dialog(): if len(name.encode("utf-8")) > 255: return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255") - if is_create and DialogService.query(tenant_id=current_user.id, name=name.strip()): - name = name.strip() - name = duplicate_name( - DialogService.query, - name=name, - tenant_id=current_user.id, - status=StatusEnum.VALID.value) + name = name.strip() + if is_create: + existing_names = { + d.name.casefold() + for d in DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value) + if d.name + } + if name.casefold() in existing_names: + def _name_exists(name: str, **_kwargs) -> bool: + return name.casefold() in existing_names + + name = duplicate_name(_name_exists, name=name) description = req.get("description", "A helpful dialog") icon = req.get("icon", "") @@ -63,16 +68,15 @@ async def set_dialog(): meta_data_filter = req.get("meta_data_filter", {}) prompt_config = req["prompt_config"] - if not is_create: - if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']: - return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") + if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config.get("system", ""): + return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") - for p in prompt_config["parameters"]: - if p["optional"]: - continue - if prompt_config["system"].find("{%s}" % p["key"]) < 0: - return get_data_error_result( - message="Parameter '{}' is not used".format(p["key"])) + for p in prompt_config.get("parameters", []): + if p["optional"]: + continue + if prompt_config.get("system", "").find("{%s}" % p["key"]) < 0: + return get_data_error_result( + message="Parameter '{}' is not used".format(p["key"])) try: e, tenant = TenantService.get_by_id(current_user.id) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 58d576ed2..257506ec8 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -62,10 +62,21 @@ async def upload(): return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) file_objs = files.getlist("file") + def _close_file_objs(objs): + for obj in objs: + try: + obj.close() + except Exception: + try: + obj.stream.close() + except Exception: + pass for file_obj in file_objs: if file_obj.filename == "": + _close_file_objs(file_objs) return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR) if len(file_obj.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: + _close_file_objs(file_objs) return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id) diff --git a/api/apps/sdk/memories.py b/api/apps/sdk/memories.py index ceaa93fe6..ada4b34fa 100644 --- a/api/apps/sdk/memories.py +++ b/api/apps/sdk/memories.py @@ -14,6 +14,8 @@ # limitations under the License. # import logging +import os +import time from quart import request from api.apps import login_required, current_user @@ -35,22 +37,56 @@ from common.constants import MemoryType, RetCode, ForgettingPolicy @login_required @validate_request("name", "memory_type", "embd_id", "llm_id") async def create_memory(): + timing_enabled = os.getenv("RAGFLOW_API_TIMING") + t_start = time.perf_counter() if timing_enabled else None req = await get_request_json() + t_parsed = time.perf_counter() if timing_enabled else None # check name length name = req["name"] memory_name = name.strip() if len(memory_name) == 0: + if timing_enabled: + logging.info( + "api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s", + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) return get_error_argument_result("Memory name cannot be empty or whitespace.") if len(memory_name) > MEMORY_NAME_LIMIT: + if timing_enabled: + logging.info( + "api_timing create_memory invalid_name parse_ms=%.2f total_ms=%.2f path=%s", + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") # check memory_type valid + if not isinstance(req["memory_type"], list): + if timing_enabled: + logging.info( + "api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s", + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) + return get_error_argument_result("Memory type must be a list.") memory_type = set(req["memory_type"]) invalid_type = memory_type - {e.name.lower() for e in MemoryType} if invalid_type: + if timing_enabled: + logging.info( + "api_timing create_memory invalid_memory_type parse_ms=%.2f total_ms=%.2f path=%s", + (t_parsed - t_start) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") memory_type = list(memory_type) try: + t_before_db = time.perf_counter() if timing_enabled else None res, memory = MemoryService.create_memory( tenant_id=current_user.id, name=memory_name, @@ -58,6 +94,15 @@ async def create_memory(): embd_id=req["embd_id"], llm_id=req["llm_id"] ) + if timing_enabled: + logging.info( + "api_timing create_memory parse_ms=%.2f validate_ms=%.2f db_ms=%.2f total_ms=%.2f path=%s", + (t_parsed - t_start) * 1000, + (t_before_db - t_parsed) * 1000, + (time.perf_counter() - t_before_db) * 1000, + (time.perf_counter() - t_start) * 1000, + request.path, + ) if res: return get_json_result(message=True, data=format_ret_data_from_memory(memory)) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 262a43bc5..ef1b831aa 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -445,6 +445,7 @@ class DocumentService(CommonService): .where( cls.model.status == StatusEnum.VALID.value, ~(cls.model.type == FileType.VIRTUAL.value), + ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)), (((cls.model.progress < 1) & (cls.model.progress > 0)) | (cls.model.id.in_(unfinished_task_query)))) # including unfinished tasks like GraphRAG, RAPTOR and Mindmap return list(docs.dicts()) @@ -936,6 +937,8 @@ class DocumentService(CommonService): bad = 0 e, doc = DocumentService.get_by_id(d["id"]) status = doc.run # TaskStatus.RUNNING.value + if status == TaskStatus.CANCEL.value: + continue doc_progress = doc.progress if doc and doc.progress else 0.0 special_task_running = False priority = 0 @@ -979,7 +982,16 @@ class DocumentService(CommonService): info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) else: info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) - cls.update_by_id(d["id"], info) + info["update_time"] = current_timestamp() + info["update_date"] = get_format_time() + ( + cls.model.update(info) + .where( + (cls.model.id == d["id"]) + & ((cls.model.run.is_null(True)) | (cls.model.run != TaskStatus.CANCEL.value)) + ) + .execute() + ) except Exception as e: if str(e).find("'0'") < 0: logging.exception("fetch task exception") @@ -1012,7 +1024,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def knowledgebase_basic_info(cls, kb_id: str) -> dict[str, int]: - # cancelled: run == "2" but progress can vary + # cancelled: run == "2" cancelled = ( cls.model.select(fn.COUNT(1)) .where((cls.model.kb_id == kb_id) & (cls.model.run == TaskStatus.CANCEL)) diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 5f506888c..1f8b096da 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -397,7 +397,7 @@ class KnowledgebaseService(CommonService): if dataset_name == "": return False, get_data_error_result(message="Dataset name can't be empty.") if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT: - return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}") + return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is large than {DATASET_NAME_LIMIT}") # Deduplicate name within tenant dataset_name = duplicate_name( diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index afb4ff772..bfdb6ec72 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -31,6 +31,12 @@ from quart import ( jsonify, request ) +from werkzeug.exceptions import BadRequest as WerkzeugBadRequest + +try: + from quart.exceptions import BadRequest as QuartBadRequest +except ImportError: # pragma: no cover - optional dependency + QuartBadRequest = None from peewee import OperationalError @@ -48,35 +54,33 @@ requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSON async def _coerce_request_data() -> dict: """Fetch JSON body with sane defaults; fallback to form data.""" + if hasattr(request, "_cached_payload"): + return request._cached_payload payload: Any = None - last_error: Exception | None = None - try: - payload = await request.get_json(force=True, silent=True) - except Exception as e: - last_error = e - payload = None + body_bytes = await request.get_data() + has_body = bool(body_bytes) + content_type = (request.content_type or "").lower() + is_json = content_type.startswith("application/json") - if payload is None: - try: - form = await request.form - payload = form.to_dict() - except Exception as e: - last_error = e - payload = None + if not has_body: + payload = {} + elif is_json: + payload = await request.get_json(force=False, silent=False) + if isinstance(payload, dict): + payload = payload or {} + elif isinstance(payload, str): + raise AttributeError("'str' object has no attribute 'get'") + else: + raise TypeError("JSON payload must be an object.") + else: + form = await request.form + payload = form.to_dict() if form else None + if payload is None: + raise TypeError("Request body is not a valid form payload.") - if payload is None: - if last_error is not None: - raise last_error - raise ValueError("No JSON body or form data found in request.") - - if isinstance(payload, dict): - return payload or {} - - if isinstance(payload, str): - raise AttributeError("'str' object has no attribute 'get'") - - raise TypeError(f"Unsupported request payload type: {type(payload)!r}") + request._cached_payload = payload + return payload async def get_request_json(): return await _coerce_request_data() @@ -124,16 +128,12 @@ def server_error_response(e): try: msg = repr(e).lower() if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg): - return get_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) + resp = get_json_result(code=RetCode.UNAUTHORIZED, message="Unauthorized") + resp.status_code = RetCode.UNAUTHORIZED + return resp except Exception as ex: logging.warning(f"error checking authorization: {ex}") - if len(e.args) > 1: - try: - serialized_data = serialize_for_json(e.args[1]) - return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data) - except Exception: - return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None) if repr(e).find("index_not_found_exception") >= 0: return get_json_result(code=RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") @@ -168,7 +168,17 @@ def validate_request(*args, **kwargs): def wrapper(func): @wraps(func) async def decorated_function(*_args, **_kwargs): - errs = process_args(await _coerce_request_data()) + exception_types = (AttributeError, TypeError, WerkzeugBadRequest) + if QuartBadRequest is not None: + exception_types = exception_types + (QuartBadRequest,) + if args or kwargs: + try: + input_arguments = await _coerce_request_data() + except exception_types: + input_arguments = {} + else: + input_arguments = await _coerce_request_data() + errs = process_args(input_arguments) if errs: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs) if inspect.iscoroutinefunction(func): diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index 11aa5d4a2..7d2bd31ee 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -318,6 +318,8 @@ class RAGFlow: for data in res["data"]["memory_list"]: result_list.append(Memory(self, data)) return { + "code": res.get("code", 0), + "message": res.get("message"), "memory_list": result_list, "total_count": res["data"]["total_count"] } diff --git a/test/testcases/test_web_api/common.py b/test/testcases/test_web_api/common.py index 6f7487676..3e298faa6 100644 --- a/test/testcases/test_web_api/common.py +++ b/test/testcases/test_web_api/common.py @@ -99,7 +99,7 @@ def batch_create_datasets(auth, num): # DOCUMENT APP -def upload_documents(auth, payload=None, files_path=None): +def upload_documents(auth, payload=None, files_path=None, *, filename_override=None): url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload" if files_path is None: @@ -115,7 +115,8 @@ def upload_documents(auth, payload=None, files_path=None): for fp in files_path: p = Path(fp) f = p.open("rb") - fields.append(("file", (p.name, f))) + filename = filename_override if filename_override is not None else p.name + fields.append(("file", (filename, f))) file_objects.append(f) m = MultipartEncoder(fields=fields) diff --git a/test/testcases/test_web_api/conftest.py b/test/testcases/test_web_api/conftest.py index 18b56a845..f87f2c9f9 100644 --- a/test/testcases/test_web_api/conftest.py +++ b/test/testcases/test_web_api/conftest.py @@ -14,7 +14,8 @@ # limitations under the License. # from time import sleep - +from ragflow_sdk import RAGFlow +from configs import HOST_ADDRESS, VERSION import pytest from common import ( batch_add_chunks, @@ -81,7 +82,9 @@ def generate_test_files(request: FixtureRequest, tmp_path): def ragflow_tmp_dir(request, tmp_path_factory): class_name = request.cls.__name__ return tmp_path_factory.mktemp(class_name) - +@pytest.fixture(scope="session") +def client(token: str) -> RAGFlow: + return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) @pytest.fixture(scope="session") def WebApiAuth(auth): diff --git a/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py b/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py index 62e8efa44..42bd28f09 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py +++ b/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py @@ -265,11 +265,11 @@ class TestChunksRetrieval: @pytest.mark.parametrize( "payload, expected_code, expected_highlight, expected_message", [ - ({"highlight": True}, 0, True, ""), - ({"highlight": "True"}, 0, True, ""), - pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), - pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), - pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + pytest.param({"highlight": True}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")), + pytest.param({"highlight": "True"}, 0, True, "", marks=pytest.mark.skip(reason="highlight not functionnal")), + ({"highlight": False}, 0, False, ""), + ({"highlight": "False"}, 0, False, ""), + ({"highlight": None}, 0, False, "") ], ) def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message): diff --git a/test/testcases/test_web_api/test_document_app/test_upload_documents.py b/test/testcases/test_web_api/test_document_app/test_upload_documents.py index f7880cea5..b006a720b 100644 --- a/test/testcases/test_web_api/test_document_app/test_upload_documents.py +++ b/test/testcases/test_web_api/test_document_app/test_upload_documents.py @@ -17,11 +17,9 @@ import string from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -import requests -from common import DOCUMENT_APP_URL, list_kbs, upload_documents -from configs import DOCUMENT_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN +from common import list_kbs, upload_documents +from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth -from requests_toolbelt import MultipartEncoder from utils.file_utils import create_txt_file @@ -111,17 +109,9 @@ class TestDocumentsUpload: kb_id = add_dataset_func fp = create_txt_file(tmp_path / "ragflow_test.txt") - url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload" - fields = [("file", ("", fp.open("rb"))), ("kb_id", kb_id)] - m = MultipartEncoder(fields=fields) - res = requests.post( - url=url, - headers={"Content-Type": m.content_type}, - auth=WebApiAuth, - data=m, - ) - assert res.json()["code"] == 101, res - assert res.json()["message"] == "No file selected!", res + res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp], filename_override="") + assert res["code"] == 101, res + assert res["message"] == "No file selected!", res @pytest.mark.p2 def test_filename_exceeds_max_length(self, WebApiAuth, add_dataset_func, tmp_path):