Refa: Move HTTP API tests to top-level test directory (#8042)

### What problem does this PR solve?

Move test cases only - CI still runs tests under sdk/python

### Type of change

- [x] Refactoring
This commit is contained in:
Liu An
2025-06-04 13:16:17 +08:00
committed by GitHub
parent b832372c98
commit 52c814b89d
39 changed files with 7934 additions and 6 deletions

View File

@ -0,0 +1,257 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from pathlib import Path
import requests
from requests_toolbelt import MultipartEncoder
from utils.file_utils import create_txt_file
HEADERS = {"Content-Type": "application/json"}
HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380")
DATASETS_API_URL = "/api/v1/datasets"
FILE_API_URL = "/api/v1/datasets/{dataset_id}/documents"
FILE_CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/chunks"
CHUNK_API_URL = "/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
CHAT_ASSISTANT_API_URL = "/api/v1/chats"
SESSION_WITH_CHAT_ASSISTANT_API_URL = "/api/v1/chats/{chat_id}/sessions"
SESSION_WITH_AGENT_API_URL = "/api/v1/agents/{agent_id}/sessions"
INVALID_API_TOKEN = "invalid_key_123"
DATASET_NAME_LIMIT = 128
DOCUMENT_NAME_LIMIT = 128
CHAT_ASSISTANT_NAME_LIMIT = 255
SESSION_WITH_CHAT_NAME_LIMIT = 255
# DATASET MANAGEMENT
def create_dataset(auth, payload=None, *, headers=HEADERS, data=None):
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json()
def list_datasets(auth, params=None, *, headers=HEADERS):
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params)
return res.json()
def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None):
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data)
return res.json()
def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None):
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json()
def batch_create_datasets(auth, num):
ids = []
for i in range(num):
res = create_dataset(auth, {"name": f"dataset_{i}"})
ids.append(res["data"]["id"])
return ids
# FILE MANAGEMENT WITHIN DATASET
def upload_documnets(auth, dataset_id, files_path=None):
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
if files_path is None:
files_path = []
fields = []
file_objects = []
try:
for fp in files_path:
p = Path(fp)
f = p.open("rb")
fields.append(("file", (p.name, f)))
file_objects.append(f)
m = MultipartEncoder(fields=fields)
res = requests.post(
url=url,
headers={"Content-Type": m.content_type},
auth=auth,
data=m,
)
return res.json()
finally:
for f in file_objects:
f.close()
def download_document(auth, dataset_id, document_id, save_path):
url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id)
res = requests.get(url=url, auth=auth, stream=True)
try:
if res.status_code == 200:
with open(save_path, "wb") as f:
for chunk in res.iter_content(chunk_size=8192):
f.write(chunk)
finally:
res.close()
return res
def list_documnets(auth, dataset_id, params=None):
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
return res.json()
def update_documnet(auth, dataset_id, document_id, payload=None):
url = f"{HOST_ADDRESS}{FILE_API_URL}/{document_id}".format(dataset_id=dataset_id)
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def delete_documnets(auth, dataset_id, payload=None):
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def parse_documnets(auth, dataset_id, payload=None):
url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id)
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def stop_parse_documnets(auth, dataset_id, payload=None):
url = f"{HOST_ADDRESS}{FILE_CHUNK_API_URL}".format(dataset_id=dataset_id)
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def bulk_upload_documents(auth, dataset_id, num, tmp_path):
fps = []
for i in range(num):
fp = create_txt_file(tmp_path / f"ragflow_test_upload_{i}.txt")
fps.append(fp)
res = upload_documnets(auth, dataset_id, fps)
document_ids = []
for document in res["data"]:
document_ids.append(document["id"])
return document_ids
# CHUNK MANAGEMENT WITHIN DATASET
def add_chunk(auth, dataset_id, document_id, payload=None):
url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id)
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def list_chunks(auth, dataset_id, document_id, params=None):
url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id)
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
return res.json()
def update_chunk(auth, dataset_id, document_id, chunk_id, payload=None):
url = f"{HOST_ADDRESS}{CHUNK_API_URL}/{chunk_id}".format(dataset_id=dataset_id, document_id=document_id)
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def delete_chunks(auth, dataset_id, document_id, payload=None):
url = f"{HOST_ADDRESS}{CHUNK_API_URL}".format(dataset_id=dataset_id, document_id=document_id)
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def retrieval_chunks(auth, payload=None):
url = f"{HOST_ADDRESS}/api/v1/retrieval"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def batch_add_chunks(auth, dataset_id, document_id, num):
chunk_ids = []
for i in range(num):
res = add_chunk(auth, dataset_id, document_id, {"content": f"chunk test {i}"})
chunk_ids.append(res["data"]["chunk"]["id"])
return chunk_ids
# CHAT ASSISTANT MANAGEMENT
def create_chat_assistant(auth, payload=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def list_chat_assistants(auth, params=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
return res.json()
def update_chat_assistant(auth, chat_assistant_id, payload=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}/{chat_assistant_id}"
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def delete_chat_assistants(auth, payload=None):
url = f"{HOST_ADDRESS}{CHAT_ASSISTANT_API_URL}"
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def batch_create_chat_assistants(auth, num):
chat_assistant_ids = []
for i in range(num):
res = create_chat_assistant(auth, {"name": f"test_chat_assistant_{i}", "dataset_ids": []})
chat_assistant_ids.append(res["data"]["id"])
return chat_assistant_ids
# SESSION MANAGEMENT
def create_session_with_chat_assistant(auth, chat_assistant_id, payload=None):
url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id)
res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def list_session_with_chat_assistants(auth, chat_assistant_id, params=None):
url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id)
res = requests.get(url=url, headers=HEADERS, auth=auth, params=params)
return res.json()
def update_session_with_chat_assistant(auth, chat_assistant_id, session_id, payload=None):
url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}/{session_id}".format(chat_id=chat_assistant_id)
res = requests.put(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def delete_session_with_chat_assistants(auth, chat_assistant_id, payload=None):
url = f"{HOST_ADDRESS}{SESSION_WITH_CHAT_ASSISTANT_API_URL}".format(chat_id=chat_assistant_id)
res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload)
return res.json()
def batch_add_sessions_with_chat_assistant(auth, chat_assistant_id, num):
session_ids = []
for i in range(num):
res = create_session_with_chat_assistant(auth, chat_assistant_id, {"name": f"session_with_chat_assistant_{i}"})
session_ids.append(res["data"]["id"])
return session_ids

View File

@ -0,0 +1,177 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import (
add_chunk,
batch_create_datasets,
bulk_upload_documents,
create_chat_assistant,
delete_chat_assistants,
delete_datasets,
delete_session_with_chat_assistants,
list_documnets,
parse_documnets,
)
from libs.auth import RAGFlowHttpApiAuth
from utils import wait_for
from utils.file_utils import (
create_docx_file,
create_eml_file,
create_excel_file,
create_html_file,
create_image_file,
create_json_file,
create_md_file,
create_pdf_file,
create_ppt_file,
create_txt_file,
)
@wait_for(30, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnets(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
@pytest.fixture(scope="session")
def api_key(token):
return RAGFlowHttpApiAuth(token)
@pytest.fixture(scope="function")
def clear_datasets(request, api_key):
def cleanup():
delete_datasets(api_key, {"ids": None})
request.addfinalizer(cleanup)
@pytest.fixture(scope="function")
def clear_chat_assistants(request, api_key):
def cleanup():
delete_chat_assistants(api_key)
request.addfinalizer(cleanup)
@pytest.fixture(scope="function")
def clear_session_with_chat_assistants(request, api_key, add_chat_assistants):
_, _, chat_assistant_ids = add_chat_assistants
def cleanup():
for chat_assistant_id in chat_assistant_ids:
delete_session_with_chat_assistants(api_key, chat_assistant_id)
request.addfinalizer(cleanup)
@pytest.fixture
def generate_test_files(request, tmp_path):
file_creators = {
"docx": (tmp_path / "ragflow_test.docx", create_docx_file),
"excel": (tmp_path / "ragflow_test.xlsx", create_excel_file),
"ppt": (tmp_path / "ragflow_test.pptx", create_ppt_file),
"image": (tmp_path / "ragflow_test.png", create_image_file),
"pdf": (tmp_path / "ragflow_test.pdf", create_pdf_file),
"txt": (tmp_path / "ragflow_test.txt", create_txt_file),
"md": (tmp_path / "ragflow_test.md", create_md_file),
"json": (tmp_path / "ragflow_test.json", create_json_file),
"eml": (tmp_path / "ragflow_test.eml", create_eml_file),
"html": (tmp_path / "ragflow_test.html", create_html_file),
}
files = {}
for file_type, (file_path, creator_func) in file_creators.items():
if request.param in ["", file_type]:
creator_func(file_path)
files[file_type] = file_path
return files
@pytest.fixture(scope="class")
def ragflow_tmp_dir(request, tmp_path_factory):
class_name = request.cls.__name__
return tmp_path_factory.mktemp(class_name)
@pytest.fixture(scope="class")
def add_dataset(request, api_key):
def cleanup():
delete_datasets(api_key, {"ids": None})
request.addfinalizer(cleanup)
dataset_ids = batch_create_datasets(api_key, 1)
return dataset_ids[0]
@pytest.fixture(scope="function")
def add_dataset_func(request, api_key):
def cleanup():
delete_datasets(api_key, {"ids": None})
request.addfinalizer(cleanup)
return batch_create_datasets(api_key, 1)[0]
@pytest.fixture(scope="class")
def add_document(api_key, add_dataset, ragflow_tmp_dir):
dataset_id = add_dataset
document_ids = bulk_upload_documents(api_key, dataset_id, 1, ragflow_tmp_dir)
return dataset_id, document_ids[0]
@pytest.fixture(scope="class")
def add_chunks(api_key, add_document):
dataset_id, document_id = add_document
parse_documnets(api_key, dataset_id, {"document_ids": [document_id]})
condition(api_key, dataset_id)
chunk_ids = []
for i in range(4):
res = add_chunk(api_key, dataset_id, document_id, {"content": f"chunk test {i}"})
chunk_ids.append(res["data"]["chunk"]["id"])
# issues/6487
from time import sleep
sleep(1)
return dataset_id, document_id, chunk_ids
@pytest.fixture(scope="class")
def add_chat_assistants(request, api_key, add_document):
def cleanup():
delete_chat_assistants(api_key)
request.addfinalizer(cleanup)
dataset_id, document_id = add_document
parse_documnets(api_key, dataset_id, {"document_ids": [document_id]})
condition(api_key, dataset_id)
chat_assistant_ids = []
for i in range(5):
res = create_chat_assistant(api_key, {"name": f"test_chat_assistant_{i}", "dataset_ids": [dataset_id]})
chat_assistant_ids.append(res["data"]["id"])
return dataset_id, document_id, chat_assistant_ids

View File

@ -0,0 +1,46 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import create_chat_assistant, delete_chat_assistants, list_documnets, parse_documnets
from utils import wait_for
@wait_for(30, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnets(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
@pytest.fixture(scope="function")
def add_chat_assistants_func(request, api_key, add_document):
def cleanup():
delete_chat_assistants(api_key)
request.addfinalizer(cleanup)
dataset_id, document_id = add_document
parse_documnets(api_key, dataset_id, {"document_ids": [document_id]})
condition(api_key, dataset_id)
chat_assistant_ids = []
for i in range(5):
res = create_chat_assistant(api_key, {"name": f"test_chat_assistant_{i}", "dataset_ids": [dataset_id]})
chat_assistant_ids.append(res["data"]["id"])
return dataset_id, document_id, chat_assistant_ids

View File

@ -0,0 +1,241 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import CHAT_ASSISTANT_NAME_LIMIT, INVALID_API_TOKEN, create_chat_assistant
from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar
from utils.file_utils import create_image_file
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = create_chat_assistant(invalid_auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_chat_assistants")
class TestChatAssistantCreate:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"name": "valid_name"}, 0, ""),
pytest.param({"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")),
({"name": ""}, 102, "`name` is required."),
({"name": "duplicated_name"}, 102, "Duplicated chat name in creating chat."),
({"name": "case insensitive"}, 102, "Duplicated chat name in creating chat."),
],
)
def test_name(self, api_key, add_chunks, payload, expected_code, expected_message):
payload["dataset_ids"] = [] # issues/
if payload["name"] == "duplicated_name":
create_chat_assistant(api_key, payload)
elif payload["name"] == "case insensitive":
create_chat_assistant(api_key, {"name": payload["name"].upper()})
res = create_chat_assistant(api_key, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert res["data"]["name"] == payload["name"]
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"dataset_ids, expected_code, expected_message",
[
([], 0, ""),
(lambda r: [r], 0, ""),
(["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id"),
("invalid_dataset_id", 102, "You don't own the dataset i"),
],
)
def test_dataset_ids(self, api_key, add_chunks, dataset_ids, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload = {"name": "ragflow test"}
if callable(dataset_ids):
payload["dataset_ids"] = dataset_ids(dataset_id)
else:
payload["dataset_ids"] = dataset_ids
res = create_chat_assistant(api_key, payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert res["data"]["name"] == payload["name"]
else:
assert res["message"] == expected_message
@pytest.mark.p3
def test_avatar(self, api_key, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": []}
res = create_chat_assistant(api_key, payload)
assert res["code"] == 0
@pytest.mark.p2
@pytest.mark.parametrize(
"llm, expected_code, expected_message",
[
({}, 0, ""),
({"model_name": "glm-4"}, 0, ""),
({"model_name": "unknown"}, 102, "`model_name` unknown doesn't exist"),
({"temperature": 0}, 0, ""),
({"temperature": 1}, 0, ""),
pytest.param({"temperature": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"temperature": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"temperature": "a"}, 0, "", marks=pytest.mark.skip),
({"top_p": 0}, 0, ""),
({"top_p": 1}, 0, ""),
pytest.param({"top_p": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_p": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_p": "a"}, 0, "", marks=pytest.mark.skip),
({"presence_penalty": 0}, 0, ""),
({"presence_penalty": 1}, 0, ""),
pytest.param({"presence_penalty": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": "a"}, 0, "", marks=pytest.mark.skip),
({"frequency_penalty": 0}, 0, ""),
({"frequency_penalty": 1}, 0, ""),
pytest.param({"frequency_penalty": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": "a"}, 0, "", marks=pytest.mark.skip),
({"max_token": 0}, 0, ""),
({"max_token": 1024}, 0, ""),
pytest.param({"max_token": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"max_token": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"max_token": "a"}, 0, "", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip),
],
)
def test_llm(self, api_key, add_chunks, llm, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload = {"name": "llm_test", "dataset_ids": [dataset_id], "llm": llm}
res = create_chat_assistant(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
if llm:
for k, v in llm.items():
assert res["data"]["llm"][k] == v
else:
assert res["data"]["llm"]["model_name"] == "glm-4-flash@ZHIPU-AI"
assert res["data"]["llm"]["temperature"] == 0.1
assert res["data"]["llm"]["top_p"] == 0.3
assert res["data"]["llm"]["presence_penalty"] == 0.4
assert res["data"]["llm"]["frequency_penalty"] == 0.7
assert res["data"]["llm"]["max_tokens"] == 512
else:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"prompt, expected_code, expected_message",
[
({}, 0, ""),
({"similarity_threshold": 0}, 0, ""),
({"similarity_threshold": 1}, 0, ""),
pytest.param({"similarity_threshold": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": "a"}, 0, "", marks=pytest.mark.skip),
({"keywords_similarity_weight": 0}, 0, ""),
({"keywords_similarity_weight": 1}, 0, ""),
pytest.param({"keywords_similarity_weight": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": "a"}, 0, "", marks=pytest.mark.skip),
({"variables": []}, 0, ""),
({"top_n": 0}, 0, ""),
({"top_n": 1}, 0, ""),
pytest.param({"top_n": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_n": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_n": "a"}, 0, "", marks=pytest.mark.skip),
({"empty_response": "Hello World"}, 0, ""),
({"empty_response": ""}, 0, ""),
({"empty_response": "!@#$%^&*()"}, 0, ""),
({"empty_response": "中文测试"}, 0, ""),
pytest.param({"empty_response": 123}, 0, "", marks=pytest.mark.skip),
pytest.param({"empty_response": True}, 0, "", marks=pytest.mark.skip),
pytest.param({"empty_response": " "}, 0, "", marks=pytest.mark.skip),
({"opener": "Hello World"}, 0, ""),
({"opener": ""}, 0, ""),
({"opener": "!@#$%^&*()"}, 0, ""),
({"opener": "中文测试"}, 0, ""),
pytest.param({"opener": 123}, 0, "", marks=pytest.mark.skip),
pytest.param({"opener": True}, 0, "", marks=pytest.mark.skip),
pytest.param({"opener": " "}, 0, "", marks=pytest.mark.skip),
({"show_quote": True}, 0, ""),
({"show_quote": False}, 0, ""),
({"prompt": "Hello World {knowledge}"}, 0, ""),
({"prompt": "{knowledge}"}, 0, ""),
({"prompt": "!@#$%^&*() {knowledge}"}, 0, ""),
({"prompt": "中文测试 {knowledge}"}, 0, ""),
({"prompt": "Hello World"}, 102, "Parameter 'knowledge' is not used"),
({"prompt": "Hello World", "variables": []}, 0, ""),
pytest.param({"prompt": 123}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"prompt": True}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip),
],
)
def test_prompt(self, api_key, add_chunks, prompt, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload = {"name": "prompt_test", "dataset_ids": [dataset_id], "prompt": prompt}
res = create_chat_assistant(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
if prompt:
for k, v in prompt.items():
if k == "keywords_similarity_weight":
assert res["data"]["prompt"][k] == 1 - v
else:
assert res["data"]["prompt"][k] == v
else:
assert res["data"]["prompt"]["similarity_threshold"] == 0.2
assert res["data"]["prompt"]["keywords_similarity_weight"] == 0.7
assert res["data"]["prompt"]["top_n"] == 6
assert res["data"]["prompt"]["variables"] == [{"key": "knowledge", "optional": False}]
assert res["data"]["prompt"]["rerank_model"] == ""
assert res["data"]["prompt"]["empty_response"] == "Sorry! No relevant content was found in the knowledge base!"
assert res["data"]["prompt"]["opener"] == "Hi! I'm your assistant, what can I do for you?"
assert res["data"]["prompt"]["show_quote"] is True
assert (
res["data"]["prompt"]["prompt"]
== 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.'
)
else:
assert res["message"] == expected_message
class TestChatAssistantCreate2:
@pytest.mark.p2
def test_unparsed_document(self, api_key, add_document):
dataset_id, _ = add_document
payload = {"name": "prompt_test", "dataset_ids": [dataset_id]}
res = create_chat_assistant(api_key, payload)
assert res["code"] == 102
assert "doesn't own parsed file" in res["message"]

View File

@ -0,0 +1,124 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, batch_create_chat_assistants, delete_chat_assistants, list_chat_assistants
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = delete_chat_assistants(invalid_auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestChatAssistantsDelete:
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
[
pytest.param(None, 0, "", 0, marks=pytest.mark.p3),
pytest.param({"ids": []}, 0, "", 0, marks=pytest.mark.p3),
pytest.param({"ids": ["invalid_id"]}, 102, "Assistant(invalid_id) not found.", 5, marks=pytest.mark.p3),
pytest.param({"ids": ["\n!?。;!?\"'"]}, 102, """Assistant(\n!?。;!?"\') not found.""", 5, marks=pytest.mark.p3),
pytest.param("not json", 100, "AttributeError(\"'str' object has no attribute 'get'\")", 5, marks=pytest.mark.p3),
pytest.param(lambda r: {"ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3),
pytest.param(lambda r: {"ids": r}, 0, "", 0, marks=pytest.mark.p1),
],
)
def test_basic_scenarios(self, api_key, add_chat_assistants_func, payload, expected_code, expected_message, remaining):
_, _, chat_assistant_ids = add_chat_assistants_func
if callable(payload):
payload = payload(chat_assistant_ids)
res = delete_chat_assistants(api_key, payload)
assert res["code"] == expected_code
if res["code"] != 0:
assert res["message"] == expected_message
res = list_chat_assistants(api_key)
assert len(res["data"]) == remaining
@pytest.mark.parametrize(
"payload",
[
pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3),
pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:5]}, marks=pytest.mark.p1),
pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3),
],
)
def test_delete_partial_invalid_id(self, api_key, add_chat_assistants_func, payload):
_, _, chat_assistant_ids = add_chat_assistants_func
if callable(payload):
payload = payload(chat_assistant_ids)
res = delete_chat_assistants(api_key, payload)
assert res["code"] == 0
assert res["data"]["errors"][0] == "Assistant(invalid_id) not found."
assert res["data"]["success_count"] == 5
res = list_chat_assistants(api_key)
assert len(res["data"]) == 0
@pytest.mark.p3
def test_repeated_deletion(self, api_key, add_chat_assistants_func):
_, _, chat_assistant_ids = add_chat_assistants_func
res = delete_chat_assistants(api_key, {"ids": chat_assistant_ids})
assert res["code"] == 0
res = delete_chat_assistants(api_key, {"ids": chat_assistant_ids})
assert res["code"] == 102
assert "not found" in res["message"]
@pytest.mark.p3
def test_duplicate_deletion(self, api_key, add_chat_assistants_func):
_, _, chat_assistant_ids = add_chat_assistants_func
res = delete_chat_assistants(api_key, {"ids": chat_assistant_ids + chat_assistant_ids})
assert res["code"] == 0
assert "Duplicate assistant ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 5
res = list_chat_assistants(api_key)
assert res["code"] == 0
@pytest.mark.p3
def test_concurrent_deletion(self, api_key):
ids = batch_create_chat_assistants(api_key, 100)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(delete_chat_assistants, api_key, {"ids": ids[i : i + 1]}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_delete_10k(self, api_key):
ids = batch_create_chat_assistants(api_key, 10_000)
res = delete_chat_assistants(api_key, {"ids": ids})
assert res["code"] == 0
res = list_chat_assistants(api_key)
assert len(res["data"]) == 0

View File

@ -0,0 +1,311 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, delete_datasets, list_chat_assistants
from libs.auth import RAGFlowHttpApiAuth
from utils import is_sorted
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = list_chat_assistants(invalid_auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("add_chat_assistants")
class TestChatAssistantsList:
@pytest.mark.p1
def test_default(self, api_key):
res = list_chat_assistants(api_key)
assert res["code"] == 0
assert len(res["data"]) == 5
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page": None, "page_size": 2}, 0, 2, ""),
({"page": 0, "page_size": 2}, 0, 2, ""),
({"page": 2, "page_size": 2}, 0, 2, ""),
({"page": 3, "page_size": 2}, 0, 1, ""),
({"page": "3", "page_size": 2}, 0, 1, ""),
pytest.param(
{"page": -1, "page_size": 2},
100,
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page": "a", "page_size": 2},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page(self, api_key, params, expected_code, expected_page_size, expected_message):
res = list_chat_assistants(api_key, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 5, ""),
({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param(
{"page_size": -1},
100,
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page_size": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page_size(
self,
api_key,
params,
expected_code,
expected_page_size,
expected_message,
):
res = list_chat_assistants(api_key, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""),
pytest.param(
{"orderby": "name", "desc": "False"},
0,
lambda r: (is_sorted(r["data"], "name", False)),
"",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"orderby": "unknown"},
102,
0,
"orderby should be create_time or update_time",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_orderby(
self,
api_key,
params,
expected_code,
assertions,
expected_message,
):
res = list_chat_assistants(api_key, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""),
pytest.param(
{"desc": "unknown"},
102,
0,
"desc should be true or false",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_desc(
self,
api_key,
params,
expected_code,
assertions,
expected_message,
):
res = list_chat_assistants(api_key, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_num, expected_message",
[
({"name": None}, 0, 5, ""),
({"name": ""}, 0, 5, ""),
({"name": "test_chat_assistant_1"}, 0, 1, ""),
({"name": "unknown"}, 102, 0, "The chat doesn't exist"),
],
)
def test_name(self, api_key, params, expected_code, expected_num, expected_message):
res = list_chat_assistants(api_key, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] in [None, ""]:
assert len(res["data"]) == expected_num
else:
assert res["data"][0]["name"] == params["name"]
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"chat_assistant_id, expected_code, expected_num, expected_message",
[
(None, 0, 5, ""),
("", 0, 5, ""),
(lambda r: r[0], 0, 1, ""),
("unknown", 102, 0, "The chat doesn't exist"),
],
)
def test_id(
self,
api_key,
add_chat_assistants,
chat_assistant_id,
expected_code,
expected_num,
expected_message,
):
_, _, chat_assistant_ids = add_chat_assistants
if callable(chat_assistant_id):
params = {"id": chat_assistant_id(chat_assistant_ids)}
else:
params = {"id": chat_assistant_id}
res = list_chat_assistants(api_key, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["id"] in [None, ""]:
assert len(res["data"]) == expected_num
else:
assert res["data"][0]["id"] == params["id"]
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"chat_assistant_id, name, expected_code, expected_num, expected_message",
[
(lambda r: r[0], "test_chat_assistant_0", 0, 1, ""),
(lambda r: r[0], "test_chat_assistant_1", 102, 0, "The chat doesn't exist"),
(lambda r: r[0], "unknown", 102, 0, "The chat doesn't exist"),
("id", "chat_assistant_0", 102, 0, "The chat doesn't exist"),
],
)
def test_name_and_id(
self,
api_key,
add_chat_assistants,
chat_assistant_id,
name,
expected_code,
expected_num,
expected_message,
):
_, _, chat_assistant_ids = add_chat_assistants
if callable(chat_assistant_id):
params = {"id": chat_assistant_id(chat_assistant_ids), "name": name}
else:
params = {"id": chat_assistant_id, "name": name}
res = list_chat_assistants(api_key, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_num
else:
assert res["message"] == expected_message
@pytest.mark.p3
def test_concurrent_list(self, api_key):
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_chat_assistants, api_key) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_invalid_params(self, api_key):
params = {"a": "b"}
res = list_chat_assistants(api_key, params=params)
assert res["code"] == 0
assert len(res["data"]) == 5
@pytest.mark.p2
def test_list_chats_after_deleting_associated_dataset(self, api_key, add_chat_assistants):
dataset_id, _, _ = add_chat_assistants
res = delete_datasets(api_key, {"ids": [dataset_id]})
assert res["code"] == 0
res = list_chat_assistants(api_key)
assert res["code"] == 0
assert len(res["data"]) == 5

View File

@ -0,0 +1,228 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import CHAT_ASSISTANT_NAME_LIMIT, INVALID_API_TOKEN, list_chat_assistants, update_chat_assistant
from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar
from utils.file_utils import create_image_file
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = update_chat_assistant(invalid_auth, "chat_assistant_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestChatAssistantUpdate:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
pytest.param({"name": "valid_name"}, 0, "", marks=pytest.mark.p1),
pytest.param({"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": ""}, 102, "`name` cannot be empty.", marks=pytest.mark.p3),
pytest.param({"name": "test_chat_assistant_1"}, 102, "Duplicated chat name in updating chat.", marks=pytest.mark.p3),
pytest.param({"name": "TEST_CHAT_ASSISTANT_1"}, 102, "Duplicated chat name in updating chat.", marks=pytest.mark.p3),
],
)
def test_name(self, api_key, add_chat_assistants_func, payload, expected_code, expected_message):
_, _, chat_assistant_ids = add_chat_assistants_func
res = update_chat_assistant(api_key, chat_assistant_ids[0], payload)
assert res["code"] == expected_code, res
if expected_code == 0:
res = list_chat_assistants(api_key, {"id": chat_assistant_ids[0]})
assert res["data"][0]["name"] == payload.get("name")
else:
assert res["message"] == expected_message
@pytest.mark.parametrize(
"dataset_ids, expected_code, expected_message",
[
pytest.param([], 0, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param(lambda r: [r], 0, "", marks=pytest.mark.p1),
pytest.param(["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id", marks=pytest.mark.p3),
pytest.param("invalid_dataset_id", 102, "You don't own the dataset i", marks=pytest.mark.p3),
],
)
def test_dataset_ids(self, api_key, add_chat_assistants_func, dataset_ids, expected_code, expected_message):
dataset_id, _, chat_assistant_ids = add_chat_assistants_func
payload = {"name": "ragflow test"}
if callable(dataset_ids):
payload["dataset_ids"] = dataset_ids(dataset_id)
else:
payload["dataset_ids"] = dataset_ids
res = update_chat_assistant(api_key, chat_assistant_ids[0], payload)
assert res["code"] == expected_code, res
if expected_code == 0:
res = list_chat_assistants(api_key, {"id": chat_assistant_ids[0]})
assert res["data"][0]["name"] == payload.get("name")
else:
assert res["message"] == expected_message
@pytest.mark.p3
def test_avatar(self, api_key, add_chat_assistants_func, tmp_path):
dataset_id, _, chat_assistant_ids = add_chat_assistants_func
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"name": "avatar_test", "avatar": encode_avatar(fn), "dataset_ids": [dataset_id]}
res = update_chat_assistant(api_key, chat_assistant_ids[0], payload)
assert res["code"] == 0
@pytest.mark.p3
@pytest.mark.parametrize(
"llm, expected_code, expected_message",
[
({}, 100, "ValueError"),
({"model_name": "glm-4"}, 0, ""),
({"model_name": "unknown"}, 102, "`model_name` unknown doesn't exist"),
({"temperature": 0}, 0, ""),
({"temperature": 1}, 0, ""),
pytest.param({"temperature": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"temperature": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"temperature": "a"}, 0, "", marks=pytest.mark.skip),
({"top_p": 0}, 0, ""),
({"top_p": 1}, 0, ""),
pytest.param({"top_p": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_p": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_p": "a"}, 0, "", marks=pytest.mark.skip),
({"presence_penalty": 0}, 0, ""),
({"presence_penalty": 1}, 0, ""),
pytest.param({"presence_penalty": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"presence_penalty": "a"}, 0, "", marks=pytest.mark.skip),
({"frequency_penalty": 0}, 0, ""),
({"frequency_penalty": 1}, 0, ""),
pytest.param({"frequency_penalty": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"frequency_penalty": "a"}, 0, "", marks=pytest.mark.skip),
({"max_token": 0}, 0, ""),
({"max_token": 1024}, 0, ""),
pytest.param({"max_token": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"max_token": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"max_token": "a"}, 0, "", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip),
],
)
def test_llm(self, api_key, add_chat_assistants_func, llm, expected_code, expected_message):
dataset_id, _, chat_assistant_ids = add_chat_assistants_func
payload = {"name": "llm_test", "dataset_ids": [dataset_id], "llm": llm}
res = update_chat_assistant(api_key, chat_assistant_ids[0], payload)
assert res["code"] == expected_code
if expected_code == 0:
res = list_chat_assistants(api_key, {"id": chat_assistant_ids[0]})
if llm:
for k, v in llm.items():
assert res["data"][0]["llm"][k] == v
else:
assert res["data"][0]["llm"]["model_name"] == "glm-4-flash@ZHIPU-AI"
assert res["data"][0]["llm"]["temperature"] == 0.1
assert res["data"][0]["llm"]["top_p"] == 0.3
assert res["data"][0]["llm"]["presence_penalty"] == 0.4
assert res["data"][0]["llm"]["frequency_penalty"] == 0.7
assert res["data"][0]["llm"]["max_tokens"] == 512
else:
assert expected_message in res["message"]
@pytest.mark.p3
@pytest.mark.parametrize(
"prompt, expected_code, expected_message",
[
({}, 100, "ValueError"),
({"similarity_threshold": 0}, 0, ""),
({"similarity_threshold": 1}, 0, ""),
pytest.param({"similarity_threshold": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"similarity_threshold": "a"}, 0, "", marks=pytest.mark.skip),
({"keywords_similarity_weight": 0}, 0, ""),
({"keywords_similarity_weight": 1}, 0, ""),
pytest.param({"keywords_similarity_weight": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"keywords_similarity_weight": "a"}, 0, "", marks=pytest.mark.skip),
({"variables": []}, 0, ""),
({"top_n": 0}, 0, ""),
({"top_n": 1}, 0, ""),
pytest.param({"top_n": -1}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_n": 10}, 0, "", marks=pytest.mark.skip),
pytest.param({"top_n": "a"}, 0, "", marks=pytest.mark.skip),
({"empty_response": "Hello World"}, 0, ""),
({"empty_response": ""}, 0, ""),
({"empty_response": "!@#$%^&*()"}, 0, ""),
({"empty_response": "中文测试"}, 0, ""),
pytest.param({"empty_response": 123}, 0, "", marks=pytest.mark.skip),
pytest.param({"empty_response": True}, 0, "", marks=pytest.mark.skip),
pytest.param({"empty_response": " "}, 0, "", marks=pytest.mark.skip),
({"opener": "Hello World"}, 0, ""),
({"opener": ""}, 0, ""),
({"opener": "!@#$%^&*()"}, 0, ""),
({"opener": "中文测试"}, 0, ""),
pytest.param({"opener": 123}, 0, "", marks=pytest.mark.skip),
pytest.param({"opener": True}, 0, "", marks=pytest.mark.skip),
pytest.param({"opener": " "}, 0, "", marks=pytest.mark.skip),
({"show_quote": True}, 0, ""),
({"show_quote": False}, 0, ""),
({"prompt": "Hello World {knowledge}"}, 0, ""),
({"prompt": "{knowledge}"}, 0, ""),
({"prompt": "!@#$%^&*() {knowledge}"}, 0, ""),
({"prompt": "中文测试 {knowledge}"}, 0, ""),
({"prompt": "Hello World"}, 102, "Parameter 'knowledge' is not used"),
({"prompt": "Hello World", "variables": []}, 0, ""),
pytest.param({"prompt": 123}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"prompt": True}, 100, """AttributeError("\'int\' object has no attribute \'find\'")""", marks=pytest.mark.skip),
pytest.param({"unknown": "unknown"}, 0, "", marks=pytest.mark.skip),
],
)
def test_prompt(self, api_key, add_chat_assistants_func, prompt, expected_code, expected_message):
dataset_id, _, chat_assistant_ids = add_chat_assistants_func
payload = {"name": "prompt_test", "dataset_ids": [dataset_id], "prompt": prompt}
res = update_chat_assistant(api_key, chat_assistant_ids[0], payload)
assert res["code"] == expected_code
if expected_code == 0:
res = list_chat_assistants(api_key, {"id": chat_assistant_ids[0]})
if prompt:
for k, v in prompt.items():
if k == "keywords_similarity_weight":
assert res["data"][0]["prompt"][k] == 1 - v
else:
assert res["data"][0]["prompt"][k] == v
else:
assert res["data"]["prompt"][0]["similarity_threshold"] == 0.2
assert res["data"]["prompt"][0]["keywords_similarity_weight"] == 0.7
assert res["data"]["prompt"][0]["top_n"] == 6
assert res["data"]["prompt"][0]["variables"] == [{"key": "knowledge", "optional": False}]
assert res["data"]["prompt"][0]["rerank_model"] == ""
assert res["data"]["prompt"][0]["empty_response"] == "Sorry! No relevant content was found in the knowledge base!"
assert res["data"]["prompt"][0]["opener"] == "Hi! I'm your assistant, what can I do for you?"
assert res["data"]["prompt"][0]["show_quote"] is True
assert (
res["data"]["prompt"][0]["prompt"]
== 'You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.\n Here is the knowledge base:\n {knowledge}\n The above is the knowledge base.'
)
else:
assert expected_message in res["message"]

View File

@ -0,0 +1,52 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import add_chunk, delete_chunks, list_documnets, parse_documnets
from utils import wait_for
@wait_for(30, 1, "Document parsing timeout")
def condition(_auth, _dataset_id):
res = list_documnets(_auth, _dataset_id)
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
@pytest.fixture(scope="function")
def add_chunks_func(request, api_key, add_document):
dataset_id, document_id = add_document
parse_documnets(api_key, dataset_id, {"document_ids": [document_id]})
condition(api_key, dataset_id)
chunk_ids = []
for i in range(4):
res = add_chunk(api_key, dataset_id, document_id, {"content": f"chunk test {i}"})
chunk_ids.append(res["data"]["chunk"]["id"])
# issues/6487
from time import sleep
sleep(1)
def cleanup():
delete_chunks(api_key, dataset_id, document_id, {"chunk_ids": chunk_ids})
request.addfinalizer(cleanup)
return dataset_id, document_id, chunk_ids

View File

@ -0,0 +1,250 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, add_chunk, delete_documnets, list_chunks
from libs.auth import RAGFlowHttpApiAuth
def validate_chunk_details(dataset_id, document_id, payload, res):
chunk = res["data"]["chunk"]
assert chunk["dataset_id"] == dataset_id
assert chunk["document_id"] == document_id
assert chunk["content"] == payload["content"]
if "important_keywords" in payload:
assert chunk["important_keywords"] == payload["important_keywords"]
if "questions" in payload:
assert chunk["questions"] == [str(q).strip() for q in payload.get("questions", []) if str(q).strip()]
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = add_chunk(invalid_auth, "dataset_id", "document_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestAddChunk:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"content": None}, 100, """TypeError("unsupported operand type(s) for +: \'NoneType\' and \'str\'")"""),
({"content": ""}, 102, "`content` is required"),
pytest.param(
{"content": 1},
100,
"""TypeError("unsupported operand type(s) for +: \'int\' and \'str\'")""",
marks=pytest.mark.skip,
),
({"content": "a"}, 0, ""),
({"content": " "}, 102, "`content` is required"),
({"content": "\n!?。;!?\"'"}, 0, ""),
],
)
def test_content(self, api_key, add_document, payload, expected_code, expected_message):
dataset_id, document_id = add_document
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
chunks_count = res["data"]["doc"]["chunk_count"]
res = add_chunk(api_key, dataset_id, document_id, payload)
assert res["code"] == expected_code
if expected_code == 0:
validate_chunk_details(dataset_id, document_id, payload, res)
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert res["data"]["doc"]["chunk_count"] == chunks_count + 1
else:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"content": "chunk test", "important_keywords": ["a", "b", "c"]}, 0, ""),
({"content": "chunk test", "important_keywords": [""]}, 0, ""),
(
{"content": "chunk test", "important_keywords": [1]},
100,
"TypeError('sequence item 0: expected str instance, int found')",
),
({"content": "chunk test", "important_keywords": ["a", "a"]}, 0, ""),
({"content": "chunk test", "important_keywords": "abc"}, 102, "`important_keywords` is required to be a list"),
({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"),
],
)
def test_important_keywords(self, api_key, add_document, payload, expected_code, expected_message):
dataset_id, document_id = add_document
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
chunks_count = res["data"]["doc"]["chunk_count"]
res = add_chunk(api_key, dataset_id, document_id, payload)
assert res["code"] == expected_code
if expected_code == 0:
validate_chunk_details(dataset_id, document_id, payload, res)
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert res["data"]["doc"]["chunk_count"] == chunks_count + 1
else:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"content": "chunk test", "questions": ["a", "b", "c"]}, 0, ""),
({"content": "chunk test", "questions": [""]}, 0, ""),
({"content": "chunk test", "questions": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
({"content": "chunk test", "questions": ["a", "a"]}, 0, ""),
({"content": "chunk test", "questions": "abc"}, 102, "`questions` is required to be a list"),
({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"),
],
)
def test_questions(self, api_key, add_document, payload, expected_code, expected_message):
dataset_id, document_id = add_document
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
chunks_count = res["data"]["doc"]["chunk_count"]
res = add_chunk(api_key, dataset_id, document_id, payload)
assert res["code"] == expected_code
if expected_code == 0:
validate_chunk_details(dataset_id, document_id, payload, res)
if res["code"] != 0:
assert False, res
res = list_chunks(api_key, dataset_id, document_id)
assert res["data"]["doc"]["chunk_count"] == chunks_count + 1
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<NotFound '404: Not Found'>"),
(
"invalid_dataset_id",
102,
"You don't own the dataset invalid_dataset_id.",
),
],
)
def test_invalid_dataset_id(
self,
api_key,
add_document,
dataset_id,
expected_code,
expected_message,
):
_, document_id = add_document
res = add_chunk(api_key, dataset_id, document_id, {"content": "a"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"document_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
"invalid_document_id",
102,
"You don't own the document invalid_document_id.",
),
],
)
def test_invalid_document_id(self, api_key, add_document, document_id, expected_code, expected_message):
dataset_id, _ = add_document
res = add_chunk(api_key, dataset_id, document_id, {"content": "chunk test"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
def test_repeated_add_chunk(self, api_key, add_document):
payload = {"content": "chunk test"}
dataset_id, document_id = add_document
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
chunks_count = res["data"]["doc"]["chunk_count"]
res = add_chunk(api_key, dataset_id, document_id, payload)
assert res["code"] == 0
validate_chunk_details(dataset_id, document_id, payload, res)
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert res["data"]["doc"]["chunk_count"] == chunks_count + 1
res = add_chunk(api_key, dataset_id, document_id, payload)
assert res["code"] == 0
validate_chunk_details(dataset_id, document_id, payload, res)
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert res["data"]["doc"]["chunk_count"] == chunks_count + 2
@pytest.mark.p2
def test_add_chunk_to_deleted_document(self, api_key, add_document):
dataset_id, document_id = add_document
delete_documnets(api_key, dataset_id, {"ids": [document_id]})
res = add_chunk(api_key, dataset_id, document_id, {"content": "chunk test"})
assert res["code"] == 102
assert res["message"] == f"You don't own the document {document_id}."
@pytest.mark.skip(reason="issues/6411")
def test_concurrent_add_chunk(self, api_key, add_document):
chunk_num = 50
dataset_id, document_id = add_document
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
chunks_count = res["data"]["doc"]["chunk_count"]
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
add_chunk,
api_key,
dataset_id,
document_id,
{"content": f"chunk test {i}"},
)
for i in range(chunk_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert res["data"]["doc"]["chunk_count"] == chunks_count + chunk_num

View File

@ -0,0 +1,194 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, batch_add_chunks, delete_chunks, list_chunks
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = delete_chunks(invalid_auth, "dataset_id", "document_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestChunksDeletion:
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<NotFound '404: Not Found'>"),
(
"invalid_dataset_id",
102,
"You don't own the dataset invalid_dataset_id.",
),
],
)
def test_invalid_dataset_id(self, api_key, add_chunks_func, dataset_id, expected_code, expected_message):
_, document_id, chunk_ids = add_chunks_func
res = delete_chunks(api_key, dataset_id, document_id, {"chunk_ids": chunk_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"document_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
("invalid_document_id", 100, """LookupError("Can't find the document with ID invalid_document_id!")"""),
],
)
def test_invalid_document_id(self, api_key, add_chunks_func, document_id, expected_code, expected_message):
dataset_id, _, chunk_ids = add_chunks_func
res = delete_chunks(api_key, dataset_id, document_id, {"chunk_ids": chunk_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.parametrize(
"payload",
[
pytest.param(lambda r: {"chunk_ids": ["invalid_id"] + r}, marks=pytest.mark.p3),
pytest.param(lambda r: {"chunk_ids": r[:1] + ["invalid_id"] + r[1:4]}, marks=pytest.mark.p1),
pytest.param(lambda r: {"chunk_ids": r + ["invalid_id"]}, marks=pytest.mark.p3),
],
)
def test_delete_partial_invalid_id(self, api_key, add_chunks_func, payload):
dataset_id, document_id, chunk_ids = add_chunks_func
if callable(payload):
payload = payload(chunk_ids)
res = delete_chunks(api_key, dataset_id, document_id, payload)
assert res["code"] == 102
assert res["message"] == "rm_chunk deleted chunks 4, expect 5"
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert len(res["data"]["chunks"]) == 1
assert res["data"]["total"] == 1
@pytest.mark.p3
def test_repeated_deletion(self, api_key, add_chunks_func):
dataset_id, document_id, chunk_ids = add_chunks_func
payload = {"chunk_ids": chunk_ids}
res = delete_chunks(api_key, dataset_id, document_id, payload)
assert res["code"] == 0
res = delete_chunks(api_key, dataset_id, document_id, payload)
assert res["code"] == 102
assert res["message"] == "rm_chunk deleted chunks 0, expect 4"
@pytest.mark.p3
def test_duplicate_deletion(self, api_key, add_chunks_func):
dataset_id, document_id, chunk_ids = add_chunks_func
res = delete_chunks(api_key, dataset_id, document_id, {"chunk_ids": chunk_ids * 2})
assert res["code"] == 0
assert "Duplicate chunk ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 4
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert len(res["data"]["chunks"]) == 1
assert res["data"]["total"] == 1
@pytest.mark.p3
def test_concurrent_deletion(self, api_key, add_document):
chunks_num = 100
dataset_id, document_id = add_document
chunk_ids = batch_add_chunks(api_key, dataset_id, document_id, chunks_num)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
delete_chunks,
api_key,
dataset_id,
document_id,
{"chunk_ids": chunk_ids[i : i + 1]},
)
for i in range(chunks_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_delete_1k(self, api_key, add_document):
chunks_num = 1_000
dataset_id, document_id = add_document
chunk_ids = batch_add_chunks(api_key, dataset_id, document_id, chunks_num)
# issues/6487
from time import sleep
sleep(1)
res = delete_chunks(api_key, dataset_id, document_id, {"chunk_ids": chunk_ids})
assert res["code"] == 0
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert len(res["data"]["chunks"]) == 1
assert res["data"]["total"] == 1
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
[
pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", 5, marks=pytest.mark.skip),
pytest.param({"chunk_ids": ["invalid_id"]}, 102, "rm_chunk deleted chunks 0, expect 1", 5, marks=pytest.mark.p3),
pytest.param("not json", 100, """UnboundLocalError("local variable \'duplicate_messages\' referenced before assignment")""", 5, marks=pytest.mark.skip(reason="pull/6376")),
pytest.param(lambda r: {"chunk_ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3),
pytest.param(lambda r: {"chunk_ids": r}, 0, "", 1, marks=pytest.mark.p1),
pytest.param({"chunk_ids": []}, 0, "", 0, marks=pytest.mark.p3),
],
)
def test_basic_scenarios(
self,
api_key,
add_chunks_func,
payload,
expected_code,
expected_message,
remaining,
):
dataset_id, document_id, chunk_ids = add_chunks_func
if callable(payload):
payload = payload(chunk_ids)
res = delete_chunks(api_key, dataset_id, document_id, payload)
assert res["code"] == expected_code
if res["code"] != 0:
assert res["message"] == expected_message
res = list_chunks(api_key, dataset_id, document_id)
if res["code"] != 0:
assert False, res
assert len(res["data"]["chunks"]) == remaining
assert res["data"]["total"] == remaining

View File

@ -0,0 +1,209 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, batch_add_chunks, list_chunks
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = list_chunks(invalid_auth, "dataset_id", "document_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestChunksList:
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page": None, "page_size": 2}, 0, 2, ""),
pytest.param({"page": 0, "page_size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
({"page": 2, "page_size": 2}, 0, 2, ""),
({"page": 3, "page_size": 2}, 0, 1, ""),
({"page": "3", "page_size": 2}, 0, 1, ""),
pytest.param({"page": -1, "page_size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
pytest.param({"page": "a", "page_size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
],
)
def test_page(self, api_key, add_chunks, params, expected_code, expected_page_size, expected_message):
dataset_id, document_id, _ = add_chunks
res = list_chunks(api_key, dataset_id, document_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 5, ""),
pytest.param({"page_size": 0}, 0, 5, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="Infinity does not support page_size=0")),
pytest.param({"page_size": 0}, 100, 0, "3013", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="Infinity does not support page_size=0")),
({"page_size": 1}, 0, 1, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param({"page_size": -1}, 0, 5, "", marks=pytest.mark.skip),
pytest.param({"page_size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
],
)
def test_page_size(self, api_key, add_chunks, params, expected_code, expected_page_size, expected_message):
dataset_id, document_id, _ = add_chunks
res = list_chunks(api_key, dataset_id, document_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"keywords": None}, 5),
({"keywords": ""}, 5),
({"keywords": "1"}, 1),
pytest.param({"keywords": "chunk"}, 4, marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6509")),
({"keywords": "ragflow"}, 1),
({"keywords": "unknown"}, 0),
],
)
def test_keywords(self, api_key, add_chunks, params, expected_page_size):
dataset_id, document_id, _ = add_chunks
res = list_chunks(api_key, dataset_id, document_id, params=params)
assert res["code"] == 0
assert len(res["data"]["chunks"]) == expected_page_size
@pytest.mark.p1
@pytest.mark.parametrize(
"chunk_id, expected_code, expected_page_size, expected_message",
[
(None, 0, 5, ""),
("", 0, 5, ""),
pytest.param(lambda r: r[0], 0, 1, "", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6499")),
pytest.param("unknown", 100, 0, """AttributeError("\'NoneType\' object has no attribute \'keys\'")""", marks=pytest.mark.skip),
],
)
def test_id(
self,
api_key,
add_chunks,
chunk_id,
expected_code,
expected_page_size,
expected_message,
):
dataset_id, document_id, chunk_ids = add_chunks
if callable(chunk_id):
params = {"id": chunk_id(chunk_ids)}
else:
params = {"id": chunk_id}
res = list_chunks(api_key, dataset_id, document_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["id"] in [None, ""]:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert res["data"]["chunks"][0]["id"] == params["id"]
else:
assert res["message"] == expected_message
@pytest.mark.p3
def test_invalid_params(self, api_key, add_chunks):
dataset_id, document_id, _ = add_chunks
params = {"a": "b"}
res = list_chunks(api_key, dataset_id, document_id, params=params)
assert res["code"] == 0
assert len(res["data"]["chunks"]) == 5
@pytest.mark.p3
def test_concurrent_list(self, api_key, add_chunks):
dataset_id, document_id, _ = add_chunks
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_chunks, api_key, dataset_id, document_id) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
assert all(len(r["data"]["chunks"]) == 5 for r in responses)
@pytest.mark.p1
def test_default(self, api_key, add_document):
dataset_id, document_id = add_document
res = list_chunks(api_key, dataset_id, document_id)
chunks_count = res["data"]["doc"]["chunk_count"]
batch_add_chunks(api_key, dataset_id, document_id, 31)
# issues/6487
from time import sleep
sleep(3)
res = list_chunks(api_key, dataset_id, document_id)
assert res["code"] == 0
assert len(res["data"]["chunks"]) == 30
assert res["data"]["doc"]["chunk_count"] == chunks_count + 31
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<NotFound '404: Not Found'>"),
(
"invalid_dataset_id",
102,
"You don't own the dataset invalid_dataset_id.",
),
],
)
def test_invalid_dataset_id(self, api_key, add_chunks, dataset_id, expected_code, expected_message):
_, document_id, _ = add_chunks
res = list_chunks(api_key, dataset_id, document_id)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"document_id, expected_code, expected_message",
[
("", 102, "The dataset not own the document chunks."),
(
"invalid_document_id",
102,
"You don't own the document invalid_document_id.",
),
],
)
def test_invalid_document_id(self, api_key, add_chunks, document_id, expected_code, expected_message):
dataset_id, _, _ = add_chunks
res = list_chunks(api_key, dataset_id, document_id)
assert res["code"] == expected_code
assert res["message"] == expected_message

View File

@ -0,0 +1,313 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pytest
from common import (
INVALID_API_TOKEN,
retrieval_chunks,
)
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = retrieval_chunks(invalid_auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestChunksRetrieval:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"question": "chunk", "dataset_ids": None}, 0, 4, ""),
({"question": "chunk", "document_ids": None}, 102, 0, "`dataset_ids` is required."),
({"question": "chunk", "dataset_ids": None, "document_ids": None}, 0, 4, ""),
({"question": "chunk"}, 102, 0, "`dataset_ids` is required."),
],
)
def test_basic_scenarios(self, api_key, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, document_id, _ = add_chunks
if "dataset_ids" in payload:
payload["dataset_ids"] = [dataset_id]
if "document_ids" in payload:
payload["document_ids"] = [document_id]
res = retrieval_chunks(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
pytest.param(
{"page": None, "page_size": 2},
100,
2,
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""",
marks=pytest.mark.skip,
),
pytest.param(
{"page": 0, "page_size": 2},
100,
0,
"ValueError('Search does not support negative slicing.')",
marks=pytest.mark.skip,
),
pytest.param({"page": 2, "page_size": 2}, 0, 2, "", marks=pytest.mark.skip(reason="issues/6646")),
({"page": 3, "page_size": 2}, 0, 0, ""),
({"page": "3", "page_size": 2}, 0, 0, ""),
pytest.param(
{"page": -1, "page_size": 2},
100,
0,
"ValueError('Search does not support negative slicing.')",
marks=pytest.mark.skip,
),
pytest.param(
{"page": "a", "page_size": 2},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip,
),
],
)
def test_page(self, api_key, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
res = retrieval_chunks(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
pytest.param(
{"page_size": None},
100,
0,
"""TypeError("int() argument must be a string, a bytes-like object or a real number, not \'NoneType\'")""",
marks=pytest.mark.skip,
),
# ({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 5}, 0, 4, ""),
({"page_size": "1"}, 0, 1, ""),
# ({"page_size": -1}, 0, 0, ""),
pytest.param(
{"page_size": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip,
),
],
)
def test_page_size(self, api_key, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
res = retrieval_chunks(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"vector_similarity_weight": 0}, 0, 4, ""),
({"vector_similarity_weight": 0.5}, 0, 4, ""),
({"vector_similarity_weight": 10}, 0, 4, ""),
pytest.param(
{"vector_similarity_weight": "a"},
100,
0,
"""ValueError("could not convert string to float: \'a\'")""",
marks=pytest.mark.skip,
),
],
)
def test_vector_similarity_weight(self, api_key, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
res = retrieval_chunks(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"top_k": 10}, 0, 4, ""),
pytest.param(
{"top_k": 1},
0,
4,
"",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
),
pytest.param(
{"top_k": 1},
0,
1,
"",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
),
pytest.param(
{"top_k": -1},
100,
4,
"must be greater than 0",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in ["infinity", "opensearch"], reason="Infinity"),
),
pytest.param(
{"top_k": -1},
100,
4,
"3014",
marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch"),
),
pytest.param(
{"top_k": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip,
),
],
)
def test_top_k(self, api_key, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
res = retrieval_chunks(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert expected_message in res["message"]
@pytest.mark.skip
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"rerank_id": "BAAI/bge-reranker-v2-m3"}, 0, ""),
pytest.param({"rerank_id": "unknown"}, 100, "LookupError('Model(unknown) not authorized')", marks=pytest.mark.skip),
],
)
def test_rerank_id(self, api_key, add_chunks, payload, expected_code, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
res = retrieval_chunks(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) > 0
else:
assert expected_message in res["message"]
@pytest.mark.skip
@pytest.mark.parametrize(
"payload, expected_code, expected_page_size, expected_message",
[
({"keyword": True}, 0, 5, ""),
({"keyword": "True"}, 0, 5, ""),
({"keyword": False}, 0, 5, ""),
({"keyword": "False"}, 0, 5, ""),
({"keyword": None}, 0, 5, ""),
],
)
def test_keyword(self, api_key, add_chunks, payload, expected_code, expected_page_size, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk test", "dataset_ids": [dataset_id]})
res = retrieval_chunks(api_key, payload)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["chunks"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_highlight, expected_message",
[
({"highlight": True}, 0, True, ""),
({"highlight": "True"}, 0, True, ""),
pytest.param({"highlight": False}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
({"highlight": "False"}, 0, False, ""),
pytest.param({"highlight": None}, 0, False, "", marks=pytest.mark.skip(reason="issues/6648")),
],
)
def test_highlight(self, api_key, add_chunks, payload, expected_code, expected_highlight, expected_message):
dataset_id, _, _ = add_chunks
payload.update({"question": "chunk", "dataset_ids": [dataset_id]})
res = retrieval_chunks(api_key, payload)
assert res["code"] == expected_code
if expected_highlight:
for chunk in res["data"]["chunks"]:
assert "highlight" in chunk
else:
for chunk in res["data"]["chunks"]:
assert "highlight" not in chunk
if expected_code != 0:
assert res["message"] == expected_message
@pytest.mark.p3
def test_invalid_params(self, api_key, add_chunks):
dataset_id, _, _ = add_chunks
payload = {"question": "chunk", "dataset_ids": [dataset_id], "a": "b"}
res = retrieval_chunks(api_key, payload)
assert res["code"] == 0
assert len(res["data"]["chunks"]) == 4
@pytest.mark.p3
def test_concurrent_retrieval(self, api_key, add_chunks):
from concurrent.futures import ThreadPoolExecutor
dataset_id, _, _ = add_chunks
payload = {"question": "chunk", "dataset_ids": [dataset_id]}
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(retrieval_chunks, api_key, payload) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)

View File

@ -0,0 +1,246 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from concurrent.futures import ThreadPoolExecutor
from random import randint
import pytest
from common import INVALID_API_TOKEN, delete_documnets, update_chunk
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = update_chunk(invalid_auth, "dataset_id", "document_id", "chunk_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestUpdatedChunk:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"content": None}, 100, "TypeError('expected string or bytes-like object')"),
pytest.param(
{"content": ""},
100,
"""APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""",
marks=pytest.mark.skip(reason="issues/6541"),
),
pytest.param(
{"content": 1},
100,
"TypeError('expected string or bytes-like object')",
marks=pytest.mark.skip,
),
({"content": "update chunk"}, 0, ""),
pytest.param(
{"content": " "},
100,
"""APIRequestFailedError(\'Error code: 400, with error text {"error":{"code":"1213","message":"未正常接收到prompt参数。"}}\')""",
marks=pytest.mark.skip(reason="issues/6541"),
),
({"content": "\n!?。;!?\"'"}, 0, ""),
],
)
def test_content(self, api_key, add_chunks, payload, expected_code, expected_message):
dataset_id, document_id, chunk_ids = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0], payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"important_keywords": ["a", "b", "c"]}, 0, ""),
({"important_keywords": [""]}, 0, ""),
({"important_keywords": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
({"important_keywords": ["a", "a"]}, 0, ""),
({"important_keywords": "abc"}, 102, "`important_keywords` should be a list"),
({"important_keywords": 123}, 102, "`important_keywords` should be a list"),
],
)
def test_important_keywords(self, api_key, add_chunks, payload, expected_code, expected_message):
dataset_id, document_id, chunk_ids = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0], payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"questions": ["a", "b", "c"]}, 0, ""),
({"questions": [""]}, 0, ""),
({"questions": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"),
({"questions": ["a", "a"]}, 0, ""),
({"questions": "abc"}, 102, "`questions` should be a list"),
({"questions": 123}, 102, "`questions` should be a list"),
],
)
def test_questions(self, api_key, add_chunks, payload, expected_code, expected_message):
dataset_id, document_id, chunk_ids = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0], payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"available": True}, 0, ""),
pytest.param({"available": "True"}, 100, """ValueError("invalid literal for int() with base 10: \'True\'")""", marks=pytest.mark.skip),
({"available": 1}, 0, ""),
({"available": False}, 0, ""),
pytest.param({"available": "False"}, 100, """ValueError("invalid literal for int() with base 10: \'False\'")""", marks=pytest.mark.skip),
({"available": 0}, 0, ""),
],
)
def test_available(
self,
api_key,
add_chunks,
payload,
expected_code,
expected_message,
):
dataset_id, document_id, chunk_ids = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0], payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<NotFound '404: Not Found'>"),
pytest.param("invalid_dataset_id", 102, "You don't own the dataset invalid_dataset_id.", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity")),
pytest.param("invalid_dataset_id", 102, "Can't find this chunk", marks=pytest.mark.skipif(os.getenv("DOC_ENGINE") in [None, "opensearch", "elasticsearch"], reason="elasticsearch")),
],
)
def test_invalid_dataset_id(self, api_key, add_chunks, dataset_id, expected_code, expected_message):
_, document_id, chunk_ids = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0])
assert res["code"] == expected_code
assert expected_message in res["message"]
@pytest.mark.p3
@pytest.mark.parametrize(
"document_id, expected_code, expected_message",
[
("", 100, "<NotFound '404: Not Found'>"),
(
"invalid_document_id",
102,
"You don't own the document invalid_document_id.",
),
],
)
def test_invalid_document_id(self, api_key, add_chunks, document_id, expected_code, expected_message):
dataset_id, _, chunk_ids = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0])
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"chunk_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
"invalid_document_id",
102,
"Can't find this chunk invalid_document_id",
),
],
)
def test_invalid_chunk_id(self, api_key, add_chunks, chunk_id, expected_code, expected_message):
dataset_id, document_id, _ = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_id)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
def test_repeated_update_chunk(self, api_key, add_chunks):
dataset_id, document_id, chunk_ids = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0], {"content": "chunk test 1"})
assert res["code"] == 0
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0], {"content": "chunk test 2"})
assert res["code"] == 0
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"unknown_key": "unknown_value"}, 0, ""),
({}, 0, ""),
pytest.param(None, 100, """TypeError("argument of type \'NoneType\' is not iterable")""", marks=pytest.mark.skip),
],
)
def test_invalid_params(self, api_key, add_chunks, payload, expected_code, expected_message):
dataset_id, document_id, chunk_ids = add_chunks
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0], payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="issues/6554")
def test_concurrent_update_chunk(self, api_key, add_chunks):
chunk_num = 50
dataset_id, document_id, chunk_ids = add_chunks
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
update_chunk,
api_key,
dataset_id,
document_id,
chunk_ids[randint(0, 3)],
{"content": f"update chunk test {i}"},
)
for i in range(chunk_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_update_chunk_to_deleted_document(self, api_key, add_chunks):
dataset_id, document_id, chunk_ids = add_chunks
delete_documnets(api_key, dataset_id, {"ids": [document_id]})
res = update_chunk(api_key, dataset_id, document_id, chunk_ids[0])
assert res["code"] == 102
assert res["message"] == f"Can't find this chunk {chunk_ids[0]}"

View File

@ -0,0 +1,39 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import batch_create_datasets, delete_datasets
@pytest.fixture(scope="class")
def add_datasets(api_key, request):
def cleanup():
delete_datasets(api_key, {"ids": None})
request.addfinalizer(cleanup)
return batch_create_datasets(api_key, 5)
@pytest.fixture(scope="function")
def add_datasets_func(api_key, request):
def cleanup():
delete_datasets(api_key, {"ids": None})
request.addfinalizer(cleanup)
return batch_create_datasets(api_key, 3)

View File

@ -0,0 +1,737 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from common import create_dataset
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
from hypothesis import example, given, settings
from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar
from utils.file_utils import create_image_file
from utils.hypothesis_utils import valid_names
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
ids=["empty_auth", "invalid_api_token"],
)
def test_auth_invalid(self, invalid_auth, expected_code, expected_message):
res = create_dataset(invalid_auth, {"name": "auth_test"})
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
class TestRquest:
@pytest.mark.p3
def test_content_type_bad(self, api_key):
BAD_CONTENT_TYPE = "text/xml"
res = create_dataset(api_key, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_message",
[
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected object, got str"),
],
ids=["malformed_json_syntax", "invalid_request_payload_type"],
)
def test_payload_bad(self, api_key, payload, expected_message):
res = create_dataset(api_key, data=payload)
assert res["code"] == 101, res
assert res["message"] == expected_message, res
@pytest.mark.usefixtures("clear_datasets")
class TestCapability:
@pytest.mark.p3
def test_create_dataset_1k(self, api_key):
for i in range(1_000):
payload = {"name": f"dataset_{i}"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, f"Failed to create dataset {i}"
@pytest.mark.p3
def test_create_dataset_concurrent(self, api_key):
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(create_dataset, api_key, {"name": f"dataset_{i}"}) for i in range(100)]
responses = list(as_completed(futures))
assert all(r["code"] == 0 for r in responses), responses
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetCreate:
@pytest.mark.p1
@given(name=valid_names())
@example("a" * 128)
@settings(max_examples=20)
def test_name(self, api_key, name):
res = create_dataset(api_key, {"name": name})
assert res["code"] == 0, res
assert res["data"]["name"] == name, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_message",
[
("", "String should have at least 1 character"),
(" ", "String should have at least 1 character"),
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
(0, "Input should be a valid string"),
(None, "Input should be a valid string"),
],
ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"],
)
def test_name_invalid(self, api_key, name, expected_message):
payload = {"name": name}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_name_duplicated(self, api_key):
name = "duplicated_name"
payload = {"name": name}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
res = create_dataset(api_key, payload)
assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p3
def test_name_case_insensitive(self, api_key):
name = "CaseInsensitive"
payload = {"name": name.upper()}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
payload = {"name": name.lower()}
res = create_dataset(api_key, payload)
assert res["code"] == 103, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
@pytest.mark.p2
def test_avatar(self, api_key, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {
"name": "avatar",
"avatar": f"data:image/png;base64,{encode_avatar(fn)}",
}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
@pytest.mark.p2
def test_avatar_exceeds_limit_length(self, api_key):
payload = {"name": "avatar_exceeds_limit_length", "avatar": "a" * 65536}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
@pytest.mark.parametrize(
"name, prefix, expected_message",
[
("empty_prefix", "", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
("missing_comma", "data:image/png;base64", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
("unsupported_mine_type", "invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"),
("invalid_mine_type", "data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"),
],
ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"],
)
def test_avatar_invalid_prefix(self, api_key, tmp_path, name, prefix, expected_message):
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {
"name": name,
"avatar": f"{prefix}{encode_avatar(fn)}",
}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_avatar_unset(self, api_key):
payload = {"name": "avatar_unset"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["avatar"] is None, res
@pytest.mark.p3
def test_avatar_none(self, api_key):
payload = {"name": "avatar_none", "avatar": None}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["avatar"] is None, res
@pytest.mark.p2
def test_description(self, api_key):
payload = {"name": "description", "description": "description"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["description"] == "description", res
@pytest.mark.p2
def test_description_exceeds_limit_length(self, api_key):
payload = {"name": "description_exceeds_limit_length", "description": "a" * 65536}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
def test_description_unset(self, api_key):
payload = {"name": "description_unset"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["description"] is None, res
@pytest.mark.p3
def test_description_none(self, api_key):
payload = {"name": "description_none", "description": None}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["description"] is None, res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, embedding_model",
[
("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"),
("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"),
("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"),
],
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"],
)
def test_embedding_model(self, api_key, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["embedding_model"] == embedding_model, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, embedding_model",
[
("unknown_llm_name", "unknown@ZHIPU-AI"),
("unknown_llm_factory", "embedding-3@unknown"),
("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"),
("tenant_no_auth", "text-embedding-3-small@OpenAI"),
],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
)
def test_embedding_model_invalid(self, api_key, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
if "tenant_no_auth" in name:
assert res["message"] == f"Unauthorized model: <{embedding_model}>", res
else:
assert res["message"] == f"Unsupported model: <{embedding_model}>", res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, embedding_model",
[
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_format(self, api_key, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
if name == "missing_at":
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@pytest.mark.p2
def test_embedding_model_unset(self, api_key):
payload = {"name": "embedding_model_unset"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
@pytest.mark.p2
def test_embedding_model_none(self, api_key):
payload = {"name": "embedding_model_none", "embedding_model": None}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert "Input should be a valid string" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, permission",
[
("me", "me"),
("team", "team"),
("me_upercase", "ME"),
("team_upercase", "TEAM"),
("whitespace", " ME "),
],
ids=["me", "team", "me_upercase", "team_upercase", "whitespace"],
)
def test_permission(self, api_key, name, permission):
payload = {"name": name, "permission": permission}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["permission"] == permission.lower().strip(), res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, permission",
[
("empty", ""),
("unknown", "unknown"),
("type_error", list()),
],
ids=["empty", "unknown", "type_error"],
)
def test_permission_invalid(self, api_key, name, permission):
payload = {"name": name, "permission": permission}
res = create_dataset(api_key, payload)
assert res["code"] == 101
assert "Input should be 'me' or 'team'" in res["message"]
@pytest.mark.p2
def test_permission_unset(self, api_key):
payload = {"name": "permission_unset"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["permission"] == "me", res
@pytest.mark.p3
def test_permission_none(self, api_key):
payload = {"name": "permission_none", "permission": None}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert "Input should be 'me' or 'team'" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, chunk_method",
[
("naive", "naive"),
("book", "book"),
("email", "email"),
("laws", "laws"),
("manual", "manual"),
("one", "one"),
("paper", "paper"),
("picture", "picture"),
("presentation", "presentation"),
("qa", "qa"),
("table", "table"),
("tag", "tag"),
],
ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
)
def test_chunk_method(self, api_key, name, chunk_method):
payload = {"name": name, "chunk_method": chunk_method}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["chunk_method"] == chunk_method, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, chunk_method",
[
("empty", ""),
("unknown", "unknown"),
("type_error", list()),
],
ids=["empty", "unknown", "type_error"],
)
def test_chunk_method_invalid(self, api_key, name, chunk_method):
payload = {"name": name, "chunk_method": chunk_method}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
def test_chunk_method_unset(self, api_key):
payload = {"name": "chunk_method_unset"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["chunk_method"] == "naive", res
@pytest.mark.p3
def test_chunk_method_none(self, api_key):
payload = {"name": "chunk_method_none", "chunk_method": None}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, api_key, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == pagerank, res
@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, api_key, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_pagerank_unset(self, api_key):
payload = {"name": "pagerank_unset"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == 0, res
@pytest.mark.p3
def test_pagerank_none(self, api_key):
payload = {"name": "pagerank_unset", "pagerank": None}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, parser_config",
[
("auto_keywords_min", {"auto_keywords": 0}),
("auto_keywords_mid", {"auto_keywords": 16}),
("auto_keywords_max", {"auto_keywords": 32}),
("auto_questions_min", {"auto_questions": 0}),
("auto_questions_mid", {"auto_questions": 5}),
("auto_questions_max", {"auto_questions": 10}),
("chunk_token_num_min", {"chunk_token_num": 1}),
("chunk_token_num_mid", {"chunk_token_num": 1024}),
("chunk_token_num_max", {"chunk_token_num": 2048}),
("delimiter", {"delimiter": "\n"}),
("delimiter_space", {"delimiter": " "}),
("html4excel_true", {"html4excel": True}),
("html4excel_false", {"html4excel": False}),
("layout_recognize_DeepDOC", {"layout_recognize": "DeepDOC"}),
("layout_recognize_navie", {"layout_recognize": "Plain Text"}),
("tag_kb_ids", {"tag_kb_ids": ["1", "2"]}),
("topn_tags_min", {"topn_tags": 1}),
("topn_tags_mid", {"topn_tags": 5}),
("topn_tags_max", {"topn_tags": 10}),
("filename_embd_weight_min", {"filename_embd_weight": 0.1}),
("filename_embd_weight_mid", {"filename_embd_weight": 0.5}),
("filename_embd_weight_max", {"filename_embd_weight": 1.0}),
("task_page_size_min", {"task_page_size": 1}),
("task_page_size_None", {"task_page_size": None}),
("pages", {"pages": [[1, 100]]}),
("pages_none", {"pages": None}),
("graphrag_true", {"graphrag": {"use_graphrag": True}}),
("graphrag_false", {"graphrag": {"use_graphrag": False}}),
("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}),
("graphrag_method_general", {"graphrag": {"method": "general"}}),
("graphrag_method_light", {"graphrag": {"method": "light"}}),
("graphrag_community_true", {"graphrag": {"community": True}}),
("graphrag_community_false", {"graphrag": {"community": False}}),
("graphrag_resolution_true", {"graphrag": {"resolution": True}}),
("graphrag_resolution_false", {"graphrag": {"resolution": False}}),
("raptor_true", {"raptor": {"use_raptor": True}}),
("raptor_false", {"raptor": {"use_raptor": False}}),
("raptor_prompt", {"raptor": {"prompt": "Who are you?"}}),
("raptor_max_token_min", {"raptor": {"max_token": 1}}),
("raptor_max_token_mid", {"raptor": {"max_token": 1024}}),
("raptor_max_token_max", {"raptor": {"max_token": 2048}}),
("raptor_threshold_min", {"raptor": {"threshold": 0.0}}),
("raptor_threshold_mid", {"raptor": {"threshold": 0.5}}),
("raptor_threshold_max", {"raptor": {"threshold": 1.0}}),
("raptor_max_cluster_min", {"raptor": {"max_cluster": 1}}),
("raptor_max_cluster_mid", {"raptor": {"max_cluster": 512}}),
("raptor_max_cluster_max", {"raptor": {"max_cluster": 1024}}),
("raptor_random_seed_min", {"raptor": {"random_seed": 0}}),
],
ids=[
"auto_keywords_min",
"auto_keywords_mid",
"auto_keywords_max",
"auto_questions_min",
"auto_questions_mid",
"auto_questions_max",
"chunk_token_num_min",
"chunk_token_num_mid",
"chunk_token_num_max",
"delimiter",
"delimiter_space",
"html4excel_true",
"html4excel_false",
"layout_recognize_DeepDOC",
"layout_recognize_navie",
"tag_kb_ids",
"topn_tags_min",
"topn_tags_mid",
"topn_tags_max",
"filename_embd_weight_min",
"filename_embd_weight_mid",
"filename_embd_weight_max",
"task_page_size_min",
"task_page_size_None",
"pages",
"pages_none",
"graphrag_true",
"graphrag_false",
"graphrag_entity_types",
"graphrag_method_general",
"graphrag_method_light",
"graphrag_community_true",
"graphrag_community_false",
"graphrag_resolution_true",
"graphrag_resolution_false",
"raptor_true",
"raptor_false",
"raptor_prompt",
"raptor_max_token_min",
"raptor_max_token_mid",
"raptor_max_token_max",
"raptor_threshold_min",
"raptor_threshold_mid",
"raptor_threshold_max",
"raptor_max_cluster_min",
"raptor_max_cluster_mid",
"raptor_max_cluster_max",
"raptor_random_seed_min",
],
)
def test_parser_config(self, api_key, name, parser_config):
payload = {"name": name, "parser_config": parser_config}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
for k, v in parser_config.items():
if isinstance(v, dict):
for kk, vv in v.items():
assert res["data"]["parser_config"][k][kk] == vv, res
else:
assert res["data"]["parser_config"][k] == v, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, parser_config, expected_message",
[
("auto_keywords_min_limit", {"auto_keywords": -1}, "Input should be greater than or equal to 0"),
("auto_keywords_max_limit", {"auto_keywords": 33}, "Input should be less than or equal to 32"),
("auto_keywords_float_not_allowed", {"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
("auto_keywords_type_invalid", {"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
("auto_questions_min_limit", {"auto_questions": -1}, "Input should be greater than or equal to 0"),
("auto_questions_max_limit", {"auto_questions": 11}, "Input should be less than or equal to 10"),
("auto_questions_float_not_allowed", {"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
("auto_questions_type_invalid", {"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
("chunk_token_num_min_limit", {"chunk_token_num": 0}, "Input should be greater than or equal to 1"),
("chunk_token_num_max_limit", {"chunk_token_num": 2049}, "Input should be less than or equal to 2048"),
("chunk_token_num_float_not_allowed", {"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
("chunk_token_num_type_invalid", {"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
("delimiter_empty", {"delimiter": ""}, "String should have at least 1 character"),
("html4excel_type_invalid", {"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"),
("tag_kb_ids_not_list", {"tag_kb_ids": "1,2"}, "Input should be a valid list"),
("tag_kb_ids_int_in_list", {"tag_kb_ids": [1, 2]}, "Input should be a valid string"),
("topn_tags_min_limit", {"topn_tags": 0}, "Input should be greater than or equal to 1"),
("topn_tags_max_limit", {"topn_tags": 11}, "Input should be less than or equal to 10"),
("topn_tags_float_not_allowed", {"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
("topn_tags_type_invalid", {"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
("filename_embd_weight_min_limit", {"filename_embd_weight": -1}, "Input should be greater than or equal to 0"),
("filename_embd_weight_max_limit", {"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"),
("filename_embd_weight_type_invalid", {"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"),
("task_page_size_min_limit", {"task_page_size": 0}, "Input should be greater than or equal to 1"),
("task_page_size_float_not_allowed", {"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
("task_page_size_type_invalid", {"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
("pages_not_list", {"pages": "1,2"}, "Input should be a valid list"),
("pages_not_list_in_list", {"pages": ["1,2"]}, "Input should be a valid list"),
("pages_not_int_list", {"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"),
("graphrag_type_invalid", {"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"),
("graphrag_entity_types_not_list", {"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"),
("graphrag_entity_types_not_str_in_list", {"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"),
("graphrag_method_unknown", {"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"),
("graphrag_method_none", {"graphrag": {"method": None}}, "Input should be 'light' or 'general'"),
("graphrag_community_type_invalid", {"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"),
("graphrag_resolution_type_invalid", {"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"),
("raptor_type_invalid", {"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"),
("raptor_prompt_empty", {"raptor": {"prompt": ""}}, "String should have at least 1 character"),
("raptor_prompt_space", {"raptor": {"prompt": " "}}, "String should have at least 1 character"),
("raptor_max_token_min_limit", {"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"),
("raptor_max_token_max_limit", {"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"),
("raptor_max_token_float_not_allowed", {"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
("raptor_max_token_type_invalid", {"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
("raptor_threshold_min_limit", {"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"),
("raptor_threshold_max_limit", {"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"),
("raptor_threshold_type_invalid", {"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"),
("raptor_max_cluster_min_limit", {"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"),
("raptor_max_cluster_max_limit", {"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"),
("raptor_max_cluster_float_not_allowed", {"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"),
("raptor_max_cluster_type_invalid", {"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
],
ids=[
"auto_keywords_min_limit",
"auto_keywords_max_limit",
"auto_keywords_float_not_allowed",
"auto_keywords_type_invalid",
"auto_questions_min_limit",
"auto_questions_max_limit",
"auto_questions_float_not_allowed",
"auto_questions_type_invalid",
"chunk_token_num_min_limit",
"chunk_token_num_max_limit",
"chunk_token_num_float_not_allowed",
"chunk_token_num_type_invalid",
"delimiter_empty",
"html4excel_type_invalid",
"tag_kb_ids_not_list",
"tag_kb_ids_int_in_list",
"topn_tags_min_limit",
"topn_tags_max_limit",
"topn_tags_float_not_allowed",
"topn_tags_type_invalid",
"filename_embd_weight_min_limit",
"filename_embd_weight_max_limit",
"filename_embd_weight_type_invalid",
"task_page_size_min_limit",
"task_page_size_float_not_allowed",
"task_page_size_type_invalid",
"pages_not_list",
"pages_not_list_in_list",
"pages_not_int_list",
"graphrag_type_invalid",
"graphrag_entity_types_not_list",
"graphrag_entity_types_not_str_in_list",
"graphrag_method_unknown",
"graphrag_method_none",
"graphrag_community_type_invalid",
"graphrag_resolution_type_invalid",
"raptor_type_invalid",
"raptor_prompt_empty",
"raptor_prompt_space",
"raptor_max_token_min_limit",
"raptor_max_token_max_limit",
"raptor_max_token_float_not_allowed",
"raptor_max_token_type_invalid",
"raptor_threshold_min_limit",
"raptor_threshold_max_limit",
"raptor_threshold_type_invalid",
"raptor_max_cluster_min_limit",
"raptor_max_cluster_max_limit",
"raptor_max_cluster_float_not_allowed",
"raptor_max_cluster_type_invalid",
"raptor_random_seed_min_limit",
"raptor_random_seed_float_not_allowed",
"raptor_random_seed_type_invalid",
"parser_config_type_invalid",
],
)
def test_parser_config_invalid(self, api_key, name, parser_config, expected_message):
payload = {"name": name, "parser_config": parser_config}
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_parser_config_empty(self, api_key):
payload = {"name": "parser_config_empty", "parser_config": {}}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}, res
@pytest.mark.p2
def test_parser_config_unset(self, api_key):
payload = {"name": "parser_config_unset"}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}, res
@pytest.mark.p3
def test_parser_config_none(self, api_key):
payload = {"name": "parser_config_none", "parser_config": None}
res = create_dataset(api_key, payload)
assert res["code"] == 0, res
assert res["data"]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": "\\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}, res
@pytest.mark.p2
@pytest.mark.parametrize(
"payload",
[
{"name": "id", "id": "id"},
{"name": "tenant_id", "tenant_id": "e57c1966f99211efb41e9e45646e0111"},
{"name": "created_by", "created_by": "created_by"},
{"name": "create_date", "create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"name": "create_time", "create_time": 1741671443322},
{"name": "update_date", "update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"name": "update_time", "update_time": 1741671443339},
{"name": "document_count", "document_count": 1},
{"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"},
{"name": "unknown_field", "unknown_field": "unknown_field"},
],
)
def test_unsupported_field(self, api_key, payload):
res = create_dataset(api_key, payload)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res

View File

@ -0,0 +1,219 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
batch_create_datasets,
delete_datasets,
list_datasets,
)
from libs.auth import RAGFlowHttpApiAuth
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_auth_invalid(self, invalid_auth, expected_code, expected_message):
res = delete_datasets(invalid_auth)
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
class TestRquest:
@pytest.mark.p3
def test_content_type_bad(self, api_key):
BAD_CONTENT_TYPE = "text/xml"
res = delete_datasets(api_key, headers={"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_message",
[
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected object, got str"),
],
ids=["malformed_json_syntax", "invalid_request_payload_type"],
)
def test_payload_bad(self, api_key, payload, expected_message):
res = delete_datasets(api_key, data=payload)
assert res["code"] == 101, res
assert res["message"] == expected_message, res
@pytest.mark.p3
def test_payload_unset(self, api_key):
res = delete_datasets(api_key, None)
assert res["code"] == 101, res
assert res["message"] == "Malformed JSON syntax: Missing commas/brackets or invalid encoding", res
class TestCapability:
@pytest.mark.p3
def test_delete_dataset_1k(self, api_key):
ids = batch_create_datasets(api_key, 1_000)
res = delete_datasets(api_key, {"ids": ids})
assert res["code"] == 0, res
res = list_datasets(api_key)
assert len(res["data"]) == 0, res
@pytest.mark.p3
def test_concurrent_deletion(self, api_key):
dataset_num = 1_000
ids = batch_create_datasets(api_key, dataset_num)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(delete_datasets, api_key, {"ids": ids[i : i + 1]}) for i in range(dataset_num)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses), responses
class TestDatasetsDelete:
@pytest.mark.p1
@pytest.mark.parametrize(
"func, expected_code, expected_message, remaining",
[
(lambda r: {"ids": r[:1]}, 0, "", 2),
(lambda r: {"ids": r}, 0, "", 0),
],
ids=["single_dataset", "multiple_datasets"],
)
def test_ids(self, api_key, add_datasets_func, func, expected_code, expected_message, remaining):
dataset_ids = add_datasets_func
if callable(func):
payload = func(dataset_ids)
res = delete_datasets(api_key, payload)
assert res["code"] == expected_code, res
res = list_datasets(api_key)
assert len(res["data"]) == remaining, res
@pytest.mark.p1
@pytest.mark.usefixtures("add_dataset_func")
def test_ids_empty(self, api_key):
payload = {"ids": []}
res = delete_datasets(api_key, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert len(res["data"]) == 1, res
@pytest.mark.p1
@pytest.mark.usefixtures("add_datasets_func")
def test_ids_none(self, api_key):
payload = {"ids": None}
res = delete_datasets(api_key, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert len(res["data"]) == 0, res
@pytest.mark.p2
@pytest.mark.usefixtures("add_dataset_func")
def test_id_not_uuid(self, api_key):
payload = {"ids": ["not_uuid"]}
res = delete_datasets(api_key, payload)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
res = list_datasets(api_key)
assert len(res["data"]) == 1, res
@pytest.mark.p3
@pytest.mark.usefixtures("add_dataset_func")
def test_id_not_uuid1(self, api_key):
payload = {"ids": [uuid.uuid4().hex]}
res = delete_datasets(api_key, payload)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
@pytest.mark.usefixtures("add_dataset_func")
def test_id_wrong_uuid(self, api_key):
payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]}
res = delete_datasets(api_key, payload)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
res = list_datasets(api_key)
assert len(res["data"]) == 1, res
@pytest.mark.p2
@pytest.mark.parametrize(
"func",
[
lambda r: {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"] + r},
lambda r: {"ids": r[:1] + ["d94a8dc02c9711f0930f7fbc369eab6d"] + r[1:3]},
lambda r: {"ids": r + ["d94a8dc02c9711f0930f7fbc369eab6d"]},
],
)
def test_ids_partial_invalid(self, api_key, add_datasets_func, func):
dataset_ids = add_datasets_func
if callable(func):
payload = func(dataset_ids)
res = delete_datasets(api_key, payload)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
res = list_datasets(api_key)
assert len(res["data"]) == 3, res
@pytest.mark.p2
def test_ids_duplicate(self, api_key, add_datasets_func):
dataset_ids = add_datasets_func
payload = {"ids": dataset_ids + dataset_ids}
res = delete_datasets(api_key, payload)
assert res["code"] == 101, res
assert "Duplicate ids:" in res["message"], res
res = list_datasets(api_key)
assert len(res["data"]) == 3, res
@pytest.mark.p2
def test_repeated_delete(self, api_key, add_datasets_func):
dataset_ids = add_datasets_func
payload = {"ids": dataset_ids}
res = delete_datasets(api_key, payload)
assert res["code"] == 0, res
res = delete_datasets(api_key, payload)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
@pytest.mark.usefixtures("add_dataset_func")
def test_field_unsupported(self, api_key):
payload = {"unknown_field": "unknown_field"}
res = delete_datasets(api_key, payload)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res
res = list_datasets(api_key)
assert len(res["data"]) == 1, res

View File

@ -0,0 +1,339 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, list_datasets
from libs.auth import RAGFlowHttpApiAuth
from utils import is_sorted
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_auth_invalid(self, invalid_auth, expected_code, expected_message):
res = list_datasets(invalid_auth)
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
class TestCapability:
@pytest.mark.p3
def test_concurrent_list(self, api_key):
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_datasets, api_key) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses), responses
@pytest.mark.usefixtures("add_datasets")
class TestDatasetsList:
@pytest.mark.p1
def test_params_unset(self, api_key):
res = list_datasets(api_key, None)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
def test_params_empty(self, api_key):
res = list_datasets(api_key, {})
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"page": 2, "page_size": 2}, 2),
({"page": 3, "page_size": 2}, 1),
({"page": 4, "page_size": 2}, 0),
({"page": "2", "page_size": 2}, 2),
({"page": 1, "page_size": 10}, 5),
],
ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"],
)
def test_page(self, api_key, params, expected_page_size):
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_page_size, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_code, expected_message",
[
({"page": 0}, 101, "Input should be greater than or equal to 1"),
({"page": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"),
],
ids=["page_0", "page_a"],
)
def test_page_invalid(self, api_key, params, expected_code, expected_message):
res = list_datasets(api_key, params=params)
assert res["code"] == expected_code, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_page_none(self, api_key):
params = {"page": None}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"page_size": 1}, 1),
({"page_size": 3}, 3),
({"page_size": 5}, 5),
({"page_size": 6}, 5),
({"page_size": "1"}, 1),
],
ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"],
)
def test_page_size(self, api_key, params, expected_page_size):
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_page_size, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_code, expected_message",
[
({"page_size": 0}, 101, "Input should be greater than or equal to 1"),
({"page_size": "a"}, 101, "Input should be a valid integer, unable to parse string as an integer"),
],
)
def test_page_size_invalid(self, api_key, params, expected_code, expected_message):
res = list_datasets(api_key, params)
assert res["code"] == expected_code, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_page_size_none(self, api_key):
params = {"page_size": None}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, assertions",
[
({"orderby": "create_time"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))),
({"orderby": "CREATE_TIME"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"orderby": "UPDATE_TIME"}, lambda r: (is_sorted(r["data"], "update_time", True))),
({"orderby": " create_time "}, lambda r: (is_sorted(r["data"], "update_time", True))),
],
ids=["orderby_create_time", "orderby_update_time", "orderby_create_time_upper", "orderby_update_time_upper", "whitespace"],
)
def test_orderby(self, api_key, params, assertions):
res = list_datasets(api_key, params)
assert res["code"] == 0, res
if callable(assertions):
assert assertions(res), res
@pytest.mark.p3
@pytest.mark.parametrize(
"params",
[
{"orderby": ""},
{"orderby": "unknown"},
],
ids=["empty", "unknown"],
)
def test_orderby_invalid(self, api_key, params):
res = list_datasets(api_key, params)
assert res["code"] == 101, res
assert "Input should be 'create_time' or 'update_time'" in res["message"], res
@pytest.mark.p3
def test_orderby_none(self, api_key):
params = {"order_by": None}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert is_sorted(res["data"], "create_time", True), res
@pytest.mark.p2
@pytest.mark.parametrize(
"params, assertions",
[
({"desc": True}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": False}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "true"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "false"}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": 1}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": 0}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "yes"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "no"}, lambda r: (is_sorted(r["data"], "create_time", False))),
({"desc": "y"}, lambda r: (is_sorted(r["data"], "create_time", True))),
({"desc": "n"}, lambda r: (is_sorted(r["data"], "create_time", False))),
],
ids=["desc=True", "desc=False", "desc=true", "desc=false", "desc=1", "desc=0", "desc=yes", "desc=no", "desc=y", "desc=n"],
)
def test_desc(self, api_key, params, assertions):
res = list_datasets(api_key, params)
assert res["code"] == 0, res
if callable(assertions):
assert assertions(res), res
@pytest.mark.p3
@pytest.mark.parametrize(
"params",
[
{"desc": 3.14},
{"desc": "unknown"},
],
ids=["empty", "unknown"],
)
def test_desc_invalid(self, api_key, params):
res = list_datasets(api_key, params)
assert res["code"] == 101, res
assert "Input should be a valid boolean, unable to interpret input" in res["message"], res
@pytest.mark.p3
def test_desc_none(self, api_key):
params = {"desc": None}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert is_sorted(res["data"], "create_time", True), res
@pytest.mark.p1
def test_name(self, api_key):
params = {"name": "dataset_1"}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == 1, res
assert res["data"][0]["name"] == "dataset_1", res
@pytest.mark.p2
def test_name_wrong(self, api_key):
params = {"name": "wrong name"}
res = list_datasets(api_key, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_name_empty(self, api_key):
params = {"name": ""}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
def test_name_none(self, api_key):
params = {"name": None}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p1
def test_id(self, api_key, add_datasets):
dataset_ids = add_datasets
params = {"id": dataset_ids[0]}
res = list_datasets(api_key, params)
assert res["code"] == 0
assert len(res["data"]) == 1
assert res["data"][0]["id"] == dataset_ids[0]
@pytest.mark.p2
def test_id_not_uuid(self, api_key):
params = {"id": "not_uuid"}
res = list_datasets(api_key, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_not_uuid1(self, api_key):
params = {"id": uuid.uuid4().hex}
res = list_datasets(api_key, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_wrong_uuid(self, api_key):
params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"}
res = list_datasets(api_key, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_id_empty(self, api_key):
params = {"id": ""}
res = list_datasets(api_key, params)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p2
def test_id_none(self, api_key):
params = {"id": None}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == 5, res
@pytest.mark.p2
@pytest.mark.parametrize(
"func, name, expected_num",
[
(lambda r: r[0], "dataset_0", 1),
(lambda r: r[0], "dataset_1", 0),
],
ids=["name_and_id_match", "name_and_id_mismatch"],
)
def test_name_and_id(self, api_key, add_datasets, func, name, expected_num):
dataset_ids = add_datasets
if callable(func):
params = {"id": func(dataset_ids), "name": name}
res = list_datasets(api_key, params)
assert res["code"] == 0, res
assert len(res["data"]) == expected_num, res
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, name",
[
(lambda r: r[0], "wrong_name"),
(uuid.uuid1().hex, "dataset_0"),
],
ids=["name", "id"],
)
def test_name_and_id_wrong(self, api_key, add_datasets, dataset_id, name):
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
res = list_datasets(api_key, params)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p2
def test_field_unsupported(self, api_key):
params = {"unknown_field": "unknown_field"}
res = list_datasets(api_key, params)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res

View File

@ -0,0 +1,820 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset
from hypothesis import HealthCheck, example, given, settings
from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar
from utils.file_utils import create_image_file
from utils.hypothesis_utils import valid_names
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
ids=["empty_auth", "invalid_api_token"],
)
def test_auth_invalid(self, invalid_auth, expected_code, expected_message):
res = update_dataset(invalid_auth, "dataset_id")
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
class TestRquest:
@pytest.mark.p3
def test_bad_content_type(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
BAD_CONTENT_TYPE = "text/xml"
res = update_dataset(api_key, dataset_id, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_message",
[
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected object, got str"),
],
ids=["malformed_json_syntax", "invalid_request_payload_type"],
)
def test_payload_bad(self, api_key, add_dataset_func, payload, expected_message):
dataset_id = add_dataset_func
res = update_dataset(api_key, dataset_id, data=payload)
assert res["code"] == 101, res
assert res["message"] == expected_message, res
@pytest.mark.p2
def test_payload_empty(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
res = update_dataset(api_key, dataset_id, {})
assert res["code"] == 101, res
assert res["message"] == "No properties were modified", res
@pytest.mark.p3
def test_payload_unset(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
res = update_dataset(api_key, dataset_id, None)
assert res["code"] == 101, res
assert res["message"] == "Malformed JSON syntax: Missing commas/brackets or invalid encoding", res
class TestCapability:
@pytest.mark.p3
def test_update_dateset_concurrent(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(update_dataset, api_key, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses), responses
class TestDatasetUpdate:
@pytest.mark.p3
def test_dataset_id_not_uuid(self, api_key):
payload = {"name": "not uuid"}
res = update_dataset(api_key, "not_uuid", payload)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p3
def test_dataset_id_not_uuid1(self, api_key):
payload = {"name": "not uuid1"}
res = update_dataset(api_key, uuid.uuid4().hex, payload)
assert res["code"] == 101, res
assert "Invalid UUID1 format" in res["message"], res
@pytest.mark.p3
def test_dataset_id_wrong_uuid(self, api_key):
payload = {"name": "wrong uuid"}
res = update_dataset(api_key, "d94a8dc02c9711f0930f7fbc369eab6d", payload)
assert res["code"] == 108, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p1
@given(name=valid_names())
@example("a" * 128)
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
def test_name(self, api_key, add_dataset_func, name):
dataset_id = add_dataset_func
payload = {"name": name}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["name"] == name, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_message",
[
("", "String should have at least 1 character"),
(" ", "String should have at least 1 character"),
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
(0, "Input should be a valid string"),
(None, "Input should be a valid string"),
],
ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"],
)
def test_name_invalid(self, api_key, add_dataset_func, name, expected_message):
dataset_id = add_dataset_func
payload = {"name": name}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_name_duplicated(self, api_key, add_datasets_func):
dataset_ids = add_datasets_func[0]
name = "dataset_1"
payload = {"name": name}
res = update_dataset(api_key, dataset_ids, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p3
def test_name_case_insensitive(self, api_key, add_datasets_func):
dataset_id = add_datasets_func[0]
name = "DATASET_1"
payload = {"name": name}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p2
def test_avatar(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {
"avatar": f"data:image/png;base64,{encode_avatar(fn)}",
}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res
@pytest.mark.p2
def test_avatar_exceeds_limit_length(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"avatar": "a" * 65536}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
@pytest.mark.parametrize(
"avatar_prefix, expected_message",
[
("", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
("data:image/png;base64", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
("invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"),
("data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"),
],
ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"],
)
def test_avatar_invalid_prefix(self, api_key, add_dataset_func, tmp_path, avatar_prefix, expected_message):
dataset_id = add_dataset_func
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"avatar": f"{avatar_prefix}{encode_avatar(fn)}"}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_avatar_none(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"avatar": None}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["avatar"] is None, res
@pytest.mark.p2
def test_description(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"description": "description"}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0
res = list_datasets(api_key, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["description"] == "description"
@pytest.mark.p2
def test_description_exceeds_limit_length(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"description": "a" * 65536}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
def test_description_none(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"description": None}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["description"] is None
@pytest.mark.p1
@pytest.mark.parametrize(
"embedding_model",
[
"BAAI/bge-large-zh-v1.5@BAAI",
"maidalun1020/bce-embedding-base_v1@Youdao",
"embedding-3@ZHIPU-AI",
],
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"],
)
def test_embedding_model(self, api_key, add_dataset_func, embedding_model):
dataset_id = add_dataset_func
payload = {"embedding_model": embedding_model}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["embedding_model"] == embedding_model, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, embedding_model",
[
("unknown_llm_name", "unknown@ZHIPU-AI"),
("unknown_llm_factory", "embedding-3@unknown"),
("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"),
("tenant_no_auth", "text-embedding-3-small@OpenAI"),
],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
)
def test_embedding_model_invalid(self, api_key, add_dataset_func, name, embedding_model):
dataset_id = add_dataset_func
payload = {"name": name, "embedding_model": embedding_model}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
if "tenant_no_auth" in name:
assert res["message"] == f"Unauthorized model: <{embedding_model}>", res
else:
assert res["message"] == f"Unsupported model: <{embedding_model}>", res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, embedding_model",
[
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_format(self, api_key, add_dataset_func, name, embedding_model):
dataset_id = add_dataset_func
payload = {"name": name, "embedding_model": embedding_model}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
if name == "missing_at":
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@pytest.mark.p2
def test_embedding_model_none(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"embedding_model": None}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid string" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"permission",
[
"me",
"team",
"ME",
"TEAM",
" ME ",
],
ids=["me", "team", "me_upercase", "team_upercase", "whitespace"],
)
def test_permission(self, api_key, add_dataset_func, permission):
dataset_id = add_dataset_func
payload = {"permission": permission}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["permission"] == permission.lower().strip(), res
@pytest.mark.p2
@pytest.mark.parametrize(
"permission",
[
"",
"unknown",
list(),
],
ids=["empty", "unknown", "type_error"],
)
def test_permission_invalid(self, api_key, add_dataset_func, permission):
dataset_id = add_dataset_func
payload = {"permission": permission}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101
assert "Input should be 'me' or 'team'" in res["message"]
@pytest.mark.p3
def test_permission_none(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"permission": None}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'me' or 'team'" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"chunk_method",
[
"naive",
"book",
"email",
"laws",
"manual",
"one",
"paper",
"picture",
"presentation",
"qa",
"table",
"tag",
],
ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
)
def test_chunk_method(self, api_key, add_dataset_func, chunk_method):
dataset_id = add_dataset_func
payload = {"chunk_method": chunk_method}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["chunk_method"] == chunk_method, res
@pytest.mark.p2
@pytest.mark.parametrize(
"chunk_method",
[
"",
"unknown",
list(),
],
ids=["empty", "unknown", "type_error"],
)
def test_chunk_method_invalid(self, api_key, add_dataset_func, chunk_method):
dataset_id = add_dataset_func
payload = {"chunk_method": chunk_method}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p3
def test_chunk_method_none(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"chunk_method": None}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, api_key, add_dataset_func, pagerank):
dataset_id = add_dataset_func
payload = {"pagerank": pagerank}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0
res = list_datasets(api_key, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == pagerank
@pytest.mark.p2
@pytest.mark.parametrize(
"pagerank, expected_message",
[
(-1, "Input should be greater than or equal to 0"),
(101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, api_key, add_dataset_func, pagerank, expected_message):
dataset_id = add_dataset_func
payload = {"pagerank": pagerank}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_pagerank_none(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": None}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"parser_config",
[
{"auto_keywords": 0},
{"auto_keywords": 16},
{"auto_keywords": 32},
{"auto_questions": 0},
{"auto_questions": 5},
{"auto_questions": 10},
{"chunk_token_num": 1},
{"chunk_token_num": 1024},
{"chunk_token_num": 2048},
{"delimiter": "\n"},
{"delimiter": " "},
{"html4excel": True},
{"html4excel": False},
{"layout_recognize": "DeepDOC"},
{"layout_recognize": "Plain Text"},
{"tag_kb_ids": ["1", "2"]},
{"topn_tags": 1},
{"topn_tags": 5},
{"topn_tags": 10},
{"filename_embd_weight": 0.1},
{"filename_embd_weight": 0.5},
{"filename_embd_weight": 1.0},
{"task_page_size": 1},
{"task_page_size": None},
{"pages": [[1, 100]]},
{"pages": None},
{"graphrag": {"use_graphrag": True}},
{"graphrag": {"use_graphrag": False}},
{"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}},
{"graphrag": {"method": "general"}},
{"graphrag": {"method": "light"}},
{"graphrag": {"community": True}},
{"graphrag": {"community": False}},
{"graphrag": {"resolution": True}},
{"graphrag": {"resolution": False}},
{"raptor": {"use_raptor": True}},
{"raptor": {"use_raptor": False}},
{"raptor": {"prompt": "Who are you?"}},
{"raptor": {"max_token": 1}},
{"raptor": {"max_token": 1024}},
{"raptor": {"max_token": 2048}},
{"raptor": {"threshold": 0.0}},
{"raptor": {"threshold": 0.5}},
{"raptor": {"threshold": 1.0}},
{"raptor": {"max_cluster": 1}},
{"raptor": {"max_cluster": 512}},
{"raptor": {"max_cluster": 1024}},
{"raptor": {"random_seed": 0}},
],
ids=[
"auto_keywords_min",
"auto_keywords_mid",
"auto_keywords_max",
"auto_questions_min",
"auto_questions_mid",
"auto_questions_max",
"chunk_token_num_min",
"chunk_token_num_mid",
"chunk_token_num_max",
"delimiter",
"delimiter_space",
"html4excel_true",
"html4excel_false",
"layout_recognize_DeepDOC",
"layout_recognize_navie",
"tag_kb_ids",
"topn_tags_min",
"topn_tags_mid",
"topn_tags_max",
"filename_embd_weight_min",
"filename_embd_weight_mid",
"filename_embd_weight_max",
"task_page_size_min",
"task_page_size_None",
"pages",
"pages_none",
"graphrag_true",
"graphrag_false",
"graphrag_entity_types",
"graphrag_method_general",
"graphrag_method_light",
"graphrag_community_true",
"graphrag_community_false",
"graphrag_resolution_true",
"graphrag_resolution_false",
"raptor_true",
"raptor_false",
"raptor_prompt",
"raptor_max_token_min",
"raptor_max_token_mid",
"raptor_max_token_max",
"raptor_threshold_min",
"raptor_threshold_mid",
"raptor_threshold_max",
"raptor_max_cluster_min",
"raptor_max_cluster_mid",
"raptor_max_cluster_max",
"raptor_random_seed_min",
],
)
def test_parser_config(self, api_key, add_dataset_func, parser_config):
dataset_id = add_dataset_func
payload = {"parser_config": parser_config}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
for k, v in parser_config.items():
if isinstance(v, dict):
for kk, vv in v.items():
assert res["data"][0]["parser_config"][k][kk] == vv, res
else:
assert res["data"][0]["parser_config"][k] == v, res
@pytest.mark.p2
@pytest.mark.parametrize(
"parser_config, expected_message",
[
({"auto_keywords": -1}, "Input should be greater than or equal to 0"),
({"auto_keywords": 33}, "Input should be less than or equal to 32"),
({"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"auto_questions": -1}, "Input should be greater than or equal to 0"),
({"auto_questions": 11}, "Input should be less than or equal to 10"),
({"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"chunk_token_num": 0}, "Input should be greater than or equal to 1"),
({"chunk_token_num": 2049}, "Input should be less than or equal to 2048"),
({"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"delimiter": ""}, "String should have at least 1 character"),
({"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"),
({"tag_kb_ids": "1,2"}, "Input should be a valid list"),
({"tag_kb_ids": [1, 2]}, "Input should be a valid string"),
({"topn_tags": 0}, "Input should be greater than or equal to 1"),
({"topn_tags": 11}, "Input should be less than or equal to 10"),
({"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"filename_embd_weight": -1}, "Input should be greater than or equal to 0"),
({"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"),
({"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"),
({"task_page_size": 0}, "Input should be greater than or equal to 1"),
({"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"pages": "1,2"}, "Input should be a valid list"),
({"pages": ["1,2"]}, "Input should be a valid list"),
({"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"),
({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"),
({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"),
({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"),
({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"),
({"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"raptor": {"prompt": ""}}, "String should have at least 1 character"),
({"raptor": {"prompt": " "}}, "String should have at least 1 character"),
({"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"),
({"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"),
({"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
({"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"),
({"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"),
({"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"),
({"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"),
({"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"),
({"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"),
({"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
({"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
],
ids=[
"auto_keywords_min_limit",
"auto_keywords_max_limit",
"auto_keywords_float_not_allowed",
"auto_keywords_type_invalid",
"auto_questions_min_limit",
"auto_questions_max_limit",
"auto_questions_float_not_allowed",
"auto_questions_type_invalid",
"chunk_token_num_min_limit",
"chunk_token_num_max_limit",
"chunk_token_num_float_not_allowed",
"chunk_token_num_type_invalid",
"delimiter_empty",
"html4excel_type_invalid",
"tag_kb_ids_not_list",
"tag_kb_ids_int_in_list",
"topn_tags_min_limit",
"topn_tags_max_limit",
"topn_tags_float_not_allowed",
"topn_tags_type_invalid",
"filename_embd_weight_min_limit",
"filename_embd_weight_max_limit",
"filename_embd_weight_type_invalid",
"task_page_size_min_limit",
"task_page_size_float_not_allowed",
"task_page_size_type_invalid",
"pages_not_list",
"pages_not_list_in_list",
"pages_not_int_list",
"graphrag_type_invalid",
"graphrag_entity_types_not_list",
"graphrag_entity_types_not_str_in_list",
"graphrag_method_unknown",
"graphrag_method_none",
"graphrag_community_type_invalid",
"graphrag_resolution_type_invalid",
"raptor_type_invalid",
"raptor_prompt_empty",
"raptor_prompt_space",
"raptor_max_token_min_limit",
"raptor_max_token_max_limit",
"raptor_max_token_float_not_allowed",
"raptor_max_token_type_invalid",
"raptor_threshold_min_limit",
"raptor_threshold_max_limit",
"raptor_threshold_type_invalid",
"raptor_max_cluster_min_limit",
"raptor_max_cluster_max_limit",
"raptor_max_cluster_float_not_allowed",
"raptor_max_cluster_type_invalid",
"raptor_random_seed_min_limit",
"raptor_random_seed_float_not_allowed",
"raptor_random_seed_type_invalid",
"parser_config_type_invalid",
],
)
def test_parser_config_invalid(self, api_key, add_dataset_func, parser_config, expected_message):
dataset_id = add_dataset_func
payload = {"parser_config": parser_config}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_parser_config_empty(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"parser_config": {}}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}, res
@pytest.mark.p3
def test_parser_config_none(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"parser_config": None}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}, res
@pytest.mark.p3
def test_parser_config_empty_with_chunk_method_change(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"chunk_method": "qa", "parser_config": {}}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res
@pytest.mark.p3
def test_parser_config_unset_with_chunk_method_change(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"chunk_method": "qa"}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res
@pytest.mark.p3
def test_parser_config_none_with_chunk_method_change(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
payload = {"chunk_method": "qa", "parser_config": None}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {"raptor": {"use_raptor": False}}, res
@pytest.mark.p2
@pytest.mark.parametrize(
"payload",
[
{"id": "id"},
{"tenant_id": "e57c1966f99211efb41e9e45646e0111"},
{"created_by": "created_by"},
{"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"create_time": 1741671443322},
{"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"update_time": 1741671443339},
{"document_count": 1},
{"chunk_count": 1},
{"token_num": 1},
{"status": "1"},
{"unknown_field": "unknown_field"},
],
)
def test_field_unsupported(self, api_key, add_dataset_func, payload):
dataset_id = add_dataset_func
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res
@pytest.mark.p2
def test_field_unset(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
res = list_datasets(api_key)
assert res["code"] == 0, res
original_data = res["data"][0]
payload = {"name": "default_unset"}
res = update_dataset(api_key, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(api_key)
assert res["code"] == 0, res
assert res["data"][0]["avatar"] == original_data["avatar"], res
assert res["data"][0]["description"] == original_data["description"], res
assert res["data"][0]["embedding_model"] == original_data["embedding_model"], res
assert res["data"][0]["permission"] == original_data["permission"], res
assert res["data"][0]["chunk_method"] == original_data["chunk_method"], res
assert res["data"][0]["pagerank"] == original_data["pagerank"], res
assert res["data"][0]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}, res

View File

@ -0,0 +1,51 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import bulk_upload_documents, delete_documnets
@pytest.fixture(scope="function")
def add_document_func(request, api_key, add_dataset, ragflow_tmp_dir):
dataset_id = add_dataset
document_ids = bulk_upload_documents(api_key, dataset_id, 1, ragflow_tmp_dir)
def cleanup():
delete_documnets(api_key, dataset_id, {"ids": document_ids})
request.addfinalizer(cleanup)
return dataset_id, document_ids[0]
@pytest.fixture(scope="class")
def add_documents(request, api_key, add_dataset, ragflow_tmp_dir):
dataset_id = add_dataset
document_ids = bulk_upload_documents(api_key, dataset_id, 5, ragflow_tmp_dir)
def cleanup():
delete_documnets(api_key, dataset_id, {"ids": document_ids})
request.addfinalizer(cleanup)
return dataset_id, document_ids
@pytest.fixture(scope="function")
def add_documents_func(api_key, add_dataset_func, ragflow_tmp_dir):
dataset_id = add_dataset_func
document_ids = bulk_upload_documents(api_key, dataset_id, 3, ragflow_tmp_dir)
return dataset_id, document_ids

View File

@ -0,0 +1,181 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, bulk_upload_documents, delete_documnets, list_documnets
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = delete_documnets(invalid_auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestDocumentsDeletion:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
[
(None, 0, "", 0),
({"ids": []}, 0, "", 0),
({"ids": ["invalid_id"]}, 102, "Documents not found: ['invalid_id']", 3),
(
{"ids": ["\n!?。;!?\"'"]},
102,
"""Documents not found: [\'\\n!?。;!?"\\\'\']""",
3,
),
(
"not json",
100,
"AttributeError(\"'str' object has no attribute 'get'\")",
3,
),
(lambda r: {"ids": r[:1]}, 0, "", 2),
(lambda r: {"ids": r}, 0, "", 0),
],
)
def test_basic_scenarios(
self,
api_key,
add_documents_func,
payload,
expected_code,
expected_message,
remaining,
):
dataset_id, document_ids = add_documents_func
if callable(payload):
payload = payload(document_ids)
res = delete_documnets(api_key, dataset_id, payload)
assert res["code"] == expected_code
if res["code"] != 0:
assert res["message"] == expected_message
res = list_documnets(api_key, dataset_id)
assert len(res["data"]["docs"]) == remaining
assert res["data"]["total"] == remaining
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
"invalid_dataset_id",
102,
"You don't own the dataset invalid_dataset_id. ",
),
],
)
def test_invalid_dataset_id(self, api_key, add_documents_func, dataset_id, expected_code, expected_message):
_, document_ids = add_documents_func
res = delete_documnets(api_key, dataset_id, {"ids": document_ids[:1]})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"payload",
[
lambda r: {"ids": ["invalid_id"] + r},
lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:3]},
lambda r: {"ids": r + ["invalid_id"]},
],
)
def test_delete_partial_invalid_id(self, api_key, add_documents_func, payload):
dataset_id, document_ids = add_documents_func
if callable(payload):
payload = payload(document_ids)
res = delete_documnets(api_key, dataset_id, payload)
assert res["code"] == 102
assert res["message"] == "Documents not found: ['invalid_id']"
res = list_documnets(api_key, dataset_id)
assert len(res["data"]["docs"]) == 0
assert res["data"]["total"] == 0
@pytest.mark.p2
def test_repeated_deletion(self, api_key, add_documents_func):
dataset_id, document_ids = add_documents_func
res = delete_documnets(api_key, dataset_id, {"ids": document_ids})
assert res["code"] == 0
res = delete_documnets(api_key, dataset_id, {"ids": document_ids})
assert res["code"] == 102
assert "Documents not found" in res["message"]
@pytest.mark.p2
def test_duplicate_deletion(self, api_key, add_documents_func):
dataset_id, document_ids = add_documents_func
res = delete_documnets(api_key, dataset_id, {"ids": document_ids + document_ids})
assert res["code"] == 0
assert "Duplicate document ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 3
res = list_documnets(api_key, dataset_id)
assert len(res["data"]["docs"]) == 0
assert res["data"]["total"] == 0
@pytest.mark.p3
def test_concurrent_deletion(api_key, add_dataset, tmp_path):
documnets_num = 100
dataset_id = add_dataset
document_ids = bulk_upload_documents(api_key, dataset_id, documnets_num, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
delete_documnets,
api_key,
dataset_id,
{"ids": document_ids[i : i + 1]},
)
for i in range(documnets_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_delete_1k(api_key, add_dataset, tmp_path):
documnets_num = 1_000
dataset_id = add_dataset
document_ids = bulk_upload_documents(api_key, dataset_id, documnets_num, tmp_path)
res = list_documnets(api_key, dataset_id)
assert res["data"]["total"] == documnets_num
res = delete_documnets(api_key, dataset_id, {"ids": document_ids})
assert res["code"] == 0
res = list_documnets(api_key, dataset_id)
assert res["data"]["total"] == 0

View File

@ -0,0 +1,178 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, bulk_upload_documents, download_document, upload_documnets
from libs.auth import RAGFlowHttpApiAuth
from requests import codes
from utils import compare_by_hash
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, tmp_path, expected_code, expected_message):
res = download_document(invalid_auth, "dataset_id", "document_id", tmp_path / "ragflow_tes.txt")
assert res.status_code == codes.ok
with (tmp_path / "ragflow_tes.txt").open("r") as f:
response_json = json.load(f)
assert response_json["code"] == expected_code
assert response_json["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"generate_test_files",
[
"docx",
"excel",
"ppt",
"image",
"pdf",
"txt",
"md",
"json",
"eml",
"html",
],
indirect=True,
)
def test_file_type_validation(api_key, add_dataset, generate_test_files, request):
dataset_id = add_dataset
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
res = upload_documnets(api_key, dataset_id, [fp])
document_id = res["data"][0]["id"]
res = download_document(
api_key,
dataset_id,
document_id,
fp.with_stem("ragflow_test_download"),
)
assert res.status_code == codes.ok
assert compare_by_hash(
fp,
fp.with_stem("ragflow_test_download"),
)
class TestDocumentDownload:
@pytest.mark.p3
@pytest.mark.parametrize(
"document_id, expected_code, expected_message",
[
(
"invalid_document_id",
102,
"The dataset not own the document invalid_document_id.",
),
],
)
def test_invalid_document_id(self, api_key, add_documents, tmp_path, document_id, expected_code, expected_message):
dataset_id, _ = add_documents
res = download_document(
api_key,
dataset_id,
document_id,
tmp_path / "ragflow_test_download_1.txt",
)
assert res.status_code == codes.ok
with (tmp_path / "ragflow_test_download_1.txt").open("r") as f:
response_json = json.load(f)
assert response_json["code"] == expected_code
assert response_json["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<NotFound '404: Not Found'>"),
(
"invalid_dataset_id",
102,
"You do not own the dataset invalid_dataset_id.",
),
],
)
def test_invalid_dataset_id(self, api_key, add_documents, tmp_path, dataset_id, expected_code, expected_message):
_, document_ids = add_documents
res = download_document(
api_key,
dataset_id,
document_ids[0],
tmp_path / "ragflow_test_download_1.txt",
)
assert res.status_code == codes.ok
with (tmp_path / "ragflow_test_download_1.txt").open("r") as f:
response_json = json.load(f)
assert response_json["code"] == expected_code
assert response_json["message"] == expected_message
@pytest.mark.p3
def test_same_file_repeat(self, api_key, add_documents, tmp_path, ragflow_tmp_dir):
num = 5
dataset_id, document_ids = add_documents
for i in range(num):
res = download_document(
api_key,
dataset_id,
document_ids[0],
tmp_path / f"ragflow_test_download_{i}.txt",
)
assert res.status_code == codes.ok
assert compare_by_hash(
ragflow_tmp_dir / "ragflow_test_upload_0.txt",
tmp_path / f"ragflow_test_download_{i}.txt",
)
@pytest.mark.p3
def test_concurrent_download(api_key, add_dataset, tmp_path):
document_count = 20
dataset_id = add_dataset
document_ids = bulk_upload_documents(api_key, dataset_id, document_count, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
download_document,
api_key,
dataset_id,
document_ids[i],
tmp_path / f"ragflow_test_download_{i}.txt",
)
for i in range(document_count)
]
responses = [f.result() for f in futures]
assert all(r.status_code == codes.ok for r in responses)
for i in range(document_count):
assert compare_by_hash(
tmp_path / f"ragflow_test_upload_{i}.txt",
tmp_path / f"ragflow_test_download_{i}.txt",
)

View File

@ -0,0 +1,357 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, list_documnets
from libs.auth import RAGFlowHttpApiAuth
from utils import is_sorted
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = list_documnets(invalid_auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestDocumentsList:
@pytest.mark.p1
def test_default(self, api_key, add_documents):
dataset_id, _ = add_documents
res = list_documnets(api_key, dataset_id)
assert res["code"] == 0
assert len(res["data"]["docs"]) == 5
assert res["data"]["total"] == 5
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
"invalid_dataset_id",
102,
"You don't own the dataset invalid_dataset_id. ",
),
],
)
def test_invalid_dataset_id(self, api_key, dataset_id, expected_code, expected_message):
res = list_documnets(api_key, dataset_id)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page": None, "page_size": 2}, 0, 2, ""),
({"page": 0, "page_size": 2}, 0, 2, ""),
({"page": 2, "page_size": 2}, 0, 2, ""),
({"page": 3, "page_size": 2}, 0, 1, ""),
({"page": "3", "page_size": 2}, 0, 1, ""),
pytest.param(
{"page": -1, "page_size": 2},
100,
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page": "a", "page_size": 2},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page(
self,
api_key,
add_documents,
params,
expected_code,
expected_page_size,
expected_message,
):
dataset_id, _ = add_documents
res = list_documnets(api_key, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_page_size
assert res["data"]["total"] == 5
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 5, ""),
({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param(
{"page_size": -1},
100,
0,
"1064",
marks=pytest.mark.skip(reason="issues/5851"),
),
pytest.param(
{"page_size": "a"},
100,
0,
"""ValueError("invalid literal for int() with base 10: \'a\'")""",
marks=pytest.mark.skip(reason="issues/5851"),
),
],
)
def test_page_size(
self,
api_key,
add_documents,
params,
expected_code,
expected_page_size,
expected_message,
):
dataset_id, _ = add_documents
res = list_documnets(api_key, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"orderby": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", True)), ""),
pytest.param({"orderby": "name", "desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "name", False)), "", marks=pytest.mark.skip(reason="issues/5851")),
pytest.param({"orderby": "unknown"}, 102, 0, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/5851")),
],
)
def test_orderby(
self,
api_key,
add_documents,
params,
expected_code,
assertions,
expected_message,
):
dataset_id, _ = add_documents
res = list_documnets(api_key, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"desc": None}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
({"desc": "true"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
({"desc": "True"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
({"desc": True}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", True)), ""),
pytest.param({"desc": "false"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), "", marks=pytest.mark.skip(reason="issues/5851")),
({"desc": "False"}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""),
({"desc": False}, 0, lambda r: (is_sorted(r["data"]["docs"], "create_time", False)), ""),
({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"]["docs"], "update_time", False)), ""),
pytest.param({"desc": "unknown"}, 102, 0, "desc should be true or false", marks=pytest.mark.skip(reason="issues/5851")),
],
)
def test_desc(
self,
api_key,
add_documents,
params,
expected_code,
assertions,
expected_message,
):
dataset_id, _ = add_documents
res = list_documnets(api_key, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"params, expected_num",
[
({"keywords": None}, 5),
({"keywords": ""}, 5),
({"keywords": "0"}, 1),
({"keywords": "ragflow_test_upload"}, 5),
({"keywords": "unknown"}, 0),
],
)
def test_keywords(self, api_key, add_documents, params, expected_num):
dataset_id, _ = add_documents
res = list_documnets(api_key, dataset_id, params=params)
assert res["code"] == 0
assert len(res["data"]["docs"]) == expected_num
assert res["data"]["total"] == expected_num
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_num, expected_message",
[
({"name": None}, 0, 5, ""),
({"name": ""}, 0, 5, ""),
({"name": "ragflow_test_upload_0.txt"}, 0, 1, ""),
(
{"name": "unknown.txt"},
102,
0,
"You don't own the document unknown.txt.",
),
],
)
def test_name(
self,
api_key,
add_documents,
params,
expected_code,
expected_num,
expected_message,
):
dataset_id, _ = add_documents
res = list_documnets(api_key, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] in [None, ""]:
assert len(res["data"]["docs"]) == expected_num
else:
assert res["data"]["docs"][0]["name"] == params["name"]
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"document_id, expected_code, expected_num, expected_message",
[
(None, 0, 5, ""),
("", 0, 5, ""),
(lambda r: r[0], 0, 1, ""),
("unknown.txt", 102, 0, "You don't own the document unknown.txt."),
],
)
def test_id(
self,
api_key,
add_documents,
document_id,
expected_code,
expected_num,
expected_message,
):
dataset_id, document_ids = add_documents
if callable(document_id):
params = {"id": document_id(document_ids)}
else:
params = {"id": document_id}
res = list_documnets(api_key, dataset_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["id"] in [None, ""]:
assert len(res["data"]["docs"]) == expected_num
else:
assert res["data"]["docs"][0]["id"] == params["id"]
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"document_id, name, expected_code, expected_num, expected_message",
[
(lambda r: r[0], "ragflow_test_upload_0.txt", 0, 1, ""),
(lambda r: r[0], "ragflow_test_upload_1.txt", 0, 0, ""),
(lambda r: r[0], "unknown", 102, 0, "You don't own the document unknown."),
(
"id",
"ragflow_test_upload_0.txt",
102,
0,
"You don't own the document id.",
),
],
)
def test_name_and_id(
self,
api_key,
add_documents,
document_id,
name,
expected_code,
expected_num,
expected_message,
):
dataset_id, document_ids = add_documents
if callable(document_id):
params = {"id": document_id(document_ids), "name": name}
else:
params = {"id": document_id, "name": name}
res = list_documnets(api_key, dataset_id, params=params)
if expected_code == 0:
assert len(res["data"]["docs"]) == expected_num
else:
assert res["message"] == expected_message
@pytest.mark.p3
def test_concurrent_list(self, api_key, add_documents):
dataset_id, _ = add_documents
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_documnets, api_key, dataset_id) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_invalid_params(self, api_key, add_documents):
dataset_id, _ = add_documents
params = {"a": "b"}
res = list_documnets(api_key, dataset_id, params=params)
assert res["code"] == 0
assert len(res["data"]["docs"]) == 5

View File

@ -0,0 +1,217 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets
from libs.auth import RAGFlowHttpApiAuth
from utils import wait_for
@wait_for(30, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_ids=None):
res = list_documnets(_auth, _dataset_id)
target_docs = res["data"]["docs"]
if _document_ids is None:
for doc in target_docs:
if doc["run"] != "DONE":
return False
return True
target_ids = set(_document_ids)
for doc in target_docs:
if doc["id"] in target_ids:
if doc.get("run") != "DONE":
return False
return True
def validate_document_details(auth, dataset_id, document_ids):
for document_id in document_ids:
res = list_documnets(auth, dataset_id, params={"id": document_id})
doc = res["data"]["docs"][0]
assert doc["run"] == "DONE"
assert len(doc["process_begin_at"]) > 0
assert doc["process_duation"] > 0
assert doc["progress"] > 0
assert "Task done" in doc["progress_msg"]
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = parse_documnets(invalid_auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestDocumentsParse:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
pytest.param(None, 102, """AttributeError("\'NoneType\' object has no attribute \'get\'")""", marks=pytest.mark.skip),
pytest.param({"document_ids": []}, 102, "`document_ids` is required", marks=pytest.mark.p1),
pytest.param({"document_ids": ["invalid_id"]}, 102, "Documents not found: ['invalid_id']", marks=pytest.mark.p3),
pytest.param({"document_ids": ["\n!?。;!?\"'"]}, 102, """Documents not found: [\'\\n!?。;!?"\\\'\']""", marks=pytest.mark.p3),
pytest.param("not json", 102, "AttributeError(\"'str' object has no attribute 'get'\")", marks=pytest.mark.skip),
pytest.param(lambda r: {"document_ids": r[:1]}, 0, "", marks=pytest.mark.p1),
pytest.param(lambda r: {"document_ids": r}, 0, "", marks=pytest.mark.p1),
],
)
def test_basic_scenarios(self, api_key, add_documents_func, payload, expected_code, expected_message):
dataset_id, document_ids = add_documents_func
if callable(payload):
payload = payload(document_ids)
res = parse_documnets(api_key, dataset_id, payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
if expected_code == 0:
condition(api_key, dataset_id, payload["document_ids"])
validate_document_details(api_key, dataset_id, payload["document_ids"])
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
"invalid_dataset_id",
102,
"You don't own the dataset invalid_dataset_id.",
),
],
)
def test_invalid_dataset_id(
self,
api_key,
add_documents_func,
dataset_id,
expected_code,
expected_message,
):
_, document_ids = add_documents_func
res = parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.parametrize(
"payload",
[
pytest.param(lambda r: {"document_ids": ["invalid_id"] + r}, marks=pytest.mark.p3),
pytest.param(lambda r: {"document_ids": r[:1] + ["invalid_id"] + r[1:3]}, marks=pytest.mark.p1),
pytest.param(lambda r: {"document_ids": r + ["invalid_id"]}, marks=pytest.mark.p3),
],
)
def test_parse_partial_invalid_document_id(self, api_key, add_documents_func, payload):
dataset_id, document_ids = add_documents_func
if callable(payload):
payload = payload(document_ids)
res = parse_documnets(api_key, dataset_id, payload)
assert res["code"] == 102
assert res["message"] == "Documents not found: ['invalid_id']"
condition(api_key, dataset_id)
validate_document_details(api_key, dataset_id, document_ids)
@pytest.mark.p3
def test_repeated_parse(self, api_key, add_documents_func):
dataset_id, document_ids = add_documents_func
res = parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
condition(api_key, dataset_id)
res = parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
@pytest.mark.p3
def test_duplicate_parse(self, api_key, add_documents_func):
dataset_id, document_ids = add_documents_func
res = parse_documnets(api_key, dataset_id, {"document_ids": document_ids + document_ids})
assert res["code"] == 0
assert "Duplicate document ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 3
condition(api_key, dataset_id)
validate_document_details(api_key, dataset_id, document_ids)
@pytest.mark.p3
def test_parse_100_files(api_key, add_dataset_func, tmp_path):
@wait_for(100, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_num):
res = list_documnets(_auth, _dataset_id, {"page_size": _document_num})
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
document_num = 100
dataset_id = add_dataset_func
document_ids = bulk_upload_documents(api_key, dataset_id, document_num, tmp_path)
res = parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
condition(api_key, dataset_id, document_num)
validate_document_details(api_key, dataset_id, document_ids)
@pytest.mark.p3
def test_concurrent_parse(api_key, add_dataset_func, tmp_path):
@wait_for(120, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_num):
res = list_documnets(_auth, _dataset_id, {"page_size": _document_num})
for doc in res["data"]["docs"]:
if doc["run"] != "DONE":
return False
return True
document_num = 100
dataset_id = add_dataset_func
document_ids = bulk_upload_documents(api_key, dataset_id, document_num, tmp_path)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
parse_documnets,
api_key,
dataset_id,
{"document_ids": document_ids[i : i + 1]},
)
for i in range(document_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
condition(api_key, dataset_id, document_num)
validate_document_details(api_key, dataset_id, document_ids)

View File

@ -0,0 +1,202 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
from time import sleep
import pytest
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets
from libs.auth import RAGFlowHttpApiAuth
from utils import wait_for
def validate_document_parse_done(auth, dataset_id, document_ids):
for document_id in document_ids:
res = list_documnets(auth, dataset_id, params={"id": document_id})
doc = res["data"]["docs"][0]
assert doc["run"] == "DONE"
assert len(doc["process_begin_at"]) > 0
assert doc["process_duation"] > 0
assert doc["progress"] > 0
assert "Task done" in doc["progress_msg"]
def validate_document_parse_cancel(auth, dataset_id, document_ids):
for document_id in document_ids:
res = list_documnets(auth, dataset_id, params={"id": document_id})
doc = res["data"]["docs"][0]
assert doc["run"] == "CANCEL"
assert len(doc["process_begin_at"]) > 0
assert doc["progress"] == 0.0
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = stop_parse_documnets(invalid_auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.skip
class TestDocumentsParseStop:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
pytest.param(None, 102, """AttributeError("\'NoneType\' object has no attribute \'get\'")""", marks=pytest.mark.skip),
pytest.param({"document_ids": []}, 102, "`document_ids` is required", marks=pytest.mark.p1),
pytest.param({"document_ids": ["invalid_id"]}, 102, "You don't own the document invalid_id.", marks=pytest.mark.p3),
pytest.param({"document_ids": ["\n!?。;!?\"'"]}, 102, """You don\'t own the document \n!?。;!?"\'.""", marks=pytest.mark.p3),
pytest.param("not json", 102, "AttributeError(\"'str' object has no attribute 'get'\")", marks=pytest.mark.skip),
pytest.param(lambda r: {"document_ids": r[:1]}, 0, "", marks=pytest.mark.p1),
pytest.param(lambda r: {"document_ids": r}, 0, "", marks=pytest.mark.p1),
],
)
def test_basic_scenarios(self, api_key, add_documents_func, payload, expected_code, expected_message):
@wait_for(10, 1, "Document parsing timeout")
def condition(_auth, _dataset_id, _document_ids):
for _document_id in _document_ids:
res = list_documnets(_auth, _dataset_id, {"id": _document_id})
if res["data"]["docs"][0]["run"] != "DONE":
return False
return True
dataset_id, document_ids = add_documents_func
parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
if callable(payload):
payload = payload(document_ids)
res = stop_parse_documnets(api_key, dataset_id, payload)
assert res["code"] == expected_code
if expected_code != 0:
assert res["message"] == expected_message
else:
completed_document_ids = list(set(document_ids) - set(payload["document_ids"]))
condition(api_key, dataset_id, completed_document_ids)
validate_document_parse_cancel(api_key, dataset_id, payload["document_ids"])
validate_document_parse_done(api_key, dataset_id, completed_document_ids)
@pytest.mark.p3
@pytest.mark.parametrize(
"invalid_dataset_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
"invalid_dataset_id",
102,
"You don't own the dataset invalid_dataset_id.",
),
],
)
def test_invalid_dataset_id(
self,
api_key,
add_documents_func,
invalid_dataset_id,
expected_code,
expected_message,
):
dataset_id, document_ids = add_documents_func
parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnets(api_key, invalid_dataset_id, {"document_ids": document_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.skip
@pytest.mark.parametrize(
"payload",
[
lambda r: {"document_ids": ["invalid_id"] + r},
lambda r: {"document_ids": r[:1] + ["invalid_id"] + r[1:3]},
lambda r: {"document_ids": r + ["invalid_id"]},
],
)
def test_stop_parse_partial_invalid_document_id(self, api_key, add_documents_func, payload):
dataset_id, document_ids = add_documents_func
parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
if callable(payload):
payload = payload(document_ids)
res = stop_parse_documnets(api_key, dataset_id, payload)
assert res["code"] == 102
assert res["message"] == "You don't own the document invalid_id."
validate_document_parse_cancel(api_key, dataset_id, document_ids)
@pytest.mark.p3
def test_repeated_stop_parse(self, api_key, add_documents_func):
dataset_id, document_ids = add_documents_func
parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
res = stop_parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
assert res["code"] == 102
assert res["message"] == "Can't stop parsing document with progress at 0 or 1"
@pytest.mark.p3
def test_duplicate_stop_parse(self, api_key, add_documents_func):
dataset_id, document_ids = add_documents_func
parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
res = stop_parse_documnets(api_key, dataset_id, {"document_ids": document_ids + document_ids})
assert res["code"] == 0
assert res["data"]["success_count"] == 3
assert f"Duplicate document ids: {document_ids[0]}" in res["data"]["errors"]
@pytest.mark.skip(reason="unstable")
def test_stop_parse_100_files(api_key, add_dataset_func, tmp_path):
document_num = 100
dataset_id = add_dataset_func
document_ids = bulk_upload_documents(api_key, dataset_id, document_num, tmp_path)
parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
sleep(1)
res = stop_parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
validate_document_parse_cancel(api_key, dataset_id, document_ids)
@pytest.mark.skip(reason="unstable")
def test_concurrent_parse(api_key, add_dataset_func, tmp_path):
document_num = 50
dataset_id = add_dataset_func
document_ids = bulk_upload_documents(api_key, dataset_id, document_num, tmp_path)
parse_documnets(api_key, dataset_id, {"document_ids": document_ids})
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
stop_parse_documnets,
api_key,
dataset_id,
{"document_ids": document_ids[i : i + 1]},
)
for i in range(document_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
validate_document_parse_cancel(api_key, dataset_id, document_ids)

View File

@ -0,0 +1,547 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN, list_documnets, update_documnet
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = update_documnet(invalid_auth, "dataset_id", "document_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestDocumentsUpdated:
@pytest.mark.p1
@pytest.mark.parametrize(
"name, expected_code, expected_message",
[
("new_name.txt", 0, ""),
(
f"{'a' * (DOCUMENT_NAME_LIMIT - 3)}.txt",
101,
"The name should be less than 128 bytes.",
),
(
0,
100,
"""AttributeError("\'int\' object has no attribute \'encode\'")""",
),
(
None,
100,
"""AttributeError("\'NoneType\' object has no attribute \'encode\'")""",
),
(
"",
101,
"The extension of file can't be changed",
),
(
"ragflow_test_upload_0",
101,
"The extension of file can't be changed",
),
(
"ragflow_test_upload_1.txt",
102,
"Duplicated document name in the same dataset.",
),
(
"RAGFLOW_TEST_UPLOAD_1.TXT",
0,
"",
),
],
)
def test_name(self, api_key, add_documents, name, expected_code, expected_message):
dataset_id, document_ids = add_documents
res = update_documnet(api_key, dataset_id, document_ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnets(api_key, dataset_id, {"id": document_ids[0]})
assert res["data"]["docs"][0]["name"] == name
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"document_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
"invalid_document_id",
102,
"The dataset doesn't own the document.",
),
],
)
def test_invalid_document_id(self, api_key, add_documents, document_id, expected_code, expected_message):
dataset_id, _ = add_documents
res = update_documnet(api_key, dataset_id, document_id, {"name": "new_name.txt"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_message",
[
("", 100, "<NotFound '404: Not Found'>"),
(
"invalid_dataset_id",
102,
"You don't own the dataset.",
),
],
)
def test_invalid_dataset_id(self, api_key, add_documents, dataset_id, expected_code, expected_message):
_, document_ids = add_documents
res = update_documnet(api_key, dataset_id, document_ids[0], {"name": "new_name.txt"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"meta_fields, expected_code, expected_message",
[({"test": "test"}, 0, ""), ("test", 102, "meta_fields must be a dictionary")],
)
def test_meta_fields(self, api_key, add_documents, meta_fields, expected_code, expected_message):
dataset_id, document_ids = add_documents
res = update_documnet(api_key, dataset_id, document_ids[0], {"meta_fields": meta_fields})
if expected_code == 0:
res = list_documnets(api_key, dataset_id, {"id": document_ids[0]})
assert res["data"]["docs"][0]["meta_fields"] == meta_fields
else:
assert res["message"] == expected_message
@pytest.mark.p2
@pytest.mark.parametrize(
"chunk_method, expected_code, expected_message",
[
("naive", 0, ""),
("manual", 0, ""),
("qa", 0, ""),
("table", 0, ""),
("paper", 0, ""),
("book", 0, ""),
("laws", 0, ""),
("presentation", 0, ""),
("picture", 0, ""),
("one", 0, ""),
("knowledge_graph", 0, ""),
("email", 0, ""),
("tag", 0, ""),
("", 102, "`chunk_method` doesn't exist"),
(
"other_chunk_method",
102,
"`chunk_method` other_chunk_method doesn't exist",
),
],
)
def test_chunk_method(self, api_key, add_documents, chunk_method, expected_code, expected_message):
dataset_id, document_ids = add_documents
res = update_documnet(api_key, dataset_id, document_ids[0], {"chunk_method": chunk_method})
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnets(api_key, dataset_id, {"id": document_ids[0]})
if chunk_method != "":
assert res["data"]["docs"][0]["chunk_method"] == chunk_method
else:
assert res["data"]["docs"][0]["chunk_method"] == "naive"
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"chunk_count": 1}, 102, "Can't change `chunk_count`."),
pytest.param(
{"create_date": "Fri, 14 Mar 2025 16:53:42 GMT"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"create_time": 1},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"created_by": "ragflow_test"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"dataset_id": "ragflow_test"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"id": "ragflow_test"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"location": "ragflow_test.txt"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"process_begin_at": 1},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"process_duation": 1.0},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param({"progress": 1.0}, 102, "Can't change `progress`."),
pytest.param(
{"progress_msg": "ragflow_test"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"run": "ragflow_test"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"size": 1},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"source_type": "ragflow_test"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"thumbnail": "ragflow_test"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
({"token_count": 1}, 102, "Can't change `token_count`."),
pytest.param(
{"type": "ragflow_test"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"update_date": "Fri, 14 Mar 2025 16:33:17 GMT"},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
pytest.param(
{"update_time": 1},
102,
"The input parameters are invalid.",
marks=pytest.mark.skip(reason="issues/6104"),
),
],
)
def test_invalid_field(
self,
api_key,
add_documents,
payload,
expected_code,
expected_message,
):
dataset_id, document_ids = add_documents
res = update_documnet(api_key, dataset_id, document_ids[0], payload)
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestUpdateDocumentParserConfig:
@pytest.mark.p2
@pytest.mark.parametrize(
"chunk_method, parser_config, expected_code, expected_message",
[
("naive", {}, 0, ""),
(
"naive",
{
"chunk_token_num": 128,
"layout_recognize": "DeepDOC",
"html4excel": False,
"delimiter": r"\n",
"task_page_size": 12,
"raptor": {"use_raptor": False},
},
0,
"",
),
pytest.param(
"naive",
{"chunk_token_num": -1},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 0},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 100000000},
100,
"AssertionError('chunk_token_num should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": 3.14},
102,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"chunk_token_num": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
(
"naive",
{"layout_recognize": "DeepDOC"},
0,
"",
),
(
"naive",
{"layout_recognize": "Naive"},
0,
"",
),
("naive", {"html4excel": True}, 0, ""),
("naive", {"html4excel": False}, 0, ""),
pytest.param(
"naive",
{"html4excel": 1},
100,
"AssertionError('html4excel should be True or False')",
marks=pytest.mark.skip(reason="issues/6098"),
),
("naive", {"delimiter": ""}, 0, ""),
("naive", {"delimiter": "`##`"}, 0, ""),
pytest.param(
"naive",
{"delimiter": 1},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": -1},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 0},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 100000000},
100,
"AssertionError('task_page_size should be in range from 1 to 100000000')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"task_page_size": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
("naive", {"raptor": {"use_raptor": True}}, 0, ""),
("naive", {"raptor": {"use_raptor": False}}, 0, ""),
pytest.param(
"naive",
{"invalid_key": "invalid_value"},
100,
"""AssertionError("Abnormal \'parser_config\'. Invalid key: invalid_key")""",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": -1},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": 32},
100,
"AssertionError('auto_keywords should be in range from 0 to 32')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_keywords": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": -1},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 10},
100,
"AssertionError('auto_questions should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"auto_questions": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": -1},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 10},
100,
"AssertionError('topn_tags should be in range from 0 to 10')",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": 3.14},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
pytest.param(
"naive",
{"topn_tags": "1024"},
100,
"",
marks=pytest.mark.skip(reason="issues/6098"),
),
],
)
def test_parser_config(
self,
api_key,
add_documents,
chunk_method,
parser_config,
expected_code,
expected_message,
):
dataset_id, document_ids = add_documents
res = update_documnet(
api_key,
dataset_id,
document_ids[0],
{"chunk_method": chunk_method, "parser_config": parser_config},
)
assert res["code"] == expected_code
if expected_code == 0:
res = list_documnets(api_key, dataset_id, {"id": document_ids[0]})
if parser_config != {}:
for k, v in parser_config.items():
assert res["data"]["docs"][0]["parser_config"][k] == v
else:
assert res["data"]["docs"][0]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}
if expected_code != 0 or expected_message:
assert res["message"] == expected_message

View File

@ -0,0 +1,218 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import string
from concurrent.futures import ThreadPoolExecutor
import pytest
import requests
from common import DOCUMENT_NAME_LIMIT, FILE_API_URL, HOST_ADDRESS, INVALID_API_TOKEN, list_datasets, upload_documnets
from libs.auth import RAGFlowHttpApiAuth
from requests_toolbelt import MultipartEncoder
from utils.file_utils import create_txt_file
@pytest.mark.p1
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = upload_documnets(invalid_auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestDocumentsUpload:
@pytest.mark.p1
def test_valid_single_upload(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
res = upload_documnets(api_key, dataset_id, [fp])
assert res["code"] == 0
assert res["data"][0]["dataset_id"] == dataset_id
assert res["data"][0]["name"] == fp.name
@pytest.mark.p1
@pytest.mark.parametrize(
"generate_test_files",
[
"docx",
"excel",
"ppt",
"image",
"pdf",
"txt",
"md",
"json",
"eml",
"html",
],
indirect=True,
)
def test_file_type_validation(self, api_key, add_dataset_func, generate_test_files, request):
dataset_id = add_dataset_func
fp = generate_test_files[request.node.callspec.params["generate_test_files"]]
res = upload_documnets(api_key, dataset_id, [fp])
assert res["code"] == 0
assert res["data"][0]["dataset_id"] == dataset_id
assert res["data"][0]["name"] == fp.name
@pytest.mark.p2
@pytest.mark.parametrize(
"file_type",
["exe", "unknown"],
)
def test_unsupported_file_type(self, api_key, add_dataset_func, tmp_path, file_type):
dataset_id = add_dataset_func
fp = tmp_path / f"ragflow_test.{file_type}"
fp.touch()
res = upload_documnets(api_key, dataset_id, [fp])
assert res["code"] == 500
assert res["message"] == f"ragflow_test.{file_type}: This type of file has not been supported yet!"
@pytest.mark.p2
def test_missing_file(self, api_key, add_dataset_func):
dataset_id = add_dataset_func
res = upload_documnets(api_key, dataset_id)
assert res["code"] == 101
assert res["message"] == "No file part!"
@pytest.mark.p3
def test_empty_file(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = tmp_path / "empty.txt"
fp.touch()
res = upload_documnets(api_key, dataset_id, [fp])
assert res["code"] == 0
assert res["data"][0]["size"] == 0
@pytest.mark.p3
def test_filename_empty(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
url = f"{HOST_ADDRESS}{FILE_API_URL}".format(dataset_id=dataset_id)
fields = (("file", ("", fp.open("rb"))),)
m = MultipartEncoder(fields=fields)
res = requests.post(
url=url,
headers={"Content-Type": m.content_type},
auth=api_key,
data=m,
)
assert res.json()["code"] == 101
assert res.json()["message"] == "No file selected!"
@pytest.mark.p2
def test_filename_exceeds_max_length(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
# filename_length = 129
fp = create_txt_file(tmp_path / f"{'a' * (DOCUMENT_NAME_LIMIT - 3)}.txt")
res = upload_documnets(api_key, dataset_id, [fp])
assert res["code"] == 101
assert res["message"] == "File name should be less than 128 bytes."
@pytest.mark.p2
def test_invalid_dataset_id(self, api_key, tmp_path):
fp = create_txt_file(tmp_path / "ragflow_test.txt")
res = upload_documnets(api_key, "invalid_dataset_id", [fp])
assert res["code"] == 100
assert res["message"] == """LookupError("Can\'t find the dataset with ID invalid_dataset_id!")"""
@pytest.mark.p2
def test_duplicate_files(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
res = upload_documnets(api_key, dataset_id, [fp, fp])
assert res["code"] == 0
assert len(res["data"]) == 2
for i in range(len(res["data"])):
assert res["data"][i]["dataset_id"] == dataset_id
expected_name = fp.name
if i != 0:
expected_name = f"{fp.stem}({i}){fp.suffix}"
assert res["data"][i]["name"] == expected_name
@pytest.mark.p2
def test_same_file_repeat(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fp = create_txt_file(tmp_path / "ragflow_test.txt")
for i in range(10):
res = upload_documnets(api_key, dataset_id, [fp])
assert res["code"] == 0
assert len(res["data"]) == 1
assert res["data"][0]["dataset_id"] == dataset_id
expected_name = fp.name
if i != 0:
expected_name = f"{fp.stem}({i}){fp.suffix}"
assert res["data"][0]["name"] == expected_name
@pytest.mark.p3
def test_filename_special_characters(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
illegal_chars = '<>:"/\\|?*'
translation_table = str.maketrans({char: "_" for char in illegal_chars})
safe_filename = string.punctuation.translate(translation_table)
fp = tmp_path / f"{safe_filename}.txt"
fp.write_text("Sample text content")
res = upload_documnets(api_key, dataset_id, [fp])
assert res["code"] == 0
assert len(res["data"]) == 1
assert res["data"][0]["dataset_id"] == dataset_id
assert res["data"][0]["name"] == fp.name
@pytest.mark.p1
def test_multiple_files(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
expected_document_count = 20
fps = []
for i in range(expected_document_count):
fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt")
fps.append(fp)
res = upload_documnets(api_key, dataset_id, fps)
assert res["code"] == 0
res = list_datasets(api_key, {"id": dataset_id})
assert res["data"][0]["document_count"] == expected_document_count
@pytest.mark.p3
def test_concurrent_upload(self, api_key, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
expected_document_count = 20
fps = []
for i in range(expected_document_count):
fp = create_txt_file(tmp_path / f"ragflow_test_{i}.txt")
fps.append(fp)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(upload_documnets, api_key, dataset_id, fps[i : i + 1]) for i in range(expected_document_count)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
res = list_datasets(api_key, {"id": dataset_id})
assert res["data"][0]["document_count"] == expected_document_count

View File

@ -0,0 +1,53 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import create_session_with_chat_assistant, delete_session_with_chat_assistants
@pytest.fixture(scope="class")
def add_sessions_with_chat_assistant(request, api_key, add_chat_assistants):
_, _, chat_assistant_ids = add_chat_assistants
def cleanup():
for chat_assistant_id in chat_assistant_ids:
delete_session_with_chat_assistants(api_key, chat_assistant_id)
request.addfinalizer(cleanup)
session_ids = []
for i in range(5):
res = create_session_with_chat_assistant(api_key, chat_assistant_ids[0], {"name": f"session_with_chat_assistant_{i}"})
session_ids.append(res["data"]["id"])
return chat_assistant_ids[0], session_ids
@pytest.fixture(scope="function")
def add_sessions_with_chat_assistant_func(request, api_key, add_chat_assistants):
_, _, chat_assistant_ids = add_chat_assistants
def cleanup():
for chat_assistant_id in chat_assistant_ids:
delete_session_with_chat_assistants(api_key, chat_assistant_id)
request.addfinalizer(cleanup)
session_ids = []
for i in range(5):
res = create_session_with_chat_assistant(api_key, chat_assistant_ids[0], {"name": f"session_with_chat_assistant_{i}"})
session_ids.append(res["data"]["id"])
return chat_assistant_ids[0], session_ids

View File

@ -0,0 +1,117 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, SESSION_WITH_CHAT_NAME_LIMIT, create_session_with_chat_assistant, delete_chat_assistants, list_session_with_chat_assistants
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = create_session_with_chat_assistant(invalid_auth, "chat_assistant_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_session_with_chat_assistants")
class TestSessionWithChatAssistantCreate:
@pytest.mark.p1
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
({"name": "valid_name"}, 0, ""),
pytest.param({"name": "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")),
({"name": ""}, 102, "`name` can not be empty."),
({"name": "duplicated_name"}, 0, ""),
({"name": "case insensitive"}, 0, ""),
],
)
def test_name(self, api_key, add_chat_assistants, payload, expected_code, expected_message):
_, _, chat_assistant_ids = add_chat_assistants
if payload["name"] == "duplicated_name":
create_session_with_chat_assistant(api_key, chat_assistant_ids[0], payload)
elif payload["name"] == "case insensitive":
create_session_with_chat_assistant(api_key, chat_assistant_ids[0], {"name": payload["name"].upper()})
res = create_session_with_chat_assistant(api_key, chat_assistant_ids[0], payload)
assert res["code"] == expected_code, res
if expected_code == 0:
assert res["data"]["name"] == payload["name"]
assert res["data"]["chat_id"] == chat_assistant_ids[0]
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"chat_assistant_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
("invalid_chat_assistant_id", 102, "You do not own the assistant."),
],
)
def test_invalid_chat_assistant_id(self, api_key, chat_assistant_id, expected_code, expected_message):
res = create_session_with_chat_assistant(api_key, chat_assistant_id, {"name": "valid_name"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
def test_concurrent_create_session(self, api_key, add_chat_assistants):
chunk_num = 1000
_, _, chat_assistant_ids = add_chat_assistants
res = list_session_with_chat_assistants(api_key, chat_assistant_ids[0])
if res["code"] != 0:
assert False, res
chunks_count = len(res["data"])
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
create_session_with_chat_assistant,
api_key,
chat_assistant_ids[0],
{"name": f"session with chat assistant test {i}"},
)
for i in range(chunk_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
res = list_session_with_chat_assistants(api_key, chat_assistant_ids[0], {"page_size": chunk_num})
if res["code"] != 0:
assert False, res
assert len(res["data"]) == chunks_count + chunk_num
@pytest.mark.p3
def test_add_session_to_deleted_chat_assistant(self, api_key, add_chat_assistants):
_, _, chat_assistant_ids = add_chat_assistants
res = delete_chat_assistants(api_key, {"ids": [chat_assistant_ids[0]]})
assert res["code"] == 0
res = create_session_with_chat_assistant(api_key, chat_assistant_ids[0], {"name": "valid_name"})
assert res["code"] == 102
assert res["message"] == "You do not own the assistant."

View File

@ -0,0 +1,170 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, batch_add_sessions_with_chat_assistant, delete_session_with_chat_assistants, list_session_with_chat_assistants
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = delete_session_with_chat_assistants(invalid_auth, "chat_assistant_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestSessionWithChatAssistantDelete:
@pytest.mark.p3
@pytest.mark.parametrize(
"chat_assistant_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
(
"invalid_chat_assistant_id",
102,
"You don't own the chat",
),
],
)
def test_invalid_chat_assistant_id(self, api_key, add_sessions_with_chat_assistant_func, chat_assistant_id, expected_code, expected_message):
_, session_ids = add_sessions_with_chat_assistant_func
res = delete_session_with_chat_assistants(api_key, chat_assistant_id, {"ids": session_ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.parametrize(
"payload",
[
pytest.param(lambda r: {"ids": ["invalid_id"] + r}, marks=pytest.mark.p3),
pytest.param(lambda r: {"ids": r[:1] + ["invalid_id"] + r[1:5]}, marks=pytest.mark.p1),
pytest.param(lambda r: {"ids": r + ["invalid_id"]}, marks=pytest.mark.p3),
],
)
def test_delete_partial_invalid_id(self, api_key, add_sessions_with_chat_assistant_func, payload):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
if callable(payload):
payload = payload(session_ids)
res = delete_session_with_chat_assistants(api_key, chat_assistant_id, payload)
assert res["code"] == 0
assert res["data"]["errors"][0] == "The chat doesn't own the session invalid_id"
res = list_session_with_chat_assistants(api_key, chat_assistant_id)
if res["code"] != 0:
assert False, res
assert len(res["data"]) == 0
@pytest.mark.p3
def test_repeated_deletion(self, api_key, add_sessions_with_chat_assistant_func):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
payload = {"ids": session_ids}
res = delete_session_with_chat_assistants(api_key, chat_assistant_id, payload)
assert res["code"] == 0
res = delete_session_with_chat_assistants(api_key, chat_assistant_id, payload)
assert res["code"] == 102
assert "The chat doesn't own the session" in res["message"]
@pytest.mark.p3
def test_duplicate_deletion(self, api_key, add_sessions_with_chat_assistant_func):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
res = delete_session_with_chat_assistants(api_key, chat_assistant_id, {"ids": session_ids * 2})
assert res["code"] == 0
assert "Duplicate session ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 5
res = list_session_with_chat_assistants(api_key, chat_assistant_id)
if res["code"] != 0:
assert False, res
assert len(res["data"]) == 0
@pytest.mark.p3
def test_concurrent_deletion(self, api_key, add_chat_assistants):
sessions_num = 100
_, _, chat_assistant_ids = add_chat_assistants
session_ids = batch_add_sessions_with_chat_assistant(api_key, chat_assistant_ids[0], sessions_num)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
delete_session_with_chat_assistants,
api_key,
chat_assistant_ids[0],
{"ids": session_ids[i : i + 1]},
)
for i in range(sessions_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_delete_1k(self, api_key, add_chat_assistants):
sessions_num = 1_000
_, _, chat_assistant_ids = add_chat_assistants
session_ids = batch_add_sessions_with_chat_assistant(api_key, chat_assistant_ids[0], sessions_num)
res = delete_session_with_chat_assistants(api_key, chat_assistant_ids[0], {"ids": session_ids})
assert res["code"] == 0
res = list_session_with_chat_assistants(api_key, chat_assistant_ids[0])
if res["code"] != 0:
assert False, res
assert len(res["data"]) == 0
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
[
pytest.param(None, 0, """TypeError("argument of type \'NoneType\' is not iterable")""", 0, marks=pytest.mark.skip),
pytest.param({"ids": ["invalid_id"]}, 102, "The chat doesn't own the session invalid_id", 5, marks=pytest.mark.p3),
pytest.param("not json", 100, """AttributeError("\'str\' object has no attribute \'get\'")""", 5, marks=pytest.mark.skip),
pytest.param(lambda r: {"ids": r[:1]}, 0, "", 4, marks=pytest.mark.p3),
pytest.param(lambda r: {"ids": r}, 0, "", 0, marks=pytest.mark.p1),
pytest.param({"ids": []}, 0, "", 0, marks=pytest.mark.p3),
],
)
def test_basic_scenarios(
self,
api_key,
add_sessions_with_chat_assistant_func,
payload,
expected_code,
expected_message,
remaining,
):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
if callable(payload):
payload = payload(session_ids)
res = delete_session_with_chat_assistants(api_key, chat_assistant_id, payload)
assert res["code"] == expected_code
if res["code"] != 0:
assert res["message"] == expected_message
res = list_session_with_chat_assistants(api_key, chat_assistant_id)
if res["code"] != 0:
assert False, res
assert len(res["data"]) == remaining

View File

@ -0,0 +1,247 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, delete_chat_assistants, list_session_with_chat_assistants
from libs.auth import RAGFlowHttpApiAuth
from utils import is_sorted
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = list_session_with_chat_assistants(invalid_auth, "chat_assistant_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestSessionsWithChatAssistantList:
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page": None, "page_size": 2}, 0, 2, ""),
pytest.param({"page": 0, "page_size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
({"page": 2, "page_size": 2}, 0, 2, ""),
({"page": 3, "page_size": 2}, 0, 1, ""),
({"page": "3", "page_size": 2}, 0, 1, ""),
pytest.param({"page": -1, "page_size": 2}, 100, 0, "ValueError('Search does not support negative slicing.')", marks=pytest.mark.skip),
pytest.param({"page": "a", "page_size": 2}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
],
)
def test_page(self, api_key, add_sessions_with_chat_assistant, params, expected_code, expected_page_size, expected_message):
chat_assistant_id, _ = add_sessions_with_chat_assistant
res = list_session_with_chat_assistants(api_key, chat_assistant_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 5, ""),
({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param({"page_size": -1}, 0, 5, "", marks=pytest.mark.skip),
pytest.param({"page_size": "a"}, 100, 0, """ValueError("invalid literal for int() with base 10: \'a\'")""", marks=pytest.mark.skip),
],
)
def test_page_size(self, api_key, add_sessions_with_chat_assistant, params, expected_code, expected_page_size, expected_message):
chat_assistant_id, _ = add_sessions_with_chat_assistant
res = list_session_with_chat_assistants(api_key, chat_assistant_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"orderby": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"orderby": "create_time"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", True)), ""),
({"orderby": "name", "desc": "False"}, 0, lambda r: (is_sorted(r["data"], "name", False)), ""),
pytest.param({"orderby": "unknown"}, 102, 0, "orderby should be create_time or update_time", marks=pytest.mark.skip(reason="issues/")),
],
)
def test_orderby(
self,
api_key,
add_sessions_with_chat_assistant,
params,
expected_code,
assertions,
expected_message,
):
chat_assistant_id, _ = add_sessions_with_chat_assistant
res = list_session_with_chat_assistants(api_key, chat_assistant_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"params, expected_code, assertions, expected_message",
[
({"desc": None}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "true"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "True"}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": True}, 0, lambda r: (is_sorted(r["data"], "create_time", True)), ""),
({"desc": "false"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": "False"}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": False}, 0, lambda r: (is_sorted(r["data"], "create_time", False)), ""),
({"desc": "False", "orderby": "update_time"}, 0, lambda r: (is_sorted(r["data"], "update_time", False)), ""),
pytest.param({"desc": "unknown"}, 102, 0, "desc should be true or false", marks=pytest.mark.skip(reason="issues/")),
],
)
def test_desc(
self,
api_key,
add_sessions_with_chat_assistant,
params,
expected_code,
assertions,
expected_message,
):
chat_assistant_id, _ = add_sessions_with_chat_assistant
res = list_session_with_chat_assistants(api_key, chat_assistant_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
assert assertions(res)
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_code, expected_num, expected_message",
[
({"name": None}, 0, 5, ""),
({"name": ""}, 0, 5, ""),
({"name": "session_with_chat_assistant_1"}, 0, 1, ""),
({"name": "unknown"}, 0, 0, ""),
],
)
def test_name(self, api_key, add_sessions_with_chat_assistant, params, expected_code, expected_num, expected_message):
chat_assistant_id, _ = add_sessions_with_chat_assistant
res = list_session_with_chat_assistants(api_key, chat_assistant_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] != "session_with_chat_assistant_1":
assert len(res["data"]) == expected_num
else:
assert res["data"][0]["name"] == params["name"]
else:
assert res["message"] == expected_message
@pytest.mark.p1
@pytest.mark.parametrize(
"session_id, expected_code, expected_num, expected_message",
[
(None, 0, 5, ""),
("", 0, 5, ""),
(lambda r: r[0], 0, 1, ""),
("unknown", 0, 0, "The chat doesn't exist"),
],
)
def test_id(self, api_key, add_sessions_with_chat_assistant, session_id, expected_code, expected_num, expected_message):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant
if callable(session_id):
params = {"id": session_id(session_ids)}
else:
params = {"id": session_id}
res = list_session_with_chat_assistants(api_key, chat_assistant_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["id"] != session_ids[0]:
assert len(res["data"]) == expected_num
else:
assert res["data"][0]["id"] == params["id"]
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"session_id, name, expected_code, expected_num, expected_message",
[
(lambda r: r[0], "session_with_chat_assistant_0", 0, 1, ""),
(lambda r: r[0], "session_with_chat_assistant_100", 0, 0, ""),
(lambda r: r[0], "unknown", 0, 0, ""),
("id", "session_with_chat_assistant_0", 0, 0, ""),
],
)
def test_name_and_id(self, api_key, add_sessions_with_chat_assistant, session_id, name, expected_code, expected_num, expected_message):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant
if callable(session_id):
params = {"id": session_id(session_ids), "name": name}
else:
params = {"id": session_id, "name": name}
res = list_session_with_chat_assistants(api_key, chat_assistant_id, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_num
else:
assert res["message"] == expected_message
@pytest.mark.p3
def test_concurrent_list(self, api_key, add_sessions_with_chat_assistant):
chat_assistant_id, _ = add_sessions_with_chat_assistant
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_session_with_chat_assistants, api_key, chat_assistant_id) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_invalid_params(self, api_key, add_sessions_with_chat_assistant):
chat_assistant_id, _ = add_sessions_with_chat_assistant
params = {"a": "b"}
res = list_session_with_chat_assistants(api_key, chat_assistant_id, params=params)
assert res["code"] == 0
assert len(res["data"]) == 5
@pytest.mark.p3
def test_list_chats_after_deleting_associated_chat_assistant(self, api_key, add_sessions_with_chat_assistant):
chat_assistant_id, _ = add_sessions_with_chat_assistant
res = delete_chat_assistants(api_key, {"ids": [chat_assistant_id]})
assert res["code"] == 0
res = list_session_with_chat_assistants(api_key, chat_assistant_id)
assert res["code"] == 102
assert "You don't own the assistant" in res["message"]

View File

@ -0,0 +1,148 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
from random import randint
import pytest
from common import INVALID_API_TOKEN, SESSION_WITH_CHAT_NAME_LIMIT, delete_chat_assistants, list_session_with_chat_assistants, update_session_with_chat_assistant
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.parametrize(
"invalid_auth, expected_code, expected_message",
[
(None, 0, "`Authorization` can't be empty"),
(
RAGFlowHttpApiAuth(INVALID_API_TOKEN),
109,
"Authentication error: API key is invalid!",
),
],
)
def test_invalid_auth(self, invalid_auth, expected_code, expected_message):
res = update_session_with_chat_assistant(invalid_auth, "chat_assistant_id", "session_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
class TestSessionWithChatAssistantUpdate:
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
pytest.param({"name": "valid_name"}, 0, "", marks=pytest.mark.p1),
pytest.param({"name": "a" * (SESSION_WITH_CHAT_NAME_LIMIT + 1)}, 102, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": 1}, 100, "", marks=pytest.mark.skip(reason="issues/")),
pytest.param({"name": ""}, 102, "`name` can not be empty.", marks=pytest.mark.p3),
pytest.param({"name": "duplicated_name"}, 0, "", marks=pytest.mark.p3),
pytest.param({"name": "case insensitive"}, 0, "", marks=pytest.mark.p3),
],
)
def test_name(self, api_key, add_sessions_with_chat_assistant_func, payload, expected_code, expected_message):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
if payload["name"] == "duplicated_name":
update_session_with_chat_assistant(api_key, chat_assistant_id, session_ids[0], payload)
elif payload["name"] == "case insensitive":
update_session_with_chat_assistant(api_key, chat_assistant_id, session_ids[0], {"name": payload["name"].upper()})
res = update_session_with_chat_assistant(api_key, chat_assistant_id, session_ids[0], payload)
assert res["code"] == expected_code, res
if expected_code == 0:
res = list_session_with_chat_assistants(api_key, chat_assistant_id, {"id": session_ids[0]})
assert res["data"][0]["name"] == payload["name"]
else:
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"chat_assistant_id, expected_code, expected_message",
[
("", 100, "<NotFound '404: Not Found'>"),
pytest.param("invalid_chat_assistant_id", 102, "Session does not exist", marks=pytest.mark.skip(reason="issues/")),
],
)
def test_invalid_chat_assistant_id(self, api_key, add_sessions_with_chat_assistant_func, chat_assistant_id, expected_code, expected_message):
_, session_ids = add_sessions_with_chat_assistant_func
res = update_session_with_chat_assistant(api_key, chat_assistant_id, session_ids[0], {"name": "valid_name"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
@pytest.mark.parametrize(
"session_id, expected_code, expected_message",
[
("", 100, "<MethodNotAllowed '405: Method Not Allowed'>"),
("invalid_session_id", 102, "Session does not exist"),
],
)
def test_invalid_session_id(self, api_key, add_sessions_with_chat_assistant_func, session_id, expected_code, expected_message):
chat_assistant_id, _ = add_sessions_with_chat_assistant_func
res = update_session_with_chat_assistant(api_key, chat_assistant_id, session_id, {"name": "valid_name"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.p3
def test_repeated_update_session(self, api_key, add_sessions_with_chat_assistant_func):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
res = update_session_with_chat_assistant(api_key, chat_assistant_id, session_ids[0], {"name": "valid_name_1"})
assert res["code"] == 0
res = update_session_with_chat_assistant(api_key, chat_assistant_id, session_ids[0], {"name": "valid_name_2"})
assert res["code"] == 0
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_code, expected_message",
[
pytest.param({"unknown_key": "unknown_value"}, 100, "ValueError", marks=pytest.mark.skip),
({}, 0, ""),
pytest.param(None, 100, "TypeError", marks=pytest.mark.skip),
],
)
def test_invalid_params(self, api_key, add_sessions_with_chat_assistant_func, payload, expected_code, expected_message):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
res = update_session_with_chat_assistant(api_key, chat_assistant_id, session_ids[0], payload)
assert res["code"] == expected_code
if expected_code != 0:
assert expected_message in res["message"]
@pytest.mark.p3
def test_concurrent_update_session(self, api_key, add_sessions_with_chat_assistant_func):
chunk_num = 50
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
update_session_with_chat_assistant,
api_key,
chat_assistant_id,
session_ids[randint(0, 4)],
{"name": f"update session test {i}"},
)
for i in range(chunk_num)
]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.p3
def test_update_session_to_deleted_chat_assistant(self, api_key, add_sessions_with_chat_assistant_func):
chat_assistant_id, session_ids = add_sessions_with_chat_assistant_func
delete_chat_assistants(api_key, {"ids": [chat_assistant_id]})
res = update_session_with_chat_assistant(api_key, chat_assistant_id, session_ids[0], {"name": "valid_name"})
assert res["code"] == 102
assert res["message"] == "You do not own the session"