diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index ea6540155..e3675b8cd 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -16,10 +16,12 @@ import logging +import os from flask import request from peewee import OperationalError +from api import settings from api.db import FileSource, StatusEnum from api.db.db_models import File from api.db.services.document_service import DocumentService @@ -48,6 +50,8 @@ from api.utils.validation_utils import ( validate_and_parse_json_request, validate_and_parse_request_args, ) +from rag.nlp import search +from rag.settings import PAGERANK_FLD @manager.route("/datasets", methods=["POST"]) # noqa: F821 @@ -97,9 +101,6 @@ def create(tenant_id): "picture", "presentation", "qa", "table", "tag" ] description: Chunking method. - pagerank: - type: integer - description: Set page rank. parser_config: type: object description: Parser configuration. @@ -352,6 +353,16 @@ def update(tenant_id, dataset_id): if not ok: return err + if "pagerank" in req and req["pagerank"] != kb.pagerank: + if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity": + return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") + + if req["pagerank"] > 0: + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) + else: + # Elasticsearch requires PAGERANK_FLD be non-zero! + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) + if not KnowledgebaseService.update_by_id(kb.id, req): return get_error_data_result(message="Update dataset error.(Database error)") diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 206a91f12..21c731aa0 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -383,7 +383,6 @@ class CreateDatasetReq(Base): embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") - pagerank: int = Field(default=0, ge=0, le=100) parser_config: ParserConfig | None = Field(default=None) @field_validator("avatar") @@ -539,6 +538,7 @@ class CreateDatasetReq(Base): class UpdateDatasetReq(CreateDatasetReq): dataset_id: str = Field(...) name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] + pagerank: int = Field(default=0, ge=0, le=100) @field_validator("dataset_id", mode="before") @classmethod diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index 95f47ade9..a4ef8fa73 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -343,7 +343,6 @@ Creates a dataset. - `"embedding_model"`: `string` - `"permission"`: `string` - `"chunk_method"`: `string` - - `"pagerank"`: `int` - `"parser_config"`: `object` ##### Request example @@ -384,12 +383,6 @@ curl --request POST \ - `"me"`: (Default) Only you can manage the dataset. - `"team"`: All team members can manage the dataset. -- `"pagerank"`: (*Body parameter*), `int` - refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank) - - Default: `0` - - Minimum: `0` - - Maximum: `100` - - `"chunk_method"`: (*Body parameter*), `enum` The chunking method of the dataset to create. Available options: - `"naive"`: General (default) diff --git a/docs/references/python_api_reference.md b/docs/references/python_api_reference.md index ec7c965d4..af66c0e49 100644 --- a/docs/references/python_api_reference.md +++ b/docs/references/python_api_reference.md @@ -100,7 +100,6 @@ RAGFlow.create_dataset( embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", permission: str = "me", chunk_method: str = "naive", - pagerank: int = 0, parser_config: DataSet.ParserConfig = None ) -> DataSet ``` @@ -148,10 +147,6 @@ The chunking method of the dataset to create. Available options: - `"one"`: One - `"email"`: Email -##### pagerank, `int` - -The pagerank of the dataset to create. Defaults to `0`. - ##### parser_config The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`: diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index c3f52e480..5b65d6201 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -56,7 +56,6 @@ class RAGFlow: embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", permission: str = "me", chunk_method: str = "naive", - pagerank: int = 0, parser_config: Optional[DataSet.ParserConfig] = None, ) -> DataSet: payload = { @@ -66,7 +65,6 @@ class RAGFlow: "embedding_model": embedding_model, "permission": permission, "chunk_method": chunk_method, - "pagerank": pagerank, } if parser_config is not None: payload["parser_config"] = parser_config.to_json() diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py index d91b4450d..b3b3f9b8a 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_create_dataset.py @@ -394,51 +394,6 @@ class TestDatasetCreate: assert res["code"] == 101, res assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res - @pytest.mark.p2 - @pytest.mark.parametrize( - "name, pagerank", - [ - ("pagerank_min", 0), - ("pagerank_mid", 50), - ("pagerank_max", 100), - ], - ids=["min", "mid", "max"], - ) - def test_pagerank(self, HttpApiAuth, name, pagerank): - payload = {"name": name, "pagerank": pagerank} - res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 0, res - assert res["data"]["pagerank"] == pagerank, res - - @pytest.mark.p3 - @pytest.mark.parametrize( - "name, pagerank, expected_message", - [ - ("pagerank_min_limit", -1, "Input should be greater than or equal to 0"), - ("pagerank_max_limit", 101, "Input should be less than or equal to 100"), - ], - ids=["min_limit", "max_limit"], - ) - def test_pagerank_invalid(self, HttpApiAuth, name, pagerank, expected_message): - payload = {"name": name, "pagerank": pagerank} - res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 101, res - assert expected_message in res["message"], res - - @pytest.mark.p3 - def test_pagerank_unset(self, HttpApiAuth): - payload = {"name": "pagerank_unset"} - res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 0, res - assert res["data"]["pagerank"] == 0, res - - @pytest.mark.p3 - def test_pagerank_none(self, HttpApiAuth): - payload = {"name": "pagerank_unset", "pagerank": None} - res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 101, res - assert "Input should be a valid integer" in res["message"], res - @pytest.mark.p1 @pytest.mark.parametrize( "name, parser_config", @@ -730,6 +685,7 @@ class TestDatasetCreate: {"name": "chunk_count", "chunk_count": 1}, {"name": "token_num", "token_num": 1}, {"name": "status", "status": "1"}, + {"name": "pagerank", "pagerank": 50}, {"name": "unknown_field", "unknown_field": "unknown_field"}, ], ) diff --git a/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py index 7d186c365..152788000 100644 --- a/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_http_api/test_dataset_mangement/test_update_dataset.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import uuid from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset +from common import list_datasets, update_dataset +from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN from hypothesis import HealthCheck, example, given, settings from libs.auth import RAGFlowHttpApiAuth from utils import encode_avatar @@ -155,10 +157,10 @@ class TestDatasetUpdate: @pytest.mark.p3 def test_name_duplicated(self, HttpApiAuth, add_datasets_func): - dataset_ids = add_datasets_func[0] + dataset_id = add_datasets_func[0] name = "dataset_1" payload = {"name": name} - res = update_dataset(HttpApiAuth, dataset_ids, payload) + res = update_dataset(HttpApiAuth, dataset_id, payload) assert res["code"] == 102, res assert res["message"] == f"Dataset name '{name}' already exists", res @@ -425,6 +427,7 @@ class TestDatasetUpdate: assert res["code"] == 101, res assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") @pytest.mark.p2 @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank): @@ -437,6 +440,35 @@ class TestDatasetUpdate: assert res["code"] == 0, res assert res["data"][0]["pagerank"] == pagerank + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"pagerank": 50} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0, res + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["pagerank"] == 50, res + + payload = {"pagerank": 0} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 0 + + res = list_datasets(HttpApiAuth, {"id": dataset_id}) + assert res["code"] == 0, res + assert res["data"][0]["pagerank"] == 0, res + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func): + dataset_id = add_dataset_func + payload = {"pagerank": 50} + res = update_dataset(HttpApiAuth, dataset_id, payload) + assert res["code"] == 101, res + assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res + @pytest.mark.p2 @pytest.mark.parametrize( "pagerank, expected_message", diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py index 342a3eb84..4ba269648 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -344,49 +344,6 @@ class TestDatasetCreate: client.create_dataset(**payload) assert "not instance of" in str(excinfo.value), str(excinfo.value) - @pytest.mark.p2 - @pytest.mark.parametrize( - "name, pagerank", - [ - ("pagerank_min", 0), - ("pagerank_mid", 50), - ("pagerank_max", 100), - ], - ids=["min", "mid", "max"], - ) - def test_pagerank(self, client, name, pagerank): - payload = {"name": name, "pagerank": pagerank} - dataset = client.create_dataset(**payload) - assert dataset.pagerank == pagerank, str(dataset) - - @pytest.mark.p3 - @pytest.mark.parametrize( - "name, pagerank, expected_message", - [ - ("pagerank_min_limit", -1, "Input should be greater than or equal to 0"), - ("pagerank_max_limit", 101, "Input should be less than or equal to 100"), - ], - ids=["min_limit", "max_limit"], - ) - def test_pagerank_invalid(self, client, name, pagerank, expected_message): - payload = {"name": name, "pagerank": pagerank} - with pytest.raises(Exception) as excinfo: - client.create_dataset(**payload) - assert expected_message in str(excinfo.value), str(excinfo.value) - - @pytest.mark.p3 - def test_pagerank_unset(self, client): - payload = {"name": "pagerank_unset"} - dataset = client.create_dataset(**payload) - assert dataset.pagerank == 0, str(dataset) - - @pytest.mark.p3 - def test_pagerank_none(self, client): - payload = {"name": "pagerank_unset", "pagerank": None} - with pytest.raises(Exception) as excinfo: - client.create_dataset(**payload) - assert "not instance of" in str(excinfo.value), str(excinfo.value) - @pytest.mark.p1 @pytest.mark.parametrize( "name, parser_config", @@ -689,6 +646,7 @@ class TestDatasetCreate: {"name": "chunk_count", "chunk_count": 1}, {"name": "token_num", "token_num": 1}, {"name": "status", "status": "1"}, + {"name": "pagerank", "pagerank": 50}, {"name": "unknown_field", "unknown_field": "unknown_field"}, ], ) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py index 8632a7088..f4a0a9163 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os from concurrent.futures import ThreadPoolExecutor, as_completed from operator import attrgetter @@ -324,6 +325,7 @@ class TestDatasetUpdate: dataset.update({"chunk_method": None}) assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") @pytest.mark.p2 @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) def test_pagerank(self, client, add_dataset_func, pagerank): @@ -334,6 +336,30 @@ class TestDatasetUpdate: retrieved_dataset = client.get_dataset(name=dataset.name) assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset) + @pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_set_to_0(self, client, add_dataset_func): + dataset = add_dataset_func + dataset.update({"pagerank": 50}) + assert dataset.pagerank == 50, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.pagerank == 50, str(retrieved_dataset) + + dataset.update({"pagerank": 0}) + assert dataset.pagerank == 0, str(dataset) + + retrieved_dataset = client.get_dataset(name=dataset.name) + assert retrieved_dataset.pagerank == 0, str(retrieved_dataset) + + @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208") + @pytest.mark.p2 + def test_pagerank_infinity(self, client, add_dataset_func): + dataset = add_dataset_func + with pytest.raises(Exception) as excinfo: + dataset.update({"pagerank": 50}) + assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value) + @pytest.mark.p2 @pytest.mark.parametrize( "pagerank, expected_message",