refa: Optimize create dataset validation (#7451)

### What problem does this PR solve?

Optimize dataset validation and add function docs

### Type of change

- [x] Refactoring
This commit is contained in:
liu an
2025-05-06 17:38:06 +08:00
committed by GitHub
parent 2f768b96e8
commit c98933499a
6 changed files with 333 additions and 93 deletions

View File

@ -39,23 +39,23 @@ SESSION_WITH_CHAT_NAME_LIMIT = 255
# DATASET MANAGEMENT
def create_dataset(auth, payload=None):
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload)
def create_dataset(auth, payload=None, headers=HEADERS, data=None):
res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload, data=data)
return res.json()
def list_datasets(auth, params=None):
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, params=params)
def list_datasets(auth, params=None, headers=HEADERS):
res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, params=params)
return res.json()
def update_dataset(auth, dataset_id, payload=None):
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=HEADERS, auth=auth, json=payload)
def update_dataset(auth, dataset_id, payload=None, headers=HEADERS):
res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_API_URL}/{dataset_id}", headers=headers, auth=auth, json=payload)
return res.json()
def delete_datasets(auth, payload=None):
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=HEADERS, auth=auth, json=payload)
def delete_datasets(auth, payload=None, headers=HEADERS):
res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_API_URL}", headers=headers, auth=auth, json=payload)
return res.json()

View File

@ -98,6 +98,25 @@ class TestDatasetCreation:
assert res["code"] == 101, res
assert res["message"] == f"Dataset name '{name.lower()}' already exists", res
def test_bad_content_type(self, get_http_api_auth):
BAD_CONTENT_TYPE = "text/xml"
res = create_dataset(get_http_api_auth, {"name": "name"}, {"Content-Type": BAD_CONTENT_TYPE})
assert res["code"] == 101, res
assert res["message"] == f"Unsupported content type: Expected application/json, got {BAD_CONTENT_TYPE}", res
@pytest.mark.parametrize(
"payload, expected_message",
[
("a", "Malformed JSON syntax: Missing commas/brackets or invalid encoding"),
('"a"', "Invalid request payload: expected objec"),
],
ids=["malformed_json_syntax", "invalid_request_payload_type"],
)
def test_bad_payload(self, get_http_api_auth, payload, expected_message):
res = create_dataset(get_http_api_auth, data=payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {
@ -158,7 +177,7 @@ class TestDatasetCreation:
("embedding-3@ZHIPU-AI", "embedding-3@ZHIPU-AI"),
("embedding_model_default", None),
],
ids=["builtin_baai", "builtin_youdao", "tenant__zhipu", "default"],
ids=["builtin_baai", "builtin_youdao", "tenant_zhipu", "default"],
)
def test_valid_embedding_model(self, get_http_api_auth, name, embedding_model):
if embedding_model is None:
@ -178,29 +197,39 @@ class TestDatasetCreation:
[
("unknown_llm_name", "unknown@ZHIPU-AI"),
("unknown_llm_factory", "embedding-3@unknown"),
("tenant_no_auth", "deepseek-chat@DeepSeek"),
("tenant_no_auth_default_tenant_llm", "text-embedding-v3@Tongyi-Qianwen"),
("tenant_no_auth", "text-embedding-3-small@OpenAI"),
],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth"],
ids=["unknown_llm_name", "unknown_llm_factory", "tenant_no_auth_default_tenant_llm", "tenant_no_auth"],
)
def test_invalid_embedding_model(self, get_http_api_auth, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert res["message"] == f"The embedding_model '{embedding_model}' is not supported", res
if "tenant_no_auth" in name:
assert res["message"] == f"Unauthorized model: <{embedding_model}>", res
else:
assert res["message"] == f"Unsupported model: <{embedding_model}>", res
@pytest.mark.parametrize(
"name, embedding_model",
[
("builtin_missing_at", "BAAI/bge-large-zh-v1.5"),
("tenant_missing_at", "embedding-3ZHIPU-AI"),
("missing_at", "BAAI/bge-large-zh-v1.5BAAI"),
("missing_model_name", "@BAAI"),
("missing_provider", "BAAI/bge-large-zh-v1.5@"),
("whitespace_only_model_name", " @BAAI"),
("whitespace_only_provider", "BAAI/bge-large-zh-v1.5@ "),
],
ids=["builtin_missing_at", "tenant_missing_at"],
ids=["missing_at", "empty_model_name", "empty_provider", "whitespace_only_model_name", "whitespace_only_provider"],
)
def test_embedding_model_missing_at(self, get_http_api_auth, name, embedding_model):
def test_embedding_model_format(self, get_http_api_auth, name, embedding_model):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 101, res
assert "Embedding model must be xxx@yyy" in res["message"], res
if name == "missing_at":
assert "Embedding model identifier must follow <model_name>@<provider> format" in res["message"], res
else:
assert "Both model_name and provider must be non-empty strings" in res["message"], res
@pytest.mark.parametrize(
"name, permission",
@ -485,7 +514,7 @@ class TestDatasetCreation:
("raptor_random_seed_min_limit", {"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
("raptor_random_seed_float_not_allowed", {"raptor": {"random_seed": 3.14}}, "Input should be a valid integer, got a number with a fractional part"),
("raptor_random_seed_type_invalid", {"raptor": {"random_seed": "string"}}, "Input should be a valid integer, unable to parse string as an integer"),
("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config have at most 65535 characters"),
("parser_config_type_invalid", {"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
],
ids=[
"auto_keywords_min_limit",