mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Test: Refactor test concurrency handling and add SDK chunk management tests (#8112)
### What problem does this PR solve? - Improve concurrent test cases by using as_completed for better reliability - Rename variables for clarity (chunk_num -> count) - Add new SDK API test suite for chunk management operations - Update HTTP API tests with consistent concurrency patterns ### Type of change - [x] Add test cases - [x] Refactoring
This commit is contained in:
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from common import INVALID_API_TOKEN, add_chunk, delete_documents, list_chunks
|
from common import INVALID_API_TOKEN, add_chunk, delete_documents, list_chunks
|
||||||
@ -224,7 +224,7 @@ class TestAddChunk:
|
|||||||
|
|
||||||
@pytest.mark.skip(reason="issues/6411")
|
@pytest.mark.skip(reason="issues/6411")
|
||||||
def test_concurrent_add_chunk(self, api_key, add_document):
|
def test_concurrent_add_chunk(self, api_key, add_document):
|
||||||
chunk_num = 50
|
count = 50
|
||||||
dataset_id, document_id = add_document
|
dataset_id, document_id = add_document
|
||||||
res = list_chunks(api_key, dataset_id, document_id)
|
res = list_chunks(api_key, dataset_id, document_id)
|
||||||
if res["code"] != 0:
|
if res["code"] != 0:
|
||||||
@ -240,11 +240,12 @@ class TestAddChunk:
|
|||||||
document_id,
|
document_id,
|
||||||
{"content": f"chunk test {i}"},
|
{"content": f"chunk test {i}"},
|
||||||
)
|
)
|
||||||
for i in range(chunk_num)
|
for i in range(count)
|
||||||
]
|
]
|
||||||
responses = [f.result() for f in futures]
|
responses = list(as_completed(futures))
|
||||||
assert all(r["code"] == 0 for r in responses)
|
assert len(responses) == count, responses
|
||||||
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
res = list_chunks(api_key, dataset_id, document_id)
|
res = list_chunks(api_key, dataset_id, document_id)
|
||||||
if res["code"] != 0:
|
if res["code"] != 0:
|
||||||
assert False, res
|
assert False, res
|
||||||
assert res["data"]["doc"]["chunk_count"] == chunks_count + chunk_num
|
assert res["data"]["doc"]["chunk_count"] == chunks_count + count
|
||||||
|
|||||||
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from common import INVALID_API_TOKEN, batch_add_chunks, delete_chunks, list_chunks
|
from common import INVALID_API_TOKEN, batch_add_chunks, delete_chunks, list_chunks
|
||||||
@ -121,9 +121,9 @@ class TestChunksDeletion:
|
|||||||
|
|
||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
def test_concurrent_deletion(self, api_key, add_document):
|
def test_concurrent_deletion(self, api_key, add_document):
|
||||||
chunks_num = 100
|
count = 100
|
||||||
dataset_id, document_id = add_document
|
dataset_id, document_id = add_document
|
||||||
chunk_ids = batch_add_chunks(api_key, dataset_id, document_id, chunks_num)
|
chunk_ids = batch_add_chunks(api_key, dataset_id, document_id, count)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
futures = [
|
futures = [
|
||||||
@ -134,10 +134,11 @@ class TestChunksDeletion:
|
|||||||
document_id,
|
document_id,
|
||||||
{"chunk_ids": chunk_ids[i : i + 1]},
|
{"chunk_ids": chunk_ids[i : i + 1]},
|
||||||
)
|
)
|
||||||
for i in range(chunks_num)
|
for i in range(count)
|
||||||
]
|
]
|
||||||
responses = [f.result() for f in futures]
|
responses = list(as_completed(futures))
|
||||||
assert all(r["code"] == 0 for r in responses)
|
assert len(responses) == count, responses
|
||||||
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
def test_delete_1k(self, api_key, add_document):
|
def test_delete_1k(self, api_key, add_document):
|
||||||
|
|||||||
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks
|
from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks
|
||||||
@ -149,12 +149,12 @@ class TestChunksList:
|
|||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
def test_concurrent_list(self, api_key, add_chunks):
|
def test_concurrent_list(self, api_key, add_chunks):
|
||||||
dataset_id, document_id, _ = add_chunks
|
dataset_id, document_id, _ = add_chunks
|
||||||
|
count = 100
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
futures = [executor.submit(list_chunks, api_key, dataset_id, document_id) for i in range(100)]
|
futures = [executor.submit(list_chunks, api_key, dataset_id, document_id) for i in range(count)]
|
||||||
responses = [f.result() for f in futures]
|
responses = list(as_completed(futures))
|
||||||
assert all(r["code"] == 0 for r in responses)
|
assert len(responses) == count, responses
|
||||||
assert all(len(r["data"]["chunks"]) == 5 for r in responses)
|
assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures)
|
||||||
|
|
||||||
@pytest.mark.p1
|
@pytest.mark.p1
|
||||||
def test_default(self, api_key, add_document):
|
def test_default(self, api_key, add_document):
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from common import (
|
from common import (
|
||||||
@ -302,12 +303,12 @@ class TestChunksRetrieval:
|
|||||||
|
|
||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
def test_concurrent_retrieval(self, api_key, add_chunks):
|
def test_concurrent_retrieval(self, api_key, add_chunks):
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
dataset_id, _, _ = add_chunks
|
dataset_id, _, _ = add_chunks
|
||||||
|
count = 100
|
||||||
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
|
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
futures = [executor.submit(retrieval_chunks, api_key, payload) for i in range(100)]
|
futures = [executor.submit(retrieval_chunks, api_key, payload) for i in range(count)]
|
||||||
responses = [f.result() for f in futures]
|
responses = list(as_completed(futures))
|
||||||
assert all(r["code"] == 0 for r in responses)
|
assert len(responses) == count, responses
|
||||||
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|||||||
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -219,7 +219,7 @@ class TestUpdatedChunk:
|
|||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554")
|
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554")
|
||||||
def test_concurrent_update_chunk(self, api_key, add_chunks):
|
def test_concurrent_update_chunk(self, api_key, add_chunks):
|
||||||
chunk_num = 50
|
count = 50
|
||||||
dataset_id, document_id, chunk_ids = add_chunks
|
dataset_id, document_id, chunk_ids = add_chunks
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
@ -232,10 +232,11 @@ class TestUpdatedChunk:
|
|||||||
chunk_ids[randint(0, 3)],
|
chunk_ids[randint(0, 3)],
|
||||||
{"content": f"update chunk test {i}"},
|
{"content": f"update chunk test {i}"},
|
||||||
)
|
)
|
||||||
for i in range(chunk_num)
|
for i in range(count)
|
||||||
]
|
]
|
||||||
responses = [f.result() for f in futures]
|
responses = list(as_completed(futures))
|
||||||
assert all(r["code"] == 0 for r in responses)
|
assert len(responses) == count, responses
|
||||||
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
def test_update_chunk_to_deleted_document(self, api_key, add_chunks):
|
def test_update_chunk_to_deleted_document(self, api_key, add_chunks):
|
||||||
|
|||||||
@ -85,7 +85,7 @@ class TestCapability:
|
|||||||
futures = [executor.submit(create_dataset, api_key, {"name": f"dataset_{i}"}) for i in range(count)]
|
futures = [executor.submit(create_dataset, api_key, {"name": f"dataset_{i}"}) for i in range(count)]
|
||||||
responses = list(as_completed(futures))
|
responses = list(as_completed(futures))
|
||||||
assert len(responses) == count, responses
|
assert len(responses) == count, responses
|
||||||
assert all(futures.result()["code"] == 0 for futures in futures)
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("clear_datasets")
|
@pytest.mark.usefixtures("clear_datasets")
|
||||||
|
|||||||
@ -93,7 +93,7 @@ class TestCapability:
|
|||||||
futures = [executor.submit(delete_datasets, api_key, {"ids": ids[i : i + 1]}) for i in range(count)]
|
futures = [executor.submit(delete_datasets, api_key, {"ids": ids[i : i + 1]}) for i in range(count)]
|
||||||
responses = list(as_completed(futures))
|
responses = list(as_completed(futures))
|
||||||
assert len(responses) == count, responses
|
assert len(responses) == count, responses
|
||||||
assert all(futures.result()["code"] == 0 for futures in futures)
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetsDelete:
|
class TestDatasetsDelete:
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class TestCapability:
|
|||||||
futures = [executor.submit(list_datasets, api_key) for i in range(count)]
|
futures = [executor.submit(list_datasets, api_key) for i in range(count)]
|
||||||
responses = list(as_completed(futures))
|
responses = list(as_completed(futures))
|
||||||
assert len(responses) == count, responses
|
assert len(responses) == count, responses
|
||||||
assert all(futures.result()["code"] == 0 for futures in futures)
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("add_datasets")
|
@pytest.mark.usefixtures("add_datasets")
|
||||||
|
|||||||
@ -95,7 +95,7 @@ class TestCapability:
|
|||||||
futures = [executor.submit(update_dataset, api_key, dataset_id, {"name": f"dataset_{i}"}) for i in range(count)]
|
futures = [executor.submit(update_dataset, api_key, dataset_id, {"name": f"dataset_{i}"}) for i in range(count)]
|
||||||
responses = list(as_completed(futures))
|
responses = list(as_completed(futures))
|
||||||
assert len(responses) == count, responses
|
assert len(responses) == count, responses
|
||||||
assert all(futures.result()["code"] == 0 for futures in futures)
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetUpdate:
|
class TestDatasetUpdate:
|
||||||
|
|||||||
@ -15,7 +15,6 @@
|
|||||||
#
|
#
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documents, list_documents
|
from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documents, list_documents
|
||||||
from libs.auth import RAGFlowHttpApiAuth
|
from libs.auth import RAGFlowHttpApiAuth
|
||||||
@ -165,7 +164,7 @@ def test_concurrent_deletion(api_key, add_dataset, tmp_path):
|
|||||||
]
|
]
|
||||||
responses = list(as_completed(futures))
|
responses = list(as_completed(futures))
|
||||||
assert len(responses) == count, responses
|
assert len(responses) == count, responses
|
||||||
assert all(futures.result()["code"] == 0 for futures in futures)
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
|
|||||||
@ -348,7 +348,7 @@ class TestDocumentsList:
|
|||||||
futures = [executor.submit(list_documents, api_key, dataset_id) for i in range(count)]
|
futures = [executor.submit(list_documents, api_key, dataset_id) for i in range(count)]
|
||||||
responses = list(as_completed(futures))
|
responses = list(as_completed(futures))
|
||||||
assert len(responses) == count, responses
|
assert len(responses) == count, responses
|
||||||
assert all(futures.result()["code"] == 0 for futures in futures)
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
def test_invalid_params(self, api_key, add_documents):
|
def test_invalid_params(self, api_key, add_documents):
|
||||||
|
|||||||
@ -211,7 +211,7 @@ def test_concurrent_parse(api_key, add_dataset_func, tmp_path):
|
|||||||
]
|
]
|
||||||
responses = list(as_completed(futures))
|
responses = list(as_completed(futures))
|
||||||
assert len(responses) == count, responses
|
assert len(responses) == count, responses
|
||||||
assert all(futures.result()["code"] == 0 for futures in futures)
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
condition(api_key, dataset_id, count)
|
condition(api_key, dataset_id, count)
|
||||||
|
|
||||||
|
|||||||
@ -213,7 +213,7 @@ class TestDocumentsUpload:
|
|||||||
futures = [executor.submit(upload_documents, api_key, dataset_id, fps[i : i + 1]) for i in range(count)]
|
futures = [executor.submit(upload_documents, api_key, dataset_id, fps[i : i + 1]) for i in range(count)]
|
||||||
responses = list(as_completed(futures))
|
responses = list(as_completed(futures))
|
||||||
assert len(responses) == count, responses
|
assert len(responses) == count, responses
|
||||||
assert all(futures.result()["code"] == 0 for futures in futures)
|
assert all(future.result()["code"] == 0 for future in futures)
|
||||||
|
|
||||||
res = list_datasets(api_key, {"id": dataset_id})
|
res = list_datasets(api_key, {"id": dataset_id})
|
||||||
assert res["data"][0]["document_count"] == count
|
assert res["data"][0]["document_count"] == count
|
||||||
|
|||||||
@ -22,11 +22,7 @@ from utils.file_utils import create_txt_file
|
|||||||
|
|
||||||
# DATASET MANAGEMENT
|
# DATASET MANAGEMENT
|
||||||
def batch_create_datasets(client: RAGFlow, num: int) -> list[DataSet]:
|
def batch_create_datasets(client: RAGFlow, num: int) -> list[DataSet]:
|
||||||
datasets = []
|
return [client.create_dataset(name=f"dataset_{i}") for i in range(num)]
|
||||||
for i in range(num):
|
|
||||||
dataset = client.create_dataset(name=f"dataset_{i}")
|
|
||||||
datasets.append(dataset)
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
# FILE MANAGEMENT WITHIN DATASET
|
# FILE MANAGEMENT WITHIN DATASET
|
||||||
@ -39,3 +35,8 @@ def bulk_upload_documents(dataset: DataSet, num: int, tmp_path: Path) -> list[Do
|
|||||||
document_infos.append({"display_name": fp.name, "blob": blob})
|
document_infos.append({"display_name": fp.name, "blob": blob})
|
||||||
|
|
||||||
return dataset.upload_documents(document_infos)
|
return dataset.upload_documents(document_infos)
|
||||||
|
|
||||||
|
|
||||||
|
# CHUNK MANAGEMENT WITHIN DATASET
|
||||||
|
def batch_add_chunks(document: Document, num: int):
|
||||||
|
return [document.add_chunk(content=f"chunk test {i}") for i in range(num)]
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from common import (
|
|||||||
)
|
)
|
||||||
from configs import HOST_ADDRESS, VERSION
|
from configs import HOST_ADDRESS, VERSION
|
||||||
from pytest import FixtureRequest
|
from pytest import FixtureRequest
|
||||||
from ragflow_sdk import DataSet, RAGFlow
|
from ragflow_sdk import Chunk, DataSet, Document, RAGFlow
|
||||||
from utils import wait_for
|
from utils import wait_for
|
||||||
from utils.file_utils import (
|
from utils.file_utils import (
|
||||||
create_docx_file,
|
create_docx_file,
|
||||||
@ -41,7 +41,7 @@ from utils.file_utils import (
|
|||||||
|
|
||||||
@wait_for(30, 1, "Document parsing timeout")
|
@wait_for(30, 1, "Document parsing timeout")
|
||||||
def condition(_dataset: DataSet):
|
def condition(_dataset: DataSet):
|
||||||
documents = DataSet.list_documents(page_size=1000)
|
documents = _dataset.list_documents(page_size=1000)
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if document.run != "DONE":
|
if document.run != "DONE":
|
||||||
return False
|
return False
|
||||||
@ -49,7 +49,7 @@ def condition(_dataset: DataSet):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def generate_test_files(request, tmp_path):
|
def generate_test_files(request: FixtureRequest, tmp_path: Path):
|
||||||
file_creators = {
|
file_creators = {
|
||||||
"docx": (tmp_path / "ragflow_test.docx", create_docx_file),
|
"docx": (tmp_path / "ragflow_test.docx", create_docx_file),
|
||||||
"excel": (tmp_path / "ragflow_test.xlsx", create_excel_file),
|
"excel": (tmp_path / "ragflow_test.xlsx", create_excel_file),
|
||||||
@ -72,13 +72,13 @@ def generate_test_files(request, tmp_path):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
def ragflow_tmp_dir(request, tmp_path_factory) -> Path:
|
def ragflow_tmp_dir(request: FixtureRequest, tmp_path_factory: Path) -> Path:
|
||||||
class_name = request.cls.__name__
|
class_name = request.cls.__name__
|
||||||
return tmp_path_factory.mktemp(class_name)
|
return tmp_path_factory.mktemp(class_name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def client(token) -> RAGFlow:
|
def client(token: str) -> RAGFlow:
|
||||||
return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)
|
return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)
|
||||||
|
|
||||||
|
|
||||||
@ -96,9 +96,7 @@ def add_dataset(request: FixtureRequest, client: RAGFlow):
|
|||||||
client.delete_datasets(ids=None)
|
client.delete_datasets(ids=None)
|
||||||
|
|
||||||
request.addfinalizer(cleanup)
|
request.addfinalizer(cleanup)
|
||||||
|
return batch_create_datasets(client, 1)[0]
|
||||||
dataset_ids = batch_create_datasets(client, 1)
|
|
||||||
return dataset_ids[0]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
@ -111,12 +109,31 @@ def add_dataset_func(request: FixtureRequest, client: RAGFlow) -> DataSet:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class")
|
@pytest.fixture(scope="class")
|
||||||
def add_document(request: FixtureRequest, add_dataset: DataSet, ragflow_tmp_dir):
|
def add_document(add_dataset: DataSet, ragflow_tmp_dir: Path) -> tuple[DataSet, Document]:
|
||||||
dataset = add_dataset
|
return add_dataset, bulk_upload_documents(add_dataset, 1, ragflow_tmp_dir)[0]
|
||||||
documents = bulk_upload_documents(dataset, 1, ragflow_tmp_dir)
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def add_chunks(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]:
|
||||||
|
dataset, document = add_document
|
||||||
|
dataset.async_parse_documents([document.id])
|
||||||
|
condition(dataset)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for i in range(4):
|
||||||
|
chunk = document.add_chunk(content=f"chunk test {i}")
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# issues/6487
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
sleep(1)
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
dataset.delete_documents(ids=None)
|
try:
|
||||||
|
document.delete_chunks(ids=[])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
request.addfinalizer(cleanup)
|
request.addfinalizer(cleanup)
|
||||||
return dataset, documents[0]
|
return dataset, document, chunks
|
||||||
|
|||||||
@ -0,0 +1,49 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import FixtureRequest
|
||||||
|
from ragflow_sdk import Chunk, DataSet, Document
|
||||||
|
from utils import wait_for
|
||||||
|
|
||||||
|
|
||||||
|
@wait_for(30, 1, "Document parsing timeout")
|
||||||
|
def condition(_dataset: DataSet):
|
||||||
|
documents = _dataset.list_documents(page_size=1000)
|
||||||
|
for document in documents:
|
||||||
|
if document.run != "DONE":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def add_chunks_func(request: FixtureRequest, add_document: tuple[DataSet, Document]) -> tuple[DataSet, Document, list[Chunk]]:
|
||||||
|
dataset, document = add_document
|
||||||
|
dataset.async_parse_documents([document.id])
|
||||||
|
condition(dataset)
|
||||||
|
chunks = [document.add_chunk(content=f"chunk test {i}") for i in range(4)]
|
||||||
|
|
||||||
|
# issues/6487
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
sleep(1)
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
document.delete_chunks(ids=[])
|
||||||
|
|
||||||
|
request.addfinalizer(cleanup)
|
||||||
|
return dataset, document, chunks
|
||||||
@ -0,0 +1,160 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from ragflow_sdk import Chunk
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chunk_details(dataset_id: str, document_id: str, payload: dict, chunk: Chunk):
|
||||||
|
assert chunk.dataset_id == dataset_id
|
||||||
|
assert chunk.document_id == document_id
|
||||||
|
assert chunk.content == payload["content"]
|
||||||
|
if "important_keywords" in payload:
|
||||||
|
assert chunk.important_keywords == payload["important_keywords"]
|
||||||
|
if "questions" in payload:
|
||||||
|
assert chunk.questions == [str(q).strip() for q in payload.get("questions", []) if str(q).strip()]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddChunk:
|
||||||
|
@pytest.mark.p1
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
({"content": None}, "not instance of"),
|
||||||
|
({"content": ""}, "`content` is required"),
|
||||||
|
({"content": 1}, "not instance of"),
|
||||||
|
({"content": "a"}, ""),
|
||||||
|
({"content": " "}, "`content` is required"),
|
||||||
|
({"content": "\n!?。;!?\"'"}, ""),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_content(self, add_document, payload, expected_message):
|
||||||
|
dataset, document = add_document
|
||||||
|
chunks_count = len(document.list_chunks())
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.add_chunk(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunk = document.add_chunk(**payload)
|
||||||
|
validate_chunk_details(dataset.id, document.id, payload, chunk)
|
||||||
|
|
||||||
|
sleep(1)
|
||||||
|
chunks = document.list_chunks()
|
||||||
|
assert len(chunks) == chunks_count + 1, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
({"content": "chunk test important_keywords 1", "important_keywords": ["a", "b", "c"]}, ""),
|
||||||
|
({"content": "chunk test important_keywords 2", "important_keywords": [""]}, ""),
|
||||||
|
({"content": "chunk test important_keywords 3", "important_keywords": [1]}, "not instance of"),
|
||||||
|
({"content": "chunk test important_keywords 4", "important_keywords": ["a", "a"]}, ""),
|
||||||
|
({"content": "chunk test important_keywords 5", "important_keywords": "abc"}, "not instance of"),
|
||||||
|
({"content": "chunk test important_keywords 6", "important_keywords": 123}, "not instance of"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_important_keywords(self, add_document, payload, expected_message):
|
||||||
|
dataset, document = add_document
|
||||||
|
chunks_count = len(document.list_chunks())
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.add_chunk(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunk = document.add_chunk(**payload)
|
||||||
|
validate_chunk_details(dataset.id, document.id, payload, chunk)
|
||||||
|
|
||||||
|
sleep(1)
|
||||||
|
chunks = document.list_chunks()
|
||||||
|
assert len(chunks) == chunks_count + 1, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
({"content": "chunk test test_questions 1", "questions": ["a", "b", "c"]}, ""),
|
||||||
|
({"content": "chunk test test_questions 2", "questions": [""]}, ""),
|
||||||
|
({"content": "chunk test test_questions 3", "questions": [1]}, "not instance of"),
|
||||||
|
({"content": "chunk test test_questions 4", "questions": ["a", "a"]}, ""),
|
||||||
|
({"content": "chunk test test_questions 5", "questions": "abc"}, "not instance of"),
|
||||||
|
({"content": "chunk test test_questions 6", "questions": 123}, "not instance of"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_questions(self, add_document, payload, expected_message):
|
||||||
|
dataset, document = add_document
|
||||||
|
chunks_count = len(document.list_chunks())
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.add_chunk(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunk = document.add_chunk(**payload)
|
||||||
|
validate_chunk_details(dataset.id, document.id, payload, chunk)
|
||||||
|
|
||||||
|
sleep(1)
|
||||||
|
chunks = document.list_chunks()
|
||||||
|
assert len(chunks) == chunks_count + 1, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_repeated_add_chunk(self, add_document):
|
||||||
|
payload = {"content": "chunk test repeated_add_chunk"}
|
||||||
|
dataset, document = add_document
|
||||||
|
chunks_count = len(document.list_chunks())
|
||||||
|
|
||||||
|
chunk1 = document.add_chunk(**payload)
|
||||||
|
validate_chunk_details(dataset.id, document.id, payload, chunk1)
|
||||||
|
sleep(1)
|
||||||
|
chunks = document.list_chunks()
|
||||||
|
assert len(chunks) == chunks_count + 1, str(chunks)
|
||||||
|
|
||||||
|
chunk2 = document.add_chunk(**payload)
|
||||||
|
validate_chunk_details(dataset.id, document.id, payload, chunk2)
|
||||||
|
sleep(1)
|
||||||
|
chunks = document.list_chunks()
|
||||||
|
assert len(chunks) == chunks_count + 1, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_add_chunk_to_deleted_document(self, add_document):
|
||||||
|
dataset, document = add_document
|
||||||
|
dataset.delete_documents(ids=[document.id])
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.add_chunk(content="chunk test")
|
||||||
|
assert f"You don't own the document {document.id}" in str(excinfo.value), str(excinfo.value)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="issues/6411")
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_concurrent_add_chunk(self, add_document):
|
||||||
|
count = 50
|
||||||
|
_, document = add_document
|
||||||
|
initial_chunk_count = len(document.list_chunks())
|
||||||
|
|
||||||
|
def add_chunk_task(i):
|
||||||
|
return document.add_chunk(content=f"chunk test concurrent {i}")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
|
futures = [executor.submit(add_chunk_task, i) for i in range(count)]
|
||||||
|
responses = list(as_completed(futures))
|
||||||
|
assert len(responses) == count, responses
|
||||||
|
sleep(5)
|
||||||
|
assert len(document.list_chunks(page_size=100)) == initial_chunk_count + count
|
||||||
@ -0,0 +1,113 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from common import batch_add_chunks
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunksDeletion:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload",
|
||||||
|
[
|
||||||
|
pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3),
|
||||||
|
pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1),
|
||||||
|
pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_delete_partial_invalid_id(self, add_chunks_func, payload):
|
||||||
|
_, document, chunks = add_chunks_func
|
||||||
|
chunk_ids = [chunk.id for chunk in chunks]
|
||||||
|
payload = payload(chunk_ids)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.delete_chunks(**payload)
|
||||||
|
assert "rm_chunk deleted chunks" in str(excinfo.value), str(excinfo.value)
|
||||||
|
|
||||||
|
remaining_chunks = document.list_chunks()
|
||||||
|
assert len(remaining_chunks) == 1, str(remaining_chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_repeated_deletion(self, add_chunks_func):
|
||||||
|
_, document, chunks = add_chunks_func
|
||||||
|
chunk_ids = [chunk.id for chunk in chunks]
|
||||||
|
document.delete_chunks(ids=chunk_ids)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.delete_chunks(ids=chunk_ids)
|
||||||
|
assert "rm_chunk deleted chunks 0, expect" in str(excinfo.value), str(excinfo.value)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_duplicate_deletion(self, add_chunks_func):
|
||||||
|
_, document, chunks = add_chunks_func
|
||||||
|
chunk_ids = [chunk.id for chunk in chunks]
|
||||||
|
document.delete_chunks(ids=chunk_ids * 2)
|
||||||
|
remaining_chunks = document.list_chunks()
|
||||||
|
assert len(remaining_chunks) == 1, str(remaining_chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_concurrent_deletion(self, add_document):
|
||||||
|
count = 100
|
||||||
|
_, document = add_document
|
||||||
|
chunks = batch_add_chunks(document, count)
|
||||||
|
chunk_ids = [chunk.id for chunk in chunks]
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
|
futures = [executor.submit(document.delete_chunks, ids=[chunk_id]) for chunk_id in chunk_ids]
|
||||||
|
responses = list(as_completed(futures))
|
||||||
|
assert len(responses) == count, responses
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_delete_1k(self, add_document):
|
||||||
|
count = 1_000
|
||||||
|
_, document = add_document
|
||||||
|
chunks = batch_add_chunks(document, count)
|
||||||
|
chunk_ids = [chunk.id for chunk in chunks]
|
||||||
|
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
sleep(1)
|
||||||
|
|
||||||
|
document.delete_chunks(ids=chunk_ids)
|
||||||
|
remaining_chunks = document.list_chunks()
|
||||||
|
assert len(remaining_chunks) == 0, str(remaining_chunks)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message, remaining",
|
||||||
|
[
|
||||||
|
pytest.param(None, "TypeError", 5, marks=pytest.mark.skip),
|
||||||
|
pytest.param({"ids": ["invalid_id"]}, "rm_chunk deleted chunks 0, expect 1", 5, marks=pytest.mark.p3),
|
||||||
|
pytest.param("not json", "UnboundLocalError", 5, marks=pytest.mark.skip(reason="pull/6376")),
|
||||||
|
pytest.param(lambda r: {"ids": r[:1]}, "", 4, marks=pytest.mark.p3),
|
||||||
|
pytest.param(lambda r: {"ids": r}, "", 1, marks=pytest.mark.p1),
|
||||||
|
pytest.param({"ids": []}, "", 0, marks=pytest.mark.p3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_basic_scenarios(self, add_chunks_func, payload, expected_message, remaining):
|
||||||
|
_, document, chunks = add_chunks_func
|
||||||
|
chunk_ids = [chunk.id for chunk in chunks]
|
||||||
|
if callable(payload):
|
||||||
|
payload = payload(chunk_ids)
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.delete_chunks(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
document.delete_chunks(**payload)
|
||||||
|
|
||||||
|
remaining_chunks = document.list_chunks()
|
||||||
|
assert len(remaining_chunks) == remaining, str(remaining_chunks)
|
||||||
@ -0,0 +1,140 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from common import batch_add_chunks
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunksList:
|
||||||
|
@pytest.mark.p1
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"params, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
({"page": None, "page_size": 2}, 2, ""),
|
||||||
|
pytest.param({"page": 0, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
|
||||||
|
({"page": 2, "page_size": 2}, 2, ""),
|
||||||
|
({"page": 3, "page_size": 2}, 1, ""),
|
||||||
|
({"page": "3", "page_size": 2}, 1, ""),
|
||||||
|
pytest.param({"page": -1, "page_size": 2}, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
|
||||||
|
pytest.param({"page": "a", "page_size": 2}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_page(self, add_chunks, params, expected_page_size, expected_message):
|
||||||
|
_, document, _ = add_chunks
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.list_chunks(**params)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = document.list_chunks(**params)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p1
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"params, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
({"page_size": None}, 5, ""),
|
||||||
|
pytest.param({"page_size": 0}, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")),
|
||||||
|
pytest.param({"page_size": 0}, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")),
|
||||||
|
({"page_size": 1}, 1, ""),
|
||||||
|
({"page_size": 6}, 5, ""),
|
||||||
|
({"page_size": "1"}, 1, ""),
|
||||||
|
pytest.param({"page_size": -1}, 5, "", marks=pytest.mark.skip),
|
||||||
|
pytest.param({"page_size": "a"}, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_page_size(self, add_chunks, params, expected_page_size, expected_message):
|
||||||
|
_, document, _ = add_chunks
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.list_chunks(**params)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = document.list_chunks(**params)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"params, expected_page_size",
|
||||||
|
[
|
||||||
|
({"keywords": None}, 5),
|
||||||
|
({"keywords": ""}, 5),
|
||||||
|
({"keywords": "1"}, 1),
|
||||||
|
pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")),
|
||||||
|
({"keywords": "ragflow"}, 1),
|
||||||
|
({"keywords": "unknown"}, 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_keywords(self, add_chunks, params, expected_page_size):
|
||||||
|
_, document, _ = add_chunks
|
||||||
|
chunks = document.list_chunks(**params)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p1
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"chunk_id, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
(None, 5, ""),
|
||||||
|
("", 5, ""),
|
||||||
|
pytest.param(lambda r: r[0], 1, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6499")),
|
||||||
|
pytest.param("unknown", 0, """AttributeError("\'NoneType\' object has no attribute \'keys\'")""", marks=pytest.mark.skip),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_id(self, add_chunks, chunk_id, expected_page_size, expected_message):
|
||||||
|
_, document, chunks = add_chunks
|
||||||
|
chunk_ids = [chunk.id for chunk in chunks]
|
||||||
|
if callable(chunk_id):
|
||||||
|
params = {"id": chunk_id(chunk_ids)}
|
||||||
|
else:
|
||||||
|
params = {"id": chunk_id}
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
document.list_chunks(**params)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = document.list_chunks(**params)
|
||||||
|
if params["id"] in [None, ""]:
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
else:
|
||||||
|
assert chunks[0].id == params["id"], str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_concurrent_list(self, add_chunks):
|
||||||
|
_, document, _ = add_chunks
|
||||||
|
count = 100
|
||||||
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
|
futures = [executor.submit(document.list_chunks) for _ in range(count)]
|
||||||
|
|
||||||
|
responses = list(as_completed(futures))
|
||||||
|
assert len(responses) == count, responses
|
||||||
|
assert all(len(future.result()) == 5 for future in futures)
|
||||||
|
|
||||||
|
@pytest.mark.p1
|
||||||
|
def test_default(self, add_document):
|
||||||
|
_, document = add_document
|
||||||
|
batch_add_chunks(document, 31)
|
||||||
|
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
sleep(3)
|
||||||
|
|
||||||
|
chunks = document.list_chunks()
|
||||||
|
assert len(chunks) == 30, str(chunks)
|
||||||
@ -0,0 +1,254 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunksRetrieval:
|
||||||
|
@pytest.mark.p1
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
({"question": "chunk", "dataset_ids": None}, 4, ""),
|
||||||
|
({"question": "chunk", "document_ids": None}, 0, "missing 1 required positional argument"),
|
||||||
|
({"question": "chunk", "dataset_ids": None, "document_ids": None}, 4, ""),
|
||||||
|
({"question": "chunk"}, 0, "missing 1 required positional argument"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_basic_scenarios(self, client, add_chunks, payload, expected_page_size, expected_message):
|
||||||
|
dataset, document, _ = add_chunks
|
||||||
|
if "dataset_ids" in payload:
|
||||||
|
payload["dataset_ids"] = [dataset.id]
|
||||||
|
if "document_ids" in payload:
|
||||||
|
payload["document_ids"] = [document.id]
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
client.retrieve(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = client.retrieve(**payload)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
{"page": None, "page_size": 2},
|
||||||
|
2,
|
||||||
|
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"page": 0, "page_size": 2},
|
||||||
|
0,
|
||||||
|
"ValueError('Search does not support negative slicing.')",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
pytest.param({"page": 2, "page_size": 2}, 2, "", marks=pytest.mark.skip(reason="issues/6646")),
|
||||||
|
({"page": 3, "page_size": 2}, 0, ""),
|
||||||
|
({"page": "3", "page_size": 2}, 0, ""),
|
||||||
|
pytest.param(
|
||||||
|
{"page": -1, "page_size": 2},
|
||||||
|
0,
|
||||||
|
"ValueError('Search does not support negative slicing.')",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"page": "a", "page_size": 2},
|
||||||
|
0,
|
||||||
|
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_page(self, client, add_chunks, payload, expected_page_size, expected_message):
|
||||||
|
dataset, _, _ = add_chunks
|
||||||
|
payload.update({"question": "chunk", "dataset_ids": [dataset.id]})
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
client.retrieve(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = client.retrieve(**payload)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
{"page_size": None},
|
||||||
|
0,
|
||||||
|
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
({"page_size": 1}, 1, ""),
|
||||||
|
({"page_size": 5}, 4, ""),
|
||||||
|
({"page_size": "1"}, 1, ""),
|
||||||
|
pytest.param(
|
||||||
|
{"page_size": "a"},
|
||||||
|
0,
|
||||||
|
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_page_size(self, client, add_chunks, payload, expected_page_size, expected_message):
|
||||||
|
dataset, _, _ = add_chunks
|
||||||
|
payload.update({"question": "chunk", "dataset_ids": [dataset.id]})
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
client.retrieve(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = client.retrieve(**payload)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
({"vector_similarity_weight": 0}, 4, ""),
|
||||||
|
({"vector_similarity_weight": 0.5}, 4, ""),
|
||||||
|
({"vector_similarity_weight": 10}, 4, ""),
|
||||||
|
pytest.param(
|
||||||
|
{"vector_similarity_weight": "a"},
|
||||||
|
0,
|
||||||
|
"""ValueError("could not convert string to float: 'a'")""",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_vector_similarity_weight(self, client, add_chunks, payload, expected_page_size, expected_message):
|
||||||
|
dataset, _, _ = add_chunks
|
||||||
|
payload.update({"question": "chunk", "dataset_ids": [dataset.id]})
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
client.retrieve(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = client.retrieve(**payload)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
({"top_k": 10}, 4, ""),
|
||||||
|
pytest.param(
|
||||||
|
{"top_k": 1},
|
||||||
|
4,
|
||||||
|
"",
|
||||||
|
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"top_k": 1},
|
||||||
|
1,
|
||||||
|
"",
|
||||||
|
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"top_k": -1},
|
||||||
|
4,
|
||||||
|
"must be greater than 0",
|
||||||
|
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"top_k": -1},
|
||||||
|
4,
|
||||||
|
"3014",
|
||||||
|
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"top_k": "a"},
|
||||||
|
0,
|
||||||
|
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_top_k(self, client, add_chunks, payload, expected_page_size, expected_message):
|
||||||
|
dataset, _, _ = add_chunks
|
||||||
|
payload.update({"question": "chunk", "dataset_ids": [dataset.id]})
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
client.retrieve(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = client.retrieve(**payload)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, ""),
|
||||||
|
pytest.param({"rerank_id": "unknown"}, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rerank_id(self, client, add_chunks, payload, expected_message):
|
||||||
|
dataset, _, _ = add_chunks
|
||||||
|
payload.update({"question": "chunk", "dataset_ids": [dataset.id]})
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
client.retrieve(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = client.retrieve(**payload)
|
||||||
|
assert len(chunks) > 0, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_page_size, expected_message",
|
||||||
|
[
|
||||||
|
({"keyword": True}, 5, ""),
|
||||||
|
({"keyword": "True"}, 5, ""),
|
||||||
|
({"keyword": False}, 5, ""),
|
||||||
|
({"keyword": "False"}, 5, ""),
|
||||||
|
({"keyword": None}, 5, ""),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_keyword(self, client, add_chunks, payload, expected_page_size, expected_message):
|
||||||
|
dataset, _, _ = add_chunks
|
||||||
|
payload.update({"question": "chunk test", "dataset_ids": [dataset.id]})
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
client.retrieve(**payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunks = client.retrieve(**payload)
|
||||||
|
assert len(chunks) == expected_page_size, str(chunks)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_concurrent_retrieval(self, client, add_chunks):
|
||||||
|
dataset, _, _ = add_chunks
|
||||||
|
count = 100
|
||||||
|
payload = {"question": "chunk", "dataset_ids": [dataset.id]}
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
|
futures = [executor.submit(client.retrieve, **payload) for _ in range(count)]
|
||||||
|
responses = list(as_completed(futures))
|
||||||
|
assert len(responses) == count, responses
|
||||||
@ -0,0 +1,154 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdatedChunk:
|
||||||
|
@pytest.mark.p1
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
({"content": None}, "TypeError('expected string or bytes-like object')"),
|
||||||
|
pytest.param(
|
||||||
|
{"content": ""},
|
||||||
|
"""APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""",
|
||||||
|
marks=pytest.mark.skip(reason="issues/6541"),
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{"content": 1},
|
||||||
|
"TypeError('expected string or bytes-like object')",
|
||||||
|
marks=pytest.mark.skip,
|
||||||
|
),
|
||||||
|
({"content": "update chunk"}, ""),
|
||||||
|
pytest.param(
|
||||||
|
{"content": " "},
|
||||||
|
"""APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""",
|
||||||
|
marks=pytest.mark.skip(reason="issues/6541"),
|
||||||
|
),
|
||||||
|
({"content": "\n!?。;!?\"'"}, ""),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_content(self, add_chunks, payload, expected_message):
|
||||||
|
_, _, chunks = add_chunks
|
||||||
|
chunk = chunks[0]
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
chunk.update(payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunk.update(payload)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
({"important_keywords": ["a", "b", "c"]}, ""),
|
||||||
|
({"important_keywords": [""]}, ""),
|
||||||
|
({"important_keywords": [1]}, "TypeError('sequence item 0: expected str instance, int found')"),
|
||||||
|
({"important_keywords": ["a", "a"]}, ""),
|
||||||
|
({"important_keywords": "abc"}, "`important_keywords` should be a list"),
|
||||||
|
({"important_keywords": 123}, "`important_keywords` should be a list"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_important_keywords(self, add_chunks, payload, expected_message):
|
||||||
|
_, _, chunks = add_chunks
|
||||||
|
chunk = chunks[0]
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
chunk.update(payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunk.update(payload)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
({"questions": ["a", "b", "c"]}, ""),
|
||||||
|
({"questions": [""]}, ""),
|
||||||
|
({"questions": [1]}, "TypeError('sequence item 0: expected str instance, int found')"),
|
||||||
|
({"questions": ["a", "a"]}, ""),
|
||||||
|
({"questions": "abc"}, "`questions` should be a list"),
|
||||||
|
({"questions": 123}, "`questions` should be a list"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_questions(self, add_chunks, payload, expected_message):
|
||||||
|
_, _, chunks = add_chunks
|
||||||
|
chunk = chunks[0]
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
chunk.update(payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunk.update(payload)
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"payload, expected_message",
|
||||||
|
[
|
||||||
|
({"available": True}, ""),
|
||||||
|
pytest.param({"available": "True"}, """ValueError("invalid literal for int() with base 10: \'True\'")""", marks=pytest.mark.skip),
|
||||||
|
({"available": 1}, ""),
|
||||||
|
({"available": False}, ""),
|
||||||
|
pytest.param({"available": "False"}, """ValueError("invalid literal for int() with base 10: \'False\'")""", marks=pytest.mark.skip),
|
||||||
|
({"available": 0}, ""),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_available(self, add_chunks, payload, expected_message):
|
||||||
|
_, _, chunks = add_chunks
|
||||||
|
chunk = chunks[0]
|
||||||
|
|
||||||
|
if expected_message:
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
chunk.update(payload)
|
||||||
|
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||||
|
else:
|
||||||
|
chunk.update(payload)
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_repeated_update_chunk(self, add_chunks):
|
||||||
|
_, _, chunks = add_chunks
|
||||||
|
chunk = chunks[0]
|
||||||
|
|
||||||
|
chunk.update({"content": "chunk test 1"})
|
||||||
|
chunk.update({"content": "chunk test 2"})
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554")
|
||||||
|
def test_concurrent_update_chunk(self, add_chunks):
|
||||||
|
count = 50
|
||||||
|
_, _, chunks = add_chunks
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
|
futures = [executor.submit(chunks[randint(0, 3)].update, {"content": f"update chunk test {i}"}) for i in range(count)]
|
||||||
|
responses = list(as_completed(futures))
|
||||||
|
assert len(responses) == count, responses
|
||||||
|
|
||||||
|
@pytest.mark.p3
|
||||||
|
def test_update_chunk_to_deleted_document(self, add_chunks):
|
||||||
|
dataset, document, chunks = add_chunks
|
||||||
|
dataset.delete_documents(ids=[document.id])
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
chunks[0].update({})
|
||||||
|
assert f"Can't find this chunk {chunks[0].id}" in str(excinfo.value), str(excinfo.value)
|
||||||
Reference in New Issue
Block a user