mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Test: Add tests for chunk API endpoints (#8616)
### What problem does this PR solve? - Add comprehensive test suite for chunk operations including: - Test files for create, list, retrieve, update, and delete chunks - Authorization tests - Batch operations tests - Update test configurations and common utilities - Validate `important_kwd` and `question_kwd` fields are lists in chunk_app.py - Reorganize imports and clean up duplicate code ### Type of change - [x] Add test cases
This commit is contained in:
@ -15,27 +15,25 @@
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
|
||||
import xxhash
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from rag.app.qa import rmPrefix, beAdoc
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.prompts import keyword_extraction, cross_languages
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
from api import settings
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
import xxhash
|
||||
import re
|
||||
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts import cross_languages, keyword_extraction
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from rag.utils import rmSpace
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@ -129,9 +127,13 @@ def set():
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
if "important_kwd" in req:
|
||||
if not isinstance(req["important_kwd"], list):
|
||||
return get_data_error_result(message="`important_kwd` should be a list")
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
|
||||
if "question_kwd" in req:
|
||||
if not isinstance(req["question_kwd"], list):
|
||||
return get_data_error_result(message="`question_kwd` should be a list")
|
||||
d["question_kwd"] = req["question_kwd"]
|
||||
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
|
||||
if "tag_kwd" in req:
|
||||
@ -235,6 +237,8 @@ def create():
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
if "tag_feas" in req:
|
||||
d["tag_feas"] = req["tag_feas"]
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
||||
@ -24,9 +24,7 @@ HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
KB_APP_URL = f"/{VERSION}/kb"
|
||||
DOCUMENT_APP_URL = f"/{VERSION}/document"
|
||||
# FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents"
|
||||
# FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks"
|
||||
# CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
CHUNK_API_URL = f"/{VERSION}/chunk"
|
||||
# CHAT_ASSISTANT_API_URL = "/api/v1/chats"
|
||||
# SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
|
||||
# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
|
||||
@ -164,3 +162,42 @@ def bulk_upload_documents(auth, kb_id, num, tmp_path):
|
||||
for document in res["data"]:
|
||||
document_ids.append(document["id"])
|
||||
return document_ids
|
||||
|
||||
|
||||
# CHUNK APP
|
||||
def add_chunk(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/create", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
def list_chunks(auth, payload=None, *, headers=HEADERS):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/list", headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def get_chunk(auth, params=None, *, headers=HEADERS):
|
||||
res = requests.get(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/get", headers=headers, auth=auth, params=params)
|
||||
return res.json()
|
||||
|
||||
|
||||
def update_chunk(auth, payload=None, *, headers=HEADERS):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/set", headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_chunks(auth, payload=None, *, headers=HEADERS):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/rm", headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def retrieval_chunks(auth, payload=None, *, headers=HEADERS):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{CHUNK_API_URL}/retrieval_test", headers=headers, auth=auth, json=payload)
|
||||
return res.json()
|
||||
|
||||
|
||||
def batch_add_chunks(auth, doc_id, num):
|
||||
chunk_ids = []
|
||||
for i in range(num):
|
||||
res = add_chunk(auth, {"doc_id": doc_id, "content_with_weight": f"chunk test {i}"})
|
||||
chunk_ids.append(res["data"]["chunk_id"])
|
||||
return chunk_ids
|
||||
|
||||
@ -13,18 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from common import (
|
||||
batch_add_chunks,
|
||||
batch_create_datasets,
|
||||
bulk_upload_documents,
|
||||
delete_chunks,
|
||||
list_chunks,
|
||||
list_documents,
|
||||
list_kbs,
|
||||
parse_documents,
|
||||
rm_kb,
|
||||
)
|
||||
|
||||
# from configs import HOST_ADDRESS, VERSION
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from pytest import FixtureRequest
|
||||
|
||||
# from ragflow_sdk import RAGFlow
|
||||
from utils import wait_for
|
||||
from utils.file_utils import (
|
||||
create_docx_file,
|
||||
create_eml_file,
|
||||
@ -39,6 +44,15 @@ from utils.file_utils import (
|
||||
)
|
||||
|
||||
|
||||
@wait_for(30, 1, "Document parsing timeout")
|
||||
def condition(_auth, _kb_id):
|
||||
res = list_documents(_auth, {"kb_id": _kb_id})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "3":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generate_test_files(request: FixtureRequest, tmp_path):
|
||||
file_creators = {
|
||||
@ -73,11 +87,6 @@ def WebApiAuth(auth):
|
||||
return RAGFlowWebApiAuth(auth)
|
||||
|
||||
|
||||
# @pytest.fixture(scope="session")
|
||||
# def client(token: str) -> RAGFlow:
|
||||
# return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth):
|
||||
def cleanup():
|
||||
@ -108,3 +117,35 @@ def add_dataset_func(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) ->
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return batch_create_datasets(WebApiAuth, 1)[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_document(request, WebApiAuth, add_dataset, ragflow_tmp_dir):
|
||||
# def cleanup():
|
||||
# res = list_documents(WebApiAuth, {"kb_id": dataset_id})
|
||||
# for doc in res["data"]["docs"]:
|
||||
# delete_document(WebApiAuth, {"doc_id": doc["id"]})
|
||||
|
||||
# request.addfinalizer(cleanup)
|
||||
|
||||
dataset_id = add_dataset
|
||||
return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 1, ragflow_tmp_dir)[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_chunks(request, WebApiAuth, add_document):
|
||||
def cleanup():
|
||||
res = list_chunks(WebApiAuth, {"doc_id": document_id})
|
||||
if res["code"] == 0:
|
||||
chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]]
|
||||
delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
kb_id, document_id = add_document
|
||||
parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"})
|
||||
condition(WebApiAuth, kb_id)
|
||||
chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4)
|
||||
# issues/6487
|
||||
sleep(1)
|
||||
return kb_id, document_id, chunk_ids
|
||||
|
||||
49
test/testcases/test_web_api/test_chunk_app/conftest.py
Normal file
49
test/testcases/test_web_api/test_chunk_app/conftest.py
Normal file
@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from common import batch_add_chunks, delete_chunks, list_chunks, list_documents, parse_documents
|
||||
from utils import wait_for
|
||||
|
||||
|
||||
@wait_for(30, 1, "Document parsing timeout")
|
||||
def condition(_auth, _kb_id):
|
||||
res = list_documents(_auth, {"kb_id": _kb_id})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "3":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_chunks_func(request, WebApiAuth, add_document):
|
||||
def cleanup():
|
||||
res = list_chunks(WebApiAuth, {"doc_id": document_id})
|
||||
chunk_ids = [chunk["chunk_id"] for chunk in res["data"]["chunks"]]
|
||||
delete_chunks(WebApiAuth, {"doc_id": document_id, "chunk_ids": chunk_ids})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
kb_id, document_id = add_document
|
||||
parse_documents(WebApiAuth, {"doc_ids": [document_id], "run": "1"})
|
||||
condition(WebApiAuth, kb_id)
|
||||
chunk_ids = batch_add_chunks(WebApiAuth, document_id, 4)
|
||||
# issues/6487
|
||||
sleep(1)
|
||||
return kb_id, document_id, chunk_ids
|
||||
223
test/testcases/test_web_api/test_chunk_app/test_create_chunk.py
Normal file
223
test/testcases/test_web_api/test_chunk_app/test_create_chunk.py
Normal file
@ -0,0 +1,223 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import add_chunk, delete_document, get_chunk, list_chunks
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
def validate_chunk_details(auth, kb_id, doc_id, payload, res):
|
||||
chunk_id = res["data"]["chunk_id"]
|
||||
res = get_chunk(auth, {"chunk_id": chunk_id})
|
||||
assert res["code"] == 0, res
|
||||
chunk = res["data"]
|
||||
assert chunk["doc_id"] == doc_id
|
||||
assert chunk["kb_id"] == kb_id
|
||||
assert chunk["content_with_weight"] == payload["content_with_weight"]
|
||||
if "important_kwd" in payload:
|
||||
assert chunk["important_kwd"] == payload["important_kwd"]
|
||||
if "question_kwd" in payload:
|
||||
expected = [str(q).strip() for q in payload.get("question_kwd", [])]
|
||||
assert chunk["question_kwd"] == expected
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = add_chunk(invalid_auth)
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestAddChunk:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"content_with_weight": None}, 100, """TypeError("unsupported operand type(s) for +: 'NoneType' and 'str'")"""),
|
||||
({"content_with_weight": ""}, 0, ""),
|
||||
pytest.param(
|
||||
{"content_with_weight": 1},
|
||||
100,
|
||||
"""TypeError("unsupported operand type(s) for +: 'int' and 'str'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
({"content_with_weight": "a"}, 0, ""),
|
||||
({"content_with_weight": " "}, 0, ""),
|
||||
({"content_with_weight": "\n!?。;!?\"'"}, 0, ""),
|
||||
],
|
||||
)
|
||||
def test_content(self, WebApiAuth, add_document, payload, expected_code, expected_message):
|
||||
kb_id, doc_id = add_document
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] == 0:
|
||||
chunks_count = res["data"]["doc"]["chunk_num"]
|
||||
else:
|
||||
chunks_count = 0
|
||||
res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"content_with_weight": "chunk test", "important_kwd": ["a", "b", "c"]}, 0, ""),
|
||||
({"content_with_weight": "chunk test", "important_kwd": [""]}, 0, ""),
|
||||
(
|
||||
{"content_with_weight": "chunk test", "important_kwd": [1]},
|
||||
100,
|
||||
"TypeError('sequence item 0: expected str instance, int found')",
|
||||
),
|
||||
({"content_with_weight": "chunk test", "important_kwd": ["a", "a"]}, 0, ""),
|
||||
({"content_with_weight": "chunk test", "important_kwd": "abc"}, 102, "`important_kwd` is required to be a list"),
|
||||
({"content_with_weight": "chunk test", "important_kwd": 123}, 102, "`important_kwd` is required to be a list"),
|
||||
],
|
||||
)
|
||||
def test_important_keywords(self, WebApiAuth, add_document, payload, expected_code, expected_message):
|
||||
kb_id, doc_id = add_document
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] == 0:
|
||||
chunks_count = res["data"]["doc"]["chunk_num"]
|
||||
else:
|
||||
chunks_count = 0
|
||||
res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"content_with_weight": "chunk test", "question_kwd": ["a", "b", "c"]}, 0, ""),
|
||||
({"content_with_weight": "chunk test", "question_kwd": [""]}, 0, ""),
|
||||
({"content_with_weight": "chunk test", "question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
|
||||
({"content_with_weight": "chunk test", "question_kwd": ["a", "a"]}, 0, ""),
|
||||
({"content_with_weight": "chunk test", "question_kwd": "abc"}, 102, "`question_kwd` is required to be a list"),
|
||||
({"content_with_weight": "chunk test", "question_kwd": 123}, 102, "`question_kwd` is required to be a list"),
|
||||
],
|
||||
)
|
||||
def test_questions(self, WebApiAuth, add_document, payload, expected_code, expected_message):
|
||||
kb_id, doc_id = add_document
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] == 0:
|
||||
chunks_count = res["data"]["doc"]["chunk_num"]
|
||||
else:
|
||||
chunks_count = 0
|
||||
res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"doc_id, expected_code, expected_message",
|
||||
[
|
||||
("", 102, "Document not found!"),
|
||||
("invalid_document_id", 102, "Document not found!"),
|
||||
],
|
||||
)
|
||||
def test_invalid_document_id(self, WebApiAuth, add_document, doc_id, expected_code, expected_message):
|
||||
_, _ = add_document
|
||||
res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"})
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_repeated_add_chunk(self, WebApiAuth, add_document):
|
||||
payload = {"content_with_weight": "chunk test"}
|
||||
kb_id, doc_id = add_document
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
chunks_count = res["data"]["doc"]["chunk_num"]
|
||||
|
||||
res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
assert res["data"]["doc"]["chunk_num"] == chunks_count + 1, res
|
||||
|
||||
res = add_chunk(WebApiAuth, {**payload, "doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
validate_chunk_details(WebApiAuth, kb_id, doc_id, payload, res)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
assert res["data"]["doc"]["chunk_num"] == chunks_count + 2, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_add_chunk_to_deleted_document(self, WebApiAuth, add_document):
|
||||
_, doc_id = add_document
|
||||
delete_document(WebApiAuth, {"doc_id": doc_id})
|
||||
res = add_chunk(WebApiAuth, {"doc_id": doc_id, "content_with_weight": "chunk test"})
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == "Document not found!", res
|
||||
|
||||
@pytest.mark.skip(reason="issues/6411")
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_add_chunk(self, WebApiAuth, add_document):
|
||||
count = 50
|
||||
_, doc_id = add_document
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] == 0:
|
||||
chunks_count = res["data"]["doc"]["chunk_num"]
|
||||
else:
|
||||
chunks_count = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
add_chunk,
|
||||
WebApiAuth,
|
||||
{"doc_id": doc_id, "content_with_weight": f"chunk test {i}"},
|
||||
)
|
||||
for i in range(count)
|
||||
]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["doc"]["chunk_num"] == chunks_count + count
|
||||
145
test/testcases/test_web_api/test_chunk_app/test_list_chunks.py
Normal file
145
test/testcases/test_web_api/test_chunk_app/test_list_chunks.py
Normal file
@ -0,0 +1,145 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import batch_add_chunks, list_chunks
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = list_chunks(invalid_auth, {"doc_id": "document_id"})
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestChunksList:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
pytest.param({"page": None, "size": 2}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip),
|
||||
pytest.param({"page": 0, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
|
||||
({"page": 2, "size": 2}, 0, 2, ""),
|
||||
({"page": 3, "size": 2}, 0, 1, ""),
|
||||
({"page": "3", "size": 2}, 0, 1, ""),
|
||||
pytest.param({"page": -1, "size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
|
||||
pytest.param({"page": "a", "size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
def test_page(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message):
|
||||
_, doc_id, _ = add_chunks
|
||||
payload = {"doc_id": doc_id}
|
||||
if params:
|
||||
payload.update(params)
|
||||
res = list_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"size": None}, 100, 0, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")"""),
|
||||
pytest.param({"size": 0}, 0, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")),
|
||||
pytest.param({"size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")),
|
||||
({"size": 1}, 0, 1, ""),
|
||||
({"size": 6}, 0, 5, ""),
|
||||
({"size": "1"}, 0, 1, ""),
|
||||
pytest.param({"size": -1}, 0, 5, "", marks=pytest.mark.skip),
|
||||
pytest.param({"size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
def test_page_size(self, WebApiAuth, add_chunks, params, expected_code, expected_page_size, expected_message):
|
||||
_, doc_id, _ = add_chunks
|
||||
payload = {"doc_id": doc_id}
|
||||
if params:
|
||||
payload.update(params)
|
||||
res = list_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_page_size",
|
||||
[
|
||||
({"keywords": None}, 5),
|
||||
({"keywords": ""}, 5),
|
||||
({"keywords": "1"}, 1),
|
||||
pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")),
|
||||
({"keywords": "content"}, 1),
|
||||
({"keywords": "unknown"}, 0),
|
||||
],
|
||||
)
|
||||
def test_keywords(self, WebApiAuth, add_chunks, params, expected_page_size):
|
||||
_, doc_id, _ = add_chunks
|
||||
payload = {"doc_id": doc_id}
|
||||
if params:
|
||||
payload.update(params)
|
||||
res = list_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_invalid_params(self, WebApiAuth, add_chunks):
|
||||
_, doc_id, _ = add_chunks
|
||||
payload = {"doc_id": doc_id, "a": "b"}
|
||||
res = list_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) == 5, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_list(self, WebApiAuth, add_chunks):
|
||||
_, doc_id, _ = add_chunks
|
||||
count = 100
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(list_chunks, WebApiAuth, {"doc_id": doc_id}) for i in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(len(future.result()["data"]["chunks"]) == 5 for future in futures)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_default(self, WebApiAuth, add_document):
|
||||
_, doc_id = add_document
|
||||
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
chunks_count = res["data"]["doc"]["chunk_num"]
|
||||
batch_add_chunks(WebApiAuth, doc_id, 31)
|
||||
# issues/6487
|
||||
from time import sleep
|
||||
|
||||
sleep(3)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["chunks"]) == 30
|
||||
assert res["data"]["doc"]["chunk_num"] == chunks_count + 31
|
||||
@ -0,0 +1,308 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import retrieval_chunks
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = retrieval_chunks(invalid_auth, {"kb_id": "dummy_kb_id", "question": "dummy question"})
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestChunksRetrieval:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"question": "chunk", "kb_id": None}, 0, 4, ""),
|
||||
({"question": "chunk", "doc_ids": None}, 101, 0, "required argument are missing: kb_id; "),
|
||||
({"question": "chunk", "kb_id": None, "doc_ids": None}, 0, 4, ""),
|
||||
({"question": "chunk"}, 101, 0, "required argument are missing: kb_id; "),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, document_id, _ = add_chunks
|
||||
if "kb_id" in payload:
|
||||
payload["kb_id"] = [dataset_id]
|
||||
if "doc_ids" in payload:
|
||||
payload["doc_ids"] = [document_id]
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
{"page": None, "size": 2},
|
||||
100,
|
||||
0,
|
||||
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param(
|
||||
{"page": 0, "size": 2},
|
||||
100,
|
||||
0,
|
||||
"ValueError('Search does not support negative slicing.')",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param({"page": 2, "size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")),
|
||||
({"page": 3, "size": 2}, 0, 0, ""),
|
||||
({"page": "3", "size": 2}, 0, 0, ""),
|
||||
pytest.param(
|
||||
{"page": -1, "size": 2},
|
||||
100,
|
||||
0,
|
||||
"ValueError('Search does not support negative slicing.')",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
pytest.param(
|
||||
{"page": "a", "size": 2},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: 'a'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
pytest.param(
|
||||
{"size": None},
|
||||
100,
|
||||
0,
|
||||
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
# ({"size": 0}, 0, 0, ""),
|
||||
({"size": 1}, 0, 1, ""),
|
||||
({"size": 5}, 0, 4, ""),
|
||||
({"size": "1"}, 0, 1, ""),
|
||||
# ({"size": -1}, 0, 0, ""),
|
||||
pytest.param(
|
||||
{"size": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: 'a'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_page_size(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"vector_similarity_weight": 0}, 0, 4, ""),
|
||||
({"vector_similarity_weight": 0.5}, 0, 4, ""),
|
||||
({"vector_similarity_weight": 10}, 0, 4, ""),
|
||||
pytest.param(
|
||||
{"vector_similarity_weight": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("could not convert string to float: 'a'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_vector_similarity_weight(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"top_k": 10}, 0, 4, ""),
|
||||
pytest.param(
|
||||
{"top_k": 1},
|
||||
0,
|
||||
4,
|
||||
"",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": 1},
|
||||
0,
|
||||
1,
|
||||
"",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": -1},
|
||||
100,
|
||||
4,
|
||||
"must be greater than 0",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": -1},
|
||||
100,
|
||||
4,
|
||||
"3014",
|
||||
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
|
||||
),
|
||||
pytest.param(
|
||||
{"top_k": "a"},
|
||||
100,
|
||||
0,
|
||||
"""ValueError("invalid literal for int() with base 10: 'a'")""",
|
||||
marks=pytest.mark.skip,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_top_k(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""),
|
||||
pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
def test_rerank_id(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) > 0, res
|
||||
else:
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"keyword": True}, 0, 5, ""),
|
||||
({"keyword": "True"}, 0, 5, ""),
|
||||
({"keyword": False}, 0, 5, ""),
|
||||
({"keyword": "False"}, 0, 5, ""),
|
||||
({"keyword": None}, 0, 5, ""),
|
||||
],
|
||||
)
|
||||
def test_keyword(self, WebApiAuth, add_chunks, payload, expected_code, expected_page_size, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk test", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["chunks"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_highlight, expected_message",
|
||||
[
|
||||
({"highlight": True}, 0, True, ""),
|
||||
({"highlight": "True"}, 0, True, ""),
|
||||
pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
pytest.param({"highlight": "False"}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
|
||||
],
|
||||
)
|
||||
def test_highlight(self, WebApiAuth, add_chunks, payload, expected_code, expected_highlight, expected_message):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload.update({"question": "chunk", "kb_id": [dataset_id]})
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_highlight:
|
||||
for chunk in res["data"]["chunks"]:
|
||||
assert "highlight" in chunk, res
|
||||
else:
|
||||
for chunk in res["data"]["chunks"]:
|
||||
assert "highlight" not in chunk, res
|
||||
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_invalid_params(self, WebApiAuth, add_chunks):
|
||||
dataset_id, _, _ = add_chunks
|
||||
payload = {"question": "chunk", "kb_id": [dataset_id], "a": "b"}
|
||||
res = retrieval_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) == 4, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_retrieval(self, WebApiAuth, add_chunks):
|
||||
dataset_id, _, _ = add_chunks
|
||||
count = 100
|
||||
payload = {"question": "chunk", "kb_id": [dataset_id]}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(retrieval_chunks, WebApiAuth, payload) for i in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
161
test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py
Normal file
161
test/testcases/test_web_api/test_chunk_app/test_rm_chunks.py
Normal file
@ -0,0 +1,161 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import batch_add_chunks, delete_chunks, list_chunks
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = delete_chunks(invalid_auth, {"doc_id": "document_id", "chunk_ids": ["1"]})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
class TestChunksDeletion:
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"doc_id, expected_code, expected_message",
|
||||
[
|
||||
("", 102, "Document not found!"),
|
||||
("invalid_document_id", 102, "Document not found!"),
|
||||
],
|
||||
)
|
||||
def test_invalid_document_id(self, WebApiAuth, add_chunks_func, doc_id, expected_code, expected_message):
|
||||
_, _, chunk_ids = add_chunks_func
|
||||
res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids})
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload",
|
||||
[
|
||||
pytest.param(lambda r: {"chunk_ids": ["invalid_id"] + r}, marks=pytest.mark.p3),
|
||||
pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1),
|
||||
pytest.param(lambda r: {"chunk_ids": r + ["invalid_id"]}, marks=pytest.mark.p3),
|
||||
],
|
||||
)
|
||||
def test_delete_partial_invalid_id(self, WebApiAuth, add_chunks_func, payload):
|
||||
_, doc_id, chunk_ids = add_chunks_func
|
||||
if callable(payload):
|
||||
payload = payload(chunk_ids)
|
||||
payload["doc_id"] = doc_id
|
||||
res = delete_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) == 0, res
|
||||
assert res["data"]["total"] == 0, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_repeated_deletion(self, WebApiAuth, add_chunks_func):
|
||||
_, doc_id, chunk_ids = add_chunks_func
|
||||
payload = {"chunk_ids": chunk_ids, "doc_id": doc_id}
|
||||
res = delete_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = delete_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == "Index updating failure", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_duplicate_deletion(self, WebApiAuth, add_chunks_func):
|
||||
_, doc_id, chunk_ids = add_chunks_func
|
||||
payload = {"chunk_ids": chunk_ids * 2, "doc_id": doc_id}
|
||||
res = delete_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["chunks"]) == 0, res
|
||||
assert res["data"]["total"] == 0, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_deletion(self, WebApiAuth, add_document):
|
||||
count = 100
|
||||
_, doc_id = add_document
|
||||
chunk_ids = batch_add_chunks(WebApiAuth, doc_id, count)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
delete_chunks,
|
||||
WebApiAuth,
|
||||
{"doc_id": doc_id, "chunk_ids": chunk_ids[i : i + 1]},
|
||||
)
|
||||
for i in range(count)
|
||||
]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_delete_1k(self, WebApiAuth, add_document):
|
||||
chunks_num = 1_000
|
||||
_, doc_id = add_document
|
||||
chunk_ids = batch_add_chunks(WebApiAuth, doc_id, chunks_num)
|
||||
|
||||
from time import sleep
|
||||
|
||||
sleep(1)
|
||||
|
||||
res = delete_chunks(WebApiAuth, {"doc_id": doc_id, "chunk_ids": chunk_ids})
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
assert len(res["data"]["chunks"]) == 0, res
|
||||
assert res["data"]["total"] == 0, res
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message, remaining",
|
||||
[
|
||||
pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", 5, marks=pytest.mark.skip),
|
||||
pytest.param({"chunk_ids": ["invalid_id"]}, 102, "Index updating failure", 4, marks=pytest.mark.p3),
|
||||
pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")),
|
||||
pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 3, marks=pytest.mark.p3),
|
||||
pytest.param(lambda r: {"chunk_ids": r}, 0, "", 0, marks=pytest.mark.p1),
|
||||
pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, WebApiAuth, add_chunks_func, payload, expected_code, expected_message, remaining):
|
||||
_, doc_id, chunk_ids = add_chunks_func
|
||||
if callable(payload):
|
||||
payload = payload(chunk_ids)
|
||||
payload["doc_id"] = doc_id
|
||||
res = delete_chunks(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if res["code"] != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
if res["code"] != 0:
|
||||
assert False, res
|
||||
assert len(res["data"]["chunks"]) == remaining, res
|
||||
assert res["data"]["total"] == remaining, res
|
||||
232
test/testcases/test_web_api/test_chunk_app/test_update_chunk.py
Normal file
232
test/testcases/test_web_api/test_chunk_app/test_update_chunk.py
Normal file
@ -0,0 +1,232 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from random import randint
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from common import delete_document, list_chunks, update_chunk
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = update_chunk(invalid_auth, {"doc_id": "doc_id", "chunk_id": "chunk_id", "content_with_weight": "test"})
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestUpdateChunk:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"content_with_weight": None}, 100, "TypeError('expected string or bytes-like object')"),
|
||||
({"content_with_weight": ""}, 0, ""),
|
||||
({"content_with_weight": 1}, 100, "TypeError('expected string or bytes-like object')"),
|
||||
({"content_with_weight": "update chunk"}, 0, ""),
|
||||
({"content_with_weight": " "}, 0, ""),
|
||||
({"content_with_weight": "\n!?。;!?\"'"}, 0, ""),
|
||||
],
|
||||
)
|
||||
def test_content(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
chunk_id = chunk_ids[0]
|
||||
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id}
|
||||
if payload:
|
||||
update_payload.update(payload)
|
||||
res = update_chunk(WebApiAuth, update_payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
else:
|
||||
sleep(1)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
for chunk in res["data"]["chunks"]:
|
||||
if chunk["chunk_id"] == chunk_id:
|
||||
assert chunk["content_with_weight"] == payload["content_with_weight"]
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"important_kwd": ["a", "b", "c"]}, 0, ""),
|
||||
({"important_kwd": [""]}, 0, ""),
|
||||
({"important_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
|
||||
({"important_kwd": ["a", "a"]}, 0, ""),
|
||||
({"important_kwd": "abc"}, 102, "`important_kwd` should be a list"),
|
||||
({"important_kwd": 123}, 102, "`important_kwd` should be a list"),
|
||||
],
|
||||
)
|
||||
def test_important_keywords(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
chunk_id = chunk_ids[0]
|
||||
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required
|
||||
if payload:
|
||||
update_payload.update(payload)
|
||||
res = update_chunk(WebApiAuth, update_payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
else:
|
||||
sleep(1)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
for chunk in res["data"]["chunks"]:
|
||||
if chunk["chunk_id"] == chunk_id:
|
||||
assert chunk["important_kwd"] == payload["important_kwd"]
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"question_kwd": ["a", "b", "c"]}, 0, ""),
|
||||
({"question_kwd": [""]}, 0, ""),
|
||||
({"question_kwd": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
|
||||
({"question_kwd": ["a", "a"]}, 0, ""),
|
||||
({"question_kwd": "abc"}, 102, "`question_kwd` should be a list"),
|
||||
({"question_kwd": 123}, 102, "`question_kwd` should be a list"),
|
||||
],
|
||||
)
|
||||
def test_questions(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
chunk_id = chunk_ids[0]
|
||||
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"} # Add content_with_weight as it's required
|
||||
if payload:
|
||||
update_payload.update(payload)
|
||||
|
||||
res = update_chunk(WebApiAuth, update_payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
else:
|
||||
sleep(1)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
for chunk in res["data"]["chunks"]:
|
||||
if chunk["chunk_id"] == chunk_id:
|
||||
assert chunk["question_kwd"] == payload["question_kwd"]
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"available_int": 1}, 0, ""),
|
||||
({"available_int": 0}, 0, ""),
|
||||
],
|
||||
)
|
||||
def test_available(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
chunk_id = chunk_ids[0]
|
||||
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"}
|
||||
if payload:
|
||||
update_payload.update(payload)
|
||||
|
||||
res = update_chunk(WebApiAuth, update_payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
else:
|
||||
sleep(1)
|
||||
res = list_chunks(WebApiAuth, {"doc_id": doc_id})
|
||||
for chunk in res["data"]["chunks"]:
|
||||
if chunk["chunk_id"] == chunk_id:
|
||||
assert chunk["available_int"] == payload["available_int"]
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"doc_id_param, expected_code, expected_message",
|
||||
[
|
||||
("", 102, "Tenant not found!"),
|
||||
("invalid_doc_id", 102, "Tenant not found!"),
|
||||
],
|
||||
)
|
||||
def test_invalid_document_id_for_update(self, WebApiAuth, add_chunks, doc_id_param, expected_code, expected_message):
|
||||
_, _, chunk_ids = add_chunks
|
||||
chunk_id = chunk_ids[0]
|
||||
|
||||
payload = {"doc_id": doc_id_param, "chunk_id": chunk_id, "content_with_weight": "test content"}
|
||||
res = update_chunk(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code
|
||||
assert expected_message in res["message"]
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_repeated_update_chunk(self, WebApiAuth, add_chunks):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
payload1 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 1"}
|
||||
res = update_chunk(WebApiAuth, payload1)
|
||||
assert res["code"] == 0
|
||||
|
||||
payload2 = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "chunk test 2"}
|
||||
res = update_chunk(WebApiAuth, payload2)
|
||||
assert res["code"] == 0
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
({"unknown_key": "unknown_value"}, 0, ""),
|
||||
({}, 0, ""),
|
||||
pytest.param(None, 100, """TypeError("int() argument must be a string, a bytes-like object or a real number, not 'NoneType'")""", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
def test_invalid_params(self, WebApiAuth, add_chunks, payload, expected_code, expected_message):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
chunk_id = chunk_ids[0]
|
||||
update_payload = {"doc_id": doc_id, "chunk_id": chunk_id, "content_with_weight": "unchanged content"}
|
||||
if payload is not None:
|
||||
update_payload.update(payload)
|
||||
|
||||
res = update_chunk(WebApiAuth, update_payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554")
|
||||
def test_concurrent_update_chunk(self, WebApiAuth, add_chunks):
|
||||
count = 50
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
update_chunk,
|
||||
WebApiAuth,
|
||||
{"doc_id": doc_id, "chunk_id": chunk_ids[randint(0, 3)], "content_with_weight": f"update chunk test {i}"},
|
||||
)
|
||||
for i in range(count)
|
||||
]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_update_chunk_to_deleted_document(self, WebApiAuth, add_chunks):
|
||||
_, doc_id, chunk_ids = add_chunks
|
||||
delete_document(WebApiAuth, {"doc_id": doc_id})
|
||||
payload = {"doc_id": doc_id, "chunk_id": chunk_ids[0], "content_with_weight": "test content"}
|
||||
res = update_chunk(WebApiAuth, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == "Tenant not found!", res
|
||||
Reference in New Issue
Block a user