Fix: Move pagerank field from create to update dataset API (#8217)

### What problem does this PR solve?

- Remove pagerank from CreateDatasetReq and add to UpdateDatasetReq
- Add pagerank update logic in dataset update endpoint
- Update API documentation to reflect changes
- Modify related test cases and SDK references

#8208

This change makes pagerank a mutable property that can only be set after
dataset creation, and only when using elasticsearch as the doc engine.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Liu An
2025-06-12 15:47:49 +08:00
committed by GitHub
parent d0c5ff04a6
commit 7fbbc9650d
9 changed files with 78 additions and 109 deletions

View File

@ -16,10 +16,12 @@
import logging import logging
import os
from flask import request from flask import request
from peewee import OperationalError from peewee import OperationalError
from api import settings
from api.db import FileSource, StatusEnum from api.db import FileSource, StatusEnum
from api.db.db_models import File from api.db.db_models import File
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -48,6 +50,8 @@ from api.utils.validation_utils import (
validate_and_parse_json_request, validate_and_parse_json_request,
validate_and_parse_request_args, validate_and_parse_request_args,
) )
from rag.nlp import search
from rag.settings import PAGERANK_FLD
@manager.route("/datasets", methods=["POST"]) # noqa: F821 @manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -97,9 +101,6 @@ def create(tenant_id):
"picture", "presentation", "qa", "table", "tag" "picture", "presentation", "qa", "table", "tag"
] ]
description: Chunking method. description: Chunking method.
pagerank:
type: integer
description: Set page rank.
parser_config: parser_config:
type: object type: object
description: Parser configuration. description: Parser configuration.
@ -352,6 +353,16 @@ def update(tenant_id, dataset_id):
if not ok: if not ok:
return err return err
if "pagerank" in req and req["pagerank"] != kb.pagerank:
if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity":
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
if not KnowledgebaseService.update_by_id(kb.id, req): if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)") return get_error_data_result(message="Update dataset error.(Database error)")

View File

@ -383,7 +383,6 @@ class CreateDatasetReq(Base):
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")] embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16) permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id") chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
pagerank: int = Field(default=0, ge=0, le=100)
parser_config: ParserConfig | None = Field(default=None) parser_config: ParserConfig | None = Field(default=None)
@field_validator("avatar") @field_validator("avatar")
@ -539,6 +538,7 @@ class CreateDatasetReq(Base):
class UpdateDatasetReq(CreateDatasetReq): class UpdateDatasetReq(CreateDatasetReq):
dataset_id: str = Field(...) dataset_id: str = Field(...)
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
pagerank: int = Field(default=0, ge=0, le=100)
@field_validator("dataset_id", mode="before") @field_validator("dataset_id", mode="before")
@classmethod @classmethod

View File

@ -343,7 +343,6 @@ Creates a dataset.
- `"embedding_model"`: `string` - `"embedding_model"`: `string`
- `"permission"`: `string` - `"permission"`: `string`
- `"chunk_method"`: `string` - `"chunk_method"`: `string`
- `"pagerank"`: `int`
- `"parser_config"`: `object` - `"parser_config"`: `object`
##### Request example ##### Request example
@ -384,12 +383,6 @@ curl --request POST \
- `"me"`: (Default) Only you can manage the dataset. - `"me"`: (Default) Only you can manage the dataset.
- `"team"`: All team members can manage the dataset. - `"team"`: All team members can manage the dataset.
- `"pagerank"`: (*Body parameter*), `int`
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
- Default: `0`
- Minimum: `0`
- Maximum: `100`
- `"chunk_method"`: (*Body parameter*), `enum<string>` - `"chunk_method"`: (*Body parameter*), `enum<string>`
The chunking method of the dataset to create. Available options: The chunking method of the dataset to create. Available options:
- `"naive"`: General (default) - `"naive"`: General (default)

View File

@ -100,7 +100,6 @@ RAGFlow.create_dataset(
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
permission: str = "me", permission: str = "me",
chunk_method: str = "naive", chunk_method: str = "naive",
pagerank: int = 0,
parser_config: DataSet.ParserConfig = None parser_config: DataSet.ParserConfig = None
) -> DataSet ) -> DataSet
``` ```
@ -148,10 +147,6 @@ The chunking method of the dataset to create. Available options:
- `"one"`: One - `"one"`: One
- `"email"`: Email - `"email"`: Email
##### pagerank, `int`
The pagerank of the dataset to create. Defaults to `0`.
##### parser_config ##### parser_config
The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`: The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`:

View File

@ -56,7 +56,6 @@ class RAGFlow:
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI", embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
permission: str = "me", permission: str = "me",
chunk_method: str = "naive", chunk_method: str = "naive",
pagerank: int = 0,
parser_config: Optional[DataSet.ParserConfig] = None, parser_config: Optional[DataSet.ParserConfig] = None,
) -> DataSet: ) -> DataSet:
payload = { payload = {
@ -66,7 +65,6 @@ class RAGFlow:
"embedding_model": embedding_model, "embedding_model": embedding_model,
"permission": permission, "permission": permission,
"chunk_method": chunk_method, "chunk_method": chunk_method,
"pagerank": pagerank,
} }
if parser_config is not None: if parser_config is not None:
payload["parser_config"] = parser_config.to_json() payload["parser_config"] = parser_config.to_json()

View File

@ -394,51 +394,6 @@ class TestDatasetCreate:
assert res["code"] == 101, res assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, HttpApiAuth, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == pagerank, res
@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, HttpApiAuth, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 101, res
assert expected_message in res["message"], res
@pytest.mark.p3
def test_pagerank_unset(self, HttpApiAuth):
payload = {"name": "pagerank_unset"}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 0, res
assert res["data"]["pagerank"] == 0, res
@pytest.mark.p3
def test_pagerank_none(self, HttpApiAuth):
payload = {"name": "pagerank_unset", "pagerank": None}
res = create_dataset(HttpApiAuth, payload)
assert res["code"] == 101, res
assert "Input should be a valid integer" in res["message"], res
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, parser_config", "name, parser_config",
@ -730,6 +685,7 @@ class TestDatasetCreate:
{"name": "chunk_count", "chunk_count": 1}, {"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1}, {"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"}, {"name": "status", "status": "1"},
{"name": "pagerank", "pagerank": 50},
{"name": "unknown_field", "unknown_field": "unknown_field"}, {"name": "unknown_field", "unknown_field": "unknown_field"},
], ],
) )

View File

@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest import pytest
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset from common import list_datasets, update_dataset
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
from hypothesis import HealthCheck, example, given, settings from hypothesis import HealthCheck, example, given, settings
from libs.auth import RAGFlowHttpApiAuth from libs.auth import RAGFlowHttpApiAuth
from utils import encode_avatar from utils import encode_avatar
@ -155,10 +157,10 @@ class TestDatasetUpdate:
@pytest.mark.p3 @pytest.mark.p3
def test_name_duplicated(self, HttpApiAuth, add_datasets_func): def test_name_duplicated(self, HttpApiAuth, add_datasets_func):
dataset_ids = add_datasets_func[0] dataset_id = add_datasets_func[0]
name = "dataset_1" name = "dataset_1"
payload = {"name": name} payload = {"name": name}
res = update_dataset(HttpApiAuth, dataset_ids, payload) res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 102, res assert res["code"] == 102, res
assert res["message"] == f"Dataset name '{name}' already exists", res assert res["message"] == f"Dataset name '{name}' already exists", res
@ -425,6 +427,7 @@ class TestDatasetUpdate:
assert res["code"] == 101, res assert res["code"] == 101, res
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank): def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank):
@ -437,6 +440,35 @@ class TestDatasetUpdate:
assert res["code"] == 0, res assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == pagerank assert res["data"][0]["pagerank"] == pagerank
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 50}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 0, res
res = list_datasets(HttpApiAuth, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == 50, res
payload = {"pagerank": 0}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 0
res = list_datasets(HttpApiAuth, {"id": dataset_id})
assert res["code"] == 0, res
assert res["data"][0]["pagerank"] == 0, res
@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func):
dataset_id = add_dataset_func
payload = {"pagerank": 50}
res = update_dataset(HttpApiAuth, dataset_id, payload)
assert res["code"] == 101, res
assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pagerank, expected_message", "pagerank, expected_message",

View File

@ -344,49 +344,6 @@ class TestDatasetCreate:
client.create_dataset(**payload) client.create_dataset(**payload)
assert "not instance of" in str(excinfo.value), str(excinfo.value) assert "not instance of" in str(excinfo.value), str(excinfo.value)
@pytest.mark.p2
@pytest.mark.parametrize(
"name, pagerank",
[
("pagerank_min", 0),
("pagerank_mid", 50),
("pagerank_max", 100),
],
ids=["min", "mid", "max"],
)
def test_pagerank(self, client, name, pagerank):
payload = {"name": name, "pagerank": pagerank}
dataset = client.create_dataset(**payload)
assert dataset.pagerank == pagerank, str(dataset)
@pytest.mark.p3
@pytest.mark.parametrize(
"name, pagerank, expected_message",
[
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
],
ids=["min_limit", "max_limit"],
)
def test_pagerank_invalid(self, client, name, pagerank, expected_message):
payload = {"name": name, "pagerank": pagerank}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert expected_message in str(excinfo.value), str(excinfo.value)
@pytest.mark.p3
def test_pagerank_unset(self, client):
payload = {"name": "pagerank_unset"}
dataset = client.create_dataset(**payload)
assert dataset.pagerank == 0, str(dataset)
@pytest.mark.p3
def test_pagerank_none(self, client):
payload = {"name": "pagerank_unset", "pagerank": None}
with pytest.raises(Exception) as excinfo:
client.create_dataset(**payload)
assert "not instance of" in str(excinfo.value), str(excinfo.value)
@pytest.mark.p1 @pytest.mark.p1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, parser_config", "name, parser_config",
@ -689,6 +646,7 @@ class TestDatasetCreate:
{"name": "chunk_count", "chunk_count": 1}, {"name": "chunk_count", "chunk_count": 1},
{"name": "token_num", "token_num": 1}, {"name": "token_num", "token_num": 1},
{"name": "status", "status": "1"}, {"name": "status", "status": "1"},
{"name": "pagerank", "pagerank": 50},
{"name": "unknown_field", "unknown_field": "unknown_field"}, {"name": "unknown_field", "unknown_field": "unknown_field"},
], ],
) )

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from operator import attrgetter from operator import attrgetter
@ -324,6 +325,7 @@ class TestDatasetUpdate:
dataset.update({"chunk_method": None}) dataset.update({"chunk_method": None})
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value) assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value)
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
def test_pagerank(self, client, add_dataset_func, pagerank): def test_pagerank(self, client, add_dataset_func, pagerank):
@ -334,6 +336,30 @@ class TestDatasetUpdate:
retrieved_dataset = client.get_dataset(name=dataset.name) retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset) assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset)
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_set_to_0(self, client, add_dataset_func):
dataset = add_dataset_func
dataset.update({"pagerank": 50})
assert dataset.pagerank == 50, str(dataset)
retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == 50, str(retrieved_dataset)
dataset.update({"pagerank": 0})
assert dataset.pagerank == 0, str(dataset)
retrieved_dataset = client.get_dataset(name=dataset.name)
assert retrieved_dataset.pagerank == 0, str(retrieved_dataset)
@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
@pytest.mark.p2
def test_pagerank_infinity(self, client, add_dataset_func):
dataset = add_dataset_func
with pytest.raises(Exception) as excinfo:
dataset.update({"pagerank": 50})
assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value)
@pytest.mark.p2 @pytest.mark.p2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pagerank, expected_message", "pagerank, expected_message",