Refa: HTTP API update dataset / test cases / docs (#7564)

### What problem does this PR solve?

This PR introduces Pydantic-based validation for the update dataset HTTP
API, improving code clarity and robustness. Key changes include:
1. Pydantic Validation
2. ​​Error Handling
3. Test Updates
4. Documentation Updates
5. fix bug: #5915

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Documentation Update
- [x] Refactoring
This commit is contained in:
liu an
2025-05-09 19:17:08 +08:00
committed by GitHub
parent 31718581b5
commit 35e36cb945
12 changed files with 1283 additions and 552 deletions

View File

@ -59,21 +59,19 @@ class RAGFlow:
pagerank: int = 0,
parser_config: DataSet.ParserConfig = None,
) -> DataSet:
if parser_config:
parser_config = parser_config.to_json()
res = self.post(
"/datasets",
{
"name": name,
"avatar": avatar,
"description": description,
"embedding_model": embedding_model,
"permission": permission,
"chunk_method": chunk_method,
"pagerank": pagerank,
"parser_config": parser_config,
},
)
payload = {
"name": name,
"avatar": avatar,
"description": description,
"embedding_model": embedding_model,
"permission": permission,
"chunk_method": chunk_method,
"pagerank": pagerank,
}
if parser_config is not None:
payload["parser_config"] = parser_config.to_json()
res = self.post("/datasets", payload)
res = res.json()
if res.get("code") == 0:
return DataSet(self, res["data"])

View File

@ -0,0 +1,28 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import hypothesis.strategies as st
@st.composite
def valid_names(draw):
base_chars = "abcdefghijklmnopqrstuvwxyz_"
first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"]))
remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=128 - 2))
name = (first_char + remaining)[:128]
return name.encode("utf-8").decode("utf-8")

View File

@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255
# DATASET MANAGEMENT
def create_dataset(auth, payload=None, headers=HEADERS, data=None):
def create_dataset(auth, payload=None, *, headers=HEADERS, data=None):
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json()
def list_datasets(auth, params=None, headers=HEADERS):
def list_datasets(auth, params=None, *, headers=HEADERS):
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params)
return res.json()
def update_dataset(auth, dataset_id, payload=None, headers=HEADERS):
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload)
def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None):
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data)
return res.json()
def delete_datasets(auth, payload=None, headers=HEADERS):
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload)
def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None):
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json()

View File

@ -37,3 +37,13 @@ def add_datasets_func(get_http_api_auth, request):
request.addfinalizer(cleanup)
return batch_create_datasets(get_http_api_auth, 3)
@pytest.fixture(scope="function")
def add_dataset_func(get_http_api_auth, request):
def cleanup():
delete_datasets(get_http_api_auth)
request.addfinalizer(cleanup)
return batch_create_datasets(get_http_api_auth, 1)[0]

View File

@ -13,30 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
import hypothesis.strategies as st
import pytest
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, create_dataset
from hypothesis import example, given, settings
from libs.auth import RAGFlowHttpApiAuth
from libs.utils import encode_avatar
from libs.utils.file_utils import create_image_file
from libs.utils.hypothesis_utils import valid_names
@st.composite
def valid_names(draw):
base_chars = "abcdefghijklmnopqrstuvwxyz_"
first_char = draw(st.sampled_from([c for c in base_chars if c.isalpha() or c == "_"]))
remaining = draw(st.text(alphabet=st.sampled_from(base_chars), min_size=0, max_size=DATASET_NAME_LIMIT - 2))
name = (first_char + remaining)[:128]
return name.encode("utf-8").decode("utf-8")
@pytest.mark.p1
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
[
@ -49,64 +39,17 @@ class TestAuthorization:
],
ids=["empty_auth", "invalid_api_token"],
)
def test_invalid_auth(self, auth, expected_code, expected_message):
def test_auth_invalid(self, auth, expected_code, expected_message):
res = create_dataset(auth, {"name": "auth_test"})
assert res["code"] == expected_code
assert res["message"] == expected_message
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetCreation:
@pytest.mark.p1
@given(name=valid_names())
@example("a" * 128)
@settings(max_examples=20)
def test_valid_name(self, get_http_api_auth, name):
res = create_dataset(get_http_api_auth, {"name": name})
assert res["code"] == 0, res
assert res["data"]["name"] == name, res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, expected_message",
[
("", "String should have at least 1 character"),
(" ", "String should have at least 1 character"),
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
(0, "Input should be a valid string"),
],
ids=["empty_name", "space_name", "too_long_name", "invalid_name"],
)
def test_invalid_name(self, get_http_api_auth, name, expected_message):
res = create_dataset(get_http_api_auth, {"name": name})
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_duplicated_name(self, get_http_api_auth):
name = "duplicated_name"
payload = {"name": name}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p2
def test_case_insensitive(self, get_http_api_auth):
name = "CaseInsensitive"
res = create_dataset(get_http_api_auth, {"name": name.upper()})
assert res["code"] == 0, res
res = create_dataset(get_http_api_auth, {"name": name.lower()})
assert res["code"] == 101, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
class TestRquest:
@pytest.mark.p3
def test_bad_content_type(self, get_http_api_auth):
def test_content_type_bad(self, get_http_api_auth):
BAD_CONTENT_TYPE = "text/xml"
res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE})
res = create_dataset(get_http_api_auth, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@ -115,15 +58,85 @@ class TestDatasetCreation:
"payload, expected_message",
[
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected objec"),
('"a"', "Invalid request payload: expected object, got str"),
],
ids=["malformed_json_syntax", "invalid_request_payload_type"],
)
def test_bad_payload(self, get_http_api_auth, payload, expected_message):
def test_payload_bad(self, get_http_api_auth, payload, expected_message):
res = create_dataset(get_http_api_auth, data=payload)
assert res["code"] == 101, res
assert res["message"] == expected_message, res
@pytest.mark.usefixtures("clear_datasets")
class TestCapability:
@pytest.mark.p3
def test_create_dataset_1k(self, get_http_api_auth):
for i in range(1_000):
payload = {"name": f"dataset_{i}"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, f"Failed to create dataset {i}"
@pytest.mark.p3
def test_create_dataset_concurrent(self, get_http_api_auth):
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(create_dataset, get_http_api_auth, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses), responses
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetCreate:
@pytest.mark.p1
@given(name=valid_names())
@example("a" * 128)
@settings(max_examples=20)
def test_name(self, get_http_api_auth, name):
res = create_dataset(get_http_api_auth, {"name": name})
assert res["code"] == 0, res
assert res["data"]["name"] == name, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_message",
[
("", "String should have at least 1 character"),
(" ", "String should have at least 1 character"),
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
(0, "Input should be a valid string"),
(None, "Input should be a valid string"),
],
ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"],
)
def test_name_invalid(self, get_http_api_auth, name, expected_message):
payload = {"name": name}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_name_duplicated(self, get_http_api_auth):
name = "duplicated_name"
payload = {"name": name}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p3
def test_name_case_insensitive(self, get_http_api_auth):
name = "CaseInsensitive"
payload = {"name": name.upper()}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
payload = {"name": name.lower()}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
@pytest.mark.p2
def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
@ -134,16 +147,10 @@ class TestDatasetCreation:
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
@pytest.mark.p3
def test_avatar_none(self, get_http_api_auth, tmp_path):
payload = {"name": "test_avatar_none", "avatar": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["avatar"] is None, res
@pytest.mark.p2
def test_avatar_exceeds_limit_length(self, get_http_api_auth):
res = create_dataset(get_http_api_auth, {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536})
payload = {"name": "exceeds_limit_length_avatar", "avatar": "a" * 65536}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@ -158,7 +165,7 @@ class TestDatasetCreation:
],
ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"],
)
def test_invalid_avatar_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message):
def test_avatar_invalid_prefix(self, get_http_api_auth, tmp_path, name, avatar_prefix, expected_message):
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {
"name": name,
@ -169,11 +176,25 @@ class TestDatasetCreation:
assert expected_message in res["message"], res
@pytest.mark.p3
def test_description_none(self, get_http_api_auth):
payload = {"name": "test_description_none", "description": None}
def test_avatar_unset(self, get_http_api_auth):
payload = {"name": "test_avatar_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["description"] is None, res
assert res["data"]["avatar"] is None, res
@pytest.mark.p3
def test_avatar_none(self, get_http_api_auth):
payload = {"name": "test_avatar_none", "avatar": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["avatar"] is None, res
@pytest.mark.p2
def test_description(self, get_http_api_auth):
payload = {"name": "test_description", "description": "description"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["description"] == "description", res
@pytest.mark.p2
def test_description_exceeds_limit_length(self, get_http_api_auth):
@ -182,6 +203,20 @@ class TestDatasetCreation:
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
def test_description_unset(self, get_http_api_auth):
payload = {"name": "test_description_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["description"] is None, res
@pytest.mark.p3
def test_description_none(self, get_http_api_auth):
payload = {"name": "test_description_none", "description": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["description"] is None, res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, embedding_model",
@ -189,22 +224,14 @@ class TestDatasetCreation:
("BAAI/bge-large-zh-v1.5@BAAI", "BAAI/bge-large-zh-v1.5@BAAI"),
("maidalun1020/bce-embedding-base_v1@Youdao", "maidalun1020/bce-embedding-base_v1@Youdao"),
("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"),
("embedding_model_default", None),
],
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"],
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"],
)
def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model):
if embedding_model is None:
payload = {"name": name}
else:
payload = {"name": name, "embedding_model": embedding_model}
def test_embedding_model(self, get_http_api_auth, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
if embedding_model is None:
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
else:
assert res["data"]["embedding_model"] == embedding_model, res
assert res["data"]["embedding_model"] == embedding_model, res
@pytest.mark.p2
@pytest.mark.parametrize(
@ -217,7 +244,7 @@ class TestDatasetCreation:
],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
)
def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model):
def test_embedding_model_invalid(self, get_http_api_auth, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
@ -247,6 +274,20 @@ class TestDatasetCreation:
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@pytest.mark.p2
def test_embedding_model_unset(self, get_http_api_auth):
payload = {"name": "embedding_model_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["embedding_model"] == "BAAI/bge-large-zh-v1.5@BAAI", res
@pytest.mark.p2
def test_embedding_model_none(self, get_http_api_auth):
payload = {"name": "test_embedding_model_none", "embedding_model": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be a valid string" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, permission",
@ -255,21 +296,14 @@ class TestDatasetCreation:
("team", "team"),
("me_upercase", "ME"),
("team_upercase", "TEAM"),
("permission_default", None),
],
ids=["me", "team", "me_upercase", "team_upercase", "permission_default"],
ids=["me", "team", "me_upercase", "team_upercase"],
)
def test_valid_permission(self, get_http_api_auth, name, permission):
if permission is None:
payload = {"name": name}
else:
payload = {"name": name, "permission": permission}
def test_permission(self, get_http_api_auth, name, permission):
payload = {"name": name, "permission": permission}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
if permission is None:
assert res["data"]["permission"] == "me", res
else:
assert res["data"]["permission"] == permission.lower(), res
assert res["data"]["permission"] == permission.lower(), res
@pytest.mark.p2
@pytest.mark.parametrize(
@ -279,13 +313,28 @@ class TestDatasetCreation:
("unknown", "unknown"),
("type_error", list()),
],
ids=["empty", "unknown", "type_error"],
)
def test_invalid_permission(self, get_http_api_auth, name, permission):
def test_permission_invalid(self, get_http_api_auth, name, permission):
payload = {"name": name, "permission": permission}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101
assert "Input should be 'me' or 'team'" in res["message"]
@pytest.mark.p2
def test_permission_unset(self, get_http_api_auth):
payload = {"name": "test_permission_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["permission"] == "me", res
@pytest.mark.p3
def test_permission_none(self, get_http_api_auth):
payload = {"name": "test_permission_none", "permission": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be 'me' or 'team'" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, chunk_method",
@ -302,20 +351,14 @@ class TestDatasetCreation:
("qa", "qa"),
("table", "table"),
("tag", "tag"),
("chunk_method_default", None),
],
ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
)
def test_valid_chunk_method(self, get_http_api_auth, name, chunk_method):
if chunk_method is None:
payload = {"name": name}
else:
payload = {"name": name, "chunk_method": chunk_method}
def test_chunk_method(self, get_http_api_auth, name, chunk_method):
payload = {"name": name, "chunk_method": chunk_method}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
if chunk_method is None:
assert res["data"]["chunk_method"] == "naive", res
else:
assert res["data"]["chunk_method"] == chunk_method, res
assert res["data"]["chunk_method"] == chunk_method, res
@pytest.mark.p2
@pytest.mark.parametrize(
@ -325,19 +368,77 @@ class TestDatasetCreation:
("unknown", "unknown"),
("type_error", list()),
],
ids=["empty", "unknown", "type_error"],
)
def test_invalid_chunk_method(self, get_http_api_auth, name, chunk_method):
def test_chunk_method_invalid(self, get_http_api_auth, name, chunk_method):
payload = {"name": name, "chunk_method": chunk_method}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
def test_chunk_method_unset(self, get_http_api_auth):
payload = {"name": "test_chunk_method_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["chunk_method"] == "naive", res
@pytest.mark.p3
def test_chunk_method_none(self, get_http_api_auth):
payload = {"name": "chunk_method_none", "chunk_method": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, get_http_api_auth, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == pagerank, res
@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, get_http_api_auth, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_pagerank_unset(self, get_http_api_auth):
payload = {"name": "pagerank_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == 0, res
@pytest.mark.p3
def test_pagerank_none(self, get_http_api_auth):
payload = {"name": "pagerank_unset", "pagerank": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, parser_config",
[
("default_none", None),
("default_empty", {}),
("auto_keywords_min", {"auto_keywords": 0}),
("auto_keywords_mid", {"auto_keywords": 16}),
("auto_keywords_max", {"auto_keywords": 32}),
@ -363,7 +464,7 @@ class TestDatasetCreation:
("task_page_size_min", {"task_page_size": 1}),
("task_page_size_None", {"task_page_size": None}),
("pages", {"pages": [[1, 100]]}),
("pages_none", None),
("pages_none", {"pages": None}),
("graphrag_true", {"graphrag": {"use_graphrag": True}}),
("graphrag_false", {"graphrag": {"use_graphrag": False}}),
("graphrag_entity_types", {"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}}),
@ -388,8 +489,6 @@ class TestDatasetCreation:
("raptor_random_seed_min", {"raptor": {"random_seed": 0}}),
],
ids=[
"default_none",
"default_empty",
"auto_keywords_min",
"auto_keywords_mid",
"auto_keywords_max",
@ -440,44 +539,16 @@ class TestDatasetCreation:
"raptor_random_seed_min",
],
)
def test_valid_parser_config(self, get_http_api_auth, name, parser_config):
if parser_config is None:
payload = {"name": name}
else:
payload = {"name": name, "parser_config": parser_config}
def test_parser_config(self, get_http_api_auth, name, parser_config):
payload = {"name": name, "parser_config": parser_config}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
if parser_config is None:
assert res["data"]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}
elif parser_config == {}:
assert res["data"]["parser_config"] == {
"auto_keywords": 0,
"auto_questions": 0,
"chunk_token_num": 128,
"delimiter": r"\n",
"filename_embd_weight": None,
"graphrag": None,
"html4excel": False,
"layout_recognize": "DeepDOC",
"pages": None,
"raptor": None,
"tag_kb_ids": [],
"task_page_size": None,
"topn_tags": 1,
}
else:
for k, v in parser_config.items():
if isinstance(v, dict):
for kk, vv in v.items():
assert res["data"]["parser_config"][k][kk] == vv
else:
assert res["data"]["parser_config"][k] == v
for k, v in parser_config.items():
if isinstance(v, dict):
for kk, vv in v.items():
assert res["data"]["parser_config"][k][kk] == vv, res
else:
assert res["data"]["parser_config"][k] == v, res
@pytest.mark.p2
@pytest.mark.parametrize(
@ -595,15 +666,72 @@ class TestDatasetCreation:
"parser_config_type_invalid",
],
)
def test_invalid_parser_config(self, get_http_api_auth, name, parser_config, expected_message):
def test_parser_config_invalid(self, get_http_api_auth, name, parser_config, expected_message):
payload = {"name": name, "parser_config": parser_config}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_parser_config_empty(self, get_http_api_auth):
payload = {"name": "default_empty", "parser_config": {}}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["parser_config"] == {
"auto_keywords": 0,
"auto_questions": 0,
"chunk_token_num": 128,
"delimiter": r"\n",
"filename_embd_weight": None,
"graphrag": None,
"html4excel": False,
"layout_recognize": "DeepDOC",
"pages": None,
"raptor": None,
"tag_kb_ids": [],
"task_page_size": None,
"topn_tags": 1,
}
@pytest.mark.p2
def test_parser_config_unset(self, get_http_api_auth):
payload = {"name": "default_unset"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, res
assert res["data"]["parser_config"] == {
"chunk_token_num": 128,
"delimiter": r"\n",
"html4excel": False,
"layout_recognize": "DeepDOC",
"raptor": {"use_raptor": False},
}, res
@pytest.mark.p3
def test_dataset_10k(self, get_http_api_auth):
for i in range(10_000):
payload = {"name": f"dataset_{i}"}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, f"Failed to create dataset {i}"
def test_parser_config_none(self, get_http_api_auth):
payload = {"name": "default_none", "parser_config": None}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize(
"payload",
[
{"name": "id", "id": "id"},
{"name": "tenant_id", "tenant_id": "e57c1966f99211efb41e9e45646e0111"},
{"name": "created_by", "created_by": "created_by"},
{"name": "create_date", "create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"name": "create_time", "create_time": 1741671443322},
{"name": "update_date", "update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"name": "update_time", "update_time": 1741671443339},
{"name": "document_count", "document_count": 1},
{"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"},
{"name": "unknown_field", "unknown_field": "unknown_field"},
],
)
def test_unsupported_field(self, get_http_api_auth, payload):
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res

View File

@ -16,21 +16,18 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
DATASET_NAME_LIMIT,
INVALID_API_TOKEN,
list_datasets,
update_dataset,
)
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset
from hypothesis import HealthCheck, example, given, settings
from libs.auth import RAGFlowHttpApiAuth
from libs.utils import encode_avatar
from libs.utils.file_utils import create_image_file
from libs.utils.hypothesis_utils import valid_names
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
@pytest.mark.p1
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
[
@ -41,111 +38,178 @@ class TestAuthorization:
"Authentication error: API key is invalid!",
),
],
ids=["empty_auth", "invalid_api_token"],
)
def test_invalid_auth(self, auth, expected_code, expected_message):
def test_auth_invalid(self, auth, expected_code, expected_message):
res = update_dataset(auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
assert res["code"] == expected_code, res
assert res["message"] == expected_message, res
class TestRquest:
@pytest.mark.p3
def test_bad_content_type(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
BAD_CONTENT_TYPE = "text/xml"
res = update_dataset(get_http_api_auth, dataset_id, {"name": "bad_content_type"}, headers={"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@pytest.mark.p3
@pytest.mark.parametrize(
"payload, expected_message",
[
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected object, got str"),
],
ids=["malformed_json_syntax", "invalid_request_payload_type"],
)
def test_payload_bad(self, get_http_api_auth, add_dataset_func, payload, expected_message):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, data=payload)
assert res["code"] == 101, res
assert res["message"] == expected_message, res
@pytest.mark.p2
def test_payload_empty(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {})
assert res["code"] == 101, res
assert res["message"] == "No properties were modified", res
class TestCapability:
@pytest.mark.p3
def test_update_dateset_concurrent(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses), responses
@pytest.mark.p1
class TestDatasetUpdate:
@pytest.mark.parametrize(
"name, expected_code, expected_message",
[
("valid_name", 0, ""),
(
"a" * (DATASET_NAME_LIMIT + 1),
102,
"Dataset name should not be longer than 128 characters.",
),
(0, 100, """AttributeError("\'int\' object has no attribute \'strip\'")"""),
(
None,
100,
"""AttributeError("\'NoneType\' object has no attribute \'strip\'")""",
),
pytest.param("", 102, "", marks=pytest.mark.skip(reason="issue/5915")),
("dataset_1", 102, "Duplicated dataset name in updating dataset."),
("DATASET_1", 102, "Duplicated dataset name in updating dataset."),
],
)
def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message):
dataset_ids = add_datasets_func
res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]})
assert res["data"][0]["name"] == name
else:
assert res["message"] == expected_message
@pytest.mark.p3
def test_dataset_id_not_uuid(self, get_http_api_auth):
payload = {"name": "dataset_id_not_uuid"}
res = update_dataset(get_http_api_auth, "not_uuid", payload)
assert res["code"] == 101, res
assert "Input should be a valid UUID" in res["message"], res
@pytest.mark.parametrize(
"embedding_model, expected_code, expected_message",
[
("BAAI/bge-large-zh-v1.5", 0, ""),
("maidalun1020/bce-embedding-base_v1", 0, ""),
(
"other_embedding_model",
102,
"`embedding_model` other_embedding_model doesn't exist",
),
(None, 102, "`embedding_model` can't be empty"),
],
)
def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message):
@pytest.mark.p3
def test_dataset_id_wrong_uuid(self, get_http_api_auth):
payload = {"name": "wrong_uuid"}
res = update_dataset(get_http_api_auth, "d94a8dc02c9711f0930f7fbc369eab6d", payload)
assert res["code"] == 102, res
assert "lacks permission for dataset" in res["message"], res
@pytest.mark.p1
@given(name=valid_names())
@example("a" * 128)
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
def test_name(self, get_http_api_auth, add_dataset_func, name):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model})
assert res["code"] == expected_code
if expected_code == 0:
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["embedding_model"] == embedding_model
else:
assert res["message"] == expected_message
payload = {"name": name}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["name"] == name, res
@pytest.mark.p2
@pytest.mark.parametrize(
"chunk_method, expected_code, expected_message",
"name, expected_message",
[
("naive", 0, ""),
("manual", 0, ""),
("qa", 0, ""),
("table", 0, ""),
("paper", 0, ""),
("book", 0, ""),
("laws", 0, ""),
("presentation", 0, ""),
("picture", 0, ""),
("one", 0, ""),
("email", 0, ""),
("tag", 0, ""),
("", 0, ""),
(
"other_chunk_method",
102,
"'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table', 'paper', 'book', 'laws', 'presentation', 'picture', 'one', 'email', 'tag']",
),
("", "String should have at least 1 character"),
(" ", "String should have at least 1 character"),
("a" * (DATASET_NAME_LIMIT + 1), "String should have at most 128 characters"),
(0, "Input should be a valid string"),
(None, "Input should be a valid string"),
],
ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"],
)
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message):
def test_name_invalid(self, get_http_api_auth, add_dataset_func, name, expected_message):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method})
assert res["code"] == expected_code
if expected_code == 0:
res = list_datasets(get_http_api_auth, {"id": dataset_id})
if chunk_method != "":
assert res["data"][0]["chunk_method"] == chunk_method
else:
assert res["data"][0]["chunk_method"] == "naive"
else:
assert res["message"] == expected_message
payload = {"name": name}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_name_duplicated(self, get_http_api_auth, add_datasets_func):
dataset_ids = add_datasets_func[0]
name = "dataset_1"
payload = {"name": name}
res = update_dataset(get_http_api_auth, dataset_ids, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p3
def test_name_case_insensitive(self, get_http_api_auth, add_datasets_func):
dataset_id = add_datasets_func[0]
name = "DATASET_1"
payload = {"name": name}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res
@pytest.mark.p2
def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"avatar": encode_avatar(fn)}
payload = {
"avatar": f"data:image/png;base64,{encode_avatar(fn)}",
}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res
@pytest.mark.p2
def test_avatar_exceeds_limit_length(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"avatar": "a" * 65536}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
@pytest.mark.parametrize(
"name, avatar_prefix, expected_message",
[
("empty_prefix", "", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
("missing_comma", "data:image/png;base64", "Missing MIME prefix. Expected format: data:<mime>;base64,<data>"),
("unsupported_mine_type", "invalid_mine_prefix:image/png;base64,", "Invalid MIME prefix format. Must start with 'data:'"),
("invalid_mine_type", "data:unsupported_mine_type;base64,", "Unsupported MIME type. Allowed: ['image/jpeg', 'image/png']"),
],
ids=["empty_prefix", "missing_comma", "unsupported_mine_type", "invalid_mine_type"],
)
def test_avatar_invalid_prefix(self, get_http_api_auth, add_dataset_func, tmp_path, name, avatar_prefix, expected_message):
dataset_id = add_dataset_func
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {
"name": name,
"avatar": f"{avatar_prefix}{encode_avatar(fn)}",
}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_avatar_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"avatar": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["avatar"] is None, res
@pytest.mark.p2
def test_description(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"description": "description"}
@ -153,95 +217,533 @@ class TestDatasetUpdate:
assert res["code"] == 0
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["description"] == "description"
def test_pagerank(self, get_http_api_auth, add_dataset_func):
@pytest.mark.p2
def test_description_exceeds_limit_length(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 1}
payload = {"description": "a" * 65536}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
assert res["code"] == 101, res
assert "String should have at most 65535 characters" in res["message"], res
@pytest.mark.p3
def test_description_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"description": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["pagerank"] == 1
def test_similarity_threshold(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"similarity_threshold": 1}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["similarity_threshold"] == 1
assert res["code"] == 0, res
assert res["data"][0]["description"] is None
@pytest.mark.p1
@pytest.mark.parametrize(
"permission, expected_code",
"embedding_model",
[
("me", 0),
("team", 0),
("", 0),
("ME", 102),
("TEAM", 102),
("other_permission", 102),
"BAAI/bge-large-zh-v1.5@BAAI",
"maidalun1020/bce-embedding-base_v1@Youdao",
"embedding-3@ZHIPU-AI",
],
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu"],
)
def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code):
def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model):
dataset_id = add_dataset_func
payload = {"embedding_model": embedding_model}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["embedding_model"] == embedding_model, res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, embedding_model",
[
("unknown_llm_name", "unknown@ZHIPU-AI"),
("unknown_llm_factory", "embedding-3@unknown"),
("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"),
("tenant_no_auth", "text-embedding-3-small@OpenAI"),
],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
)
def test_embedding_model_invalid(self, get_http_api_auth, add_dataset_func, name, embedding_model):
dataset_id = add_dataset_func
payload = {"name": name, "embedding_model": embedding_model}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
if "tenant_no_auth" in name:
assert res["message"] == f"Unauthorized model: <{embedding_model}>", res
else:
assert res["message"] == f"Unsupported model: <{embedding_model}>", res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, embedding_model",
[
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_format(self, get_http_api_auth, add_dataset_func, name, embedding_model):
dataset_id = add_dataset_func
payload = {"name": name, "embedding_model": embedding_model}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
if name == "missing_at":
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@pytest.mark.p2
def test_embedding_model_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"embedding_model": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid string" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"name, permission",
[
("me", "me"),
("team", "team"),
("me_upercase", "ME"),
("team_upercase", "TEAM"),
],
ids=["me", "team", "me_upercase", "team_upercase"],
)
def test_permission(self, get_http_api_auth, add_dataset_func, name, permission):
dataset_id = add_dataset_func
payload = {"name": name, "permission": permission}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["permission"] == permission.lower(), res
@pytest.mark.p2
@pytest.mark.parametrize(
"permission",
[
"",
"unknown",
list(),
],
ids=["empty", "unknown", "type_error"],
)
def test_permission_invalid(self, get_http_api_auth, add_dataset_func, permission):
dataset_id = add_dataset_func
payload = {"permission": permission}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == expected_code
assert res["code"] == 101
assert "Input should be 'me' or 'team'" in res["message"]
res = list_datasets(get_http_api_auth, {"id": dataset_id})
if expected_code == 0 and permission != "":
assert res["data"][0]["permission"] == permission
if permission == "":
assert res["data"][0]["permission"] == "me"
def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func):
@pytest.mark.p3
def test_permission_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"vector_similarity_weight": 1}
payload = {"name": "test_permission_none", "permission": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'me' or 'team'" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"chunk_method",
[
"naive",
"book",
"email",
"laws",
"manual",
"one",
"paper",
"picture",
"presentation",
"qa",
"table",
"tag",
],
ids=["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"],
)
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method):
dataset_id = add_dataset_func
payload = {"chunk_method": chunk_method}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["chunk_method"] == chunk_method, res
@pytest.mark.p2
@pytest.mark.parametrize(
"chunk_method",
[
"",
"unknown",
list(),
],
ids=["empty", "unknown", "type_error"],
)
def test_chunk_method_invalid(self, get_http_api_auth, add_dataset_func, chunk_method):
dataset_id = add_dataset_func
payload = {"chunk_method": chunk_method}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p3
def test_chunk_method_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"chunk_method": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, get_http_api_auth, add_dataset_func, pagerank):
dataset_id = add_dataset_func
payload = {"pagerank": pagerank}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["vector_similarity_weight"] == 1
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == pagerank
def test_invalid_dataset_id(self, get_http_api_auth):
res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"})
assert res["code"] == 102
assert res["message"] == "You don't own the dataset"
@pytest.mark.p2
@pytest.mark.parametrize(
"pagerank, expected_message",
[
(-1, "Input should be greater than or equal to 0"),
(101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, get_http_api_auth, add_dataset_func, pagerank, expected_message):
dataset_id = add_dataset_func
payload = {"pagerank": pagerank}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_pagerank_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res
@pytest.mark.p1
@pytest.mark.parametrize(
"parser_config",
[
{"auto_keywords": 0},
{"auto_keywords": 16},
{"auto_keywords": 32},
{"auto_questions": 0},
{"auto_questions": 5},
{"auto_questions": 10},
{"chunk_token_num": 1},
{"chunk_token_num": 1024},
{"chunk_token_num": 2048},
{"delimiter": "\n"},
{"delimiter": " "},
{"html4excel": True},
{"html4excel": False},
{"layout_recognize": "DeepDOC"},
{"layout_recognize": "Plain Text"},
{"tag_kb_ids": ["1", "2"]},
{"topn_tags": 1},
{"topn_tags": 5},
{"topn_tags": 10},
{"filename_embd_weight": 0.1},
{"filename_embd_weight": 0.5},
{"filename_embd_weight": 1.0},
{"task_page_size": 1},
{"task_page_size": None},
{"pages": [[1, 100]]},
{"pages": None},
{"graphrag": {"use_graphrag": True}},
{"graphrag": {"use_graphrag": False}},
{"graphrag": {"entity_types": ["age", "sex", "height", "weight"]}},
{"graphrag": {"method": "general"}},
{"graphrag": {"method": "light"}},
{"graphrag": {"community": True}},
{"graphrag": {"community": False}},
{"graphrag": {"resolution": True}},
{"graphrag": {"resolution": False}},
{"raptor": {"use_raptor": True}},
{"raptor": {"use_raptor": False}},
{"raptor": {"prompt": "Who are you?"}},
{"raptor": {"max_token": 1}},
{"raptor": {"max_token": 1024}},
{"raptor": {"max_token": 2048}},
{"raptor": {"threshold": 0.0}},
{"raptor": {"threshold": 0.5}},
{"raptor": {"threshold": 1.0}},
{"raptor": {"max_cluster": 1}},
{"raptor": {"max_cluster": 512}},
{"raptor": {"max_cluster": 1024}},
{"raptor": {"random_seed": 0}},
],
ids=[
"auto_keywords_min",
"auto_keywords_mid",
"auto_keywords_max",
"auto_questions_min",
"auto_questions_mid",
"auto_questions_max",
"chunk_token_num_min",
"chunk_token_num_mid",
"chunk_token_num_max",
"delimiter",
"delimiter_space",
"html4excel_true",
"html4excel_false",
"layout_recognize_DeepDOC",
"layout_recognize_navie",
"tag_kb_ids",
"topn_tags_min",
"topn_tags_mid",
"topn_tags_max",
"filename_embd_weight_min",
"filename_embd_weight_mid",
"filename_embd_weight_max",
"task_page_size_min",
"task_page_size_None",
"pages",
"pages_none",
"graphrag_true",
"graphrag_false",
"graphrag_entity_types",
"graphrag_method_general",
"graphrag_method_light",
"graphrag_community_true",
"graphrag_community_false",
"graphrag_resolution_true",
"graphrag_resolution_false",
"raptor_true",
"raptor_false",
"raptor_prompt",
"raptor_max_token_min",
"raptor_max_token_mid",
"raptor_max_token_max",
"raptor_threshold_min",
"raptor_threshold_mid",
"raptor_threshold_max",
"raptor_max_cluster_min",
"raptor_max_cluster_mid",
"raptor_max_cluster_max",
"raptor_random_seed_min",
],
)
def test_parser_config(self, get_http_api_auth, add_dataset_func, parser_config):
dataset_id = add_dataset_func
payload = {"parser_config": parser_config}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
for k, v in parser_config.items():
if isinstance(v, dict):
for kk, vv in v.items():
assert res["data"][0]["parser_config"][k][kk] == vv, res
else:
assert res["data"][0]["parser_config"][k] == v, res
@pytest.mark.p2
@pytest.mark.parametrize(
"parser_config, expected_message",
[
({"auto_keywords": -1}, "Input should be greater than or equal to 0"),
({"auto_keywords": 33}, "Input should be less than or equal to 32"),
({"auto_keywords": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"auto_keywords": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"auto_questions": -1}, "Input should be greater than or equal to 0"),
({"auto_questions": 11}, "Input should be less than or equal to 10"),
({"auto_questions": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"auto_questions": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"chunk_token_num": 0}, "Input should be greater than or equal to 1"),
({"chunk_token_num": 2049}, "Input should be less than or equal to 2048"),
({"chunk_token_num": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"chunk_token_num": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"delimiter": ""}, "String should have at least 1 character"),
({"html4excel": "string"}, "Input should be a valid boolean, unable to interpret input"),
({"tag_kb_ids": "1,2"}, "Input should be a valid list"),
({"tag_kb_ids": [1, 2]}, "Input should be a valid string"),
({"topn_tags": 0}, "Input should be greater than or equal to 1"),
({"topn_tags": 11}, "Input should be less than or equal to 10"),
({"topn_tags": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"topn_tags": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"filename_embd_weight": -1}, "Input should be greater than or equal to 0"),
({"filename_embd_weight": 1.1}, "Input should be less than or equal to 1"),
({"filename_embd_weight": "string"}, "Input should be a valid number, unable to parse string as a number"),
({"task_page_size": 0}, "Input should be greater than or equal to 1"),
({"task_page_size": 3.14}, "Input should be a valid integer, got a number with a fractional part"),
({"task_page_size": "string"}, "Input should be a valid integer, unable to parse string as an integer"),
({"pages": "1,2"}, "Input should be a valid list"),
({"pages": ["1,2"]}, "Input should be a valid list"),
({"pages": [["string1", "string2"]]}, "Input should be a valid integer, unable to parse string as an integer"),
({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"),
({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"),
({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"),
({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"),
({"graphrag": {"community": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean, unable to interpret input"),
({"raptor": {"prompt": ""}}, "String should have at least 1 character"),
({"raptor": {"prompt": " "}}, "String should have at least 1 character"),
({"raptor": {"max_token": 0}}, "Input should be greater than or equal to 1"),
({"raptor": {"max_token": 2049}}, "Input should be less than or equal to 2048"),
({"raptor": {"max_token": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
({"raptor": {"max_token": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"raptor": {"threshold": -0.1}}, "Input should be greater than or equal to 0"),
({"raptor": {"threshold": 1.1}}, "Input should be less than or equal to 1"),
({"raptor": {"threshold": "string"}}, "Input should be a valid number, unable to parse string as a number"),
({"raptor": {"max_cluster": 0}}, "Input should be greater than or equal to 1"),
({"raptor": {"max_cluster": 1025}}, "Input should be less than or equal to 1024"),
({"raptor": {"max_cluster": 3.14}}, "Input should be a valid integer, got a number with a fractional par"),
({"raptor": {"max_cluster": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
({"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
],
ids=[
"auto_keywords_min_limit",
"auto_keywords_max_limit",
"auto_keywords_float_not_allowed",
"auto_keywords_type_invalid",
"auto_questions_min_limit",
"auto_questions_max_limit",
"auto_questions_float_not_allowed",
"auto_questions_type_invalid",
"chunk_token_num_min_limit",
"chunk_token_num_max_limit",
"chunk_token_num_float_not_allowed",
"chunk_token_num_type_invalid",
"delimiter_empty",
"html4excel_type_invalid",
"tag_kb_ids_not_list",
"tag_kb_ids_int_in_list",
"topn_tags_min_limit",
"topn_tags_max_limit",
"topn_tags_float_not_allowed",
"topn_tags_type_invalid",
"filename_embd_weight_min_limit",
"filename_embd_weight_max_limit",
"filename_embd_weight_type_invalid",
"task_page_size_min_limit",
"task_page_size_float_not_allowed",
"task_page_size_type_invalid",
"pages_not_list",
"pages_not_list_in_list",
"pages_not_int_list",
"graphrag_type_invalid",
"graphrag_entity_types_not_list",
"graphrag_entity_types_not_str_in_list",
"graphrag_method_unknown",
"graphrag_method_none",
"graphrag_community_type_invalid",
"graphrag_resolution_type_invalid",
"raptor_type_invalid",
"raptor_prompt_empty",
"raptor_prompt_space",
"raptor_max_token_min_limit",
"raptor_max_token_max_limit",
"raptor_max_token_float_not_allowed",
"raptor_max_token_type_invalid",
"raptor_threshold_min_limit",
"raptor_threshold_max_limit",
"raptor_threshold_type_invalid",
"raptor_max_cluster_min_limit",
"raptor_max_cluster_max_limit",
"raptor_max_cluster_float_not_allowed",
"raptor_max_cluster_type_invalid",
"raptor_random_seed_min_limit",
"raptor_random_seed_float_not_allowed",
"raptor_random_seed_type_invalid",
"parser_config_type_invalid",
],
)
def test_parser_config_invalid(self, get_http_api_auth, add_dataset_func, parser_config, expected_message):
dataset_id = add_dataset_func
payload = {"parser_config": parser_config}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p2
def test_parser_config_empty(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"parser_config": {}}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(get_http_api_auth)
assert res["code"] == 0, res
assert res["data"][0]["parser_config"] == {}
# @pytest.mark.p2
# def test_parser_config_unset(self, get_http_api_auth, add_dataset_func):
# dataset_id = add_dataset_func
# payload = {"name": "default_unset"}
# res = update_dataset(get_http_api_auth, dataset_id, payload)
# assert res["code"] == 0, res
# res = list_datasets(get_http_api_auth)
# assert res["code"] == 0, res
# assert res["data"][0]["parser_config"] == {
# "chunk_token_num": 128,
# "delimiter": r"\n",
# "html4excel": False,
# "layout_recognize": "DeepDOC",
# "raptor": {"use_raptor": False},
# }, res
@pytest.mark.p3
def test_parser_config_none(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"parser_config": None}
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101, res
assert "Input should be a valid dictionary or instance of ParserConfig" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize(
"payload",
[
{"chunk_count": 1},
{"id": "id"},
{"tenant_id": "e57c1966f99211efb41e9e45646e0111"},
{"created_by": "created_by"},
{"create_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"create_time": 1741671443322},
{"created_by": "aa"},
{"document_count": 1},
{"id": "id"},
{"status": "1"},
{"tenant_id": "e57c1966f99211efb41e9e45646e0111"},
{"token_num": 1},
{"update_date": "Tue, 11 Mar 2025 13:37:23 GMT"},
{"update_time": 1741671443339},
{"document_count": 1},
{"chunk_count": 1},
{"token_num": 1},
{"status": "1"},
{"unknown_field": "unknown_field"},
],
)
def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload):
def test_unsupported_field(self, get_http_api_auth, add_dataset_func, payload):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101
assert "is readonly" in res["message"]
def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0})
assert res["code"] == 100
@pytest.mark.p3
def test_concurrent_update(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
assert res["code"] == 101, res
assert "Extra inputs are not permitted" in res["message"], res

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor
from time import sleep
import pytest
from common import INVALID_API_TOKEN, bulk_upload_documents, list_documnets, parse_documnets, stop_parse_documnets
@ -173,6 +174,7 @@ def test_stop_parse_100_files(get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
document_ids = bulk_upload_documents(get_http_api_auth, dataset_id, document_num, tmp_path)
parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
sleep(1)
res = stop_parse_documnets(get_http_api_auth, dataset_id, {"document_ids": document_ids})
assert res["code"] == 0
validate_document_parse_cancel(get_http_api_auth, dataset_id, document_ids)