Test: Update test cases to reduce execution time (#6470)

### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] update test cases
This commit is contained in:
liu an
2025-03-25 09:17:05 +08:00
committed by GitHub
parent 390086c6ab
commit b6f3242c6c
17 changed files with 704 additions and 695 deletions

View File

@ -0,0 +1,26 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from common import batch_create_datasets, delete_dataset
@pytest.fixture(scope="class")
def get_dataset_ids(get_http_api_auth):
ids = batch_create_datasets(get_http_api_auth, 5)
yield ids
delete_dataset(get_http_api_auth)

View File

@ -21,6 +21,7 @@ from libs.utils import encode_avatar
from libs.utils.file_utils import create_image_file
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -39,6 +40,7 @@ class TestAuthorization:
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetCreation:
@pytest.mark.parametrize(
"payload, expected_code",
@ -74,6 +76,7 @@ class TestDatasetCreation:
assert res["code"] == 0, f"Failed to create dataset {i}"
@pytest.mark.usefixtures("clear_datasets")
class TestAdvancedConfigurations:
def test_avatar(self, get_http_api_auth, tmp_path):
fn = create_image_file(tmp_path / "ragflow_test.png")
@ -172,9 +175,7 @@ class TestAdvancedConfigurations:
("other_embedding_model", "other_embedding_model", 102),
],
)
def test_embedding_model(
self, get_http_api_auth, name, embedding_model, expected_code
):
def test_embedding_model(self, get_http_api_auth, name, embedding_model, expected_code):
payload = {"name": name, "embedding_model": embedding_model}
res = create_dataset(get_http_api_auth, payload)
assert res["code"] == expected_code

View File

@ -19,13 +19,14 @@ from concurrent.futures import ThreadPoolExecutor
import pytest
from common import (
INVALID_API_TOKEN,
create_datasets,
batch_create_datasets,
delete_dataset,
list_dataset,
)
from libs.auth import RAGFlowHttpApiAuth
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -39,7 +40,7 @@ class TestAuthorization:
],
)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(auth, {"ids": ids})
assert res["code"] == expected_code
assert res["message"] == expected_message
@ -48,6 +49,7 @@ class TestAuthorization:
assert len(res["data"]) == 1
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetDeletion:
@pytest.mark.parametrize(
"payload, expected_code, expected_message, remaining",
@ -72,7 +74,7 @@ class TestDatasetDeletion:
],
)
def test_basic_scenarios(self, get_http_api_auth, payload, expected_code, expected_message, remaining):
ids = create_datasets(get_http_api_auth, 3)
ids = batch_create_datasets(get_http_api_auth, 3)
if callable(payload):
payload = payload(ids)
res = delete_dataset(get_http_api_auth, payload)
@ -92,7 +94,7 @@ class TestDatasetDeletion:
],
)
def test_delete_partial_invalid_id(self, get_http_api_auth, payload):
ids = create_datasets(get_http_api_auth, 3)
ids = batch_create_datasets(get_http_api_auth, 3)
if callable(payload):
payload = payload(ids)
res = delete_dataset(get_http_api_auth, payload)
@ -104,7 +106,7 @@ class TestDatasetDeletion:
assert len(res["data"]) == 0
def test_repeated_deletion(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(get_http_api_auth, {"ids": ids})
assert res["code"] == 0
@ -113,7 +115,7 @@ class TestDatasetDeletion:
assert res["message"] == f"You don't own the dataset {ids[0]}"
def test_duplicate_deletion(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = delete_dataset(get_http_api_auth, {"ids": ids + ids})
assert res["code"] == 0
assert res["data"]["errors"][0] == f"Duplicate dataset ids: {ids[0]}"
@ -123,7 +125,7 @@ class TestDatasetDeletion:
assert len(res["data"]) == 0
def test_concurrent_deletion(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 100)
ids = batch_create_datasets(get_http_api_auth, 100)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(delete_dataset, get_http_api_auth, {"ids": ids[i : i + 1]}) for i in range(100)]
@ -132,7 +134,7 @@ class TestDatasetDeletion:
@pytest.mark.slow
def test_delete_10k(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 10_000)
ids = batch_create_datasets(get_http_api_auth, 10_000)
res = delete_dataset(get_http_api_auth, {"ids": ids})
assert res["code"] == 0

View File

@ -16,19 +16,16 @@
from concurrent.futures import ThreadPoolExecutor
import pytest
from common import INVALID_API_TOKEN, create_datasets, list_dataset
from common import INVALID_API_TOKEN, list_dataset
from libs.auth import RAGFlowHttpApiAuth
def is_sorted(data, field, descending=True):
timestamps = [ds[field] for ds in data]
return (
all(a >= b for a, b in zip(timestamps, timestamps[1:]))
if descending
else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
)
return all(a >= b for a, b in zip(timestamps, timestamps[1:])) if descending else all(a <= b for a, b in zip(timestamps, timestamps[1:]))
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -47,13 +44,13 @@ class TestAuthorization:
assert res["message"] == expected_message
@pytest.mark.usefixtures("get_dataset_ids")
class TestDatasetList:
def test_default(self, get_http_api_auth):
create_datasets(get_http_api_auth, 31)
res = list_dataset(get_http_api_auth, params={})
assert res["code"] == 0
assert len(res["data"]) == 30
assert len(res["data"]) == 5
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
@ -79,15 +76,7 @@ class TestDatasetList:
),
],
)
def test_page(
self,
get_http_api_auth,
params,
expected_code,
expected_page_size,
expected_message,
):
create_datasets(get_http_api_auth, 5)
def test_page(self, get_http_api_auth, params, expected_code, expected_page_size, expected_message):
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -98,10 +87,10 @@ class TestDatasetList:
@pytest.mark.parametrize(
"params, expected_code, expected_page_size, expected_message",
[
({"page_size": None}, 0, 30, ""),
({"page_size": None}, 0, 5, ""),
({"page_size": 0}, 0, 0, ""),
({"page_size": 1}, 0, 1, ""),
({"page_size": 32}, 0, 31, ""),
({"page_size": 6}, 0, 5, ""),
({"page_size": "1"}, 0, 1, ""),
pytest.param(
{"page_size": -1},
@ -127,7 +116,6 @@ class TestDatasetList:
expected_page_size,
expected_message,
):
create_datasets(get_http_api_auth, 31)
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -180,7 +168,6 @@ class TestDatasetList:
assertions,
expected_message,
):
create_datasets(get_http_api_auth, 3)
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -257,7 +244,6 @@ class TestDatasetList:
assertions,
expected_message,
):
create_datasets(get_http_api_auth, 3)
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -269,16 +255,13 @@ class TestDatasetList:
@pytest.mark.parametrize(
"params, expected_code, expected_num, expected_message",
[
({"name": None}, 0, 3, ""),
({"name": ""}, 0, 3, ""),
({"name": None}, 0, 5, ""),
({"name": ""}, 0, 5, ""),
({"name": "dataset_1"}, 0, 1, ""),
({"name": "unknown"}, 102, 0, "You don't own the dataset unknown"),
],
)
def test_name(
self, get_http_api_auth, params, expected_code, expected_num, expected_message
):
create_datasets(get_http_api_auth, 3)
def test_name(self, get_http_api_auth, params, expected_code, expected_num, expected_message):
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == expected_code
if expected_code == 0:
@ -292,8 +275,8 @@ class TestDatasetList:
@pytest.mark.parametrize(
"dataset_id, expected_code, expected_num, expected_message",
[
(None, 0, 3, ""),
("", 0, 3, ""),
(None, 0, 5, ""),
("", 0, 5, ""),
(lambda r: r[0], 0, 1, ""),
("unknown", 102, 0, "You don't own the dataset unknown"),
],
@ -301,14 +284,15 @@ class TestDatasetList:
def test_id(
self,
get_http_api_auth,
get_dataset_ids,
dataset_id,
expected_code,
expected_num,
expected_message,
):
ids = create_datasets(get_http_api_auth, 3)
dataset_ids = get_dataset_ids
if callable(dataset_id):
params = {"id": dataset_id(ids)}
params = {"id": dataset_id(dataset_ids)}
else:
params = {"id": dataset_id}
@ -334,15 +318,16 @@ class TestDatasetList:
def test_name_and_id(
self,
get_http_api_auth,
get_dataset_ids,
dataset_id,
name,
expected_code,
expected_num,
expected_message,
):
ids = create_datasets(get_http_api_auth, 3)
dataset_ids = get_dataset_ids
if callable(dataset_id):
params = {"id": dataset_id(ids), "name": name}
params = {"id": dataset_id(dataset_ids), "name": name}
else:
params = {"id": dataset_id, "name": name}
@ -353,12 +338,8 @@ class TestDatasetList:
assert res["message"] == expected_message
def test_concurrent_list(self, get_http_api_auth):
create_datasets(get_http_api_auth, 3)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(list_dataset, get_http_api_auth) for i in range(100)
]
futures = [executor.submit(list_dataset, get_http_api_auth) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)
@ -366,4 +347,4 @@ class TestDatasetList:
params = {"a": "b"}
res = list_dataset(get_http_api_auth, params=params)
assert res["code"] == 0
assert len(res["data"]) == 0
assert len(res["data"]) == 5

View File

@ -19,7 +19,7 @@ import pytest
from common import (
DATASET_NAME_LIMIT,
INVALID_API_TOKEN,
create_datasets,
batch_create_datasets,
list_dataset,
update_dataset,
)
@ -30,6 +30,7 @@ from libs.utils.file_utils import create_image_file
# TODO: Missing scenario for updating embedding_model with chunk_count != 0
@pytest.mark.usefixtures("clear_datasets")
class TestAuthorization:
@pytest.mark.parametrize(
"auth, expected_code, expected_message",
@ -42,15 +43,14 @@ class TestAuthorization:
),
],
)
def test_invalid_auth(
self, get_http_api_auth, auth, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
def test_invalid_auth(self, get_http_api_auth, auth, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(auth, ids[0], {"name": "new_name"})
assert res["code"] == expected_code
assert res["message"] == expected_message
@pytest.mark.usefixtures("clear_datasets")
class TestDatasetUpdate:
@pytest.mark.parametrize(
"name, expected_code, expected_message",
@ -73,7 +73,7 @@ class TestDatasetUpdate:
],
)
def test_name(self, get_http_api_auth, name, expected_code, expected_message):
ids = create_datasets(get_http_api_auth, 2)
ids = batch_create_datasets(get_http_api_auth, 2)
res = update_dataset(get_http_api_auth, ids[0], {"name": name})
assert res["code"] == expected_code
if expected_code == 0:
@ -105,13 +105,9 @@ class TestDatasetUpdate:
(None, 102, "`embedding_model` can't be empty"),
],
)
def test_embedding_model(
self, get_http_api_auth, embedding_model, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
res = update_dataset(
get_http_api_auth, ids[0], {"embedding_model": embedding_model}
)
def test_embedding_model(self, get_http_api_auth, embedding_model, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"embedding_model": embedding_model})
assert res["code"] == expected_code
if expected_code == 0:
res = list_dataset(get_http_api_auth, {"id": ids[0]})
@ -139,16 +135,12 @@ class TestDatasetUpdate:
(
"other_chunk_method",
102,
"'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table',"
" 'paper', 'book', 'laws', 'presentation', 'picture', 'one', "
"'knowledge_graph', 'email', 'tag']",
"'other_chunk_method' is not in ['naive', 'manual', 'qa', 'table', 'paper', 'book', 'laws', 'presentation', 'picture', 'one', 'knowledge_graph', 'email', 'tag']",
),
],
)
def test_chunk_method(
self, get_http_api_auth, chunk_method, expected_code, expected_message
):
ids = create_datasets(get_http_api_auth, 1)
def test_chunk_method(self, get_http_api_auth, chunk_method, expected_code, expected_message):
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"chunk_method": chunk_method})
assert res["code"] == expected_code
if expected_code == 0:
@ -161,14 +153,14 @@ class TestDatasetUpdate:
assert res["message"] == expected_message
def test_avatar(self, get_http_api_auth, tmp_path):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
fn = create_image_file(tmp_path / "ragflow_test.png")
payload = {"avatar": encode_avatar(fn)}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
def test_description(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"description": "description"}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
@ -177,7 +169,7 @@ class TestDatasetUpdate:
assert res["data"][0]["description"] == "description"
def test_pagerank(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"pagerank": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
@ -186,7 +178,7 @@ class TestDatasetUpdate:
assert res["data"][0]["pagerank"] == 1
def test_similarity_threshold(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"similarity_threshold": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
@ -206,7 +198,7 @@ class TestDatasetUpdate:
],
)
def test_permission(self, get_http_api_auth, permission, expected_code):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"permission": permission}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == expected_code
@ -218,7 +210,7 @@ class TestDatasetUpdate:
assert res["data"][0]["permission"] == "me"
def test_vector_similarity_weight(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
payload = {"vector_similarity_weight": 1}
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 0
@ -227,10 +219,8 @@ class TestDatasetUpdate:
assert res["data"][0]["vector_similarity_weight"] == 1
def test_invalid_dataset_id(self, get_http_api_auth):
create_datasets(get_http_api_auth, 1)
res = update_dataset(
get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"}
)
batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, "invalid_dataset_id", {"name": "invalid_dataset_id"})
assert res["code"] == 102
assert res["message"] == "You don't own the dataset"
@ -251,25 +241,20 @@ class TestDatasetUpdate:
],
)
def test_modify_read_only_field(self, get_http_api_auth, payload):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], payload)
assert res["code"] == 101
assert "is readonly" in res["message"]
def test_modify_unknown_field(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
res = update_dataset(get_http_api_auth, ids[0], {"unknown_field": 0})
assert res["code"] == 100
def test_concurrent_update(self, get_http_api_auth):
ids = create_datasets(get_http_api_auth, 1)
ids = batch_create_datasets(get_http_api_auth, 1)
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(
update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}
)
for i in range(100)
]
futures = [executor.submit(update_dataset, get_http_api_auth, ids[0], {"name": f"dataset_{i}"}) for i in range(100)]
responses = [f.result() for f in futures]
assert all(r["code"] == 0 for r in responses)