From 0b40eb3e90029baf09b3f39f7ce9f79e3a9a75f3 Mon Sep 17 00:00:00 2001 From: Liu An Date: Wed, 2 Jul 2025 09:49:08 +0800 Subject: [PATCH] Test: Add tests for chunk API endpoints (#8616) ### What problem does this PR solve? - Add comprehensive test suite for chunk operations including: - Test files for create, list, retrieve, update, and delete chunks - Authorization tests - Batch operations tests - Update test configurations and common utilities - Validate `important_kwd` and `question_kwd` fields are lists in chunk_app.py - Reorganize imports and clean up duplicate code ### Type of change - [x] Add test cases --- api/apps/chunk_app.py | 32 +- test/testcases/test_http_api/conftest.py | 1 - test/testcases/test_web_api/common.py | 43 ++- test/testcases/test_web_api/conftest.py | 59 +++- .../test_web_api/test_chunk_app/conftest.py | 49 +++ .../test_chunk_app/test_create_chunk.py | 223 +++++++++++++ .../test_chunk_app/test_list_chunks.py | 145 +++++++++ .../test_chunk_app/test_retrieval_chunks.py | 308 ++++++++++++++++++ .../test_chunk_app/test_rm_chunks.py | 161 +++++++++ .../test_chunk_app/test_update_chunk.py | 232 +++++++++++++ 10 files changed, 1226 insertions(+), 27 deletions(-) create mode 100644 test/testcases/test_web_api/test_chunk_app/conftest.py create mode 100644 test/testcases/test_web_api/test_chunk_app/test_create_chunk.py create mode 100644 test/testcases/test_web_api/test_chunk_app/test_list_chunks.py create mode 100644 test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py create mode 100644 test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py create mode 100644 test/testcases/test_web_api/test_chunk_app/test_update_chunk.py diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index c5bdee502..d6603fcf3 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -15,27 +15,25 @@ # import datetime import json +import re +import xxhash from flask import request -from flask_login import login_required, current_user +from flask_login import current_user, login_required -from rag.app.qa import rmPrefix, beAdoc -from rag.app.tag import label_question -from rag.nlp import search, rag_tokenizer -from rag.prompts import keyword_extraction, cross_languages -from rag.settings import PAGERANK_FLD -from rag.utils import rmSpace +from api import settings from api.db import LLMType, ParserType +from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.user_service import UserTenantService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request -from api.db.services.document_service import DocumentService -from api import settings -from api.utils.api_utils import get_json_result -import xxhash -import re - +from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request +from rag.app.qa import beAdoc, rmPrefix +from rag.app.tag import label_question +from rag.nlp import rag_tokenizer, search +from rag.prompts import cross_languages, keyword_extraction +from rag.settings import PAGERANK_FLD +from rag.utils import rmSpace @manager.route('/list', methods=['POST']) # noqa: F821 @@ -129,9 +127,13 @@ def set(): d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) if "important_kwd" in req: + if not isinstance(req["important_kwd"], list): + return get_data_error_result(message="`important_kwd` should be a list") d["important_kwd"] = req["important_kwd"] d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) if "question_kwd" in req: + if not isinstance(req["question_kwd"], list): + return get_data_error_result(message="`question_kwd` should be a list") d["question_kwd"] = req["question_kwd"] d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) if "tag_kwd" in req: @@ -235,6 +237,8 @@ def create(): d["create_timestamp_flt"] = datetime.datetime.now().timestamp() if "tag_feas" in req: d["tag_feas"] = req["tag_feas"] + if "tag_feas" in req: + d["tag_feas"] = req["tag_feas"] try: e, doc = DocumentService.get_by_id(req["doc_id"]) diff --git a/test/testcases/test_http_api/conftest.py b/test/testcases/test_http_api/conftest.py index 983ef8aee..eab05d09b 100644 --- a/test/testcases/test_http_api/conftest.py +++ b/test/testcases/test_http_api/conftest.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from time import sleep import pytest diff --git a/test/testcases/test_web_api/common.py b/test/testcases/test_web_api/common.py index 7181018a4..b9b75c1aa 100644 --- a/test/testcases/test_web_api/common.py +++ b/test/testcases/test_web_api/common.py @@ -24,9 +24,7 @@ HEADERS = {"Content-Type": "application/json"} KB_APP_URL = f"/{VERSION}/kb" DOCUMENT_APP_URL = f"/{VERSION}/document" -# FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents" -# FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks" -# CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks" +CHUNK_API_URL = f"/{VERSION}/chunk" # CHAT_ASSISTANT_API_URL = "/api/v1/chats" # SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions" # SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions" @@ -164,3 +162,42 @@ def bulk_upload_documents(auth, kb_id, num, tmp_path): for document in res["data"]: document_ids.append(document["id"]) return document_ids + + +# CHUNK APP +def add_chunk(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/create", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + +def list_chunks(auth, payload=None, *, headers=HEADERS): + res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/list", headers=headers, auth=auth, json=payload) + return res.json() + + +def get_chunk(auth, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/get", headers=headers, auth=auth, params=params) + return res.json() + + +def update_chunk(auth, payload=None, *, headers=HEADERS): + res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/set", headers=headers, auth=auth, json=payload) + return res.json() + + +def delete_chunks(auth, payload=None, *, headers=HEADERS): + res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/rm", headers=headers, auth=auth, json=payload) + return res.json() + + +def retrieval_chunks(auth, payload=None, *, headers=HEADERS): + res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/retrieval_test", headers=headers, auth=auth, json=payload) + return res.json() + + +def batch_add_chunks(auth, doc_id, num): + chunk_ids = [] + for i in range(num): + res = add_chunk(auth, {"doc_id": doc_id, "content_with_weight": f"chunk test {i}"}) + chunk_ids.append(res["data"]["chunk_id"]) + return chunk_ids diff --git a/test/testcases/test_web_api/conftest.py b/test/testcases/test_web_api/conftest.py index 82fcf982f..ebe0e6c29 100644 --- a/test/testcases/test_web_api/conftest.py +++ b/test/testcases/test_web_api/conftest.py @@ -13,18 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from time import sleep + import pytest from common import ( + batch_add_chunks, batch_create_datasets, + bulk_upload_documents, + delete_chunks, + list_chunks, + list_documents, list_kbs, + parse_documents, rm_kb, ) - -# from configs import HOST_ADDRESS, VERSION from libs.auth import RAGFlowWebApiAuth from pytest import FixtureRequest - -# from ragflow_sdk import RAGFlow +from utils import wait_for from utils.file_utils import ( create_docx_file, create_eml_file, @@ -39,6 +44,15 @@ from utils.file_utils import ( ) +@wait_for(30, 1, "Document parsing timeout") +def condition(_auth, _kb_id): + res = list_documents(_auth, {"kb_id": _kb_id}) + for doc in res["data"]["docs"]: + if doc["run"] != "3": + return False + return True + + @pytest.fixture def generate_test_files(request: FixtureRequest, tmp_path): file_creators = { @@ -73,11 +87,6 @@ def WebApiAuth(auth): return RAGFlowWebApiAuth(auth) -# @pytest.fixture(scope="session") -# def client(token: str) -> RAGFlow: -# return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) - - @pytest.fixture(scope="function") def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth): def cleanup(): @@ -108,3 +117,35 @@ def add_dataset_func(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> request.addfinalizer(cleanup) return batch_create_datasets(WebApiAuth, 1)[0] + + +@pytest.fixture(scope="class") +def add_document(request, WebApiAuth, add_dataset, ragflow_tmp_dir): + # def cleanup(): + # res = list_documents(WebApiAuth, {"kb_id": dataset_id}) + # for doc in res["data"]["docs"]: + # delete_document(WebApiAuth, {"doc_id": doc["id"]}) + + # request.addfinalizer(cleanup) + + dataset_id = add_dataset + return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 1, ragflow_tmp_dir)[0] + + +@pytest.fixture(scope="class") +def add_chunks(request, WebApiAuth, add_document): + def cleanup(): + res = list_chunks(WebApiAuth, {"doc_id": document_id}) + if res["code"] == 0: + chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]] + delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids}) + + request.addfinalizer(cleanup) + + kb_id, document_id = add_document + parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"}) + condition(WebApiAuth, kb_id) + chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4) + # issues/6487 + sleep(1) + return kb_id, document_id, chunk_ids diff --git a/test/testcases/test_web_api/test_chunk_app/conftest.py b/test/testcases/test_web_api/test_chunk_app/conftest.py new file mode 100644 index 000000000..e51a2f09b --- /dev/null +++ b/test/testcases/test_web_api/test_chunk_app/conftest.py @@ -0,0 +1,49 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from time import sleep + +import pytest +from common import batch_add_chunks, delete_chunks, list_chunks, list_documents, parse_documents +from utils import wait_for + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_auth, _kb_id): + res = list_documents(_auth, {"kb_id": _kb_id}) + for doc in res["data"]["docs"]: + if doc["run"] != "3": + return False + return True + + +@pytest.fixture(scope="function") +def add_chunks_func(request, WebApiAuth, add_document): + def cleanup(): + res = list_chunks(WebApiAuth, {"doc_id": document_id}) + chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]] + delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids}) + + request.addfinalizer(cleanup) + + kb_id, document_id = add_document + parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"}) + condition(WebApiAuth, kb_id) + chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4) + # issues/6487 + sleep(1) + return kb_id, document_id, chunk_ids diff --git a/test/testcases/test_web_api/test_chunk_app/test_create_chunk.py b/test/testcases/test_web_api/test_chunk_app/test_create_chunk.py new file mode 100644 index 000000000..c2731b421 --- /dev/null +++ b/test/testcases/test_web_api/test_chunk_app/test_create_chunk.py @@ -0,0 +1,223 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import add_chunk, delete_document, get_chunk, list_chunks +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +def validate_chunk_details(auth, kb_id, doc_id, payload, res): + chunk_id = res["data"]["chunk_id"] + res = get_chunk(auth, {"chunk_id": chunk_id}) + assert res["code"] == 0, res + chunk = res["data"] + assert chunk["doc_id"] == doc_id + assert chunk["kb_id"] == kb_id + assert chunk["content_with_weight"] == payload["content_with_weight"] + if "important_kwd" in payload: + assert chunk["important_kwd"] == payload["important_kwd"] + if "question_kwd" in payload: + expected = [str(q).strip() for q in payload.get("question_kwd", [])] + assert chunk["question_kwd"] == expected + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = add_chunk(invalid_auth) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestAddChunk: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"content_with_weight": None}, 100, """TypeError("unsupported operand type(s) for +: 'NoneType' and 'str'")"""), + ({"content_with_weight": ""}, 0, ""), + pytest.param( + {"content_with_weight": 1}, + 100, + """TypeError("unsupported operand type(s) for +: 'int' and 'str'")""", + marks=pytest.mark.skip, + ), + ({"content_with_weight": "a"}, 0, ""), + ({"content_with_weight": " "}, 0, ""), + ({"content_with_weight": "\n!?。;!?\"'"}, 0, ""), + ], + ) + def test_content(self, WebApiAuth, add_document, payload, expected_code, expected_message): + kb_id, doc_id = add_document + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] == 0: + chunks_count = res["data"]["doc"]["chunk_num"] + else: + chunks_count = 0 + res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) + assert res["code"] == expected_code, res + if expected_code == 0: + validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + assert res["code"] == 0, res + assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"content_with_weight": "chunk test", "important_kwd": ["a", "b", "c"]}, 0, ""), + ({"content_with_weight": "chunk test", "important_kwd": [""]}, 0, ""), + ( + {"content_with_weight": "chunk test", "important_kwd": [1]}, + 100, + "TypeError('sequence item 0: expected str instance, int found')", + ), + ({"content_with_weight": "chunk test", "important_kwd": ["a", "a"]}, 0, ""), + ({"content_with_weight": "chunk test", "important_kwd": "abc"}, 102, "`important_kwd` is required to be a list"), + ({"content_with_weight": "chunk test", "important_kwd": 123}, 102, "`important_kwd` is required to be a list"), + ], + ) + def test_important_keywords(self, WebApiAuth, add_document, payload, expected_code, expected_message): + kb_id, doc_id = add_document + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] == 0: + chunks_count = res["data"]["doc"]["chunk_num"] + else: + chunks_count = 0 + res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) + assert res["code"] == expected_code, res + if expected_code == 0: + validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + assert res["code"] == 0, res + assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"content_with_weight": "chunk test", "question_kwd": ["a", "b", "c"]}, 0, ""), + ({"content_with_weight": "chunk test", "question_kwd": [""]}, 0, ""), + ({"content_with_weight": "chunk test", "question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"content_with_weight": "chunk test", "question_kwd": ["a", "a"]}, 0, ""), + ({"content_with_weight": "chunk test", "question_kwd": "abc"}, 102, "`question_kwd` is required to be a list"), + ({"content_with_weight": "chunk test", "question_kwd": 123}, 102, "`question_kwd` is required to be a list"), + ], + ) + def test_questions(self, WebApiAuth, add_document, payload, expected_code, expected_message): + kb_id, doc_id = add_document + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] == 0: + chunks_count = res["data"]["doc"]["chunk_num"] + else: + chunks_count = 0 + res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) + assert res["code"] == expected_code, res + if expected_code == 0: + validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + assert res["code"] == 0, res + assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "doc_id, expected_code, expected_message", + [ + ("", 102, "Document not found!"), + ("invalid_document_id", 102, "Document not found!"), + ], + ) + def test_invalid_document_id(self, WebApiAuth, add_document, doc_id, expected_code, expected_message): + _, _ = add_document + res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"}) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + @pytest.mark.p3 + def test_repeated_add_chunk(self, WebApiAuth, add_document): + payload = {"content_with_weight": "chunk test"} + kb_id, doc_id = add_document + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] != 0: + assert False, res + chunks_count = res["data"]["doc"]["chunk_num"] + + res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) + assert res["code"] == 0, res + validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] != 0: + assert False, res + assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res + + res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id}) + assert res["code"] == 0, res + validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] != 0: + assert False, res + assert res["data"]["doc"]["chunk_num"] == chunks_count + 2, res + + @pytest.mark.p2 + def test_add_chunk_to_deleted_document(self, WebApiAuth, add_document): + _, doc_id = add_document + delete_document(WebApiAuth, {"doc_id": doc_id}) + res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"}) + assert res["code"] == 102, res + assert res["message"] == "Document not found!", res + + @pytest.mark.skip(reason="issues/6411") + @pytest.mark.p3 + def test_concurrent_add_chunk(self, WebApiAuth, add_document): + count = 50 + _, doc_id = add_document + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] == 0: + chunks_count = res["data"]["doc"]["chunk_num"] + else: + chunks_count = 0 + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + add_chunk, + WebApiAuth, + {"doc_id": doc_id, "content_with_weight": f"chunk test {i}"}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + assert res["code"] == 0, res + assert res["data"]["doc"]["chunk_num"] == chunks_count + count diff --git a/test/testcases/test_web_api/test_chunk_app/test_list_chunks.py b/test/testcases/test_web_api/test_chunk_app/test_list_chunks.py new file mode 100644 index 000000000..dd567e01d --- /dev/null +++ b/test/testcases/test_web_api/test_chunk_app/test_list_chunks.py @@ -0,0 +1,145 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_chunks, list_chunks +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = list_chunks(invalid_auth, {"doc_id": "document_id"}) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestChunksList: + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + pytest.param({"page": None, "size": 2}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip), + pytest.param({"page": 0, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + ({"page": 2, "size": 2}, 0, 2, ""), + ({"page": 3, "size": 2}, 0, 1, ""), + ({"page": "3", "size": 2}, 0, 1, ""), + pytest.param({"page": -1, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + pytest.param({"page": "a", "size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message): + _, doc_id, _ = add_chunks + payload = {"doc_id": doc_id} + if params: + payload.update(params) + res = list_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_code, expected_page_size, expected_message", + [ + ({"size": None}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")"""), + pytest.param({"size": 0}, 0, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")), + pytest.param({"size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")), + ({"size": 1}, 0, 1, ""), + ({"size": 6}, 0, 5, ""), + ({"size": "1"}, 0, 1, ""), + pytest.param({"size": -1}, 0, 5, "", marks=pytest.mark.skip), + pytest.param({"size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page_size(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message): + _, doc_id, _ = add_chunks + payload = {"doc_id": doc_id} + if params: + payload.update(params) + res = list_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"keywords": None}, 5), + ({"keywords": ""}, 5), + ({"keywords": "1"}, 1), + pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")), + ({"keywords": "content"}, 1), + ({"keywords": "unknown"}, 0), + ], + ) + def test_keywords(self, WebApiAuth, add_chunks, params, expected_page_size): + _, doc_id, _ = add_chunks + payload = {"doc_id": doc_id} + if params: + payload.update(params) + res = list_chunks(WebApiAuth, payload) + assert res["code"] == 0, res + assert len(res["data"]["chunks"]) == expected_page_size, res + + @pytest.mark.p3 + def test_invalid_params(self, WebApiAuth, add_chunks): + _, doc_id, _ = add_chunks + payload = {"doc_id": doc_id, "a": "b"} + res = list_chunks(WebApiAuth, payload) + assert res["code"] == 0, res + assert len(res["data"]["chunks"]) == 5, res + + @pytest.mark.p3 + def test_concurrent_list(self, WebApiAuth, add_chunks): + _, doc_id, _ = add_chunks + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(list_chunks, WebApiAuth, {"doc_id": doc_id}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures) + + @pytest.mark.p1 + def test_default(self, WebApiAuth, add_document): + _, doc_id = add_document + + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + chunks_count = res["data"]["doc"]["chunk_num"] + batch_add_chunks(WebApiAuth, doc_id, 31) + # issues/6487 + from time import sleep + + sleep(3) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + assert res["code"] == 0 + assert len(res["data"]["chunks"]) == 30 + assert res["data"]["doc"]["chunk_num"] == chunks_count + 31 diff --git a/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py b/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py new file mode 100644 index 000000000..62e8efa44 --- /dev/null +++ b/test/testcases/test_web_api/test_chunk_app/test_retrieval_chunks.py @@ -0,0 +1,308 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import retrieval_chunks +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = retrieval_chunks(invalid_auth, {"kb_id": "dummy_kb_id", "question": "dummy question"}) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestChunksRetrieval: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"question": "chunk", "kb_id": None}, 0, 4, ""), + ({"question": "chunk", "doc_ids": None}, 101, 0, "required argument are missing: kb_id; "), + ({"question": "chunk", "kb_id": None, "doc_ids": None}, 0, 4, ""), + ({"question": "chunk"}, 101, 0, "required argument are missing: kb_id; "), + ], + ) + def test_basic_scenarios(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, document_id, _ = add_chunks + if "kb_id" in payload: + payload["kb_id"] = [dataset_id] + if "doc_ids" in payload: + payload["doc_ids"] = [document_id] + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + pytest.param( + {"page": None, "size": 2}, + 100, + 0, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": 0, "size": 2}, + 100, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param({"page": 2, "size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")), + ({"page": 3, "size": 2}, 0, 0, ""), + ({"page": "3", "size": 2}, 0, 0, ""), + pytest.param( + {"page": -1, "size": 2}, + 100, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": "a", "size": 2}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: 'a'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "kb_id": [dataset_id]}) + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + pytest.param( + {"size": None}, + 100, + 0, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", + marks=pytest.mark.skip, + ), + # ({"size": 0}, 0, 0, ""), + ({"size": 1}, 0, 1, ""), + ({"size": 5}, 0, 4, ""), + ({"size": "1"}, 0, 1, ""), + # ({"size": -1}, 0, 0, ""), + pytest.param( + {"size": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: 'a'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page_size(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "kb_id": [dataset_id]}) + + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"vector_similarity_weight": 0}, 0, 4, ""), + ({"vector_similarity_weight": 0.5}, 0, 4, ""), + ({"vector_similarity_weight": 10}, 0, 4, ""), + pytest.param( + {"vector_similarity_weight": "a"}, + 100, + 0, + """ValueError("could not convert string to float: 'a'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_vector_similarity_weight(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "kb_id": [dataset_id]}) + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"top_k": 10}, 0, 4, ""), + pytest.param( + {"top_k": 1}, + 0, + 4, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + ), + pytest.param( + {"top_k": 1}, + 0, + 1, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": -1}, + 100, + 4, + "must be greater than 0", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + ), + pytest.param( + {"top_k": -1}, + 100, + 4, + "3014", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": "a"}, + 100, + 0, + """ValueError("invalid literal for int() with base 10: 'a'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_top_k(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "kb_id": [dataset_id]}) + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size, res + else: + assert expected_message in res["message"], res + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""), + pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), + ], + ) + def test_rerank_id(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "kb_id": [dataset_id]}) + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) > 0, res + else: + assert expected_message in res["message"], res + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload, expected_code, expected_page_size, expected_message", + [ + ({"keyword": True}, 0, 5, ""), + ({"keyword": "True"}, 0, 5, ""), + ({"keyword": False}, 0, 5, ""), + ({"keyword": "False"}, 0, 5, ""), + ({"keyword": None}, 0, 5, ""), + ], + ) + def test_keyword(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk test", "kb_id": [dataset_id]}) + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_code == 0: + assert len(res["data"]["chunks"]) == expected_page_size, res + else: + assert res["message"] == expected_message, res + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_highlight, expected_message", + [ + ({"highlight": True}, 0, True, ""), + ({"highlight": "True"}, 0, True, ""), + pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), + ], + ) + def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message): + dataset_id, _, _ = add_chunks + payload.update({"question": "chunk", "kb_id": [dataset_id]}) + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if expected_highlight: + for chunk in res["data"]["chunks"]: + assert "highlight" in chunk, res + else: + for chunk in res["data"]["chunks"]: + assert "highlight" not in chunk, res + + if expected_code != 0: + assert res["message"] == expected_message, res + + @pytest.mark.p3 + def test_invalid_params(self, WebApiAuth, add_chunks): + dataset_id, _, _ = add_chunks + payload = {"question": "chunk", "kb_id": [dataset_id], "a": "b"} + res = retrieval_chunks(WebApiAuth, payload) + assert res["code"] == 0, res + assert len(res["data"]["chunks"]) == 4, res + + @pytest.mark.p3 + def test_concurrent_retrieval(self, WebApiAuth, add_chunks): + dataset_id, _, _ = add_chunks + count = 100 + payload = {"question": "chunk", "kb_id": [dataset_id]} + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(retrieval_chunks, WebApiAuth, payload) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) diff --git a/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py b/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py new file mode 100644 index 000000000..b293daf10 --- /dev/null +++ b/test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py @@ -0,0 +1,161 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_chunks, delete_chunks, list_chunks +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = delete_chunks(invalid_auth, {"doc_id": "document_id", "chunk_ids": ["1"]}) + assert res["code"] == expected_code + assert res["message"] == expected_message + + +class TestChunksDeletion: + @pytest.mark.p3 + @pytest.mark.parametrize( + "doc_id, expected_code, expected_message", + [ + ("", 102, "Document not found!"), + ("invalid_document_id", 102, "Document not found!"), + ], + ) + def test_invalid_document_id(self, WebApiAuth, add_chunks_func, doc_id, expected_code, expected_message): + _, _, chunk_ids = add_chunks_func + res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids}) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"chunk_ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"chunk_ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_delete_partial_invalid_id(self, WebApiAuth, add_chunks_func, payload): + _, doc_id, chunk_ids = add_chunks_func + if callable(payload): + payload = payload(chunk_ids) + payload["doc_id"] = doc_id + res = delete_chunks(WebApiAuth, payload) + assert res["code"] == 0, res + + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + assert res["code"] == 0, res + assert len(res["data"]["chunks"]) == 0, res + assert res["data"]["total"] == 0, res + + @pytest.mark.p3 + def test_repeated_deletion(self, WebApiAuth, add_chunks_func): + _, doc_id, chunk_ids = add_chunks_func + payload = {"chunk_ids": chunk_ids, "doc_id": doc_id} + res = delete_chunks(WebApiAuth, payload) + assert res["code"] == 0, res + + res = delete_chunks(WebApiAuth, payload) + assert res["code"] == 102, res + assert res["message"] == "Index updating failure", res + + @pytest.mark.p3 + def test_duplicate_deletion(self, WebApiAuth, add_chunks_func): + _, doc_id, chunk_ids = add_chunks_func + payload = {"chunk_ids": chunk_ids * 2, "doc_id": doc_id} + res = delete_chunks(WebApiAuth, payload) + assert res["code"] == 0, res + + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + assert res["code"] == 0, res + assert len(res["data"]["chunks"]) == 0, res + assert res["data"]["total"] == 0, res + + @pytest.mark.p3 + def test_concurrent_deletion(self, WebApiAuth, add_document): + count = 100 + _, doc_id = add_document + chunk_ids = batch_add_chunks(WebApiAuth, doc_id, count) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + delete_chunks, + WebApiAuth, + {"doc_id": doc_id, "chunk_ids": chunk_ids[i : i + 1]}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_delete_1k(self, WebApiAuth, add_document): + chunks_num = 1_000 + _, doc_id = add_document + chunk_ids = batch_add_chunks(WebApiAuth, doc_id, chunks_num) + + from time import sleep + + sleep(1) + + res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids}) + assert res["code"] == 0 + + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] != 0: + assert False, res + assert len(res["data"]["chunks"]) == 0, res + assert res["data"]["total"] == 0, res + + @pytest.mark.parametrize( + "payload, expected_code, expected_message, remaining", + [ + pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", 5, marks=pytest.mark.skip), + pytest.param({"chunk_ids": ["invalid_id"]}, 102, "Index updating failure", 4, marks=pytest.mark.p3), + pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")), + pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 3, marks=pytest.mark.p3), + pytest.param(lambda r: {"chunk_ids": r}, 0, "", 0, marks=pytest.mark.p1), + pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3), + ], + ) + def test_basic_scenarios(self, WebApiAuth, add_chunks_func, payload, expected_code, expected_message, remaining): + _, doc_id, chunk_ids = add_chunks_func + if callable(payload): + payload = payload(chunk_ids) + payload["doc_id"] = doc_id + res = delete_chunks(WebApiAuth, payload) + assert res["code"] == expected_code, res + if res["code"] != 0: + assert res["message"] == expected_message, res + + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + if res["code"] != 0: + assert False, res + assert len(res["data"]["chunks"]) == remaining, res + assert res["data"]["total"] == remaining, res diff --git a/test/testcases/test_web_api/test_chunk_app/test_update_chunk.py b/test/testcases/test_web_api/test_chunk_app/test_update_chunk.py new file mode 100644 index 000000000..b1fcd567a --- /dev/null +++ b/test/testcases/test_web_api/test_chunk_app/test_update_chunk.py @@ -0,0 +1,232 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from random import randint +from time import sleep + +import pytest +from common import delete_document, list_chunks, update_chunk +from configs import INVALID_API_TOKEN +from libs.auth import RAGFlowWebApiAuth + + +@pytest.mark.p1 +class TestAuthorization: + @pytest.mark.parametrize( + "invalid_auth, expected_code, expected_message", + [ + (None, 401, ""), + (RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, ""), + ], + ) + def test_invalid_auth(self, invalid_auth, expected_code, expected_message): + res = update_chunk(invalid_auth, {"doc_id": "doc_id", "chunk_id": "chunk_id", "content_with_weight": "test"}) + assert res["code"] == expected_code, res + assert res["message"] == expected_message, res + + +class TestUpdateChunk: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"content_with_weight": None}, 100, "TypeError('expected string or bytes-like object')"), + ({"content_with_weight": ""}, 0, ""), + ({"content_with_weight": 1}, 100, "TypeError('expected string or bytes-like object')"), + ({"content_with_weight": "update chunk"}, 0, ""), + ({"content_with_weight": " "}, 0, ""), + ({"content_with_weight": "\n!?。;!?\"'"}, 0, ""), + ], + ) + def test_content(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): + _, doc_id, chunk_ids = add_chunks + chunk_id = chunk_ids[0] + update_payload = {"doc_id": doc_id, "chunk_id": chunk_id} + if payload: + update_payload.update(payload) + res = update_chunk(WebApiAuth, update_payload) + assert res["code"] == expected_code, res + if expected_code != 0: + assert res["message"] == expected_message, res + else: + sleep(1) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + for chunk in res["data"]["chunks"]: + if chunk["chunk_id"] == chunk_id: + assert chunk["content_with_weight"] == payload["content_with_weight"] + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"important_kwd": ["a", "b", "c"]}, 0, ""), + ({"important_kwd": [""]}, 0, ""), + ({"important_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"important_kwd": ["a", "a"]}, 0, ""), + ({"important_kwd": "abc"}, 102, "`important_kwd` should be a list"), + ({"important_kwd": 123}, 102, "`important_kwd` should be a list"), + ], + ) + def test_important_keywords(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): + _, doc_id, chunk_ids = add_chunks + chunk_id = chunk_ids[0] + update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required + if payload: + update_payload.update(payload) + res = update_chunk(WebApiAuth, update_payload) + assert res["code"] == expected_code, res + if expected_code != 0: + assert res["message"] == expected_message, res + else: + sleep(1) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + for chunk in res["data"]["chunks"]: + if chunk["chunk_id"] == chunk_id: + assert chunk["important_kwd"] == payload["important_kwd"] + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"question_kwd": ["a", "b", "c"]}, 0, ""), + ({"question_kwd": [""]}, 0, ""), + ({"question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"question_kwd": ["a", "a"]}, 0, ""), + ({"question_kwd": "abc"}, 102, "`question_kwd` should be a list"), + ({"question_kwd": 123}, 102, "`question_kwd` should be a list"), + ], + ) + def test_questions(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): + _, doc_id, chunk_ids = add_chunks + chunk_id = chunk_ids[0] + update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required + if payload: + update_payload.update(payload) + + res = update_chunk(WebApiAuth, update_payload) + assert res["code"] == expected_code, res + if expected_code != 0: + assert res["message"] == expected_message, res + else: + sleep(1) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + for chunk in res["data"]["chunks"]: + if chunk["chunk_id"] == chunk_id: + assert chunk["question_kwd"] == payload["question_kwd"] + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"available_int": 1}, 0, ""), + ({"available_int": 0}, 0, ""), + ], + ) + def test_available(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): + _, doc_id, chunk_ids = add_chunks + chunk_id = chunk_ids[0] + update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} + if payload: + update_payload.update(payload) + + res = update_chunk(WebApiAuth, update_payload) + assert res["code"] == expected_code, res + if expected_code != 0: + assert res["message"] == expected_message, res + else: + sleep(1) + res = list_chunks(WebApiAuth, {"doc_id": doc_id}) + for chunk in res["data"]["chunks"]: + if chunk["chunk_id"] == chunk_id: + assert chunk["available_int"] == payload["available_int"] + + @pytest.mark.p3 + @pytest.mark.parametrize( + "doc_id_param, expected_code, expected_message", + [ + ("", 102, "Tenant not found!"), + ("invalid_doc_id", 102, "Tenant not found!"), + ], + ) + def test_invalid_document_id_for_update(self, WebApiAuth, add_chunks, doc_id_param, expected_code, expected_message): + _, _, chunk_ids = add_chunks + chunk_id = chunk_ids[0] + + payload = {"doc_id": doc_id_param, "chunk_id": chunk_id, "content_with_weight": "test content"} + res = update_chunk(WebApiAuth, payload) + assert res["code"] == expected_code + assert expected_message in res["message"] + + @pytest.mark.p3 + def test_repeated_update_chunk(self, WebApiAuth, add_chunks): + _, doc_id, chunk_ids = add_chunks + payload1 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 1"} + res = update_chunk(WebApiAuth, payload1) + assert res["code"] == 0 + + payload2 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 2"} + res = update_chunk(WebApiAuth, payload2) + assert res["code"] == 0 + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_code, expected_message", + [ + ({"unknown_key": "unknown_value"}, 0, ""), + ({}, 0, ""), + pytest.param(None, 100, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip), + ], + ) + def test_invalid_params(self, WebApiAuth, add_chunks, payload, expected_code, expected_message): + _, doc_id, chunk_ids = add_chunks + chunk_id = chunk_ids[0] + update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} + if payload is not None: + update_payload.update(payload) + + res = update_chunk(WebApiAuth, update_payload) + assert res["code"] == expected_code, res + if expected_code != 0: + assert res["message"] == expected_message, res + + @pytest.mark.p3 + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") + def test_concurrent_update_chunk(self, WebApiAuth, add_chunks): + count = 50 + _, doc_id, chunk_ids = add_chunks + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + update_chunk, + WebApiAuth, + {"doc_id": doc_id, "chunk_id": chunk_ids[randint(0, 3)], "content_with_weight": f"update chunk test {i}"}, + ) + for i in range(count) + ] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) + + @pytest.mark.p3 + def test_update_chunk_to_deleted_document(self, WebApiAuth, add_chunks): + _, doc_id, chunk_ids = add_chunks + delete_document(WebApiAuth, {"doc_id": doc_id}) + payload = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "test content"} + res = update_chunk(WebApiAuth, payload) + assert res["code"] == 102, res + assert res["message"] == "Tenant not found!", res