mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Test: Add document app tests (#8456)
### What problem does this PR solve? - Add new test suite for document app with create/list/parse/upload/remove tests - Update API URLs to use version variable from config in HTTP and web API tests ### Type of change - [x] Add test cases
This commit is contained in:
@ -13,22 +13,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from configs import HOST_ADDRESS
|
||||
from configs import HOST_ADDRESS, VERSION
|
||||
from requests_toolbelt import MultipartEncoder
|
||||
from utils.file_utils import create_txt_file
|
||||
|
||||
HEADERS = {"Content-Type": "application/json"}
|
||||
DATASETS_API_URL = "/api/v1/datasets"
|
||||
FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents"
|
||||
FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks"
|
||||
CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
CHAT_ASSISTANT_API_URL = "/api/v1/chats"
|
||||
SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
|
||||
SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
|
||||
DATASETS_API_URL = f"/api/{VERSION}/datasets"
|
||||
FILE_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/documents"
|
||||
FILE_CHUNK_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/chunks"
|
||||
CHUNK_API_URL = f"/api/{VERSION}/datasets/{{dataset_id}}/documents/{{document_id}}/chunks"
|
||||
CHAT_ASSISTANT_API_URL = f"/api/{VERSION}/chats"
|
||||
SESSION_WITH_CHAT_ASSISTANT_API_URL = f"/api/{VERSION}/chats/{{chat_id}}/sessions"
|
||||
SESSION_WITH_AGENT_API_URL = f"/api/{VERSION}/agents/{{agent_id}}/sessions"
|
||||
|
||||
|
||||
# DATASET MANAGEMENT
|
||||
|
||||
@ -346,7 +346,7 @@ class TestDocumentsList:
|
||||
count = 100
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(list_documents, HttpApiAuth, dataset_id) for i in range(count)]
|
||||
futures = [executor.submit(list_documents, HttpApiAuth, dataset_id) for _ in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
|
||||
@ -209,7 +209,7 @@ class TestDocumentsUpload:
|
||||
fps.append(fp)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(upload_documents, HttpApiAuth, dataset_id, fps[i : i + 1]) for i in range(count)]
|
||||
futures = [executor.submit(upload_documents, HttpApiAuth, dataset_id, [fp]) for fp in fps]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
|
||||
@ -13,12 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from configs import HOST_ADDRESS
|
||||
from configs import HOST_ADDRESS, VERSION
|
||||
from requests_toolbelt import MultipartEncoder
|
||||
from utils.file_utils import create_txt_file
|
||||
|
||||
HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
KB_APP_URL = "/v1/kb"
|
||||
KB_APP_URL = f"/{VERSION}/kb"
|
||||
DOCUMENT_APP_URL = f"/{VERSION}/document"
|
||||
# FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents"
|
||||
# FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks"
|
||||
# CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||||
@ -27,7 +32,7 @@ KB_APP_URL = "/v1/kb"
|
||||
# SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
|
||||
|
||||
|
||||
# DATASET MANAGEMENT
|
||||
# KB APP
|
||||
def create_kb(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/create", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
@ -91,3 +96,71 @@ def batch_create_datasets(auth, num):
|
||||
res = create_kb(auth, {"name": f"kb_{i}"})
|
||||
ids.append(res["data"]["kb_id"])
|
||||
return ids
|
||||
|
||||
|
||||
# DOCUMENT APP
|
||||
def upload_documents(auth, payload=None, files_path=None):
|
||||
url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload"
|
||||
|
||||
if files_path is None:
|
||||
files_path = []
|
||||
|
||||
fields = []
|
||||
file_objects = []
|
||||
try:
|
||||
if payload:
|
||||
for k, v in payload.items():
|
||||
fields.append((k, str(v)))
|
||||
|
||||
for fp in files_path:
|
||||
p = Path(fp)
|
||||
f = p.open("rb")
|
||||
fields.append(("file", (p.name, f)))
|
||||
file_objects.append(f)
|
||||
m = MultipartEncoder(fields=fields)
|
||||
|
||||
res = requests.post(
|
||||
url=url,
|
||||
headers={"Content-Type": m.content_type},
|
||||
auth=auth,
|
||||
data=m,
|
||||
)
|
||||
return res.json()
|
||||
finally:
|
||||
for f in file_objects:
|
||||
f.close()
|
||||
|
||||
|
||||
def create_document(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/create", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
def list_documents(auth, params=None, payload=None, *, headers=HEADERS, data=None):
|
||||
if payload is None:
|
||||
payload = {}
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/list", headers=headers, auth=auth, params=params, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
def delete_document(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/rm", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
def parse_documents(auth, payload=None, *, headers=HEADERS, data=None):
|
||||
res = requests.post(url=f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/run", headers=headers, auth=auth, json=payload, data=data)
|
||||
return res.json()
|
||||
|
||||
|
||||
def bulk_upload_documents(auth, kb_id, num, tmp_path):
|
||||
fps = []
|
||||
for i in range(num):
|
||||
fp = create_txt_file(tmp_path / f"ragflow_test_upload_{i}.txt")
|
||||
fps.append(fp)
|
||||
|
||||
res = upload_documents(auth, {"kb_id": kb_id}, fps)
|
||||
document_ids = []
|
||||
for document in res["data"]:
|
||||
document_ids.append(document["id"])
|
||||
return document_ids
|
||||
|
||||
@ -16,11 +16,15 @@
|
||||
import pytest
|
||||
from common import (
|
||||
batch_create_datasets,
|
||||
list_kbs,
|
||||
rm_kb,
|
||||
)
|
||||
from configs import HOST_ADDRESS, VERSION
|
||||
|
||||
# from configs import HOST_ADDRESS, VERSION
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from pytest import FixtureRequest
|
||||
from ragflow_sdk import RAGFlow
|
||||
|
||||
# from ragflow_sdk import RAGFlow
|
||||
from utils.file_utils import (
|
||||
create_docx_file,
|
||||
create_eml_file,
|
||||
@ -69,32 +73,38 @@ def WebApiAuth(auth):
|
||||
return RAGFlowWebApiAuth(auth)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(token: str) -> RAGFlow:
|
||||
return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)
|
||||
# @pytest.fixture(scope="session")
|
||||
# def client(token: str) -> RAGFlow:
|
||||
# return RAGFlow(api_key=token, base_url=HOST_ADDRESS, version=VERSION)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def clear_datasets(request: FixtureRequest, client: RAGFlow):
|
||||
def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth):
|
||||
def cleanup():
|
||||
client.delete_datasets(ids=None)
|
||||
res = list_kbs(WebApiAuth, params={"page_size": 1000})
|
||||
for kb in res["data"]["kbs"]:
|
||||
rm_kb(WebApiAuth, {"kb_id": kb["id"]})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_dataset(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> str:
|
||||
def add_dataset(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> str:
|
||||
def cleanup():
|
||||
client.delete_datasets(ids=None)
|
||||
res = list_kbs(WebApiAuth, params={"page_size": 1000})
|
||||
for kb in res["data"]["kbs"]:
|
||||
rm_kb(WebApiAuth, {"kb_id": kb["id"]})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return batch_create_datasets(WebApiAuth, 1)[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_dataset_func(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWebApiAuth) -> str:
|
||||
def add_dataset_func(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> str:
|
||||
def cleanup():
|
||||
client.delete_datasets(ids=None)
|
||||
res = list_kbs(WebApiAuth, params={"page_size": 1000})
|
||||
for kb in res["data"]["kbs"]:
|
||||
rm_kb(WebApiAuth, {"kb_id": kb["id"]})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return batch_create_datasets(WebApiAuth, 1)[0]
|
||||
|
||||
58
test/testcases/test_web_api/test_document_app/conftest.py
Normal file
58
test/testcases/test_web_api/test_document_app/conftest.py
Normal file
@ -0,0 +1,58 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import pytest
|
||||
from common import bulk_upload_documents, delete_document, list_documents
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_document_func(request, WebApiAuth, add_dataset, ragflow_tmp_dir):
|
||||
def cleanup():
|
||||
res = list_documents(WebApiAuth, {"kb_id": dataset_id})
|
||||
for doc in res["data"]["docs"]:
|
||||
delete_document(WebApiAuth, {"doc_id": doc["id"]})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
dataset_id = add_dataset
|
||||
return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 1, ragflow_tmp_dir)[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_documents(request, WebApiAuth, add_dataset, ragflow_tmp_dir):
|
||||
def cleanup():
|
||||
res = list_documents(WebApiAuth, {"kb_id": dataset_id})
|
||||
for doc in res["data"]["docs"]:
|
||||
delete_document(WebApiAuth, {"doc_id": doc["id"]})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
dataset_id = add_dataset
|
||||
return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 5, ragflow_tmp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def add_documents_func(request, WebApiAuth, add_dataset_func, ragflow_tmp_dir):
|
||||
def cleanup():
|
||||
res = list_documents(WebApiAuth, {"kb_id": dataset_id})
|
||||
for doc in res["data"]["docs"]:
|
||||
delete_document(WebApiAuth, {"doc_id": doc["id"]})
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
dataset_id = add_dataset_func
|
||||
return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 3, ragflow_tmp_dir)
|
||||
@ -0,0 +1,92 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import string
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import create_document, list_kbs
|
||||
from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from utils.file_utils import create_txt_file
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = create_document(invalid_auth)
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestDocumentCreate:
|
||||
@pytest.mark.p3
|
||||
def test_filename_empty(self, WebApiAuth, add_dataset_func):
|
||||
kb_id = add_dataset_func
|
||||
payload = {"name": "", "kb_id": kb_id}
|
||||
res = create_document(WebApiAuth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == "File name can't be empty.", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filename_max_length(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 4)}.txt")
|
||||
res = create_document(WebApiAuth, {"name": fp.name, "kb_id": kb_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["name"] == fp.name, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_invalid_kb_id(self, WebApiAuth):
|
||||
res = create_document(WebApiAuth, {"name": "ragflow_test.txt", "kb_id": "invalid_kb_id"})
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == "Can't find this knowledgebase!", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_filename_special_characters(self, WebApiAuth, add_dataset_func):
|
||||
kb_id = add_dataset_func
|
||||
illegal_chars = '<>:"/\\|?*'
|
||||
translation_table = str.maketrans({char: "_" for char in illegal_chars})
|
||||
safe_filename = string.punctuation.translate(translation_table)
|
||||
filename = f"{safe_filename}.txt"
|
||||
|
||||
res = create_document(WebApiAuth, {"name": filename, "kb_id": kb_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["kb_id"] == kb_id, res
|
||||
assert res["data"]["name"] == filename, f"Expected: {filename}, Got: {res['data']['name']}"
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_upload(self, WebApiAuth, add_dataset_func):
|
||||
kb_id = add_dataset_func
|
||||
|
||||
count = 20
|
||||
filenames = [f"ragflow_test_{i}.txt" for i in range(count)]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(create_document, WebApiAuth, {"name": name, "kb_id": kb_id}) for name in filenames]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures), responses
|
||||
|
||||
res = list_kbs(WebApiAuth, {"id": kb_id})
|
||||
assert res["data"]["kbs"][0]["doc_num"] == count, res
|
||||
@ -0,0 +1,180 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import list_documents
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from utils import is_sorted
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = list_documents(invalid_auth, {"kb_id": "dataset_id"})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
|
||||
class TestDocumentsList:
|
||||
@pytest.mark.p1
|
||||
def test_default(self, WebApiAuth, add_documents):
|
||||
kb_id, _ = add_documents
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id})
|
||||
assert res["code"] == 0
|
||||
assert len(res["data"]["docs"]) == 5
|
||||
assert res["data"]["total"] == 5
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"kb_id, expected_code, expected_message",
|
||||
[
|
||||
("", 101, 'Lack of "KB ID"'),
|
||||
("invalid_dataset_id", 103, "Only owner of knowledgebase authorized for this operation."),
|
||||
],
|
||||
)
|
||||
def test_invalid_dataset_id(self, WebApiAuth, kb_id, expected_code, expected_message):
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id})
|
||||
assert res["code"] == expected_code
|
||||
assert res["message"] == expected_message
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"page": None, "page_size": 2}, 0, 5, ""),
|
||||
({"page": 0, "page_size": 2}, 0, 5, ""),
|
||||
({"page": 2, "page_size": 2}, 0, 2, ""),
|
||||
({"page": 3, "page_size": 2}, 0, 1, ""),
|
||||
({"page": "3", "page_size": 2}, 0, 1, ""),
|
||||
pytest.param({"page": -1, "page_size": 2}, 100, 0, "1064", marks=pytest.mark.skip(reason="issues/5851")),
|
||||
pytest.param({"page": "a", "page_size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: 'a'")""", marks=pytest.mark.skip(reason="issues/5851")),
|
||||
],
|
||||
)
|
||||
def test_page(self, WebApiAuth, add_documents, params, expected_code, expected_page_size, expected_message):
|
||||
kb_id, _ = add_documents
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id, **params})
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["docs"]) == expected_page_size, res
|
||||
assert res["data"]["total"] == 5, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, expected_page_size, expected_message",
|
||||
[
|
||||
({"page_size": None}, 0, 5, ""),
|
||||
({"page_size": 0}, 0, 5, ""),
|
||||
({"page_size": 1}, 0, 5, ""),
|
||||
({"page_size": 6}, 0, 5, ""),
|
||||
({"page_size": "1"}, 0, 5, ""),
|
||||
pytest.param({"page_size": -1}, 100, 0, "1064", marks=pytest.mark.skip(reason="issues/5851")),
|
||||
pytest.param({"page_size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: 'a'")""", marks=pytest.mark.skip(reason="issues/5851")),
|
||||
],
|
||||
)
|
||||
def test_page_size(self, WebApiAuth, add_documents, params, expected_code, expected_page_size, expected_message):
|
||||
kb_id, _ = add_documents
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id, **params})
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
assert len(res["data"]["docs"]) == expected_page_size, res
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, assertions, expected_message",
|
||||
[
|
||||
({"orderby": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
|
||||
({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
|
||||
({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", True)), ""),
|
||||
pytest.param({"orderby": "name", "desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "name", False)), "", marks=pytest.mark.skip(reason="issues/5851")),
|
||||
pytest.param({"orderby": "unknown"}, 102, 0, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")),
|
||||
],
|
||||
)
|
||||
def test_orderby(self, WebApiAuth, add_documents, params, expected_code, assertions, expected_message):
|
||||
kb_id, _ = add_documents
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id, **params})
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
if callable(assertions):
|
||||
assert assertions(res)
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_code, assertions, expected_message",
|
||||
[
|
||||
({"desc": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
|
||||
({"desc": "true"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
|
||||
({"desc": "True"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
|
||||
({"desc": True}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
|
||||
pytest.param({"desc": "false"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), "", marks=pytest.mark.skip(reason="issues/5851")),
|
||||
({"desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""),
|
||||
({"desc": False}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""),
|
||||
({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", False)), ""),
|
||||
pytest.param({"desc": "unknown"}, 102, 0, "desc should be true or false", marks=pytest.mark.skip(reason="issues/5851")),
|
||||
],
|
||||
)
|
||||
def test_desc(self, WebApiAuth, add_documents, params, expected_code, assertions, expected_message):
|
||||
kb_id, _ = add_documents
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id, **params})
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
if callable(assertions):
|
||||
assert assertions(res)
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_num",
|
||||
[
|
||||
({"keywords": None}, 5),
|
||||
({"keywords": ""}, 5),
|
||||
({"keywords": "0"}, 1),
|
||||
({"keywords": "ragflow_test_upload"}, 5),
|
||||
({"keywords": "unknown"}, 0),
|
||||
],
|
||||
)
|
||||
def test_keywords(self, WebApiAuth, add_documents, params, expected_num):
|
||||
kb_id, _ = add_documents
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id, **params})
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]["docs"]) == expected_num, res
|
||||
assert res["data"]["total"] == expected_num, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_list(self, WebApiAuth, add_documents):
|
||||
kb_id, _ = add_documents
|
||||
count = 100
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(list_documents, WebApiAuth, {"kb_id": kb_id}) for i in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures), responses
|
||||
@ -0,0 +1,256 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import bulk_upload_documents, list_documents, parse_documents
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from utils import wait_for
|
||||
|
||||
|
||||
@wait_for(30, 1, "Document parsing timeout")
|
||||
def condition(_auth, _kb_id, _document_ids=None):
|
||||
res = list_documents(_auth, {"kb_id": _kb_id})
|
||||
target_docs = res["data"]["docs"]
|
||||
|
||||
if _document_ids is None:
|
||||
for doc in target_docs:
|
||||
if doc["run"] != "3":
|
||||
return False
|
||||
return True
|
||||
|
||||
target_ids = set(_document_ids)
|
||||
for doc in target_docs:
|
||||
if doc["id"] in target_ids:
|
||||
if doc.get("run") != "3":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def validate_document_parse_done(auth, _kb_id, _document_ids):
|
||||
res = list_documents(auth, {"kb_id": _kb_id})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["id"] not in _document_ids:
|
||||
continue
|
||||
assert doc["run"] == "3"
|
||||
assert len(doc["process_begin_at"]) > 0
|
||||
assert doc["process_duation"] > 0
|
||||
assert doc["progress"] > 0
|
||||
assert "Task done" in doc["progress_msg"]
|
||||
|
||||
|
||||
def validate_document_parse_cancel(auth, _kb_id, _document_ids):
|
||||
res = list_documents(auth, {"kb_id": _kb_id})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["id"] not in _document_ids:
|
||||
continue
|
||||
assert doc["run"] == "2"
|
||||
assert len(doc["process_begin_at"]) > 0
|
||||
assert doc["progress"] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = parse_documents(invalid_auth)
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestDocumentsParse:
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
pytest.param(None, 101, "required argument are missing: doc_ids, run; ", marks=pytest.mark.skip),
|
||||
pytest.param({"doc_ids": [], "run": "1"}, 0, "", marks=pytest.mark.p1),
|
||||
pytest.param({"doc_ids": ["invalid_id"], "run": "1"}, 109, "No authorization.", marks=pytest.mark.p3),
|
||||
pytest.param({"doc_ids": ["\n!?。;!?\"'"], "run": "1"}, 109, "No authorization.", marks=pytest.mark.p3),
|
||||
pytest.param("not json", 101, "required argument are missing: doc_ids, run; ", marks=pytest.mark.skip),
|
||||
pytest.param(lambda r: {"doc_ids": r[:1], "run": "1"}, 0, "", marks=pytest.mark.p1),
|
||||
pytest.param(lambda r: {"doc_ids": r, "run": "1"}, 0, "", marks=pytest.mark.p1),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, WebApiAuth, add_documents_func, payload, expected_code, expected_message):
|
||||
kb_id, document_ids = add_documents_func
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = parse_documents(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
condition(WebApiAuth, kb_id, payload["doc_ids"])
|
||||
validate_document_parse_done(WebApiAuth, kb_id, payload["doc_ids"])
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload",
|
||||
[
|
||||
pytest.param(lambda r: {"doc_ids": ["invalid_id"] + r, "run": "1"}, marks=pytest.mark.p3),
|
||||
pytest.param(lambda r: {"doc_ids": r[:1] + ["invalid_id"] + r[1:3], "run": "1"}, marks=pytest.mark.p1),
|
||||
pytest.param(lambda r: {"doc_ids": r + ["invalid_id"], "run": "1"}, marks=pytest.mark.p3),
|
||||
],
|
||||
)
|
||||
def test_parse_partial_invalid_document_id(self, WebApiAuth, add_documents_func, payload):
|
||||
_, document_ids = add_documents_func
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = parse_documents(WebApiAuth, payload)
|
||||
assert res["code"] == 109, res
|
||||
assert res["message"] == "No authorization.", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_repeated_parse(self, WebApiAuth, add_documents_func):
|
||||
kb_id, document_ids = add_documents_func
|
||||
res = parse_documents(WebApiAuth, {"doc_ids": document_ids, "run": "1"})
|
||||
assert res["code"] == 0, res
|
||||
|
||||
condition(WebApiAuth, kb_id, document_ids)
|
||||
|
||||
res = parse_documents(WebApiAuth, {"doc_ids": document_ids, "run": "1"})
|
||||
assert res["code"] == 0, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_duplicate_parse(self, WebApiAuth, add_documents_func):
|
||||
kb_id, document_ids = add_documents_func
|
||||
res = parse_documents(WebApiAuth, {"doc_ids": document_ids + document_ids, "run": "1"})
|
||||
assert res["code"] == 0, res
|
||||
assert res["message"] == "success", res
|
||||
|
||||
condition(WebApiAuth, kb_id, document_ids)
|
||||
validate_document_parse_done(WebApiAuth, kb_id, document_ids)
|
||||
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_parse_100_files(WebApiAuth, add_dataset_func, tmp_path):
|
||||
@wait_for(100, 1, "Document parsing timeout")
|
||||
def condition(_auth, _kb_id, _document_num):
|
||||
res = list_documents(_auth, {"kb_id": _kb_id, "page_size": _document_num})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "3":
|
||||
return False
|
||||
return True
|
||||
|
||||
document_num = 100
|
||||
kb_id = add_dataset_func
|
||||
document_ids = bulk_upload_documents(WebApiAuth, kb_id, document_num, tmp_path)
|
||||
res = parse_documents(WebApiAuth, {"doc_ids": document_ids, "run": "1"})
|
||||
assert res["code"] == 0, res
|
||||
|
||||
condition(WebApiAuth, kb_id, document_num)
|
||||
|
||||
validate_document_parse_done(WebApiAuth, kb_id, document_ids)
|
||||
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_parse(WebApiAuth, add_dataset_func, tmp_path):
|
||||
@wait_for(120, 1, "Document parsing timeout")
|
||||
def condition(_auth, _kb_id, _document_num):
|
||||
res = list_documents(_auth, {"kb_id": _kb_id, "page_size": _document_num})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["run"] != "3":
|
||||
return False
|
||||
return True
|
||||
|
||||
count = 100
|
||||
kb_id = add_dataset_func
|
||||
document_ids = bulk_upload_documents(WebApiAuth, kb_id, count, tmp_path)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
parse_documents,
|
||||
WebApiAuth,
|
||||
{"doc_ids": [document_ids[i]], "run": "1"},
|
||||
)
|
||||
for i in range(count)
|
||||
]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
|
||||
condition(WebApiAuth, kb_id, count)
|
||||
|
||||
validate_document_parse_done(WebApiAuth, kb_id, document_ids)
|
||||
|
||||
|
||||
# @pytest.mark.skip
|
||||
class TestDocumentsParseStop:
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message",
|
||||
[
|
||||
pytest.param(None, 101, "required argument are missing: doc_ids, run; ", marks=pytest.mark.skip),
|
||||
pytest.param({"doc_ids": [], "run": "2"}, 0, "", marks=pytest.mark.p1),
|
||||
pytest.param({"doc_ids": ["invalid_id"], "run": "2"}, 109, "No authorization.", marks=pytest.mark.p3),
|
||||
pytest.param({"doc_ids": ["\n!?。;!?\"'"], "run": "2"}, 109, "No authorization.", marks=pytest.mark.p3),
|
||||
pytest.param("not json", 101, "required argument are missing: doc_ids, run; ", marks=pytest.mark.skip),
|
||||
pytest.param(lambda r: {"doc_ids": r[:1], "run": "2"}, 0, "", marks=pytest.mark.p1),
|
||||
pytest.param(lambda r: {"doc_ids": r, "run": "2"}, 0, "", marks=pytest.mark.p1),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, WebApiAuth, add_documents_func, payload, expected_code, expected_message):
|
||||
@wait_for(10, 1, "Document parsing timeout")
|
||||
def condition(_auth, _kb_id, _doc_ids):
|
||||
res = list_documents(_auth, {"kb_id": _kb_id})
|
||||
for doc in res["data"]["docs"]:
|
||||
if doc["id"] in _doc_ids:
|
||||
if doc["run"] != "3":
|
||||
return False
|
||||
return True
|
||||
|
||||
kb_id, document_ids = add_documents_func
|
||||
parse_documents(WebApiAuth, {"doc_ids": document_ids, "run": "1"})
|
||||
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
|
||||
res = parse_documents(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if expected_code == 0:
|
||||
completed_document_ids = list(set(document_ids) - set(payload["doc_ids"]))
|
||||
condition(WebApiAuth, kb_id, completed_document_ids)
|
||||
validate_document_parse_cancel(WebApiAuth, kb_id, payload["doc_ids"])
|
||||
validate_document_parse_done(WebApiAuth, kb_id, completed_document_ids)
|
||||
else:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.parametrize(
|
||||
"payload",
|
||||
[
|
||||
lambda r: {"doc_ids": ["invalid_id"] + r, "run": "2"},
|
||||
lambda r: {"doc_ids": r[:1] + ["invalid_id"] + r[1:3], "run": "2"},
|
||||
lambda r: {"doc_ids": r + ["invalid_id"], "run": "2"},
|
||||
],
|
||||
)
|
||||
def test_stop_parse_partial_invalid_document_id(self, WebApiAuth, add_documents_func, payload):
|
||||
kb_id, document_ids = add_documents_func
|
||||
parse_documents(WebApiAuth, {"doc_ids": document_ids, "run": "1"})
|
||||
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = parse_documents(WebApiAuth, payload)
|
||||
assert res["code"] == 109, res
|
||||
assert res["message"] == "No authorization.", res
|
||||
|
||||
validate_document_parse_cancel(WebApiAuth, kb_id, document_ids)
|
||||
@ -0,0 +1,104 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import bulk_upload_documents, delete_document, list_documents
|
||||
from configs import INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = delete_document(invalid_auth)
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestDocumentsDeletion:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_code, expected_message, remaining",
|
||||
[
|
||||
(None, 101, "required argument are missing: doc_id; ", 3),
|
||||
({"doc_id": ""}, 109, "No authorization.", 3),
|
||||
({"doc_id": "invalid_id"}, 109, "No authorization.", 3),
|
||||
({"doc_id": "\n!?。;!?\"'"}, 109, "No authorization.", 3),
|
||||
("not json", 101, "required argument are missing: doc_id; ", 3),
|
||||
(lambda r: {"doc_id": r[0]}, 0, "", 2),
|
||||
],
|
||||
)
|
||||
def test_basic_scenarios(self, WebApiAuth, add_documents_func, payload, expected_code, expected_message, remaining):
|
||||
kb_id, document_ids = add_documents_func
|
||||
if callable(payload):
|
||||
payload = payload(document_ids)
|
||||
res = delete_document(WebApiAuth, payload)
|
||||
assert res["code"] == expected_code, res
|
||||
if res["code"] != 0:
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id})
|
||||
assert len(res["data"]["docs"]) == remaining, res
|
||||
assert res["data"]["total"] == remaining, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_repeated_deletion(self, WebApiAuth, add_documents_func):
|
||||
_, document_ids = add_documents_func
|
||||
for doc_id in document_ids:
|
||||
res = delete_document(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
|
||||
for doc_id in document_ids:
|
||||
res = delete_document(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 109, res
|
||||
assert res["message"] == "No authorization.", res
|
||||
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_deletion(WebApiAuth, add_dataset, tmp_path):
|
||||
count = 100
|
||||
kb_id = add_dataset
|
||||
document_ids = bulk_upload_documents(WebApiAuth, kb_id, count, tmp_path)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(delete_document, WebApiAuth, {"doc_id": document_ids[i]}) for i in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures), responses
|
||||
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_delete_100(WebApiAuth, add_dataset, tmp_path):
|
||||
documents_num = 100
|
||||
kb_id = add_dataset
|
||||
document_ids = bulk_upload_documents(WebApiAuth, kb_id, documents_num, tmp_path)
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id})
|
||||
assert res["data"]["total"] == documents_num, res
|
||||
|
||||
for doc_id in document_ids:
|
||||
res = delete_document(WebApiAuth, {"doc_id": doc_id})
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_documents(WebApiAuth, {"kb_id": kb_id})
|
||||
assert res["data"]["total"] == 0, res
|
||||
@ -0,0 +1,201 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import string
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from common import DOCUMENT_APP_URL, list_kbs, upload_documents
|
||||
from configs import DOCUMENT_NAME_LIMIT, HOST_ADDRESS, INVALID_API_TOKEN
|
||||
from libs.auth import RAGFlowWebApiAuth
|
||||
from requests_toolbelt import MultipartEncoder
|
||||
from utils.file_utils import create_txt_file
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.usefixtures("clear_datasets")
|
||||
class TestAuthorization:
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_code, expected_message",
|
||||
[
|
||||
(None, 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
(RAGFlowWebApiAuth(INVALID_API_TOKEN), 401, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
|
||||
res = upload_documents(invalid_auth)
|
||||
assert res["code"] == expected_code, res
|
||||
assert res["message"] == expected_message, res
|
||||
|
||||
|
||||
class TestDocumentsUpload:
|
||||
@pytest.mark.p1
|
||||
def test_valid_single_upload(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp])
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["kb_id"] == kb_id, res
|
||||
assert res["data"][0]["name"] == fp.name, res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"generate_test_files",
|
||||
[
|
||||
"docx",
|
||||
"excel",
|
||||
"ppt",
|
||||
"image",
|
||||
"pdf",
|
||||
"txt",
|
||||
"md",
|
||||
"json",
|
||||
"eml",
|
||||
"html",
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_file_type_validation(self, WebApiAuth, add_dataset_func, generate_test_files, request):
|
||||
kb_id = add_dataset_func
|
||||
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp])
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["kb_id"] == kb_id, res
|
||||
assert res["data"][0]["name"] == fp.name, res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"file_type",
|
||||
["exe", "unknown"],
|
||||
)
|
||||
def test_unsupported_file_type(self, WebApiAuth, add_dataset_func, tmp_path, file_type):
|
||||
kb_id = add_dataset_func
|
||||
fp = tmp_path / f"ragflow_test.{file_type}"
|
||||
fp.touch()
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp])
|
||||
assert res["code"] == 500, res
|
||||
assert res["message"] == f"ragflow_test.{file_type}: This type of file has not been supported yet!", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_missing_file(self, WebApiAuth, add_dataset_func):
|
||||
kb_id = add_dataset_func
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id})
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == "No file part!", res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_empty_file(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
fp = tmp_path / "empty.txt"
|
||||
fp.touch()
|
||||
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp])
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["size"] == 0, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_filename_empty(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
url = f"{HOST_ADDRESS}{DOCUMENT_APP_URL}/upload"
|
||||
fields = [("file", ("", fp.open("rb"))), ("kb_id", kb_id)]
|
||||
m = MultipartEncoder(fields=fields)
|
||||
res = requests.post(
|
||||
url=url,
|
||||
headers={"Content-Type": m.content_type},
|
||||
auth=WebApiAuth,
|
||||
data=m,
|
||||
)
|
||||
assert res.json()["code"] == 101, res
|
||||
assert res.json()["message"] == "No file selected!", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filename_exceeds_max_length(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 4)}.txt")
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp])
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["name"] == fp.name, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_invalid_kb_id(self, WebApiAuth, tmp_path):
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
res = upload_documents(WebApiAuth, {"kb_id": "invalid_kb_id"}, [fp])
|
||||
assert res["code"] == 100, res
|
||||
assert res["message"] == """LookupError("Can't find this knowledgebase!")""", res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_duplicate_files(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
fp = create_txt_file(tmp_path / "ragflow_test.txt")
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp, fp])
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 2, res
|
||||
for i in range(len(res["data"])):
|
||||
assert res["data"][i]["kb_id"] == kb_id, res
|
||||
expected_name = fp.name
|
||||
if i != 0:
|
||||
expected_name = f"{fp.stem}({i}){fp.suffix}"
|
||||
assert res["data"][i]["name"] == expected_name, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_filename_special_characters(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
illegal_chars = '<>:"/\\|?*'
|
||||
translation_table = str.maketrans({char: "_" for char in illegal_chars})
|
||||
safe_filename = string.punctuation.translate(translation_table)
|
||||
fp = tmp_path / f"{safe_filename}.txt"
|
||||
fp.write_text("Sample text content")
|
||||
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, [fp])
|
||||
assert res["code"] == 0, res
|
||||
assert len(res["data"]) == 1, res
|
||||
assert res["data"][0]["kb_id"] == kb_id, res
|
||||
assert res["data"][0]["name"] == fp.name, res
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_multiple_files(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
expected_document_count = 20
|
||||
fps = []
|
||||
for i in range(expected_document_count):
|
||||
fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt")
|
||||
fps.append(fp)
|
||||
res = upload_documents(WebApiAuth, {"kb_id": kb_id}, fps)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_kbs(WebApiAuth)
|
||||
assert res["data"]["kbs"][0]["doc_num"] == expected_document_count, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_concurrent_upload(self, WebApiAuth, add_dataset_func, tmp_path):
|
||||
kb_id = add_dataset_func
|
||||
|
||||
count = 20
|
||||
fps = []
|
||||
for i in range(count):
|
||||
fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt")
|
||||
fps.append(fp)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(upload_documents, WebApiAuth, {"kb_id": kb_id}, fps[i : i + 1]) for i in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures), responses
|
||||
|
||||
res = list_kbs(WebApiAuth)
|
||||
assert res["data"]["kbs"][0]["doc_num"] == count, res
|
||||
Reference in New Issue
Block a user