Test: Refactor test fixtures and test cases (#6709)

### What problem does this PR solve?

 Refactor test fixtures and test cases

### Type of change

- [ ] Refactoring test cases
This commit is contained in:
liu an
2025-04-01 13:39:07 +08:00
committed by GitHub
parent 20b8ccd1e9
commit 58e6e7b668
22 changed files with 881 additions and 837 deletions

View File

@ -16,14 +16,24 @@
import pytest
from common import batch_create_datasets, delete_dataset
from common import batch_create_datasets, delete_datasets
@pytest.fixture(scope="class")
def get_dataset_ids(get_http_api_auth, request):
def add_datasets(get_http_api_auth, request):
def cleanup():
delete_dataset(get_http_api_auth)
delete_datasets(get_http_api_auth)
request.addfinalizer(cleanup)
return batch_create_datasets(get_http_api_auth, 5)
@pytest.fixture(scope="function")
def add_datasets_func(get_http_api_auth, request):
def cleanup():
delete_datasets(get_http_api_auth)
request.addfinalizer(cleanup)
return batch_create_datasets(get_http_api_auth, 3)

View File

@ -75,9 +75,6 @@ class TestDatasetCreation:
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == 0, f"Failed to create dataset {i}"
@pytest.mark.usefixtures("clear_datasets")
class TestAdvancedConfigurations:
def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {

View File

@ -20,13 +20,12 @@ import pytest
from common import (
INVALID_API_TOKEN,
batch_create_datasets,
delete_dataset,
list_dataset,
delete_datasets,
list_datasets,
)
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -39,18 +38,13 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(auth, {"ids": ids})
def test_invalid_auth(self, auth, expected_code, expected_message):
res = delete_datasets(auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
res = list_dataset(get_http_api_auth)
assert len(res["data"]) == 1
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetDeletion:
class TestDatasetsDeletion:
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
[
@ -73,16 +67,16 @@ class TestDatasetDeletion:
(lambda r: {"ids": r}, 0, "", 0),
],
)
def test_basic_scenarios(self, get_http_api_auth, payload, expected_code, expected_message, remaining):
ids = batch_create_datasets(get_http_api_auth, 3)
def test_basic_scenarios(self, get_http_api_auth, add_datasets_func, payload, expected_code, expected_message, remaining):
dataset_ids = add_datasets_func
if callable(payload):
payload = payload(ids)
res = delete_dataset(get_http_api_auth, payload)
payload = payload(dataset_ids)
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == expected_code
if res["code"] != 0:
assert res["message"] == expected_message
res = list_dataset(get_http_api_auth)
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == remaining
@pytest.mark.parametrize(
@ -93,50 +87,50 @@ class TestDatasetDeletion:
lambda r: {"ids": r + ["invalid_id"]},
],
)
def test_delete_partial_invalid_id(self, get_http_api_auth, payload):
ids = batch_create_datasets(get_http_api_auth, 3)
def test_delete_partial_invalid_id(self, get_http_api_auth, add_datasets_func, payload):
dataset_ids = add_datasets_func
if callable(payload):
payload = payload(ids)
res = delete_dataset(get_http_api_auth, payload)
payload = payload(dataset_ids)
res = delete_datasets(get_http_api_auth, payload)
assert res["code"] == 0
assert res["data"]["errors"][0] == "You don't own the dataset invalid_id"
assert res["data"]["success_count"] == 3
res = list_dataset(get_http_api_auth)
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == 0
def test_repeated_deletion(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(get_http_api_auth, {"ids": ids})
def test_repeated_deletion(self, get_http_api_auth, add_datasets_func):
dataset_ids = add_datasets_func
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids})
assert res["code"] == 0
res = delete_dataset(get_http_api_auth, {"ids": ids})
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids})
assert res["code"] == 102
assert res["message"] == f"You don't own the dataset {ids[0]}"
assert "You don't own the dataset" in res["message"]
def test_duplicate_deletion(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(get_http_api_auth, {"ids": ids + ids})
def test_duplicate_deletion(self, get_http_api_auth, add_datasets_func):
dataset_ids = add_datasets_func
res = delete_datasets(get_http_api_auth, {"ids": dataset_ids + dataset_ids})
assert res["code"] == 0
assert res["data"]["errors"][0] == f"Duplicate dataset ids: {ids[0]}"
assert res["data"]["success_count"] == 1
assert "Duplicate dataset ids" in res["data"]["errors"][0]
assert res["data"]["success_count"] == 3
res = list_dataset(get_http_api_auth)
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == 0
def test_concurrent_deletion(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 100)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(delete_dataset, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)]
futures = [executor.submit(delete_datasets, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@pytest.mark.slow
def test_delete_10k(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 10_000)
res = delete_dataset(get_http_api_auth, {"ids": ids})
res = delete_datasets(get_http_api_auth, {"ids": ids})
assert res["code"] == 0
res = list_dataset(get_http_api_auth)
res = list_datasets(get_http_api_auth)
assert len(res["data"]) == 0

View File

@ -16,7 +16,7 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, list_dataset
from common import INVALID_API_TOKEN, list_datasets
from libs.auth import RAGFlowHttpApiAuth
@ -25,7 +25,6 @@ def is_sorted(data, field, descending=True):
return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -39,15 +38,15 @@ class TestAuthorization:
],
)
def test_invalid_auth(self, auth, expected_code, expected_message):
res = list_dataset(auth)
res = list_datasets(auth)
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("get_dataset_ids")
class TestDatasetList:
@pytest.mark.usefixtures("add_datasets")
class TestDatasetsList:
def test_default(self, get_http_api_auth):
res = list_dataset(get_http_api_auth, params={})
res = list_datasets(get_http_api_auth, params={})
assert res["code"] == 0
assert len(res["data"]) == 5
@ -77,7 +76,7 @@ class TestDatasetList:
],
)
def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
@ -116,7 +115,7 @@ class TestDatasetList:
expected_page_size,
expected_message,
):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
assert len(res["data"]) == expected_page_size
@ -168,7 +167,7 @@ class TestDatasetList:
assertions,
expected_message,
):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
@ -244,7 +243,7 @@ class TestDatasetList:
assertions,
expected_message,
):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if callable(assertions):
@ -262,7 +261,7 @@ class TestDatasetList:
],
)
def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message):
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["name"] in [None, ""]:
@ -284,19 +283,19 @@ class TestDatasetList:
def test_id(
self,
get_http_api_auth,
get_dataset_ids,
add_datasets,
dataset_id,
expected_code,
expected_num,
expected_message,
):
dataset_ids = get_dataset_ids
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids)}
else:
params = {"id": dataset_id}
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
if params["id"] in [None, ""]:
@ -318,20 +317,20 @@ class TestDatasetList:
def test_name_and_id(
self,
get_http_api_auth,
get_dataset_ids,
add_datasets,
dataset_id,
name,
expected_code,
expected_num,
expected_message,
):
dataset_ids = get_dataset_ids
dataset_ids = add_datasets
if callable(dataset_id):
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
if expected_code == 0:
assert len(res["data"]) == expected_num
else:
@ -339,12 +338,12 @@ class TestDatasetList:
def test_concurrent_list(self, get_http_api_auth):
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(list_dataset, get_http_api_auth) for i in range(100)]
futures = [executor.submit(list_datasets, get_http_api_auth) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
def test_invalid_params(self, get_http_api_auth):
params = {"a": "b"}
res = list_dataset(get_http_api_auth, params=params)
res = list_datasets(get_http_api_auth, params=params)
assert res["code"] == 0
assert len(res["data"]) == 5

View File

@ -19,8 +19,7 @@ import pytest
from common import (
DATASET_NAME_LIMIT,
INVALID_API_TOKEN,
batch_create_datasets,
list_dataset,
list_datasets,
update_dataset,
)
from libs.auth import RAGFlowHttpApiAuth
@ -30,7 +29,6 @@ from libs.utils.file_utils import create_image_file
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -43,14 +41,12 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(auth, ids[0], {"name": "new_name"})
def test_invalid_auth(self, auth, expected_code, expected_message):
res = update_dataset(auth, "dataset_id")
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetUpdate:
@pytest.mark.parametrize(
"name, expected_code, expected_message",
@ -72,12 +68,12 @@ class TestDatasetUpdate:
("DATASET_1", 102, "Duplicated dataset name in updating dataset."),
],
)
def test_name(self, get_http_api_auth, name, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 2)
res = update_dataset(get_http_api_auth, ids[0], {"name": name})
def test_name(self, get_http_api_auth, add_datasets_func, name, expected_code, expected_message):
dataset_ids = add_datasets_func
res = update_dataset(get_http_api_auth, dataset_ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_ids[0]})
assert res["data"][0]["name"] == name
else:
assert res["message"] == expected_message
@ -95,12 +91,12 @@ class TestDatasetUpdate:
(None, 102, "`embedding_model` can't be empty"),
],
)
def test_embedding_model(self, get_http_api_auth, embedding_model, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"embedding_model": embedding_model})
def test_embedding_model(self, get_http_api_auth, add_dataset_func, embedding_model, expected_code, expected_message):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"embedding_model": embedding_model})
assert res["code"] == expected_code
if expected_code == 0:
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["embedding_model"] == embedding_model
else:
assert res["message"] == expected_message
@ -129,12 +125,12 @@ class TestDatasetUpdate:
),
],
)
def test_chunk_method(self, get_http_api_auth, chunk_method, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method})
def test_chunk_method(self, get_http_api_auth, add_dataset_func, chunk_method, expected_code, expected_message):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"chunk_method": chunk_method})
assert res["code"] == expected_code
if expected_code == 0:
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
if chunk_method != "":
assert res["data"][0]["chunk_method"] == chunk_method
else:
@ -142,38 +138,38 @@ class TestDatasetUpdate:
else:
assert res["message"] == expected_message
def test_avatar(self, get_http_api_auth, tmp_path):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_avatar(self, get_http_api_auth, add_dataset_func, tmp_path):
dataset_id = add_dataset_func
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"avatar": encode_avatar(fn)}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
def test_description(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_description(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"description": "description"}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["description"] == "description"
def test_pagerank(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_pagerank(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["pagerank"] == 1
def test_similarity_threshold(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_similarity_threshold(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"similarity_threshold": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["similarity_threshold"] == 1
@pytest.mark.parametrize(
@ -187,29 +183,28 @@ class TestDatasetUpdate:
("other_permission", 102),
],
)
def test_permission(self, get_http_api_auth, permission, expected_code):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_permission(self, get_http_api_auth, add_dataset_func, permission, expected_code):
dataset_id = add_dataset_func
payload = {"permission": permission}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == expected_code
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
if expected_code == 0 and permission != "":
assert res["data"][0]["permission"] == permission
if permission == "":
assert res["data"][0]["permission"] == "me"
def test_vector_similarity_weight(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_vector_similarity_weight(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"vector_similarity_weight": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 0
res = list_dataset(get_http_api_auth, {"id": ids[0]})
res = list_datasets(get_http_api_auth, {"id": dataset_id})
assert res["data"][0]["vector_similarity_weight"] == 1
def test_invalid_dataset_id(self, get_http_api_auth):
batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"})
assert res["code"] == 102
assert res["message"] == "You don't own the dataset"
@ -230,21 +225,21 @@ class TestDatasetUpdate:
{"update_time": 1741671443339},
],
)
def test_modify_read_only_field(self, get_http_api_auth, payload):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], payload)
def test_modify_read_only_field(self, get_http_api_auth, add_dataset_func, payload):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, payload)
assert res["code"] == 101
assert "is readonly" in res["message"]
def test_modify_unknown_field(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"unknown_field": 0})
def test_modify_unknown_field(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
res = update_dataset(get_http_api_auth, dataset_id, {"unknown_field": 0})
assert res["code"] == 100
def test_concurrent_update(self, get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 1)
def test_concurrent_update(self, get_http_api_auth, add_dataset_func):
dataset_id = add_dataset_func
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}) for i in range(100)]
futures = [executor.submit(update_dataset, get_http_api_auth, dataset_id, {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)