From 5825a24d26c86316f8c87b71a83bf42f1fd5cfbe Mon Sep 17 00:00:00 2001 From: Liu An Date: Fri, 6 Jun 2025 19:43:14 +0800 Subject: [PATCH] Test: Refactor test concurrency handling and add SDK chunk management tests (#8112) ### What problem does this PR solve? - Improve concurrent test cases by using as_completed for better reliability - Rename variables for clarity (chunk_num -> count) - Add new SDK API test suite for chunk management operations - Update HTTP API tests with consistent concurrency patterns ### Type of change - [x] Add test cases - [x] Refactoring --- .../test_add_chunk.py | 13 +- .../test_delete_chunks.py | 13 +- .../test_list_chunks.py | 12 +- .../test_retrieval_chunks.py | 11 +- .../test_update_chunk.py | 11 +- .../test_create_dataset.py | 2 +- .../test_delete_datasets.py | 2 +- .../test_list_datasets.py | 2 +- .../test_update_dataset.py | 2 +- .../test_delete_documents.py | 3 +- .../test_list_documents.py | 2 +- .../test_parse_documents.py | 2 +- .../test_upload_documents.py | 2 +- test/testcases/test_sdk_api/common.py | 11 +- test/testcases/test_sdk_api/conftest.py | 43 ++- .../conftest.py | 49 ++++ .../test_add_chunk.py | 160 +++++++++++ .../test_delete_chunks.py | 113 ++++++++ .../test_list_chunks.py | 140 ++++++++++ .../test_retrieval_chunks.py | 254 ++++++++++++++++++ .../test_update_chunk.py | 154 +++++++++++ 21 files changed, 946 insertions(+), 55 deletions(-) create mode 100644 test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py create mode 100644 test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_add_chunk.py create mode 100644 test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py create mode 100644 test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py create mode 100644 test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py create mode 100644 test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py index 917dc8357..ab1bfac0b 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed import pytest from common import INVALID_API_TOKEN, add_chunk, delete_documents, list_chunks @@ -224,7 +224,7 @@ class TestAddChunk: @pytest.mark.skip(reason="issues/6411") def test_concurrent_add_chunk(self, api_key, add_document): - chunk_num = 50 + count = 50 dataset_id, document_id = add_document res = list_chunks(api_key, dataset_id, document_id) if res["code"] != 0: @@ -240,11 +240,12 @@ class TestAddChunk: document_id, {"content": f"chunk test {i}"}, ) - for i in range(chunk_num) + for i in range(count) ] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) res = list_chunks(api_key, dataset_id, document_id) if res["code"] != 0: assert False, res - assert res["data"]["doc"]["chunk_count"] == chunks_count + chunk_num + assert res["data"]["doc"]["chunk_count"] == chunks_count + count diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py index 4475c9660..4813f9f8b 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed import pytest from common import INVALID_API_TOKEN, batch_add_chunks, delete_chunks, list_chunks @@ -121,9 +121,9 @@ class TestChunksDeletion: @pytest.mark.p3 def test_concurrent_deletion(self, api_key, add_document): - chunks_num = 100 + count = 100 dataset_id, document_id = add_document - chunk_ids = batch_add_chunks(api_key, dataset_id, document_id, chunks_num) + chunk_ids = batch_add_chunks(api_key, dataset_id, document_id, count) with ThreadPoolExecutor(max_workers=5) as executor: futures = [ @@ -134,10 +134,11 @@ class TestChunksDeletion: document_id, {"chunk_ids": chunk_ids[i : i + 1]}, ) - for i in range(chunks_num) + for i in range(count) ] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) @pytest.mark.p3 def test_delete_1k(self, api_key, add_document): diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py index 45fe5377a..6861ebaf3 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py @@ -14,7 +14,7 @@ # limitations under the License. # import os -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed import pytest from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks @@ -149,12 +149,12 @@ class TestChunksList: @pytest.mark.p3 def test_concurrent_list(self, api_key, add_chunks): dataset_id, document_id, _ = add_chunks - + count = 100 with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(list_chunks, api_key, dataset_id, document_id) for i in range(100)] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) - assert all(len(r["data"]["chunks"]) == 5 for r in responses) + futures = [executor.submit(list_chunks, api_key, dataset_id, document_id) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures) @pytest.mark.p1 def test_default(self, api_key, add_document): diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py index 0c54d025e..6a411aab2 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -14,6 +14,7 @@ # limitations under the License. # import os +from concurrent.futures import ThreadPoolExecutor, as_completed import pytest from common import ( @@ -302,12 +303,12 @@ class TestChunksRetrieval: @pytest.mark.p3 def test_concurrent_retrieval(self, api_key, add_chunks): - from concurrent.futures import ThreadPoolExecutor - dataset_id, _, _ = add_chunks + count = 100 payload = {"question": "chunk", "dataset_ids": [dataset_id]} with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(retrieval_chunks, api_key, payload) for i in range(100)] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) + futures = [executor.submit(retrieval_chunks, api_key, payload) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) diff --git a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py index d70f0925a..dacb7bcbe 100644 --- a/test/testcases/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py +++ b/test/testcases/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py @@ -14,7 +14,7 @@ # limitations under the License. # import os -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from random import randint import pytest @@ -219,7 +219,7 @@ class TestUpdatedChunk: @pytest.mark.p3 @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") def test_concurrent_update_chunk(self, api_key, add_chunks): - chunk_num = 50 + count = 50 dataset_id, document_id, chunk_ids = add_chunks with ThreadPoolExecutor(max_workers=5) as executor: @@ -232,10 +232,11 @@ class TestUpdatedChunk: chunk_ids[randint(0, 3)], {"content": f"update chunk test {i}"}, ) - for i in range(chunk_num) + for i in range(count) ] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(future.result()["code"] == 0 for future in futures) @pytest.mark.p3 def test_update_chunk_to_deleted_document(self, api_key, add_chunks): diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py index e4ac174a6..4442912ea 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -85,7 +85,7 @@ class TestCapability: futures = [executor.submit(create_dataset, api_key, {"name": f"dataset_{i}"}) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses - assert all(futures.result()["code"] == 0 for futures in futures) + assert all(future.result()["code"] == 0 for future in futures) @pytest.mark.usefixtures("clear_datasets") diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_delete_datasets.py b/test/testcases/test_http_api/test_dataset_mangement/test_delete_datasets.py index d79a85ba4..e822ee480 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_delete_datasets.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_delete_datasets.py @@ -93,7 +93,7 @@ class TestCapability: futures = [executor.submit(delete_datasets, api_key, {"ids": ids[i : i + 1]}) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses - assert all(futures.result()["code"] == 0 for futures in futures) + assert all(future.result()["code"] == 0 for future in futures) class TestDatasetsDelete: diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_list_datasets.py b/test/testcases/test_http_api/test_dataset_mangement/test_list_datasets.py index 0af4c004c..a96ce7762 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_list_datasets.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_list_datasets.py @@ -49,7 +49,7 @@ class TestCapability: futures = [executor.submit(list_datasets, api_key) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses - assert all(futures.result()["code"] == 0 for futures in futures) + assert all(future.result()["code"] == 0 for future in futures) @pytest.mark.usefixtures("add_datasets") diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py index afc2cfd63..29e4355ec 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py @@ -95,7 +95,7 @@ class TestCapability: futures = [executor.submit(update_dataset, api_key, dataset_id, {"name": f"dataset_{i}"}) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses - assert all(futures.result()["code"] == 0 for futures in futures) + assert all(future.result()["code"] == 0 for future in futures) class TestDatasetUpdate: diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py index b44534aa5..04c0b97d4 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_delete_documents.py @@ -15,7 +15,6 @@ # from concurrent.futures import ThreadPoolExecutor, as_completed - import pytest from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documents, list_documents from libs.auth import RAGFlowHttpApiAuth @@ -165,7 +164,7 @@ def test_concurrent_deletion(api_key, add_dataset, tmp_path): ] responses = list(as_completed(futures)) assert len(responses) == count, responses - assert all(futures.result()["code"] == 0 for futures in futures) + assert all(future.result()["code"] == 0 for future in futures) @pytest.mark.p3 diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py index d4a6d6406..145fd839a 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_list_documents.py @@ -348,7 +348,7 @@ class TestDocumentsList: futures = [executor.submit(list_documents, api_key, dataset_id) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses - assert all(futures.result()["code"] == 0 for futures in futures) + assert all(future.result()["code"] == 0 for future in futures) @pytest.mark.p3 def test_invalid_params(self, api_key, add_documents): diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py index f6bc9b768..489b315a9 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py @@ -211,7 +211,7 @@ def test_concurrent_parse(api_key, add_dataset_func, tmp_path): ] responses = list(as_completed(futures)) assert len(responses) == count, responses - assert all(futures.result()["code"] == 0 for futures in futures) + assert all(future.result()["code"] == 0 for future in futures) condition(api_key, dataset_id, count) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py index 801b0b29b..b149fa7fe 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_upload_documents.py @@ -213,7 +213,7 @@ class TestDocumentsUpload: futures = [executor.submit(upload_documents, api_key, dataset_id, fps[i : i + 1]) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses - assert all(futures.result()["code"] == 0 for futures in futures) + assert all(future.result()["code"] == 0 for future in futures) res = list_datasets(api_key, {"id": dataset_id}) assert res["data"][0]["document_count"] == count diff --git a/test/testcases/test_sdk_api/common.py b/test/testcases/test_sdk_api/common.py index 4a1092f97..65fac9363 100644 --- a/test/testcases/test_sdk_api/common.py +++ b/test/testcases/test_sdk_api/common.py @@ -22,11 +22,7 @@ from utils.file_utils import create_txt_file # DATASET MANAGEMENT def batch_create_datasets(client: RAGFlow, num: int) -> list[DataSet]: - datasets = [] - for i in range(num): - dataset = client.create_dataset(name=f"dataset_{i}") - datasets.append(dataset) - return datasets + return [client.create_dataset(name=f"dataset_{i}") for i in range(num)] # FILE MANAGEMENT WITHIN DATASET @@ -39,3 +35,8 @@ def bulk_upload_documents(dataset: DataSet, num: int, tmp_path: Path) -> list[Do document_infos.append({"display_name": fp.name, "blob": blob}) return dataset.upload_documents(document_infos) + + +# CHUNK MANAGEMENT WITHIN DATASET +def batch_add_chunks(document: Document, num: int): + return [document.add_chunk(content=f"chunk test {i}") for i in range(num)] diff --git a/test/testcases/test_sdk_api/conftest.py b/test/testcases/test_sdk_api/conftest.py index e4c07bf67..215228c9a 100644 --- a/test/testcases/test_sdk_api/conftest.py +++ b/test/testcases/test_sdk_api/conftest.py @@ -23,7 +23,7 @@ from common import ( ) from configs import HOST_ADDRESS, VERSION from pytest import FixtureRequest -from ragflow_sdk import DataSet, RAGFlow +from ragflow_sdk import Chunk, DataSet, Document, RAGFlow from utils import wait_for from utils.file_utils import ( create_docx_file, @@ -41,7 +41,7 @@ from utils.file_utils import ( @wait_for(30, 1, "Document parsing timeout") def condition(_dataset: DataSet): - documents = DataSet.list_documents(page_size=1000) + documents = _dataset.list_documents(page_size=1000) for document in documents: if document.run != "DONE": return False @@ -49,7 +49,7 @@ def condition(_dataset: DataSet): @pytest.fixture -def generate_test_files(request, tmp_path): +def generate_test_files(request: FixtureRequest, tmp_path: Path): file_creators = { "docx": (tmp_path / "ragflow_test.docx", create_docx_file), "excel": (tmp_path / "ragflow_test.xlsx", create_excel_file), @@ -72,13 +72,13 @@ def generate_test_files(request, tmp_path): @pytest.fixture(scope="class") -def ragflow_tmp_dir(request, tmp_path_factory) -> Path: +def ragflow_tmp_dir(request: FixtureRequest, tmp_path_factory: Path) -> Path: class_name = request.cls.__name__ return tmp_path_factory.mktemp(class_name) @pytest.fixture(scope="session") -def client(token) -> RAGFlow: +def client(token: str) -> RAGFlow: return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION) @@ -96,9 +96,7 @@ def add_dataset(request: FixtureRequest, client: RAGFlow): client.delete_datasets(ids=None) request.addfinalizer(cleanup) - - dataset_ids = batch_create_datasets(client, 1) - return dataset_ids[0] + return batch_create_datasets(client, 1)[0] @pytest.fixture(scope="function") @@ -111,12 +109,31 @@ def add_dataset_func(request: FixtureRequest, client: RAGFlow) -> DataSet: @pytest.fixture(scope="class") -def add_document(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp_dir): - dataset = add_dataset - documents = bulk_upload_documents(dataset, 1, ragflow_tmp_dir) +def add_document(add_dataset: DataSet, ragflow_tmp_dir: Path) -> tuple[DataSet, Document]: + return add_dataset, bulk_upload_documents(add_dataset, 1, ragflow_tmp_dir)[0] + + +@pytest.fixture(scope="class") +def add_chunks(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: + dataset, document = add_document + dataset.async_parse_documents([document.id]) + condition(dataset) + + chunks = [] + for i in range(4): + chunk = document.add_chunk(content=f"chunk test {i}") + chunks.append(chunk) + + # issues/6487 + from time import sleep + + sleep(1) def cleanup(): - dataset.delete_documents(ids=None) + try: + document.delete_chunks(ids=[]) + except Exception: + pass request.addfinalizer(cleanup) - return dataset, documents[0] + return dataset, document, chunks diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py new file mode 100644 index 000000000..627d89e5a --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/conftest.py @@ -0,0 +1,49 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from pytest import FixtureRequest +from ragflow_sdk import Chunk, DataSet, Document +from utils import wait_for + + +@wait_for(30, 1, "Document parsing timeout") +def condition(_dataset: DataSet): + documents = _dataset.list_documents(page_size=1000) + for document in documents: + if document.run != "DONE": + return False + return True + + +@pytest.fixture(scope="function") +def add_chunks_func(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]: + dataset, document = add_document + dataset.async_parse_documents([document.id]) + condition(dataset) + chunks = [document.add_chunk(content=f"chunk test {i}") for i in range(4)] + + # issues/6487 + from time import sleep + + sleep(1) + + def cleanup(): + document.delete_chunks(ids=[]) + + request.addfinalizer(cleanup) + return dataset, document, chunks diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_add_chunk.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_add_chunk.py new file mode 100644 index 000000000..5d1638db2 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_add_chunk.py @@ -0,0 +1,160 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed +from time import sleep + +import pytest +from ragflow_sdk import Chunk + + +def validate_chunk_details(dataset_id: str, document_id: str, payload: dict, chunk: Chunk): + assert chunk.dataset_id == dataset_id + assert chunk.document_id == document_id + assert chunk.content == payload["content"] + if "important_keywords" in payload: + assert chunk.important_keywords == payload["important_keywords"] + if "questions" in payload: + assert chunk.questions == [str(q).strip() for q in payload.get("questions", []) if str(q).strip()] + + +class TestAddChunk: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"content": None}, "not instance of"), + ({"content": ""}, "`content` is required"), + ({"content": 1}, "not instance of"), + ({"content": "a"}, ""), + ({"content": " "}, "`content` is required"), + ({"content": "\n!?。;!?\"'"}, ""), + ], + ) + def test_content(self, add_document, payload, expected_message): + dataset, document = add_document + chunks_count = len(document.list_chunks()) + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.add_chunk(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk) + + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"content": "chunk test important_keywords 1", "important_keywords": ["a", "b", "c"]}, ""), + ({"content": "chunk test important_keywords 2", "important_keywords": [""]}, ""), + ({"content": "chunk test important_keywords 3", "important_keywords": [1]}, "not instance of"), + ({"content": "chunk test important_keywords 4", "important_keywords": ["a", "a"]}, ""), + ({"content": "chunk test important_keywords 5", "important_keywords": "abc"}, "not instance of"), + ({"content": "chunk test important_keywords 6", "important_keywords": 123}, "not instance of"), + ], + ) + def test_important_keywords(self, add_document, payload, expected_message): + dataset, document = add_document + chunks_count = len(document.list_chunks()) + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.add_chunk(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk) + + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"content": "chunk test test_questions 1", "questions": ["a", "b", "c"]}, ""), + ({"content": "chunk test test_questions 2", "questions": [""]}, ""), + ({"content": "chunk test test_questions 3", "questions": [1]}, "not instance of"), + ({"content": "chunk test test_questions 4", "questions": ["a", "a"]}, ""), + ({"content": "chunk test test_questions 5", "questions": "abc"}, "not instance of"), + ({"content": "chunk test test_questions 6", "questions": 123}, "not instance of"), + ], + ) + def test_questions(self, add_document, payload, expected_message): + dataset, document = add_document + chunks_count = len(document.list_chunks()) + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.add_chunk(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk) + + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + @pytest.mark.p3 + def test_repeated_add_chunk(self, add_document): + payload = {"content": "chunk test repeated_add_chunk"} + dataset, document = add_document + chunks_count = len(document.list_chunks()) + + chunk1 = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk1) + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + chunk2 = document.add_chunk(**payload) + validate_chunk_details(dataset.id, document.id, payload, chunk2) + sleep(1) + chunks = document.list_chunks() + assert len(chunks) == chunks_count + 1, str(chunks) + + @pytest.mark.p2 + def test_add_chunk_to_deleted_document(self, add_document): + dataset, document = add_document + dataset.delete_documents(ids=[document.id]) + + with pytest.raises(Exception) as excinfo: + document.add_chunk(content="chunk test") + assert f"You don't own the document {document.id}" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.skip(reason="issues/6411") + @pytest.mark.p3 + def test_concurrent_add_chunk(self, add_document): + count = 50 + _, document = add_document + initial_chunk_count = len(document.list_chunks()) + + def add_chunk_task(i): + return document.add_chunk(content=f"chunk test concurrent {i}") + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(add_chunk_task, i) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + sleep(5) + assert len(document.list_chunks(page_size=100)) == initial_chunk_count + count diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py new file mode 100644 index 000000000..25aac7b88 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_delete_chunks.py @@ -0,0 +1,113 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_chunks + + +class TestChunksDeletion: + @pytest.mark.parametrize( + "payload", + [ + pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1), + pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3), + ], + ) + def test_delete_partial_invalid_id(self, add_chunks_func, payload): + _, document, chunks = add_chunks_func + chunk_ids = [chunk.id for chunk in chunks] + payload = payload(chunk_ids) + + with pytest.raises(Exception) as excinfo: + document.delete_chunks(**payload) + assert "rm_chunk deleted chunks" in str(excinfo.value), str(excinfo.value) + + remaining_chunks = document.list_chunks() + assert len(remaining_chunks) == 1, str(remaining_chunks) + + @pytest.mark.p3 + def test_repeated_deletion(self, add_chunks_func): + _, document, chunks = add_chunks_func + chunk_ids = [chunk.id for chunk in chunks] + document.delete_chunks(ids=chunk_ids) + + with pytest.raises(Exception) as excinfo: + document.delete_chunks(ids=chunk_ids) + assert "rm_chunk deleted chunks 0, expect" in str(excinfo.value), str(excinfo.value) + + @pytest.mark.p3 + def test_duplicate_deletion(self, add_chunks_func): + _, document, chunks = add_chunks_func + chunk_ids = [chunk.id for chunk in chunks] + document.delete_chunks(ids=chunk_ids * 2) + remaining_chunks = document.list_chunks() + assert len(remaining_chunks) == 1, str(remaining_chunks) + + @pytest.mark.p3 + def test_concurrent_deletion(self, add_document): + count = 100 + _, document = add_document + chunks = batch_add_chunks(document, count) + chunk_ids = [chunk.id for chunk in chunks] + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(document.delete_chunks, ids=[chunk_id]) for chunk_id in chunk_ids] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + @pytest.mark.p3 + def test_delete_1k(self, add_document): + count = 1_000 + _, document = add_document + chunks = batch_add_chunks(document, count) + chunk_ids = [chunk.id for chunk in chunks] + + from time import sleep + + sleep(1) + + document.delete_chunks(ids=chunk_ids) + remaining_chunks = document.list_chunks() + assert len(remaining_chunks) == 0, str(remaining_chunks) + + @pytest.mark.parametrize( + "payload, expected_message, remaining", + [ + pytest.param(None, "TypeError", 5, marks=pytest.mark.skip), + pytest.param({"ids": ["invalid_id"]}, "rm_chunk deleted chunks 0, expect 1", 5, marks=pytest.mark.p3), + pytest.param("not json", "UnboundLocalError", 5, marks=pytest.mark.skip(reason="pull/6376")), + pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3), + pytest.param(lambda r: {"ids": r}, "", 1, marks=pytest.mark.p1), + pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3), + ], + ) + def test_basic_scenarios(self, add_chunks_func, payload, expected_message, remaining): + _, document, chunks = add_chunks_func + chunk_ids = [chunk.id for chunk in chunks] + if callable(payload): + payload = payload(chunk_ids) + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.delete_chunks(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + document.delete_chunks(**payload) + + remaining_chunks = document.list_chunks() + assert len(remaining_chunks) == remaining, str(remaining_chunks) diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py new file mode 100644 index 000000000..76f9da5e0 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_list_chunks.py @@ -0,0 +1,140 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +from common import batch_add_chunks + + +class TestChunksList: + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page": None, "page_size": 2}, 2, ""), + pytest.param({"page": 0, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + ({"page": 2, "page_size": 2}, 2, ""), + ({"page": 3, "page_size": 2}, 1, ""), + ({"page": "3", "page_size": 2}, 1, ""), + pytest.param({"page": -1, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip), + pytest.param({"page": "a", "page_size": 2}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page(self, add_chunks, params, expected_page_size, expected_message): + _, document, _ = add_chunks + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.list_chunks(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = document.list_chunks(**params) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "params, expected_page_size, expected_message", + [ + ({"page_size": None}, 5, ""), + pytest.param({"page_size": 0}, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")), + pytest.param({"page_size": 0}, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")), + ({"page_size": 1}, 1, ""), + ({"page_size": 6}, 5, ""), + ({"page_size": "1"}, 1, ""), + pytest.param({"page_size": -1}, 5, "", marks=pytest.mark.skip), + pytest.param({"page_size": "a"}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip), + ], + ) + def test_page_size(self, add_chunks, params, expected_page_size, expected_message): + _, document, _ = add_chunks + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.list_chunks(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = document.list_chunks(**params) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "params, expected_page_size", + [ + ({"keywords": None}, 5), + ({"keywords": ""}, 5), + ({"keywords": "1"}, 1), + pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")), + ({"keywords": "ragflow"}, 1), + ({"keywords": "unknown"}, 0), + ], + ) + def test_keywords(self, add_chunks, params, expected_page_size): + _, document, _ = add_chunks + chunks = document.list_chunks(**params) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p1 + @pytest.mark.parametrize( + "chunk_id, expected_page_size, expected_message", + [ + (None, 5, ""), + ("", 5, ""), + pytest.param(lambda r: r[0], 1, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6499")), + pytest.param("unknown", 0, """AttributeError("\'NoneType\' object has no attribute \'keys\'")""", marks=pytest.mark.skip), + ], + ) + def test_id(self, add_chunks, chunk_id, expected_page_size, expected_message): + _, document, chunks = add_chunks + chunk_ids = [chunk.id for chunk in chunks] + if callable(chunk_id): + params = {"id": chunk_id(chunk_ids)} + else: + params = {"id": chunk_id} + + if expected_message: + with pytest.raises(Exception) as excinfo: + document.list_chunks(**params) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = document.list_chunks(**params) + if params["id"] in [None, ""]: + assert len(chunks) == expected_page_size, str(chunks) + else: + assert chunks[0].id == params["id"], str(chunks) + + @pytest.mark.p3 + def test_concurrent_list(self, add_chunks): + _, document, _ = add_chunks + count = 100 + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(document.list_chunks) for _ in range(count)] + + responses = list(as_completed(futures)) + assert len(responses) == count, responses + assert all(len(future.result()) == 5 for future in futures) + + @pytest.mark.p1 + def test_default(self, add_document): + _, document = add_document + batch_add_chunks(document, 31) + + from time import sleep + + sleep(3) + + chunks = document.list_chunks() + assert len(chunks) == 30, str(chunks) diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py new file mode 100644 index 000000000..e1b3fa7f0 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -0,0 +1,254 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + + +class TestChunksRetrieval: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + ({"question": "chunk", "dataset_ids": None}, 4, ""), + ({"question": "chunk", "document_ids": None}, 0, "missing 1 required positional argument"), + ({"question": "chunk", "dataset_ids": None, "document_ids": None}, 4, ""), + ({"question": "chunk"}, 0, "missing 1 required positional argument"), + ], + ) + def test_basic_scenarios(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, document, _ = add_chunks + if "dataset_ids" in payload: + payload["dataset_ids"] = [dataset.id] + if "document_ids" in payload: + payload["document_ids"] = [document.id] + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + pytest.param( + {"page": None, "page_size": 2}, + 2, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": 0, "page_size": 2}, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param({"page": 2, "page_size": 2}, 2, "", marks=pytest.mark.skip(reason="issues/6646")), + ({"page": 3, "page_size": 2}, 0, ""), + ({"page": "3", "page_size": 2}, 0, ""), + pytest.param( + {"page": -1, "page_size": 2}, + 0, + "ValueError('Search does not support negative slicing.')", + marks=pytest.mark.skip, + ), + pytest.param( + {"page": "a", "page_size": 2}, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + pytest.param( + {"page_size": None}, + 0, + """TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""", + marks=pytest.mark.skip, + ), + ({"page_size": 1}, 1, ""), + ({"page_size": 5}, 4, ""), + ({"page_size": "1"}, 1, ""), + pytest.param( + {"page_size": "a"}, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_page_size(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p3 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + ({"vector_similarity_weight": 0}, 4, ""), + ({"vector_similarity_weight": 0.5}, 4, ""), + ({"vector_similarity_weight": 10}, 4, ""), + pytest.param( + {"vector_similarity_weight": "a"}, + 0, + """ValueError("could not convert string to float: 'a'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_vector_similarity_weight(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + ({"top_k": 10}, 4, ""), + pytest.param( + {"top_k": 1}, + 4, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + ), + pytest.param( + {"top_k": 1}, + 1, + "", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": -1}, + 4, + "must be greater than 0", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"), + ), + pytest.param( + {"top_k": -1}, + 4, + "3014", + marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"), + ), + pytest.param( + {"top_k": "a"}, + 0, + """ValueError("invalid literal for int() with base 10: \'a\'")""", + marks=pytest.mark.skip, + ), + ], + ) + def test_top_k(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, ""), + pytest.param({"rerank_id": "unknown"}, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), + ], + ) + def test_rerank_id(self, client, add_chunks, payload, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) > 0, str(chunks) + + @pytest.mark.skip + @pytest.mark.parametrize( + "payload, expected_page_size, expected_message", + [ + ({"keyword": True}, 5, ""), + ({"keyword": "True"}, 5, ""), + ({"keyword": False}, 5, ""), + ({"keyword": "False"}, 5, ""), + ({"keyword": None}, 5, ""), + ], + ) + def test_keyword(self, client, add_chunks, payload, expected_page_size, expected_message): + dataset, _, _ = add_chunks + payload.update({"question": "chunk test", "dataset_ids": [dataset.id]}) + + if expected_message: + with pytest.raises(Exception) as excinfo: + client.retrieve(**payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunks = client.retrieve(**payload) + assert len(chunks) == expected_page_size, str(chunks) + + @pytest.mark.p3 + def test_concurrent_retrieval(self, client, add_chunks): + dataset, _, _ = add_chunks + count = 100 + payload = {"question": "chunk", "dataset_ids": [dataset.id]} + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(client.retrieve, **payload) for _ in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses diff --git a/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py new file mode 100644 index 000000000..dc85d6385 --- /dev/null +++ b/test/testcases/test_sdk_api/test_chunk_management_within_dataset/test_update_chunk.py @@ -0,0 +1,154 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from random import randint + +import pytest + + +class TestUpdatedChunk: + @pytest.mark.p1 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"content": None}, "TypeError('expected string or bytes-like object')"), + pytest.param( + {"content": ""}, + """APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""", + marks=pytest.mark.skip(reason="issues/6541"), + ), + pytest.param( + {"content": 1}, + "TypeError('expected string or bytes-like object')", + marks=pytest.mark.skip, + ), + ({"content": "update chunk"}, ""), + pytest.param( + {"content": " "}, + """APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""", + marks=pytest.mark.skip(reason="issues/6541"), + ), + ({"content": "\n!?。;!?\"'"}, ""), + ], + ) + def test_content(self, add_chunks, payload, expected_message): + _, _, chunks = add_chunks + chunk = chunks[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chunk.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk.update(payload) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"important_keywords": ["a", "b", "c"]}, ""), + ({"important_keywords": [""]}, ""), + ({"important_keywords": [1]}, "TypeError('sequence item 0: expected str instance, int found')"), + ({"important_keywords": ["a", "a"]}, ""), + ({"important_keywords": "abc"}, "`important_keywords` should be a list"), + ({"important_keywords": 123}, "`important_keywords` should be a list"), + ], + ) + def test_important_keywords(self, add_chunks, payload, expected_message): + _, _, chunks = add_chunks + chunk = chunks[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chunk.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk.update(payload) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"questions": ["a", "b", "c"]}, ""), + ({"questions": [""]}, ""), + ({"questions": [1]}, "TypeError('sequence item 0: expected str instance, int found')"), + ({"questions": ["a", "a"]}, ""), + ({"questions": "abc"}, "`questions` should be a list"), + ({"questions": 123}, "`questions` should be a list"), + ], + ) + def test_questions(self, add_chunks, payload, expected_message): + _, _, chunks = add_chunks + chunk = chunks[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chunk.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk.update(payload) + + @pytest.mark.p2 + @pytest.mark.parametrize( + "payload, expected_message", + [ + ({"available": True}, ""), + pytest.param({"available": "True"}, """ValueError("invalid literal for int() with base 10: \'True\'")""", marks=pytest.mark.skip), + ({"available": 1}, ""), + ({"available": False}, ""), + pytest.param({"available": "False"}, """ValueError("invalid literal for int() with base 10: \'False\'")""", marks=pytest.mark.skip), + ({"available": 0}, ""), + ], + ) + def test_available(self, add_chunks, payload, expected_message): + _, _, chunks = add_chunks + chunk = chunks[0] + + if expected_message: + with pytest.raises(Exception) as excinfo: + chunk.update(payload) + assert expected_message in str(excinfo.value), str(excinfo.value) + else: + chunk.update(payload) + + @pytest.mark.p3 + def test_repeated_update_chunk(self, add_chunks): + _, _, chunks = add_chunks + chunk = chunks[0] + + chunk.update({"content": "chunk test 1"}) + chunk.update({"content": "chunk test 2"}) + + @pytest.mark.p3 + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554") + def test_concurrent_update_chunk(self, add_chunks): + count = 50 + _, _, chunks = add_chunks + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(chunks[randint(0, 3)].update, {"content": f"update chunk test {i}"}) for i in range(count)] + responses = list(as_completed(futures)) + assert len(responses) == count, responses + + @pytest.mark.p3 + def test_update_chunk_to_deleted_document(self, add_chunks): + dataset, document, chunks = add_chunks + dataset.delete_documents(ids=[document.id]) + + with pytest.raises(Exception) as excinfo: + chunks[0].update({}) + assert f"Can't find this chunk {chunks[0].id}" in str(excinfo.value), str(excinfo.value)