diff --git a/sdk/python/test/conftest.py b/sdk/python/test/conftest.py index c32a096ce..4dc87ed20 100644 --- a/sdk/python/test/conftest.py +++ b/sdk/python/test/conftest.py @@ -15,26 +15,27 @@ # import os + import pytest import requests - from libs.auth import RAGFlowHttpApiAuth -HOST_ADDRESS = os.getenv('HOST_ADDRESS', 'http://127.0.0.1:9380') +HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") # def generate_random_email(): # return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com' + def generate_email(): - return 'user_123@1.com' + return "user_123@1.com" EMAIL = generate_email() # password is "123" -PASSWORD = '''ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD +PASSWORD = """ctAseGvejiaSWWZ88T/m4FQVOpQyUvP+x7sXtdv3feqZACiQleuewkUi35E16wSd5C5QcnkkcV9cYc8TKPTRZlxappDuirxghxoOvFcJxFU4ixLsD fN33jCHRoDUW81IH9zjij/vaw8IbVyb6vuwg6MX6inOEBRRzVbRYxXOu1wkWY6SsI8X70oF9aeLFp/PzQpjoe/YbSqpTq8qqrmHzn9vO+yvyYyvmDsphXe -X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w==''' +X8f7fp9c7vUsfOCkM+gHY3PadG+QHa7KI7mzTKgUTZImK6BZtfRBATDTthEUbbaTewY4H0MnWiCeeDhcbeQao6cFy1To8pE3RpmxnGnS8BsBn8w==""" def register(): @@ -92,3 +93,64 @@ def get_email(): @pytest.fixture(scope="session") def get_http_api_auth(get_api_key_fixture): return RAGFlowHttpApiAuth(get_api_key_fixture) + + +def get_my_llms(auth, name): + url = HOST_ADDRESS + "/v1/llm/my_llms" + authorization = {"Authorization": auth} + response = requests.get(url=url, headers=authorization) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + if name in res.get("data"): + return True + return False + + +def add_models(auth): + url = HOST_ADDRESS + "/v1/llm/set_api_key" + authorization = {"Authorization": auth} + models_info = { + "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": "d06253dacd404180aa8afb096fcb6c30.KatwBIUpvCSml9sU"}, + } + + for name, model_info in models_info.items(): + if not get_my_llms(auth, name): + response = requests.post(url=url, headers=authorization, json=model_info) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + + +def get_tenant_info(auth): + url = HOST_ADDRESS + "/v1/user/tenant_info" + authorization = {"Authorization": auth} + response = requests.get(url=url, headers=authorization) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) + return res["data"].get("tenant_id") + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(get_auth): + auth = get_auth + try: + add_models(auth) + tenant_id = get_tenant_info(auth) + except Exception as e: + raise Exception(e) + url = HOST_ADDRESS + "/v1/user/set_tenant_info" + authorization = {"Authorization": get_auth} + tenant_info = { + "tenant_id": tenant_id, + "llm_id": "glm-4-flash@ZHIPU-AI", + "embd_id": "embedding-3@ZHIPU-AI", + "img2txt_id": "glm-4v@ZHIPU-AI", + "asr_id": "", + "tts_id": None, + } + response = requests.post(url=url, headers=authorization, json=tenant_info) + res = response.json() + if res.get("code") != 0: + raise Exception(res.get("message")) diff --git a/sdk/python/test/test_http_api/common.py b/sdk/python/test/test_http_api/common.py index 739fd06da..1ec6ad8d1 100644 --- a/sdk/python/test/test_http_api/common.py +++ b/sdk/python/test/test_http_api/common.py @@ -27,6 +27,7 @@ DATASETS_API_URL = "/api/v1/datasets" FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents" FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks" CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks" +CHAT_ASSISTANT_API_URL = "/api/v1/chats" INVALID_API_TOKEN = "invalid_key_123" DATASET_NAME_LIMIT = 128 @@ -39,7 +40,7 @@ def create_dataset(auth, payload=None): return res.json() -def list_dataset(auth, params=None): +def list_datasets(auth, params=None): res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params) return res.json() @@ -49,7 +50,7 @@ def update_dataset(auth, dataset_id, payload=None): return res.json() -def delete_dataset(auth, payload=None): +def delete_datasets(auth, payload=None): res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload) return res.json() @@ -105,7 +106,7 @@ def download_document(auth, dataset_id, document_id, save_path): return res -def list_documnet(auth, dataset_id, params=None): +def list_documnets(auth, dataset_id, params=None): url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) return res.json() @@ -117,19 +118,19 @@ def update_documnet(auth, dataset_id, document_id, payload=None): return res.json() -def delete_documnet(auth, dataset_id, payload=None): +def delete_documnets(auth, dataset_id, payload=None): url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() -def parse_documnet(auth, dataset_id, payload=None): +def parse_documnets(auth, dataset_id, payload=None): url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() -def stop_parse_documnet(auth, dataset_id, payload=None): +def stop_parse_documnets(auth, dataset_id, payload=None): url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id) res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() @@ -184,3 +185,36 @@ def batch_add_chunks(auth, dataset_id, document_id, num): res = add_chunk(auth, dataset_id, document_id, {"content": f"chunk test {i}"}) chunk_ids.append(res["data"]["chunk"]["id"]) return chunk_ids + + +# CHAT ASSISTANT MANAGEMENT +def create_chat_assistant(auth, payload=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def list_chat_assistants(auth, params=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def update_chat_assistant(auth, chat_assistant_id, payload=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}" + res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def delete_chat_assistants(auth, payload=None): + url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}" + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def batch_create_chat_assistants(auth, num): + chat_assistant_ids = [] + for i in range(num): + res = create_chat_assistant(auth, {"name": f"test_chat_assistant_{i}"}) + chat_assistant_ids.append(res["data"]["id"]) + return chat_assistant_ids diff --git a/sdk/python/test/test_http_api/conftest.py b/sdk/python/test/test_http_api/conftest.py index 704e55476..8826392f1 100644 --- a/sdk/python/test/test_http_api/conftest.py +++ b/sdk/python/test/test_http_api/conftest.py @@ -14,9 +14,8 @@ # limitations under the License. # - import pytest -from common import delete_dataset +from common import batch_create_datasets, bulk_upload_documents, delete_datasets from libs.utils.file_utils import ( create_docx_file, create_eml_file, @@ -34,7 +33,7 @@ from libs.utils.file_utils import ( @pytest.fixture(scope="function") def clear_datasets(get_http_api_auth): yield - delete_dataset(get_http_api_auth) + delete_datasets(get_http_api_auth) @pytest.fixture @@ -58,3 +57,38 @@ def generate_test_files(request, tmp_path): creator_func(file_path) files[file_type] = file_path return files + + +@pytest.fixture(scope="class") +def ragflow_tmp_dir(request, tmp_path_factory): + class_name = request.cls.__name__ + return tmp_path_factory.mktemp(class_name) + + +@pytest.fixture(scope="class") +def add_dataset(request, get_http_api_auth): + def cleanup(): + delete_datasets(get_http_api_auth) + + request.addfinalizer(cleanup) + + dataset_ids = batch_create_datasets(get_http_api_auth, 1) + return dataset_ids[0] + + +@pytest.fixture(scope="function") +def add_dataset_func(request, get_http_api_auth): + def cleanup(): + delete_datasets(get_http_api_auth) + + request.addfinalizer(cleanup) + + dataset_ids = batch_create_datasets(get_http_api_auth, 1) + return dataset_ids[0] + + +@pytest.fixture(scope="class") +def add_document(get_http_api_auth, add_dataset, ragflow_tmp_dir): + dataset_id = add_dataset + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir) + return dataset_id, document_ids[0] diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/conftest.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/conftest.py index 69de33600..a2fe92bc9 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/conftest.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/conftest.py @@ -16,13 +16,13 @@ import pytest -from common import add_chunk, batch_create_datasets, bulk_upload_documents, delete_chunks, delete_dataset, list_documnet, parse_documnet +from common import add_chunk, delete_chunks, list_documnets, parse_documnets from libs.utils import wait_for @wait_for(10, 1, "Document parsing timeout") def condition(_auth, _dataset_id): - res = list_documnet(_auth, _dataset_id) + res = list_documnets(_auth, _dataset_id) for doc in res["data"]["docs"]: if doc["run"] != "DONE": return False @@ -30,29 +30,11 @@ def condition(_auth, _dataset_id): @pytest.fixture(scope="class") -def chunk_management_tmp_dir(tmp_path_factory): - return tmp_path_factory.mktemp("chunk_management") - - -@pytest.fixture(scope="class") -def get_dataset_id_and_document_id(get_http_api_auth, chunk_management_tmp_dir, request): - def cleanup(): - delete_dataset(get_http_api_auth) - - request.addfinalizer(cleanup) - - dataset_ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = dataset_ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, chunk_management_tmp_dir) - parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) +def add_chunks(get_http_api_auth, add_document): + dataset_id, document_id = add_document + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]}) condition(get_http_api_auth, dataset_id) - return dataset_id, document_ids[0] - - -@pytest.fixture(scope="class") -def add_chunks(get_http_api_auth, get_dataset_id_and_document_id): - dataset_id, document_id = get_dataset_id_and_document_id chunk_ids = [] for i in range(4): res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": f"chunk test {i}"}) @@ -66,8 +48,10 @@ def add_chunks(get_http_api_auth, get_dataset_id_and_document_id): @pytest.fixture(scope="function") -def add_chunks_func(get_http_api_auth, get_dataset_id_and_document_id, request): - dataset_id, document_id = get_dataset_id_and_document_id +def add_chunks_func(request, get_http_api_auth, add_document): + dataset_id, document_id = add_document + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": [document_id]}) + condition(get_http_api_auth, dataset_id) chunk_ids = [] for i in range(4): diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py index d8691f1f9..0386b3fcd 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_add_chunk.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor import pytest -from common import INVALID_API_TOKEN, add_chunk, delete_documnet, list_chunks +from common import INVALID_API_TOKEN, add_chunk, delete_documnets, list_chunks from libs.auth import RAGFlowHttpApiAuth @@ -44,7 +44,7 @@ class TestAuthorization: ], ) def test_invalid_auth(self, auth, expected_code, expected_message): - res = add_chunk(auth, "dataset_id", "document_id", {}) + res = add_chunk(auth, "dataset_id", "document_id") assert res["code"] == expected_code assert res["message"] == expected_message @@ -66,8 +66,8 @@ class TestAddChunk: ({"content": "\n!?。;!?\"'"}, 0, ""), ], ) - def test_content(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message): - dataset_id, document_id = get_dataset_id_and_document_id + def test_content(self, get_http_api_auth, add_document, payload, expected_code, expected_message): + dataset_id, document_id = add_document res = list_chunks(get_http_api_auth, dataset_id, document_id) if res["code"] != 0: assert False, res @@ -98,8 +98,8 @@ class TestAddChunk: ({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"), ], ) - def test_important_keywords(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message): - dataset_id, document_id = get_dataset_id_and_document_id + def test_important_keywords(self, get_http_api_auth, add_document, payload, expected_code, expected_message): + dataset_id, document_id = add_document res = list_chunks(get_http_api_auth, dataset_id, document_id) if res["code"] != 0: assert False, res @@ -126,8 +126,8 @@ class TestAddChunk: ({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"), ], ) - def test_questions(self, get_http_api_auth, get_dataset_id_and_document_id, payload, expected_code, expected_message): - dataset_id, document_id = get_dataset_id_and_document_id + def test_questions(self, get_http_api_auth, add_document, payload, expected_code, expected_message): + dataset_id, document_id = add_document res = list_chunks(get_http_api_auth, dataset_id, document_id) if res["code"] != 0: assert False, res @@ -157,12 +157,12 @@ class TestAddChunk: def test_invalid_dataset_id( self, get_http_api_auth, - get_dataset_id_and_document_id, + add_document, dataset_id, expected_code, expected_message, ): - _, document_id = get_dataset_id_and_document_id + _, document_id = add_document res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "a"}) assert res["code"] == expected_code assert res["message"] == expected_message @@ -178,15 +178,15 @@ class TestAddChunk: ), ], ) - def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_id, document_id, expected_code, expected_message): - dataset_id, _ = get_dataset_id_and_document_id + def test_invalid_document_id(self, get_http_api_auth, add_document, document_id, expected_code, expected_message): + dataset_id, _ = add_document res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "chunk test"}) assert res["code"] == expected_code assert res["message"] == expected_message - def test_repeated_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id): + def test_repeated_add_chunk(self, get_http_api_auth, add_document): payload = {"content": "chunk test"} - dataset_id, document_id = get_dataset_id_and_document_id + dataset_id, document_id = add_document res = list_chunks(get_http_api_auth, dataset_id, document_id) if res["code"] != 0: assert False, res @@ -207,17 +207,17 @@ class TestAddChunk: assert False, res assert res["data"]["doc"]["chunk_count"] == chunks_count + 2 - def test_add_chunk_to_deleted_document(self, get_http_api_auth, get_dataset_id_and_document_id): - dataset_id, document_id = get_dataset_id_and_document_id - delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]}) + def test_add_chunk_to_deleted_document(self, get_http_api_auth, add_document): + dataset_id, document_id = add_document + delete_documnets(get_http_api_auth, dataset_id, {"ids": [document_id]}) res = add_chunk(get_http_api_auth, dataset_id, document_id, {"content": "chunk test"}) assert res["code"] == 102 assert res["message"] == f"You don't own the document {document_id}." @pytest.mark.skip(reason="issues/6411") - def test_concurrent_add_chunk(self, get_http_api_auth, get_dataset_id_and_document_id): + def test_concurrent_add_chunk(self, get_http_api_auth, add_document): chunk_num = 50 - dataset_id, document_id = get_dataset_id_and_document_id + dataset_id, document_id = add_document res = list_chunks(get_http_api_auth, dataset_id, document_id) if res["code"] != 0: assert False, res diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py index ba97dffbc..f04bc9622 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_delete_chunks.py @@ -39,7 +39,7 @@ class TestAuthorization: assert res["message"] == expected_message -class TestChunkstDeletion: +class TestChunksDeletion: @pytest.mark.parametrize( "dataset_id, expected_code, expected_message", [ @@ -61,25 +61,14 @@ class TestChunkstDeletion: "document_id, expected_code, expected_message", [ ("", 100, ""), - pytest.param( - "invalid_document_id", - 100, - "LookupError('Document not found which is supposed to be there')", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6611"), - ), - pytest.param( - "invalid_document_id", - 100, - "rm_chunk deleted chunks 0, expect 4", - marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "elasticsearch"], reason="issues/6611"), - ), + ("invalid_document_id", 100, """LookupError("Can't find the document with ID invalid_document_id!")"""), ], ) def test_invalid_document_id(self, get_http_api_auth, add_chunks_func, document_id, expected_code, expected_message): dataset_id, _, chunk_ids = add_chunks_func res = delete_chunks(get_http_api_auth, dataset_id, document_id, {"chunk_ids": chunk_ids}) assert res["code"] == expected_code - #assert res["message"] == expected_message + assert res["message"] == expected_message @pytest.mark.parametrize( "payload", diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py index cf70fd7be..f930c971f 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_list_chunks.py @@ -17,11 +17,7 @@ import os from concurrent.futures import ThreadPoolExecutor import pytest -from common import ( - INVALID_API_TOKEN, - batch_add_chunks, - list_chunks, -) +from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks from libs.auth import RAGFlowHttpApiAuth @@ -153,8 +149,9 @@ class TestChunksList: assert all(r["code"] == 0 for r in responses) assert all(len(r["data"]["chunks"]) == 5 for r in responses) - def test_default(self, get_http_api_auth, get_dataset_id_and_document_id): - dataset_id, document_id = get_dataset_id_and_document_id + def test_default(self, get_http_api_auth, add_document): + dataset_id, document_id = add_document + res = list_chunks(get_http_api_auth, dataset_id, document_id) chunks_count = res["data"]["doc"]["chunk_count"] batch_add_chunks(get_http_api_auth, dataset_id, document_id, 31) diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py index 5b2103640..fd10443e6 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_retrieval_chunks.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import os import pytest @@ -52,9 +51,7 @@ class TestChunksRetrieval: ({"question": "chunk"}, 102, 0, "`dataset_ids` is required."), ], ) - def test_basic_scenarios( - self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message - ): + def test_basic_scenarios(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): dataset_id, document_id, _ = add_chunks if "dataset_ids" in payload: payload["dataset_ids"] = [dataset_id] @@ -137,9 +134,7 @@ class TestChunksRetrieval: ), ], ) - def test_page_size( - self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message - ): + def test_page_size(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): dataset_id, _, _ = add_chunks payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) @@ -165,9 +160,7 @@ class TestChunksRetrieval: ), ], ) - def test_vector_similarity_weight( - self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message - ): + def test_vector_similarity_weight(self, get_http_api_auth, add_chunks, payload, expected_code, expected_page_size, expected_message): dataset_id, _, _ = add_chunks payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) res = retrieval_chunks(get_http_api_auth, payload) @@ -233,9 +226,7 @@ class TestChunksRetrieval: "payload, expected_code, expected_message", [ ({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""), - pytest.param( - {"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip - ), + pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip), ], ) def test_rerank_id(self, get_http_api_auth, add_chunks, payload, expected_code, expected_message): @@ -248,7 +239,6 @@ class TestChunksRetrieval: else: assert expected_message in res["message"] - @pytest.mark.skip(reason="chat model is not set") @pytest.mark.parametrize( "payload, expected_code, expected_page_size, expected_message", [ @@ -279,9 +269,7 @@ class TestChunksRetrieval: pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")), ], ) - def test_highlight( - self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message - ): + def test_highlight(self, get_http_api_auth, add_chunks, payload, expected_code, expected_highlight, expected_message): dataset_id, _, _ = add_chunks payload.update({"question": "chunk", "dataset_ids": [dataset_id]}) res = retrieval_chunks(get_http_api_auth, payload) @@ -302,3 +290,14 @@ class TestChunksRetrieval: res = retrieval_chunks(get_http_api_auth, payload) assert res["code"] == 0 assert len(res["data"]["chunks"]) == 4 + + def test_concurrent_retrieval(self, get_http_api_auth, add_chunks): + from concurrent.futures import ThreadPoolExecutor + + dataset_id, _, _ = add_chunks + payload = {"question": "chunk", "dataset_ids": [dataset_id]} + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(retrieval_chunks, get_http_api_auth, payload) for i in range(100)] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses) diff --git a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py index 57bad2715..b3fcfa008 100644 --- a/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py +++ b/sdk/python/test/test_http_api/test_chunk_management_within_dataset/test_update_chunk.py @@ -18,7 +18,7 @@ from concurrent.futures import ThreadPoolExecutor from random import randint import pytest -from common import INVALID_API_TOKEN, delete_documnet, update_chunk +from common import INVALID_API_TOKEN, delete_documnets, update_chunk from libs.auth import RAGFlowHttpApiAuth @@ -233,7 +233,7 @@ class TestUpdatedChunk: def test_update_chunk_to_deleted_document(self, get_http_api_auth, add_chunks): dataset_id, document_id, chunk_ids = add_chunks - delete_documnet(get_http_api_auth, dataset_id, {"ids": [document_id]}) + delete_documnets(get_http_api_auth, dataset_id, {"ids": [document_id]}) res = update_chunk(get_http_api_auth, dataset_id, document_id, chunk_ids[0]) assert res["code"] == 102 assert res["message"] == f"Can't find this chunk {chunk_ids[0]}" diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/conftest.py b/sdk/python/test/test_http_api/test_dataset_mangement/conftest.py index a9b22a0ef..6f7582912 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/conftest.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/conftest.py @@ -16,14 +16,24 @@ import pytest -from common import batch_create_datasets, delete_dataset +from common import batch_create_datasets, delete_datasets @pytest.fixture(scope="class") -def get_dataset_ids(get_http_api_auth, request): +def add_datasets(get_http_api_auth, request): def cleanup(): - delete_dataset(get_http_api_auth) + delete_datasets(get_http_api_auth) request.addfinalizer(cleanup) return batch_create_datasets(get_http_api_auth, 5) + + +@pytest.fixture(scope="function") +def add_datasets_func(get_http_api_auth, request): + def cleanup(): + delete_datasets(get_http_api_auth) + + request.addfinalizer(cleanup) + + return batch_create_datasets(get_http_api_auth, 3) diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py index d20c180cd..5f0a9d0fb 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -75,9 +75,6 @@ class TestDatasetCreation: res = create_dataset(get_http_api_auth, payload) assert res["code"] == 0, f"Failed to create dataset {i}" - -@pytest.mark.usefixtures("clear_datasets") -class TestAdvancedConfigurations: def test_avatar(self, get_http_api_auth, tmp_path): fn = create_image_file(tmp_path / "ragflow_test.png") payload = { diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_delete_datasets.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_delete_datasets.py index 5f40ca7f1..cd9a92e45 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_delete_datasets.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_delete_datasets.py @@ -20,13 +20,12 @@ import pytest from common import ( INVALID_API_TOKEN, batch_create_datasets, - delete_dataset, - list_dataset, + delete_datasets, + list_datasets, ) from libs.auth import RAGFlowHttpApiAuth -@pytest.mark.usefixtures("clear_datasets") class TestAuthorization: @pytest.mark.parametrize( "auth, expected_code, expected_message", @@ -39,18 +38,13 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message): - ids = batch_create_datasets(get_http_api_auth, 1) - res = delete_dataset(auth, {"ids": ids}) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = delete_datasets(auth) assert res["code"] == expected_code assert res["message"] == expected_message - res = list_dataset(get_http_api_auth) - assert len(res["data"]) == 1 - -@pytest.mark.usefixtures("clear_datasets") -class TestDatasetDeletion: +class TestDatasetsDeletion: @pytest.mark.parametrize( "payload, expected_code, expected_message, remaining", [ @@ -73,16 +67,16 @@ class TestDatasetDeletion: (lambda r: {"ids": r}, 0, "", 0), ], ) - def test_basic_scenarios(self, get_http_api_auth, payload, expected_code, expected_message, remaining): - ids = batch_create_datasets(get_http_api_auth, 3) + def test_basic_scenarios(self, get_http_api_auth, add_datasets_func, payload, expected_code, expected_message, remaining): + dataset_ids = add_datasets_func if callable(payload): - payload = payload(ids) - res = delete_dataset(get_http_api_auth, payload) + payload = payload(dataset_ids) + res = delete_datasets(get_http_api_auth, payload) assert res["code"] == expected_code if res["code"] != 0: assert res["message"] == expected_message - res = list_dataset(get_http_api_auth) + res = list_datasets(get_http_api_auth) assert len(res["data"]) == remaining @pytest.mark.parametrize( @@ -93,50 +87,50 @@ class TestDatasetDeletion: lambda r: {"ids": r + ["invalid_id"]}, ], ) - def test_delete_partial_invalid_id(self, get_http_api_auth, payload): - ids = batch_create_datasets(get_http_api_auth, 3) + def test_delete_partial_invalid_id(self, get_http_api_auth, add_datasets_func, payload): + dataset_ids = add_datasets_func if callable(payload): - payload = payload(ids) - res = delete_dataset(get_http_api_auth, payload) + payload = payload(dataset_ids) + res = delete_datasets(get_http_api_auth, payload) assert res["code"] == 0 assert res["data"]["errors"][0] == "You don't own the dataset invalid_id" assert res["data"]["success_count"] == 3 - res = list_dataset(get_http_api_auth) + res = list_datasets(get_http_api_auth) assert len(res["data"]) == 0 - def test_repeated_deletion(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) - res = delete_dataset(get_http_api_auth, {"ids": ids}) + def test_repeated_deletion(self, get_http_api_auth, add_datasets_func): + dataset_ids = add_datasets_func + res = delete_datasets(get_http_api_auth, {"ids": dataset_ids}) assert res["code"] == 0 - res = delete_dataset(get_http_api_auth, {"ids": ids}) + res = delete_datasets(get_http_api_auth, {"ids": dataset_ids}) assert res["code"] == 102 - assert res["message"] == f"You don't own the dataset {ids[0]}" + assert "You don't own the dataset" in res["message"] - def test_duplicate_deletion(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) - res = delete_dataset(get_http_api_auth, {"ids": ids + ids}) + def test_duplicate_deletion(self, get_http_api_auth, add_datasets_func): + dataset_ids = add_datasets_func + res = delete_datasets(get_http_api_auth, {"ids": dataset_ids + dataset_ids}) assert res["code"] == 0 - assert res["data"]["errors"][0] == f"Duplicate dataset ids: {ids[0]}" - assert res["data"]["success_count"] == 1 + assert "Duplicate dataset ids" in res["data"]["errors"][0] + assert res["data"]["success_count"] == 3 - res = list_dataset(get_http_api_auth) + res = list_datasets(get_http_api_auth) assert len(res["data"]) == 0 def test_concurrent_deletion(self, get_http_api_auth): ids = batch_create_datasets(get_http_api_auth, 100) with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(delete_dataset, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)] + futures = [executor.submit(delete_datasets, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)] responses = [f.result() for f in futures] assert all(r["code"] == 0 for r in responses) @pytest.mark.slow def test_delete_10k(self, get_http_api_auth): ids = batch_create_datasets(get_http_api_auth, 10_000) - res = delete_dataset(get_http_api_auth, {"ids": ids}) + res = delete_datasets(get_http_api_auth, {"ids": ids}) assert res["code"] == 0 - res = list_dataset(get_http_api_auth) + res = list_datasets(get_http_api_auth) assert len(res["data"]) == 0 diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py index e19ed16ce..35f3057de 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_list_datasets.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor import pytest -from common import INVALID_API_TOKEN, list_dataset +from common import INVALID_API_TOKEN, list_datasets from libs.auth import RAGFlowHttpApiAuth @@ -25,7 +25,6 @@ def is_sorted(data, field, descending=True): return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:])) -@pytest.mark.usefixtures("clear_datasets") class TestAuthorization: @pytest.mark.parametrize( "auth, expected_code, expected_message", @@ -39,15 +38,15 @@ class TestAuthorization: ], ) def test_invalid_auth(self, auth, expected_code, expected_message): - res = list_dataset(auth) + res = list_datasets(auth) assert res["code"] == expected_code assert res["message"] == expected_message -@pytest.mark.usefixtures("get_dataset_ids") -class TestDatasetList: +@pytest.mark.usefixtures("add_datasets") +class TestDatasetsList: def test_default(self, get_http_api_auth): - res = list_dataset(get_http_api_auth, params={}) + res = list_datasets(get_http_api_auth, params={}) assert res["code"] == 0 assert len(res["data"]) == 5 @@ -77,7 +76,7 @@ class TestDatasetList: ], ) def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message): - res = list_dataset(get_http_api_auth, params=params) + res = list_datasets(get_http_api_auth, params=params) assert res["code"] == expected_code if expected_code == 0: assert len(res["data"]) == expected_page_size @@ -116,7 +115,7 @@ class TestDatasetList: expected_page_size, expected_message, ): - res = list_dataset(get_http_api_auth, params=params) + res = list_datasets(get_http_api_auth, params=params) assert res["code"] == expected_code if expected_code == 0: assert len(res["data"]) == expected_page_size @@ -168,7 +167,7 @@ class TestDatasetList: assertions, expected_message, ): - res = list_dataset(get_http_api_auth, params=params) + res = list_datasets(get_http_api_auth, params=params) assert res["code"] == expected_code if expected_code == 0: if callable(assertions): @@ -244,7 +243,7 @@ class TestDatasetList: assertions, expected_message, ): - res = list_dataset(get_http_api_auth, params=params) + res = list_datasets(get_http_api_auth, params=params) assert res["code"] == expected_code if expected_code == 0: if callable(assertions): @@ -262,7 +261,7 @@ class TestDatasetList: ], ) def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message): - res = list_dataset(get_http_api_auth, params=params) + res = list_datasets(get_http_api_auth, params=params) assert res["code"] == expected_code if expected_code == 0: if params["name"] in [None, ""]: @@ -284,19 +283,19 @@ class TestDatasetList: def test_id( self, get_http_api_auth, - get_dataset_ids, + add_datasets, dataset_id, expected_code, expected_num, expected_message, ): - dataset_ids = get_dataset_ids + dataset_ids = add_datasets if callable(dataset_id): params = {"id": dataset_id(dataset_ids)} else: params = {"id": dataset_id} - res = list_dataset(get_http_api_auth, params=params) + res = list_datasets(get_http_api_auth, params=params) assert res["code"] == expected_code if expected_code == 0: if params["id"] in [None, ""]: @@ -318,20 +317,20 @@ class TestDatasetList: def test_name_and_id( self, get_http_api_auth, - get_dataset_ids, + add_datasets, dataset_id, name, expected_code, expected_num, expected_message, ): - dataset_ids = get_dataset_ids + dataset_ids = add_datasets if callable(dataset_id): params = {"id": dataset_id(dataset_ids), "name": name} else: params = {"id": dataset_id, "name": name} - res = list_dataset(get_http_api_auth, params=params) + res = list_datasets(get_http_api_auth, params=params) if expected_code == 0: assert len(res["data"]) == expected_num else: @@ -339,12 +338,12 @@ class TestDatasetList: def test_concurrent_list(self, get_http_api_auth): with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(list_dataset, get_http_api_auth) for i in range(100)] + futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)] responses = [f.result() for f in futures] assert all(r["code"] == 0 for r in responses) def test_invalid_params(self, get_http_api_auth): params = {"a": "b"} - res = list_dataset(get_http_api_auth, params=params) + res = list_datasets(get_http_api_auth, params=params) assert res["code"] == 0 assert len(res["data"]) == 5 diff --git a/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py b/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py index f160a3b74..f189855dd 100644 --- a/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py +++ b/sdk/python/test/test_http_api/test_dataset_mangement/test_update_dataset.py @@ -19,8 +19,7 @@ import pytest from common import ( DATASET_NAME_LIMIT, INVALID_API_TOKEN, - batch_create_datasets, - list_dataset, + list_datasets, update_dataset, ) from libs.auth import RAGFlowHttpApiAuth @@ -30,7 +29,6 @@ from libs.utils.file_utils import create_image_file # TODO: Missing scenario for updating embedding_model with chunk_count != 0 -@pytest.mark.usefixtures("clear_datasets") class TestAuthorization: @pytest.mark.parametrize( "auth, expected_code, expected_message", @@ -43,14 +41,12 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message): - ids = batch_create_datasets(get_http_api_auth, 1) - res = update_dataset(auth, ids[0], {"name": "new_name"}) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = update_dataset(auth, "dataset_id") assert res["code"] == expected_code assert res["message"] == expected_message -@pytest.mark.usefixtures("clear_datasets") class TestDatasetUpdate: @pytest.mark.parametrize( "name, expected_code, expected_message", @@ -72,12 +68,12 @@ class TestDatasetUpdate: ("DATASET_1", 102, "Duplicated dataset name in updating dataset."), ], ) - def test_name(self, get_http_api_auth, name, expected_code, expected_message): - ids = batch_create_datasets(get_http_api_auth, 2) - res = update_dataset(get_http_api_auth, ids[0], {"name": name}) + def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message): + dataset_ids = add_datasets_func + res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name}) assert res["code"] == expected_code if expected_code == 0: - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]}) assert res["data"][0]["name"] == name else: assert res["message"] == expected_message @@ -95,12 +91,12 @@ class TestDatasetUpdate: (None, 102, "`embedding_model` can't be empty"), ], ) - def test_embedding_model(self, get_http_api_auth, embedding_model, expected_code, expected_message): - ids = batch_create_datasets(get_http_api_auth, 1) - res = update_dataset(get_http_api_auth, ids[0], {"embedding_model": embedding_model}) + def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message): + dataset_id = add_dataset_func + res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model}) assert res["code"] == expected_code if expected_code == 0: - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) assert res["data"][0]["embedding_model"] == embedding_model else: assert res["message"] == expected_message @@ -129,12 +125,12 @@ class TestDatasetUpdate: ), ], ) - def test_chunk_method(self, get_http_api_auth, chunk_method, expected_code, expected_message): - ids = batch_create_datasets(get_http_api_auth, 1) - res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method}) + def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message): + dataset_id = add_dataset_func + res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method}) assert res["code"] == expected_code if expected_code == 0: - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) if chunk_method != "": assert res["data"][0]["chunk_method"] == chunk_method else: @@ -142,38 +138,38 @@ class TestDatasetUpdate: else: assert res["message"] == expected_message - def test_avatar(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func fn = create_image_file(tmp_path / "ragflow_test.png") payload = {"avatar": encode_avatar(fn)} - res = update_dataset(get_http_api_auth, ids[0], payload) + res = update_dataset(get_http_api_auth, dataset_id, payload) assert res["code"] == 0 - def test_description(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_description(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func payload = {"description": "description"} - res = update_dataset(get_http_api_auth, ids[0], payload) + res = update_dataset(get_http_api_auth, dataset_id, payload) assert res["code"] == 0 - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) assert res["data"][0]["description"] == "description" - def test_pagerank(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_pagerank(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func payload = {"pagerank": 1} - res = update_dataset(get_http_api_auth, ids[0], payload) + res = update_dataset(get_http_api_auth, dataset_id, payload) assert res["code"] == 0 - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) assert res["data"][0]["pagerank"] == 1 - def test_similarity_threshold(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_similarity_threshold(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func payload = {"similarity_threshold": 1} - res = update_dataset(get_http_api_auth, ids[0], payload) + res = update_dataset(get_http_api_auth, dataset_id, payload) assert res["code"] == 0 - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) assert res["data"][0]["similarity_threshold"] == 1 @pytest.mark.parametrize( @@ -187,29 +183,28 @@ class TestDatasetUpdate: ("other_permission", 102), ], ) - def test_permission(self, get_http_api_auth, permission, expected_code): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code): + dataset_id = add_dataset_func payload = {"permission": permission} - res = update_dataset(get_http_api_auth, ids[0], payload) + res = update_dataset(get_http_api_auth, dataset_id, payload) assert res["code"] == expected_code - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) if expected_code == 0 and permission != "": assert res["data"][0]["permission"] == permission if permission == "": assert res["data"][0]["permission"] == "me" - def test_vector_similarity_weight(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func payload = {"vector_similarity_weight": 1} - res = update_dataset(get_http_api_auth, ids[0], payload) + res = update_dataset(get_http_api_auth, dataset_id, payload) assert res["code"] == 0 - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) assert res["data"][0]["vector_similarity_weight"] == 1 def test_invalid_dataset_id(self, get_http_api_auth): - batch_create_datasets(get_http_api_auth, 1) res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"}) assert res["code"] == 102 assert res["message"] == "You don't own the dataset" @@ -230,21 +225,21 @@ class TestDatasetUpdate: {"update_time": 1741671443339}, ], ) - def test_modify_read_only_field(self, get_http_api_auth, payload): - ids = batch_create_datasets(get_http_api_auth, 1) - res = update_dataset(get_http_api_auth, ids[0], payload) + def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload): + dataset_id = add_dataset_func + res = update_dataset(get_http_api_auth, dataset_id, payload) assert res["code"] == 101 assert "is readonly" in res["message"] - def test_modify_unknown_field(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) - res = update_dataset(get_http_api_auth, ids[0], {"unknown_field": 0}) + def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0}) assert res["code"] == 100 - def test_concurrent_update(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_concurrent_update(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}) for i in range(100)] + futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)] responses = [f.result() for f in futures] assert all(r["code"] == 0 for r in responses) diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/conftest.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/conftest.py index df871507a..3f48b2056 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/conftest.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/conftest.py @@ -16,22 +16,36 @@ import pytest -from common import batch_create_datasets, bulk_upload_documents, delete_dataset +from common import bulk_upload_documents, delete_documnets -@pytest.fixture(scope="class") -def file_management_tmp_dir(tmp_path_factory): - return tmp_path_factory.mktemp("file_management") +@pytest.fixture(scope="function") +def add_document_func(request, get_http_api_auth, add_dataset, ragflow_tmp_dir): + dataset_id = add_dataset + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, ragflow_tmp_dir) - -@pytest.fixture(scope="class") -def get_dataset_id_and_document_ids(get_http_api_auth, file_management_tmp_dir, request): def cleanup(): - delete_dataset(get_http_api_auth) + delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) request.addfinalizer(cleanup) + return dataset_id, document_ids[0] + + +@pytest.fixture(scope="class") +def add_documents(request, get_http_api_auth, add_dataset, ragflow_tmp_dir): + dataset_id = add_dataset + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 5, ragflow_tmp_dir) + + def cleanup(): + delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) + + request.addfinalizer(cleanup) + return dataset_id, document_ids + + +@pytest.fixture(scope="function") +def add_documents_func(get_http_api_auth, add_dataset_func, ragflow_tmp_dir): + dataset_id = add_dataset_func + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, ragflow_tmp_dir) - dataset_ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = dataset_ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 5, file_management_tmp_dir) return dataset_id, document_ids diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_delete_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_delete_documents.py index acfc74343..60aaa7ae9 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_delete_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_delete_documents.py @@ -16,13 +16,7 @@ from concurrent.futures import ThreadPoolExecutor import pytest -from common import ( - INVALID_API_TOKEN, - batch_create_datasets, - bulk_upload_documents, - delete_documnet, - list_documnet, -) +from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documnets, list_documnets from libs.auth import RAGFlowHttpApiAuth @@ -38,15 +32,13 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message): - dataset_id, document_ids = get_dataset_id_and_document_ids - res = delete_documnet(auth, dataset_id, {"ids": document_ids}) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = delete_documnets(auth, "dataset_id") assert res["code"] == expected_code assert res["message"] == expected_message -@pytest.mark.usefixtures("clear_datasets") -class TestDocumentDeletion: +class TestDocumentsDeletion: @pytest.mark.parametrize( "payload, expected_code, expected_message, remaining", [ @@ -72,22 +64,21 @@ class TestDocumentDeletion: def test_basic_scenarios( self, get_http_api_auth, - tmp_path, + add_documents_func, payload, expected_code, expected_message, remaining, ): - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path) + dataset_id, document_ids = add_documents_func if callable(payload): payload = payload(document_ids) - res = delete_documnet(get_http_api_auth, ids[0], payload) + res = delete_documnets(get_http_api_auth, dataset_id, payload) assert res["code"] == expected_code if res["code"] != 0: assert res["message"] == expected_message - res = list_documnet(get_http_api_auth, ids[0]) + res = list_documnets(get_http_api_auth, dataset_id) assert len(res["data"]["docs"]) == remaining assert res["data"]["total"] == remaining @@ -102,10 +93,9 @@ class TestDocumentDeletion: ), ], ) - def test_invalid_dataset_id(self, get_http_api_auth, tmp_path, dataset_id, expected_code, expected_message): - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path) - res = delete_documnet(get_http_api_auth, dataset_id, {"ids": document_ids[:1]}) + def test_invalid_dataset_id(self, get_http_api_auth, add_documents_func, dataset_id, expected_code, expected_message): + _, document_ids = add_documents_func + res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids[:1]}) assert res["code"] == expected_code assert res["message"] == expected_message @@ -117,69 +107,68 @@ class TestDocumentDeletion: lambda r: {"ids": r + ["invalid_id"]}, ], ) - def test_delete_partial_invalid_id(self, get_http_api_auth, tmp_path, payload): - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 3, tmp_path) + def test_delete_partial_invalid_id(self, get_http_api_auth, add_documents_func, payload): + dataset_id, document_ids = add_documents_func if callable(payload): payload = payload(document_ids) - res = delete_documnet(get_http_api_auth, ids[0], payload) + res = delete_documnets(get_http_api_auth, dataset_id, payload) assert res["code"] == 102 assert res["message"] == "Documents not found: ['invalid_id']" - res = list_documnet(get_http_api_auth, ids[0]) + res = list_documnets(get_http_api_auth, dataset_id) assert len(res["data"]["docs"]) == 0 assert res["data"]["total"] == 0 - def test_repeated_deletion(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) - res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids}) + def test_repeated_deletion(self, get_http_api_auth, add_documents_func): + dataset_id, document_ids = add_documents_func + res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) assert res["code"] == 0 - res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids}) + res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) assert res["code"] == 102 - assert res["message"] == f"Documents not found: {document_ids}" + assert "Documents not found" in res["message"] - def test_duplicate_deletion(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) - res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids + document_ids}) + def test_duplicate_deletion(self, get_http_api_auth, add_documents_func): + dataset_id, document_ids = add_documents_func + res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids + document_ids}) assert res["code"] == 0 - assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}" - assert res["data"]["success_count"] == 1 + assert "Duplicate document ids" in res["data"]["errors"][0] + assert res["data"]["success_count"] == 3 - res = list_documnet(get_http_api_auth, ids[0]) + res = list_documnets(get_http_api_auth, dataset_id) assert len(res["data"]["docs"]) == 0 assert res["data"]["total"] == 0 - def test_concurrent_deletion(self, get_http_api_auth, tmp_path): - documnets_num = 100 - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path) - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit( - delete_documnet, - get_http_api_auth, - ids[0], - {"ids": document_ids[i : i + 1]}, - ) - for i in range(documnets_num) - ] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) +def test_concurrent_deletion(get_http_api_auth, add_dataset, tmp_path): + documnets_num = 100 + dataset_id = add_dataset + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, documnets_num, tmp_path) - @pytest.mark.slow - def test_delete_1k(self, get_http_api_auth, tmp_path): - documnets_num = 1_000 - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], documnets_num, tmp_path) - res = list_documnet(get_http_api_auth, ids[0]) - assert res["data"]["total"] == documnets_num + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + delete_documnets, + get_http_api_auth, + dataset_id, + {"ids": document_ids[i : i + 1]}, + ) + for i in range(documnets_num) + ] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses) - res = delete_documnet(get_http_api_auth, ids[0], {"ids": document_ids}) - assert res["code"] == 0 - res = list_documnet(get_http_api_auth, ids[0]) - assert res["data"]["total"] == 0 +@pytest.mark.slow +def test_delete_1k(get_http_api_auth, add_dataset, tmp_path): + documnets_num = 1_000 + dataset_id = add_dataset + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, documnets_num, tmp_path) + res = list_documnets(get_http_api_auth, dataset_id) + assert res["data"]["total"] == documnets_num + + res = delete_documnets(get_http_api_auth, dataset_id, {"ids": document_ids}) + assert res["code"] == 0 + + res = list_documnets(get_http_api_auth, dataset_id) + assert res["data"]["total"] == 0 diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_download_document.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_download_document.py index 0ec4c5fc9..0b11218e0 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_download_document.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_download_document.py @@ -18,7 +18,7 @@ import json from concurrent.futures import ThreadPoolExecutor import pytest -from common import INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, download_document, upload_documnets +from common import INVALID_API_TOKEN, bulk_upload_documents, download_document, upload_documnets from libs.auth import RAGFlowHttpApiAuth from libs.utils import compare_by_hash from requests import codes @@ -36,9 +36,8 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_dataset_id_and_document_ids, tmp_path, auth, expected_code, expected_message): - dataset_id, document_ids = get_dataset_id_and_document_ids - res = download_document(auth, dataset_id, document_ids[0], tmp_path / "ragflow_tes.txt") + def test_invalid_auth(self, tmp_path, auth, expected_code, expected_message): + res = download_document(auth, "dataset_id", "document_id", tmp_path / "ragflow_tes.txt") assert res.status_code == codes.ok with (tmp_path / "ragflow_tes.txt").open("r") as f: response_json = json.load(f) @@ -46,7 +45,6 @@ class TestAuthorization: assert response_json["message"] == expected_message -@pytest.mark.usefixtures("clear_datasets") @pytest.mark.parametrize( "generate_test_files", [ @@ -63,15 +61,15 @@ class TestAuthorization: ], indirect=True, ) -def test_file_type_validation(get_http_api_auth, generate_test_files, request): - ids = batch_create_datasets(get_http_api_auth, 1) +def test_file_type_validation(get_http_api_auth, add_dataset, generate_test_files, request): + dataset_id = add_dataset fp = generate_test_files[request.node.callspec.params["generate_test_files"]] - res = upload_documnets(get_http_api_auth, ids[0], [fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp]) document_id = res["data"][0]["id"] res = download_document( get_http_api_auth, - ids[0], + dataset_id, document_id, fp.with_stem("ragflow_test_download"), ) @@ -93,8 +91,8 @@ class TestDocumentDownload: ), ], ) - def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, document_id, expected_code, expected_message): - dataset_id, _ = get_dataset_id_and_document_ids + def test_invalid_document_id(self, get_http_api_auth, add_documents, tmp_path, document_id, expected_code, expected_message): + dataset_id, _ = add_documents res = download_document( get_http_api_auth, dataset_id, @@ -118,8 +116,8 @@ class TestDocumentDownload: ), ], ) - def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, dataset_id, expected_code, expected_message): - _, document_ids = get_dataset_id_and_document_ids + def test_invalid_dataset_id(self, get_http_api_auth, add_documents, tmp_path, dataset_id, expected_code, expected_message): + _, document_ids = add_documents res = download_document( get_http_api_auth, dataset_id, @@ -132,9 +130,9 @@ class TestDocumentDownload: assert response_json["code"] == expected_code assert response_json["message"] == expected_message - def test_same_file_repeat(self, get_http_api_auth, get_dataset_id_and_document_ids, tmp_path, file_management_tmp_dir): + def test_same_file_repeat(self, get_http_api_auth, add_documents, tmp_path, ragflow_tmp_dir): num = 5 - dataset_id, document_ids = get_dataset_id_and_document_ids + dataset_id, document_ids = add_documents for i in range(num): res = download_document( get_http_api_auth, @@ -144,23 +142,22 @@ class TestDocumentDownload: ) assert res.status_code == codes.ok assert compare_by_hash( - file_management_tmp_dir / "ragflow_test_upload_0.txt", + ragflow_tmp_dir / "ragflow_test_upload_0.txt", tmp_path / f"ragflow_test_download_{i}.txt", ) -@pytest.mark.usefixtures("clear_datasets") -def test_concurrent_download(get_http_api_auth, tmp_path): +def test_concurrent_download(get_http_api_auth, add_dataset, tmp_path): document_count = 20 - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], document_count, tmp_path) + dataset_id = add_dataset + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_count, tmp_path) with ThreadPoolExecutor(max_workers=5) as executor: futures = [ executor.submit( download_document, get_http_api_auth, - ids[0], + dataset_id, document_ids[i], tmp_path / f"ragflow_test_download_{i}.txt", ) diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_list_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_list_documents.py index be623318c..15111ece9 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_list_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_list_documents.py @@ -16,10 +16,7 @@ from concurrent.futures import ThreadPoolExecutor import pytest -from common import ( - INVALID_API_TOKEN, - list_documnet, -) +from common import INVALID_API_TOKEN, list_documnets from libs.auth import RAGFlowHttpApiAuth @@ -40,17 +37,16 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message): - dataset_id, _ = get_dataset_id_and_document_ids - res = list_documnet(auth, dataset_id) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = list_documnets(auth, "dataset_id") assert res["code"] == expected_code assert res["message"] == expected_message -class TestDocumentList: - def test_default(self, get_http_api_auth, get_dataset_id_and_document_ids): - dataset_id, _ = get_dataset_id_and_document_ids - res = list_documnet(get_http_api_auth, dataset_id) +class TestDocumentsList: + def test_default(self, get_http_api_auth, add_documents): + dataset_id, _ = add_documents + res = list_documnets(get_http_api_auth, dataset_id) assert res["code"] == 0 assert len(res["data"]["docs"]) == 5 assert res["data"]["total"] == 5 @@ -66,8 +62,8 @@ class TestDocumentList: ), ], ) - def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message): - res = list_documnet(get_http_api_auth, dataset_id) + def test_invalid_dataset_id(self, get_http_api_auth, dataset_id, expected_code, expected_message): + res = list_documnets(get_http_api_auth, dataset_id) assert res["code"] == expected_code assert res["message"] == expected_message @@ -98,14 +94,14 @@ class TestDocumentList: def test_page( self, get_http_api_auth, - get_dataset_id_and_document_ids, + add_documents, params, expected_code, expected_page_size, expected_message, ): - dataset_id, _ = get_dataset_id_and_document_ids - res = list_documnet(get_http_api_auth, dataset_id, params=params) + dataset_id, _ = add_documents + res = list_documnets(get_http_api_auth, dataset_id, params=params) assert res["code"] == expected_code if expected_code == 0: assert len(res["data"]["docs"]) == expected_page_size @@ -140,14 +136,14 @@ class TestDocumentList: def test_page_size( self, get_http_api_auth, - get_dataset_id_and_document_ids, + add_documents, params, expected_code, expected_page_size, expected_message, ): - dataset_id, _ = get_dataset_id_and_document_ids - res = list_documnet(get_http_api_auth, dataset_id, params=params) + dataset_id, _ = add_documents + res = list_documnets(get_http_api_auth, dataset_id, params=params) assert res["code"] == expected_code if expected_code == 0: assert len(res["data"]["docs"]) == expected_page_size @@ -194,14 +190,14 @@ class TestDocumentList: def test_orderby( self, get_http_api_auth, - get_dataset_id_and_document_ids, + add_documents, params, expected_code, assertions, expected_message, ): - dataset_id, _ = get_dataset_id_and_document_ids - res = list_documnet(get_http_api_auth, dataset_id, params=params) + dataset_id, _ = add_documents + res = list_documnets(get_http_api_auth, dataset_id, params=params) assert res["code"] == expected_code if expected_code == 0: if callable(assertions): @@ -273,14 +269,14 @@ class TestDocumentList: def test_desc( self, get_http_api_auth, - get_dataset_id_and_document_ids, + add_documents, params, expected_code, assertions, expected_message, ): - dataset_id, _ = get_dataset_id_and_document_ids - res = list_documnet(get_http_api_auth, dataset_id, params=params) + dataset_id, _ = add_documents + res = list_documnets(get_http_api_auth, dataset_id, params=params) assert res["code"] == expected_code if expected_code == 0: if callable(assertions): @@ -298,9 +294,9 @@ class TestDocumentList: ({"keywords": "unknown"}, 0), ], ) - def test_keywords(self, get_http_api_auth, get_dataset_id_and_document_ids, params, expected_num): - dataset_id, _ = get_dataset_id_and_document_ids - res = list_documnet(get_http_api_auth, dataset_id, params=params) + def test_keywords(self, get_http_api_auth, add_documents, params, expected_num): + dataset_id, _ = add_documents + res = list_documnets(get_http_api_auth, dataset_id, params=params) assert res["code"] == 0 assert len(res["data"]["docs"]) == expected_num assert res["data"]["total"] == expected_num @@ -322,14 +318,14 @@ class TestDocumentList: def test_name( self, get_http_api_auth, - get_dataset_id_and_document_ids, + add_documents, params, expected_code, expected_num, expected_message, ): - dataset_id, _ = get_dataset_id_and_document_ids - res = list_documnet(get_http_api_auth, dataset_id, params=params) + dataset_id, _ = add_documents + res = list_documnets(get_http_api_auth, dataset_id, params=params) assert res["code"] == expected_code if expected_code == 0: if params["name"] in [None, ""]: @@ -351,18 +347,18 @@ class TestDocumentList: def test_id( self, get_http_api_auth, - get_dataset_id_and_document_ids, + add_documents, document_id, expected_code, expected_num, expected_message, ): - dataset_id, document_ids = get_dataset_id_and_document_ids + dataset_id, document_ids = add_documents if callable(document_id): params = {"id": document_id(document_ids)} else: params = {"id": document_id} - res = list_documnet(get_http_api_auth, dataset_id, params=params) + res = list_documnets(get_http_api_auth, dataset_id, params=params) assert res["code"] == expected_code if expected_code == 0: @@ -391,36 +387,36 @@ class TestDocumentList: def test_name_and_id( self, get_http_api_auth, - get_dataset_id_and_document_ids, + add_documents, document_id, name, expected_code, expected_num, expected_message, ): - dataset_id, document_ids = get_dataset_id_and_document_ids + dataset_id, document_ids = add_documents if callable(document_id): params = {"id": document_id(document_ids), "name": name} else: params = {"id": document_id, "name": name} - res = list_documnet(get_http_api_auth, dataset_id, params=params) + res = list_documnets(get_http_api_auth, dataset_id, params=params) if expected_code == 0: assert len(res["data"]["docs"]) == expected_num else: assert res["message"] == expected_message - def test_concurrent_list(self, get_http_api_auth, get_dataset_id_and_document_ids): - dataset_id, _ = get_dataset_id_and_document_ids + def test_concurrent_list(self, get_http_api_auth, add_documents): + dataset_id, _ = add_documents with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(list_documnet, get_http_api_auth, dataset_id) for i in range(100)] + futures = [executor.submit(list_documnets, get_http_api_auth, dataset_id) for i in range(100)] responses = [f.result() for f in futures] assert all(r["code"] == 0 for r in responses) - def test_invalid_params(self, get_http_api_auth, get_dataset_id_and_document_ids): - dataset_id, _ = get_dataset_id_and_document_ids + def test_invalid_params(self, get_http_api_auth, add_documents): + dataset_id, _ = add_documents params = {"a": "b"} - res = list_documnet(get_http_api_auth, dataset_id, params=params) + res = list_documnets(get_http_api_auth, dataset_id, params=params) assert res["code"] == 0 assert len(res["data"]["docs"]) == 5 diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_parse_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_parse_documents.py index 1c84bc5c9..83fa40cf9 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_parse_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_parse_documents.py @@ -16,20 +16,14 @@ from concurrent.futures import ThreadPoolExecutor import pytest -from common import ( - INVALID_API_TOKEN, - batch_create_datasets, - bulk_upload_documents, - list_documnet, - parse_documnet, -) +from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets from libs.auth import RAGFlowHttpApiAuth from libs.utils import wait_for def validate_document_details(auth, dataset_id, document_ids): for document_id in document_ids: - res = list_documnet(auth, dataset_id, params={"id": document_id}) + res = list_documnets(auth, dataset_id, params={"id": document_id}) doc = res["data"]["docs"][0] assert doc["run"] == "DONE" assert len(doc["process_begin_at"]) > 0 @@ -50,14 +44,12 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message): - dataset_id, document_ids = get_dataset_id_and_document_ids - res = parse_documnet(auth, dataset_id, {"document_ids": document_ids}) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = parse_documnets(auth, "dataset_id") assert res["code"] == expected_code assert res["message"] == expected_message -@pytest.mark.usefixtures("clear_datasets") class TestDocumentsParse: @pytest.mark.parametrize( "payload, expected_code, expected_message", @@ -89,21 +81,19 @@ class TestDocumentsParse: (lambda r: {"document_ids": r}, 0, ""), ], ) - def test_basic_scenarios(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message): + def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message): @wait_for(10, 1, "Document parsing timeout") def condition(_auth, _dataset_id, _document_ids): for _document_id in _document_ids: - res = list_documnet(_auth, _dataset_id, {"id": _document_id}) + res = list_documnets(_auth, _dataset_id, {"id": _document_id}) if res["data"]["docs"][0]["run"] != "DONE": return False return True - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path) + dataset_id, document_ids = add_documents_func if callable(payload): payload = payload(document_ids) - res = parse_documnet(get_http_api_auth, dataset_id, payload) + res = parse_documnets(get_http_api_auth, dataset_id, payload) assert res["code"] == expected_code if expected_code != 0: assert res["message"] == expected_message @@ -125,14 +115,13 @@ class TestDocumentsParse: def test_invalid_dataset_id( self, get_http_api_auth, - tmp_path, + add_documents_func, dataset_id, expected_code, expected_message, ): - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) - res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + _, document_ids = add_documents_func + res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) assert res["code"] == expected_code assert res["message"] == expected_message @@ -144,21 +133,19 @@ class TestDocumentsParse: lambda r: {"document_ids": r + ["invalid_id"]}, ], ) - def test_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload): + def test_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload): @wait_for(10, 1, "Document parsing timeout") def condition(_auth, _dataset_id): - res = list_documnet(_auth, _dataset_id) + res = list_documnets(_auth, _dataset_id) for doc in res["data"]["docs"]: if doc["run"] != "DONE": return False return True - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path) + dataset_id, document_ids = add_documents_func if callable(payload): payload = payload(document_ids) - res = parse_documnet(get_http_api_auth, dataset_id, payload) + res = parse_documnets(get_http_api_auth, dataset_id, payload) assert res["code"] == 102 assert res["message"] == "Documents not found: ['invalid_id']" @@ -166,96 +153,92 @@ class TestDocumentsParse: validate_document_details(get_http_api_auth, dataset_id, document_ids) - def test_repeated_parse(self, get_http_api_auth, tmp_path): + def test_repeated_parse(self, get_http_api_auth, add_documents_func): @wait_for(10, 1, "Document parsing timeout") def condition(_auth, _dataset_id): - res = list_documnet(_auth, _dataset_id) + res = list_documnets(_auth, _dataset_id) for doc in res["data"]["docs"]: if doc["run"] != "DONE": return False return True - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path) - res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + dataset_id, document_ids = add_documents_func + res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) assert res["code"] == 0 condition(get_http_api_auth, dataset_id) - res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) assert res["code"] == 0 - def test_duplicate_parse(self, get_http_api_auth, tmp_path): + def test_duplicate_parse(self, get_http_api_auth, add_documents_func): @wait_for(10, 1, "Document parsing timeout") def condition(_auth, _dataset_id): - res = list_documnet(_auth, _dataset_id) + res = list_documnets(_auth, _dataset_id) for doc in res["data"]["docs"]: if doc["run"] != "DONE": return False return True - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path) - res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) + dataset_id, document_ids = add_documents_func + res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) assert res["code"] == 0 - assert res["data"]["errors"][0] == f"Duplicate document ids: {document_ids[0]}" - assert res["data"]["success_count"] == 1 + assert "Duplicate document ids" in res["data"]["errors"][0] + assert res["data"]["success_count"] == 3 condition(get_http_api_auth, dataset_id) validate_document_details(get_http_api_auth, dataset_id, document_ids) - @pytest.mark.slow - def test_parse_100_files(self, get_http_api_auth, tmp_path): - @wait_for(100, 1, "Document parsing timeout") - def condition(_auth, _dataset_id, _document_num): - res = list_documnet(_auth, _dataset_id, {"page_size": _document_num}) - for doc in res["data"]["docs"]: - if doc["run"] != "DONE": - return False - return True - document_num = 100 - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) - res = parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) - assert res["code"] == 0 +@pytest.mark.slow +def test_parse_100_files(get_http_api_auth, add_datase_func, tmp_path): + @wait_for(100, 1, "Document parsing timeout") + def condition(_auth, _dataset_id, _document_num): + res = list_documnets(_auth, _dataset_id, {"page_size": _document_num}) + for doc in res["data"]["docs"]: + if doc["run"] != "DONE": + return False + return True - condition(get_http_api_auth, dataset_id, document_num) + document_num = 100 + dataset_id = add_datase_func + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) + res = parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0 - validate_document_details(get_http_api_auth, dataset_id, document_ids) + condition(get_http_api_auth, dataset_id, document_num) - @pytest.mark.slow - def test_concurrent_parse(self, get_http_api_auth, tmp_path): - @wait_for(120, 1, "Document parsing timeout") - def condition(_auth, _dataset_id, _document_num): - res = list_documnet(_auth, _dataset_id, {"page_size": _document_num}) - for doc in res["data"]["docs"]: - if doc["run"] != "DONE": - return False - return True + validate_document_details(get_http_api_auth, dataset_id, document_ids) - document_num = 100 - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit( - parse_documnet, - get_http_api_auth, - dataset_id, - {"document_ids": document_ids[i : i + 1]}, - ) - for i in range(document_num) - ] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) +@pytest.mark.slow +def test_concurrent_parse(get_http_api_auth, add_datase_func, tmp_path): + @wait_for(120, 1, "Document parsing timeout") + def condition(_auth, _dataset_id, _document_num): + res = list_documnets(_auth, _dataset_id, {"page_size": _document_num}) + for doc in res["data"]["docs"]: + if doc["run"] != "DONE": + return False + return True - condition(get_http_api_auth, dataset_id, document_num) + document_num = 100 + dataset_id = add_datase_func + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) - validate_document_details(get_http_api_auth, dataset_id, document_ids) + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + parse_documnets, + get_http_api_auth, + dataset_id, + {"document_ids": document_ids[i : i + 1]}, + ) + for i in range(document_num) + ] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses) + + condition(get_http_api_auth, dataset_id, document_num) + + validate_document_details(get_http_api_auth, dataset_id, document_ids) diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py index 264ab486e..f425d5c03 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_stop_parse_documents.py @@ -16,21 +16,14 @@ from concurrent.futures import ThreadPoolExecutor import pytest -from common import ( - INVALID_API_TOKEN, - batch_create_datasets, - bulk_upload_documents, - list_documnet, - parse_documnet, - stop_parse_documnet, -) +from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets from libs.auth import RAGFlowHttpApiAuth from libs.utils import wait_for def validate_document_parse_done(auth, dataset_id, document_ids): for document_id in document_ids: - res = list_documnet(auth, dataset_id, params={"id": document_id}) + res = list_documnets(auth, dataset_id, params={"id": document_id}) doc = res["data"]["docs"][0] assert doc["run"] == "DONE" assert len(doc["process_begin_at"]) > 0 @@ -41,14 +34,13 @@ def validate_document_parse_done(auth, dataset_id, document_ids): def validate_document_parse_cancel(auth, dataset_id, document_ids): for document_id in document_ids: - res = list_documnet(auth, dataset_id, params={"id": document_id}) + res = list_documnets(auth, dataset_id, params={"id": document_id}) doc = res["data"]["docs"][0] assert doc["run"] == "CANCEL" assert len(doc["process_begin_at"]) > 0 assert doc["progress"] == 0.0 -@pytest.mark.usefixtures("clear_datasets") class TestAuthorization: @pytest.mark.parametrize( "auth, expected_code, expected_message", @@ -61,15 +53,13 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message): - ids = batch_create_datasets(get_http_api_auth, 1) - res = stop_parse_documnet(auth, ids[0]) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = stop_parse_documnets(auth, "dataset_id") assert res["code"] == expected_code assert res["message"] == expected_message @pytest.mark.skip -@pytest.mark.usefixtures("clear_datasets") class TestDocumentsParseStop: @pytest.mark.parametrize( "payload, expected_code, expected_message", @@ -101,24 +91,22 @@ class TestDocumentsParseStop: (lambda r: {"document_ids": r}, 0, ""), ], ) - def test_basic_scenarios(self, get_http_api_auth, tmp_path, payload, expected_code, expected_message): + def test_basic_scenarios(self, get_http_api_auth, add_documents_func, payload, expected_code, expected_message): @wait_for(10, 1, "Document parsing timeout") def condition(_auth, _dataset_id, _document_ids): for _document_id in _document_ids: - res = list_documnet(_auth, _dataset_id, {"id": _document_id}) + res = list_documnets(_auth, _dataset_id, {"id": _document_id}) if res["data"]["docs"][0]["run"] != "DONE": return False return True - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path) - parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + dataset_id, document_ids = add_documents_func + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) if callable(payload): payload = payload(document_ids) - res = stop_parse_documnet(get_http_api_auth, dataset_id, payload) + res = stop_parse_documnets(get_http_api_auth, dataset_id, payload) assert res["code"] == expected_code if expected_code != 0: assert res["message"] == expected_message @@ -129,7 +117,7 @@ class TestDocumentsParseStop: validate_document_parse_done(get_http_api_auth, dataset_id, completed_document_ids) @pytest.mark.parametrize( - "dataset_id, expected_code, expected_message", + "invalid_dataset_id, expected_code, expected_message", [ ("", 100, ""), ( @@ -142,14 +130,14 @@ class TestDocumentsParseStop: def test_invalid_dataset_id( self, get_http_api_auth, - tmp_path, - dataset_id, + add_documents_func, + invalid_dataset_id, expected_code, expected_message, ): - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) - res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + dataset_id, document_ids = add_documents_func + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + res = stop_parse_documnets(get_http_api_auth, invalid_dataset_id, {"document_ids": document_ids}) assert res["code"] == expected_code assert res["message"] == expected_message @@ -162,71 +150,65 @@ class TestDocumentsParseStop: lambda r: {"document_ids": r + ["invalid_id"]}, ], ) - def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, tmp_path, payload): - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 3, tmp_path) - parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + def test_stop_parse_partial_invalid_document_id(self, get_http_api_auth, add_documents_func, payload): + dataset_id, document_ids = add_documents_func + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) if callable(payload): payload = payload(document_ids) - res = stop_parse_documnet(get_http_api_auth, dataset_id, payload) + res = stop_parse_documnets(get_http_api_auth, dataset_id, payload) assert res["code"] == 102 assert res["message"] == "You don't own the document invalid_id." validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) - def test_repeated_stop_parse(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path) - parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) - res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + def test_repeated_stop_parse(self, get_http_api_auth, add_documents_func): + dataset_id, document_ids = add_documents_func + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) assert res["code"] == 0 - res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) assert res["code"] == 102 assert res["message"] == "Can't stop parsing document with progress at 0 or 1" - def test_duplicate_stop_parse(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, 1, tmp_path) - parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) - res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) + def test_duplicate_stop_parse(self, get_http_api_auth, add_documents_func): + dataset_id, document_ids = add_documents_func + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids + document_ids}) assert res["code"] == 0 - assert res["data"]["success_count"] == 1 + assert res["data"]["success_count"] == 3 assert f"Duplicate document ids: {document_ids[0]}" in res["data"]["errors"] - @pytest.mark.slow - def test_stop_parse_100_files(self, get_http_api_auth, tmp_path): - document_num = 100 - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) - parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) - res = stop_parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) - assert res["code"] == 0 - validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) - @pytest.mark.slow - def test_concurrent_parse(self, get_http_api_auth, tmp_path): - document_num = 50 - ids = batch_create_datasets(get_http_api_auth, 1) - dataset_id = ids[0] - document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) - parse_documnet(get_http_api_auth, dataset_id, {"document_ids": document_ids}) +@pytest.mark.slow +def test_stop_parse_100_files(get_http_api_auth, add_datase_func, tmp_path): + document_num = 100 + dataset_id = add_datase_func + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0 + validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [ - executor.submit( - stop_parse_documnet, - get_http_api_auth, - dataset_id, - {"document_ids": document_ids[i : i + 1]}, - ) - for i in range(document_num) - ] - responses = [f.result() for f in futures] - assert all(r["code"] == 0 for r in responses) - validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) + +@pytest.mark.slow +def test_concurrent_parse(get_http_api_auth, add_datase_func, tmp_path): + document_num = 50 + dataset_id = add_datase_func + document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path) + parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids}) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + stop_parse_documnets, + get_http_api_auth, + dataset_id, + {"document_ids": document_ids[i : i + 1]}, + ) + for i in range(document_num) + ] + responses = [f.result() for f in futures] + assert all(r["code"] == 0 for r in responses) + validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids) diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_update_document.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_update_document.py index 3096e5832..56814bc80 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_update_document.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_update_document.py @@ -16,7 +16,7 @@ import pytest -from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, batch_create_datasets, bulk_upload_documents, list_documnet, update_documnet +from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, list_documnets, update_documnet from libs.auth import RAGFlowHttpApiAuth @@ -32,14 +32,13 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_dataset_id_and_document_ids, auth, expected_code, expected_message): - dataset_id, document_ids = get_dataset_id_and_document_ids - res = update_documnet(auth, dataset_id, document_ids[0], {"name": "auth_test.txt"}) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = update_documnet(auth, "dataset_id", "document_id") assert res["code"] == expected_code assert res["message"] == expected_message -class TestUpdatedDocument: +class TestDocumentsUpdated: @pytest.mark.parametrize( "name, expected_code, expected_message", [ @@ -81,12 +80,12 @@ class TestUpdatedDocument: ), ], ) - def test_name(self, get_http_api_auth, get_dataset_id_and_document_ids, name, expected_code, expected_message): - dataset_id, document_ids = get_dataset_id_and_document_ids + def test_name(self, get_http_api_auth, add_documents, name, expected_code, expected_message): + dataset_id, document_ids = add_documents res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": name}) assert res["code"] == expected_code if expected_code == 0: - res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]}) + res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) assert res["data"]["docs"][0]["name"] == name else: assert res["message"] == expected_message @@ -102,8 +101,8 @@ class TestUpdatedDocument: ), ], ) - def test_invalid_document_id(self, get_http_api_auth, get_dataset_id_and_document_ids, document_id, expected_code, expected_message): - dataset_id, _ = get_dataset_id_and_document_ids + def test_invalid_document_id(self, get_http_api_auth, add_documents, document_id, expected_code, expected_message): + dataset_id, _ = add_documents res = update_documnet(get_http_api_auth, dataset_id, document_id, {"name": "new_name.txt"}) assert res["code"] == expected_code assert res["message"] == expected_message @@ -119,8 +118,8 @@ class TestUpdatedDocument: ), ], ) - def test_invalid_dataset_id(self, get_http_api_auth, get_dataset_id_and_document_ids, dataset_id, expected_code, expected_message): - _, document_ids = get_dataset_id_and_document_ids + def test_invalid_dataset_id(self, get_http_api_auth, add_documents, dataset_id, expected_code, expected_message): + _, document_ids = add_documents res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"name": "new_name.txt"}) assert res["code"] == expected_code assert res["message"] == expected_message @@ -129,11 +128,11 @@ class TestUpdatedDocument: "meta_fields, expected_code, expected_message", [({"test": "test"}, 0, ""), ("test", 102, "meta_fields must be a dictionary")], ) - def test_meta_fields(self, get_http_api_auth, get_dataset_id_and_document_ids, meta_fields, expected_code, expected_message): - dataset_id, document_ids = get_dataset_id_and_document_ids + def test_meta_fields(self, get_http_api_auth, add_documents, meta_fields, expected_code, expected_message): + dataset_id, document_ids = add_documents res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"meta_fields": meta_fields}) if expected_code == 0: - res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]}) + res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) assert res["data"]["docs"][0]["meta_fields"] == meta_fields else: assert res["message"] == expected_message @@ -162,12 +161,12 @@ class TestUpdatedDocument: ), ], ) - def test_chunk_method(self, get_http_api_auth, get_dataset_id_and_document_ids, chunk_method, expected_code, expected_message): - dataset_id, document_ids = get_dataset_id_and_document_ids + def test_chunk_method(self, get_http_api_auth, add_documents, chunk_method, expected_code, expected_message): + dataset_id, document_ids = add_documents res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], {"chunk_method": chunk_method}) assert res["code"] == expected_code if expected_code == 0: - res = list_documnet(get_http_api_auth, dataset_id, {"id": document_ids[0]}) + res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) if chunk_method != "": assert res["data"]["docs"][0]["chunk_method"] == chunk_method else: @@ -282,259 +281,259 @@ class TestUpdatedDocument: def test_invalid_field( self, get_http_api_auth, - get_dataset_id_and_document_ids, + add_documents, payload, expected_code, expected_message, ): - dataset_id, document_ids = get_dataset_id_and_document_ids + dataset_id, document_ids = add_documents res = update_documnet(get_http_api_auth, dataset_id, document_ids[0], payload) assert res["code"] == expected_code assert res["message"] == expected_message -@pytest.mark.usefixtures("clear_datasets") -@pytest.mark.parametrize( - "chunk_method, parser_config, expected_code, expected_message", - [ - ("naive", {}, 0, ""), - ( - "naive", - { - "chunk_token_num": 128, - "layout_recognize": "DeepDOC", - "html4excel": False, - "delimiter": "\\n!?;。;!?", - "task_page_size": 12, - "raptor": {"use_raptor": False}, - }, - 0, - "", - ), - pytest.param( - "naive", - {"chunk_token_num": -1}, - 100, - "AssertionError('chunk_token_num should be in range from 1 to 100000000')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"chunk_token_num": 0}, - 100, - "AssertionError('chunk_token_num should be in range from 1 to 100000000')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"chunk_token_num": 100000000}, - 100, - "AssertionError('chunk_token_num should be in range from 1 to 100000000')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"chunk_token_num": 3.14}, - 102, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"chunk_token_num": "1024"}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - ( - "naive", - {"layout_recognize": "DeepDOC"}, - 0, - "", - ), - ( - "naive", - {"layout_recognize": "Naive"}, - 0, - "", - ), - ("naive", {"html4excel": True}, 0, ""), - ("naive", {"html4excel": False}, 0, ""), - pytest.param( - "naive", - {"html4excel": 1}, - 100, - "AssertionError('html4excel should be True or False')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - ("naive", {"delimiter": ""}, 0, ""), - ("naive", {"delimiter": "`##`"}, 0, ""), - pytest.param( - "naive", - {"delimiter": 1}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"task_page_size": -1}, - 100, - "AssertionError('task_page_size should be in range from 1 to 100000000')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"task_page_size": 0}, - 100, - "AssertionError('task_page_size should be in range from 1 to 100000000')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"task_page_size": 100000000}, - 100, - "AssertionError('task_page_size should be in range from 1 to 100000000')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"task_page_size": 3.14}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"task_page_size": "1024"}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - ("naive", {"raptor": {"use_raptor": True}}, 0, ""), - ("naive", {"raptor": {"use_raptor": False}}, 0, ""), - pytest.param( - "naive", - {"invalid_key": "invalid_value"}, - 100, - """AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"auto_keywords": -1}, - 100, - "AssertionError('auto_keywords should be in range from 0 to 32')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"auto_keywords": 32}, - 100, - "AssertionError('auto_keywords should be in range from 0 to 32')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"auto_questions": 3.14}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"auto_keywords": "1024"}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"auto_questions": -1}, - 100, - "AssertionError('auto_questions should be in range from 0 to 10')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"auto_questions": 10}, - 100, - "AssertionError('auto_questions should be in range from 0 to 10')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"auto_questions": 3.14}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"auto_questions": "1024"}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"topn_tags": -1}, - 100, - "AssertionError('topn_tags should be in range from 0 to 10')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"topn_tags": 10}, - 100, - "AssertionError('topn_tags should be in range from 0 to 10')", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"topn_tags": 3.14}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - pytest.param( - "naive", - {"topn_tags": "1024"}, - 100, - "", - marks=pytest.mark.skip(reason="issues/6098"), - ), - ], -) -def test_parser_config( - get_http_api_auth, - tmp_path, - chunk_method, - parser_config, - expected_code, - expected_message, -): - ids = batch_create_datasets(get_http_api_auth, 1) - document_ids = bulk_upload_documents(get_http_api_auth, ids[0], 1, tmp_path) - res = update_documnet( - get_http_api_auth, - ids[0], - document_ids[0], - {"chunk_method": chunk_method, "parser_config": parser_config}, +class TestUpdateDocumentParserConfig: + @pytest.mark.parametrize( + "chunk_method, parser_config, expected_code, expected_message", + [ + ("naive", {}, 0, ""), + ( + "naive", + { + "chunk_token_num": 128, + "layout_recognize": "DeepDOC", + "html4excel": False, + "delimiter": "\\n!?;。;!?", + "task_page_size": 12, + "raptor": {"use_raptor": False}, + }, + 0, + "", + ), + pytest.param( + "naive", + {"chunk_token_num": -1}, + 100, + "AssertionError('chunk_token_num should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 0}, + 100, + "AssertionError('chunk_token_num should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 100000000}, + 100, + "AssertionError('chunk_token_num should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": 3.14}, + 102, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"chunk_token_num": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ( + "naive", + {"layout_recognize": "DeepDOC"}, + 0, + "", + ), + ( + "naive", + {"layout_recognize": "Naive"}, + 0, + "", + ), + ("naive", {"html4excel": True}, 0, ""), + ("naive", {"html4excel": False}, 0, ""), + pytest.param( + "naive", + {"html4excel": 1}, + 100, + "AssertionError('html4excel should be True or False')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ("naive", {"delimiter": ""}, 0, ""), + ("naive", {"delimiter": "`##`"}, 0, ""), + pytest.param( + "naive", + {"delimiter": 1}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": -1}, + 100, + "AssertionError('task_page_size should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 0}, + 100, + "AssertionError('task_page_size should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 100000000}, + 100, + "AssertionError('task_page_size should be in range from 1 to 100000000')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": 3.14}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"task_page_size": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ("naive", {"raptor": {"use_raptor": True}}, 0, ""), + ("naive", {"raptor": {"use_raptor": False}}, 0, ""), + pytest.param( + "naive", + {"invalid_key": "invalid_value"}, + 100, + """AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": -1}, + 100, + "AssertionError('auto_keywords should be in range from 0 to 32')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": 32}, + 100, + "AssertionError('auto_keywords should be in range from 0 to 32')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": 3.14}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_keywords": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": -1}, + 100, + "AssertionError('auto_questions should be in range from 0 to 10')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": 10}, + 100, + "AssertionError('auto_questions should be in range from 0 to 10')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": 3.14}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"auto_questions": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": -1}, + 100, + "AssertionError('topn_tags should be in range from 0 to 10')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": 10}, + 100, + "AssertionError('topn_tags should be in range from 0 to 10')", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": 3.14}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + pytest.param( + "naive", + {"topn_tags": "1024"}, + 100, + "", + marks=pytest.mark.skip(reason="issues/6098"), + ), + ], ) - assert res["code"] == expected_code - if expected_code == 0: - res = list_documnet(get_http_api_auth, ids[0], {"id": document_ids[0]}) - if parser_config != {}: - for k, v in parser_config.items(): - assert res["data"]["docs"][0]["parser_config"][k] == v - else: - assert res["data"]["docs"][0]["parser_config"] == { - "chunk_token_num": 128, - "delimiter": "\\n!?;。;!?", - "html4excel": False, - "layout_recognize": "DeepDOC", - "raptor": {"use_raptor": False}, - } - if expected_code != 0 or expected_message: - assert res["message"] == expected_message + def test_parser_config( + self, + get_http_api_auth, + add_documents, + chunk_method, + parser_config, + expected_code, + expected_message, + ): + dataset_id, document_ids = add_documents + res = update_documnet( + get_http_api_auth, + dataset_id, + document_ids[0], + {"chunk_method": chunk_method, "parser_config": parser_config}, + ) + assert res["code"] == expected_code + if expected_code == 0: + res = list_documnets(get_http_api_auth, dataset_id, {"id": document_ids[0]}) + if parser_config != {}: + for k, v in parser_config.items(): + assert res["data"]["docs"][0]["parser_config"][k] == v + else: + assert res["data"]["docs"][0]["parser_config"] == { + "chunk_token_num": 128, + "delimiter": "\\n!?;。;!?", + "html4excel": False, + "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, + } + if expected_code != 0 or expected_message: + assert res["message"] == expected_message diff --git a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_upload_documents.py b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_upload_documents.py index 443364afc..bf8576f4d 100644 --- a/sdk/python/test/test_http_api/test_file_management_within_dataset/test_upload_documents.py +++ b/sdk/python/test/test_http_api/test_file_management_within_dataset/test_upload_documents.py @@ -19,15 +19,7 @@ from concurrent.futures import ThreadPoolExecutor import pytest import requests -from common import ( - DOCUMENT_NAME_LIMIT, - FILE_API_URL, - HOST_ADDRESS, - INVALID_API_TOKEN, - batch_create_datasets, - list_dataset, - upload_documnets, -) +from common import DOCUMENT_NAME_LIMIT, FILE_API_URL, HOST_ADDRESS, INVALID_API_TOKEN, list_datasets, upload_documnets from libs.auth import RAGFlowHttpApiAuth from libs.utils.file_utils import create_txt_file from requests_toolbelt import MultipartEncoder @@ -46,21 +38,19 @@ class TestAuthorization: ), ], ) - def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message): - ids = batch_create_datasets(get_http_api_auth, 1) - res = upload_documnets(auth, ids[0]) + def test_invalid_auth(self, auth, expected_code, expected_message): + res = upload_documnets(auth, "dataset_id") assert res["code"] == expected_code assert res["message"] == expected_message -@pytest.mark.usefixtures("clear_datasets") -class TestUploadDocuments: - def test_valid_single_upload(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) +class TestDocumentsUpload: + def test_valid_single_upload(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func fp = create_txt_file(tmp_path / "ragflow_test.txt") - res = upload_documnets(get_http_api_auth, ids[0], [fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp]) assert res["code"] == 0 - assert res["data"][0]["dataset_id"] == ids[0] + assert res["data"][0]["dataset_id"] == dataset_id assert res["data"][0]["name"] == fp.name @pytest.mark.parametrize( @@ -79,45 +69,45 @@ class TestUploadDocuments: ], indirect=True, ) - def test_file_type_validation(self, get_http_api_auth, generate_test_files, request): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_file_type_validation(self, get_http_api_auth, add_dataset_func, generate_test_files, request): + dataset_id = add_dataset_func fp = generate_test_files[request.node.callspec.params["generate_test_files"]] - res = upload_documnets(get_http_api_auth, ids[0], [fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp]) assert res["code"] == 0 - assert res["data"][0]["dataset_id"] == ids[0] + assert res["data"][0]["dataset_id"] == dataset_id assert res["data"][0]["name"] == fp.name @pytest.mark.parametrize( "file_type", ["exe", "unknown"], ) - def test_unsupported_file_type(self, get_http_api_auth, tmp_path, file_type): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_unsupported_file_type(self, get_http_api_auth, add_dataset_func, tmp_path, file_type): + dataset_id = add_dataset_func fp = tmp_path / f"ragflow_test.{file_type}" fp.touch() - res = upload_documnets(get_http_api_auth, ids[0], [fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp]) assert res["code"] == 500 assert res["message"] == f"ragflow_test.{file_type}: This type of file has not been supported yet!" - def test_missing_file(self, get_http_api_auth): - ids = batch_create_datasets(get_http_api_auth, 1) - res = upload_documnets(get_http_api_auth, ids[0]) + def test_missing_file(self, get_http_api_auth, add_dataset_func): + dataset_id = add_dataset_func + res = upload_documnets(get_http_api_auth, dataset_id) assert res["code"] == 101 assert res["message"] == "No file part!" - def test_empty_file(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_empty_file(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func fp = tmp_path / "empty.txt" fp.touch() - res = upload_documnets(get_http_api_auth, ids[0], [fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp]) assert res["code"] == 0 assert res["data"][0]["size"] == 0 - def test_filename_empty(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_filename_empty(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func fp = create_txt_file(tmp_path / "ragflow_test.txt") - url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=ids[0]) + url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id) fields = (("file", ("", fp.open("rb"))),) m = MultipartEncoder(fields=fields) res = requests.post( @@ -129,11 +119,11 @@ class TestUploadDocuments: assert res.json()["code"] == 101 assert res.json()["message"] == "No file selected!" - def test_filename_exceeds_max_length(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_filename_exceeds_max_length(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func # filename_length = 129 fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 3)}.txt") - res = upload_documnets(get_http_api_auth, ids[0], [fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp]) assert res["code"] == 101 assert res["message"] == "File name should be less than 128 bytes." @@ -143,61 +133,61 @@ class TestUploadDocuments: assert res["code"] == 100 assert res["message"] == """LookupError("Can\'t find the dataset with ID invalid_dataset_id!")""" - def test_duplicate_files(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_duplicate_files(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func fp = create_txt_file(tmp_path / "ragflow_test.txt") - res = upload_documnets(get_http_api_auth, ids[0], [fp, fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp, fp]) assert res["code"] == 0 assert len(res["data"]) == 2 for i in range(len(res["data"])): - assert res["data"][i]["dataset_id"] == ids[0] + assert res["data"][i]["dataset_id"] == dataset_id expected_name = fp.name if i != 0: expected_name = f"{fp.stem}({i}){fp.suffix}" assert res["data"][i]["name"] == expected_name - def test_same_file_repeat(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_same_file_repeat(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func fp = create_txt_file(tmp_path / "ragflow_test.txt") for i in range(10): - res = upload_documnets(get_http_api_auth, ids[0], [fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp]) assert res["code"] == 0 assert len(res["data"]) == 1 - assert res["data"][0]["dataset_id"] == ids[0] + assert res["data"][0]["dataset_id"] == dataset_id expected_name = fp.name if i != 0: expected_name = f"{fp.stem}({i}){fp.suffix}" assert res["data"][0]["name"] == expected_name - def test_filename_special_characters(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_filename_special_characters(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func illegal_chars = '<>:"/\\|?*' translation_table = str.maketrans({char: "_" for char in illegal_chars}) safe_filename = string.punctuation.translate(translation_table) fp = tmp_path / f"{safe_filename}.txt" fp.write_text("Sample text content") - res = upload_documnets(get_http_api_auth, ids[0], [fp]) + res = upload_documnets(get_http_api_auth, dataset_id, [fp]) assert res["code"] == 0 assert len(res["data"]) == 1 - assert res["data"][0]["dataset_id"] == ids[0] + assert res["data"][0]["dataset_id"] == dataset_id assert res["data"][0]["name"] == fp.name - def test_multiple_files(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_multiple_files(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func expected_document_count = 20 fps = [] for i in range(expected_document_count): fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt") fps.append(fp) - res = upload_documnets(get_http_api_auth, ids[0], fps) + res = upload_documnets(get_http_api_auth, dataset_id, fps) assert res["code"] == 0 - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) assert res["data"][0]["document_count"] == expected_document_count - def test_concurrent_upload(self, get_http_api_auth, tmp_path): - ids = batch_create_datasets(get_http_api_auth, 1) + def test_concurrent_upload(self, get_http_api_auth, add_dataset_func, tmp_path): + dataset_id = add_dataset_func expected_document_count = 20 fps = [] @@ -206,9 +196,9 @@ class TestUploadDocuments: fps.append(fp) with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(upload_documnets, get_http_api_auth, ids[0], fps[i : i + 1]) for i in range(expected_document_count)] + futures = [executor.submit(upload_documnets, get_http_api_auth, dataset_id, fps[i : i + 1]) for i in range(expected_document_count)] responses = [f.result() for f in futures] assert all(r["code"] == 0 for r in responses) - res = list_dataset(get_http_api_auth, {"id": ids[0]}) + res = list_datasets(get_http_api_auth, {"id": dataset_id}) assert res["data"][0]["document_count"] == expected_document_count