mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fix: Move pagerank field from create to update dataset API (#8217)
### What problem does this PR solve? - Remove pagerank from CreateDatasetReq and add to UpdateDatasetReq - Add pagerank update logic in dataset update endpoint - Update API documentation to reflect changes - Modify related test cases and SDK references #8208 This change makes pagerank a mutable property that can only be set after dataset creation, and only when using elasticsearch as the doc engine. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -16,10 +16,12 @@
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from flask import request
|
||||
from peewee import OperationalError
|
||||
|
||||
from api import settings
|
||||
from api.db import FileSource, StatusEnum
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -48,6 +50,8 @@ from api.utils.validation_utils import (
|
||||
validate_and_parse_json_request,
|
||||
validate_and_parse_request_args,
|
||||
)
|
||||
from rag.nlp import search
|
||||
from rag.settings import PAGERANK_FLD
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@ -97,9 +101,6 @@ def create(tenant_id):
|
||||
"picture", "presentation", "qa", "table", "tag"
|
||||
]
|
||||
description: Chunking method.
|
||||
pagerank:
|
||||
type: integer
|
||||
description: Set page rank.
|
||||
parser_config:
|
||||
type: object
|
||||
description: Parser configuration.
|
||||
@ -352,6 +353,16 @@ def update(tenant_id, dataset_id):
|
||||
if not ok:
|
||||
return err
|
||||
|
||||
if "pagerank" in req and req["pagerank"] != kb.pagerank:
|
||||
if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity":
|
||||
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
||||
|
||||
if req["pagerank"] > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||
|
||||
|
||||
@ -383,7 +383,6 @@ class CreateDatasetReq(Base):
|
||||
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
|
||||
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
|
||||
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
|
||||
pagerank: int = Field(default=0, ge=0, le=100)
|
||||
parser_config: ParserConfig | None = Field(default=None)
|
||||
|
||||
@field_validator("avatar")
|
||||
@ -539,6 +538,7 @@ class CreateDatasetReq(Base):
|
||||
class UpdateDatasetReq(CreateDatasetReq):
|
||||
dataset_id: str = Field(...)
|
||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
|
||||
pagerank: int = Field(default=0, ge=0, le=100)
|
||||
|
||||
@field_validator("dataset_id", mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -343,7 +343,6 @@ Creates a dataset.
|
||||
- `"embedding_model"`: `string`
|
||||
- `"permission"`: `string`
|
||||
- `"chunk_method"`: `string`
|
||||
- `"pagerank"`: `int`
|
||||
- `"parser_config"`: `object`
|
||||
|
||||
##### Request example
|
||||
@ -384,12 +383,6 @@ curl --request POST \
|
||||
- `"me"`: (Default) Only you can manage the dataset.
|
||||
- `"team"`: All team members can manage the dataset.
|
||||
|
||||
- `"pagerank"`: (*Body parameter*), `int`
|
||||
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
|
||||
- Default: `0`
|
||||
- Minimum: `0`
|
||||
- Maximum: `100`
|
||||
|
||||
- `"chunk_method"`: (*Body parameter*), `enum<string>`
|
||||
The chunking method of the dataset to create. Available options:
|
||||
- `"naive"`: General (default)
|
||||
|
||||
@ -100,7 +100,6 @@ RAGFlow.create_dataset(
|
||||
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
|
||||
permission: str = "me",
|
||||
chunk_method: str = "naive",
|
||||
pagerank: int = 0,
|
||||
parser_config: DataSet.ParserConfig = None
|
||||
) -> DataSet
|
||||
```
|
||||
@ -148,10 +147,6 @@ The chunking method of the dataset to create. Available options:
|
||||
- `"one"`: One
|
||||
- `"email"`: Email
|
||||
|
||||
##### pagerank, `int`
|
||||
|
||||
The pagerank of the dataset to create. Defaults to `0`.
|
||||
|
||||
##### parser_config
|
||||
|
||||
The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`:
|
||||
|
||||
@ -56,7 +56,6 @@ class RAGFlow:
|
||||
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
|
||||
permission: str = "me",
|
||||
chunk_method: str = "naive",
|
||||
pagerank: int = 0,
|
||||
parser_config: Optional[DataSet.ParserConfig] = None,
|
||||
) -> DataSet:
|
||||
payload = {
|
||||
@ -66,7 +65,6 @@ class RAGFlow:
|
||||
"embedding_model": embedding_model,
|
||||
"permission": permission,
|
||||
"chunk_method": chunk_method,
|
||||
"pagerank": pagerank,
|
||||
}
|
||||
if parser_config is not None:
|
||||
payload["parser_config"] = parser_config.to_json()
|
||||
|
||||
@ -394,51 +394,6 @@ class TestDatasetCreate:
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, pagerank",
|
||||
[
|
||||
("pagerank_min", 0),
|
||||
("pagerank_mid", 50),
|
||||
("pagerank_max", 100),
|
||||
],
|
||||
ids=["min", "mid", "max"],
|
||||
)
|
||||
def test_pagerank(self, HttpApiAuth, name, pagerank):
|
||||
payload = {"name": name, "pagerank": pagerank}
|
||||
res = create_dataset(HttpApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["pagerank"] == pagerank, res
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"name, pagerank, expected_message",
|
||||
[
|
||||
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
|
||||
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
|
||||
],
|
||||
ids=["min_limit", "max_limit"],
|
||||
)
|
||||
def test_pagerank_invalid(self, HttpApiAuth, name, pagerank, expected_message):
|
||||
payload = {"name": name, "pagerank": pagerank}
|
||||
res = create_dataset(HttpApiAuth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert expected_message in res["message"], res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_pagerank_unset(self, HttpApiAuth):
|
||||
payload = {"name": "pagerank_unset"}
|
||||
res = create_dataset(HttpApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"]["pagerank"] == 0, res
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_pagerank_none(self, HttpApiAuth):
|
||||
payload = {"name": "pagerank_unset", "pagerank": None}
|
||||
res = create_dataset(HttpApiAuth, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be a valid integer" in res["message"], res
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"name, parser_config",
|
||||
@ -730,6 +685,7 @@ class TestDatasetCreate:
|
||||
{"name": "chunk_count", "chunk_count": 1},
|
||||
{"name": "token_num", "token_num": 1},
|
||||
{"name": "status", "status": "1"},
|
||||
{"name": "pagerank", "pagerank": 50},
|
||||
{"name": "unknown_field", "unknown_field": "unknown_field"},
|
||||
],
|
||||
)
|
||||
|
||||
@ -13,11 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset
|
||||
from common import list_datasets, update_dataset
|
||||
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
|
||||
from hypothesis import HealthCheck, example, given, settings
|
||||
from libs.auth import RAGFlowHttpApiAuth
|
||||
from utils import encode_avatar
|
||||
@ -155,10 +157,10 @@ class TestDatasetUpdate:
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_duplicated(self, HttpApiAuth, add_datasets_func):
|
||||
dataset_ids = add_datasets_func[0]
|
||||
dataset_id = add_datasets_func[0]
|
||||
name = "dataset_1"
|
||||
payload = {"name": name}
|
||||
res = update_dataset(HttpApiAuth, dataset_ids, payload)
|
||||
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||
assert res["code"] == 102, res
|
||||
assert res["message"] == f"Dataset name '{name}' already exists", res
|
||||
|
||||
@ -425,6 +427,7 @@ class TestDatasetUpdate:
|
||||
assert res["code"] == 101, res
|
||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
||||
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
|
||||
def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank):
|
||||
@ -437,6 +440,35 @@ class TestDatasetUpdate:
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["pagerank"] == pagerank
|
||||
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
|
||||
@pytest.mark.p2
|
||||
def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"pagerank": 50}
|
||||
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
res = list_datasets(HttpApiAuth, {"id": dataset_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["pagerank"] == 50, res
|
||||
|
||||
payload = {"pagerank": 0}
|
||||
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||
assert res["code"] == 0
|
||||
|
||||
res = list_datasets(HttpApiAuth, {"id": dataset_id})
|
||||
assert res["code"] == 0, res
|
||||
assert res["data"][0]["pagerank"] == 0, res
|
||||
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
|
||||
@pytest.mark.p2
|
||||
def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func):
|
||||
dataset_id = add_dataset_func
|
||||
payload = {"pagerank": 50}
|
||||
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||
assert res["code"] == 101, res
|
||||
assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"pagerank, expected_message",
|
||||
|
||||
@ -344,49 +344,6 @@ class TestDatasetCreate:
|
||||
client.create_dataset(**payload)
|
||||
assert "not instance of" in str(excinfo.value), str(excinfo.value)
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, pagerank",
|
||||
[
|
||||
("pagerank_min", 0),
|
||||
("pagerank_mid", 50),
|
||||
("pagerank_max", 100),
|
||||
],
|
||||
ids=["min", "mid", "max"],
|
||||
)
|
||||
def test_pagerank(self, client, name, pagerank):
|
||||
payload = {"name": name, "pagerank": pagerank}
|
||||
dataset = client.create_dataset(**payload)
|
||||
assert dataset.pagerank == pagerank, str(dataset)
|
||||
|
||||
@pytest.mark.p3
|
||||
@pytest.mark.parametrize(
|
||||
"name, pagerank, expected_message",
|
||||
[
|
||||
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
|
||||
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
|
||||
],
|
||||
ids=["min_limit", "max_limit"],
|
||||
)
|
||||
def test_pagerank_invalid(self, client, name, pagerank, expected_message):
|
||||
payload = {"name": name, "pagerank": pagerank}
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
client.create_dataset(**payload)
|
||||
assert expected_message in str(excinfo.value), str(excinfo.value)
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_pagerank_unset(self, client):
|
||||
payload = {"name": "pagerank_unset"}
|
||||
dataset = client.create_dataset(**payload)
|
||||
assert dataset.pagerank == 0, str(dataset)
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_pagerank_none(self, client):
|
||||
payload = {"name": "pagerank_unset", "pagerank": None}
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
client.create_dataset(**payload)
|
||||
assert "not instance of" in str(excinfo.value), str(excinfo.value)
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"name, parser_config",
|
||||
@ -689,6 +646,7 @@ class TestDatasetCreate:
|
||||
{"name": "chunk_count", "chunk_count": 1},
|
||||
{"name": "token_num", "token_num": 1},
|
||||
{"name": "status", "status": "1"},
|
||||
{"name": "pagerank", "pagerank": 50},
|
||||
{"name": "unknown_field", "unknown_field": "unknown_field"},
|
||||
],
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from operator import attrgetter
|
||||
|
||||
@ -324,6 +325,7 @@ class TestDatasetUpdate:
|
||||
dataset.update({"chunk_method": None})
|
||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value)
|
||||
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
|
||||
def test_pagerank(self, client, add_dataset_func, pagerank):
|
||||
@ -334,6 +336,30 @@ class TestDatasetUpdate:
|
||||
retrieved_dataset = client.get_dataset(name=dataset.name)
|
||||
assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset)
|
||||
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
|
||||
@pytest.mark.p2
|
||||
def test_pagerank_set_to_0(self, client, add_dataset_func):
|
||||
dataset = add_dataset_func
|
||||
dataset.update({"pagerank": 50})
|
||||
assert dataset.pagerank == 50, str(dataset)
|
||||
|
||||
retrieved_dataset = client.get_dataset(name=dataset.name)
|
||||
assert retrieved_dataset.pagerank == 50, str(retrieved_dataset)
|
||||
|
||||
dataset.update({"pagerank": 0})
|
||||
assert dataset.pagerank == 0, str(dataset)
|
||||
|
||||
retrieved_dataset = client.get_dataset(name=dataset.name)
|
||||
assert retrieved_dataset.pagerank == 0, str(retrieved_dataset)
|
||||
|
||||
@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
|
||||
@pytest.mark.p2
|
||||
def test_pagerank_infinity(self, client, add_dataset_func):
|
||||
dataset = add_dataset_func
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
dataset.update({"pagerank": 50})
|
||||
assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value)
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"pagerank, expected_message",
|
||||
|
||||
Reference in New Issue
Block a user