mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fix: Move pagerank field from create to update dataset API (#8217)
### What problem does this PR solve? - Remove pagerank from CreateDatasetReq and add to UpdateDatasetReq - Add pagerank update logic in dataset update endpoint - Update API documentation to reflect changes - Modify related test cases and SDK references #8208 This change makes pagerank a mutable property that can only be set after dataset creation, and only when using elasticsearch as the doc engine. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -16,10 +16,12 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from peewee import OperationalError
|
from peewee import OperationalError
|
||||||
|
|
||||||
|
from api import settings
|
||||||
from api.db import FileSource, StatusEnum
|
from api.db import FileSource, StatusEnum
|
||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
@ -48,6 +50,8 @@ from api.utils.validation_utils import (
|
|||||||
validate_and_parse_json_request,
|
validate_and_parse_json_request,
|
||||||
validate_and_parse_request_args,
|
validate_and_parse_request_args,
|
||||||
)
|
)
|
||||||
|
from rag.nlp import search
|
||||||
|
from rag.settings import PAGERANK_FLD
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||||
@ -97,9 +101,6 @@ def create(tenant_id):
|
|||||||
"picture", "presentation", "qa", "table", "tag"
|
"picture", "presentation", "qa", "table", "tag"
|
||||||
]
|
]
|
||||||
description: Chunking method.
|
description: Chunking method.
|
||||||
pagerank:
|
|
||||||
type: integer
|
|
||||||
description: Set page rank.
|
|
||||||
parser_config:
|
parser_config:
|
||||||
type: object
|
type: object
|
||||||
description: Parser configuration.
|
description: Parser configuration.
|
||||||
@ -352,6 +353,16 @@ def update(tenant_id, dataset_id):
|
|||||||
if not ok:
|
if not ok:
|
||||||
return err
|
return err
|
||||||
|
|
||||||
|
if "pagerank" in req and req["pagerank"] != kb.pagerank:
|
||||||
|
if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity":
|
||||||
|
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
||||||
|
|
||||||
|
if req["pagerank"] > 0:
|
||||||
|
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id)
|
||||||
|
else:
|
||||||
|
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||||
|
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id)
|
||||||
|
|
||||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||||
|
|
||||||
|
|||||||
@ -383,7 +383,6 @@ class CreateDatasetReq(Base):
|
|||||||
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
|
embedding_model: Annotated[str, StringConstraints(strip_whitespace=True, max_length=255), Field(default="", serialization_alias="embd_id")]
|
||||||
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
|
permission: PermissionEnum = Field(default=PermissionEnum.me, min_length=1, max_length=16)
|
||||||
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
|
chunk_method: ChunkMethodnEnum = Field(default=ChunkMethodnEnum.naive, min_length=1, max_length=32, serialization_alias="parser_id")
|
||||||
pagerank: int = Field(default=0, ge=0, le=100)
|
|
||||||
parser_config: ParserConfig | None = Field(default=None)
|
parser_config: ParserConfig | None = Field(default=None)
|
||||||
|
|
||||||
@field_validator("avatar")
|
@field_validator("avatar")
|
||||||
@ -539,6 +538,7 @@ class CreateDatasetReq(Base):
|
|||||||
class UpdateDatasetReq(CreateDatasetReq):
|
class UpdateDatasetReq(CreateDatasetReq):
|
||||||
dataset_id: str = Field(...)
|
dataset_id: str = Field(...)
|
||||||
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
|
name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")]
|
||||||
|
pagerank: int = Field(default=0, ge=0, le=100)
|
||||||
|
|
||||||
@field_validator("dataset_id", mode="before")
|
@field_validator("dataset_id", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -343,7 +343,6 @@ Creates a dataset.
|
|||||||
- `"embedding_model"`: `string`
|
- `"embedding_model"`: `string`
|
||||||
- `"permission"`: `string`
|
- `"permission"`: `string`
|
||||||
- `"chunk_method"`: `string`
|
- `"chunk_method"`: `string`
|
||||||
- `"pagerank"`: `int`
|
|
||||||
- `"parser_config"`: `object`
|
- `"parser_config"`: `object`
|
||||||
|
|
||||||
##### Request example
|
##### Request example
|
||||||
@ -384,12 +383,6 @@ curl --request POST \
|
|||||||
- `"me"`: (Default) Only you can manage the dataset.
|
- `"me"`: (Default) Only you can manage the dataset.
|
||||||
- `"team"`: All team members can manage the dataset.
|
- `"team"`: All team members can manage the dataset.
|
||||||
|
|
||||||
- `"pagerank"`: (*Body parameter*), `int`
|
|
||||||
refer to [Set page rank](https://ragflow.io/docs/dev/set_page_rank)
|
|
||||||
- Default: `0`
|
|
||||||
- Minimum: `0`
|
|
||||||
- Maximum: `100`
|
|
||||||
|
|
||||||
- `"chunk_method"`: (*Body parameter*), `enum<string>`
|
- `"chunk_method"`: (*Body parameter*), `enum<string>`
|
||||||
The chunking method of the dataset to create. Available options:
|
The chunking method of the dataset to create. Available options:
|
||||||
- `"naive"`: General (default)
|
- `"naive"`: General (default)
|
||||||
|
|||||||
@ -100,7 +100,6 @@ RAGFlow.create_dataset(
|
|||||||
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
|
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
|
||||||
permission: str = "me",
|
permission: str = "me",
|
||||||
chunk_method: str = "naive",
|
chunk_method: str = "naive",
|
||||||
pagerank: int = 0,
|
|
||||||
parser_config: DataSet.ParserConfig = None
|
parser_config: DataSet.ParserConfig = None
|
||||||
) -> DataSet
|
) -> DataSet
|
||||||
```
|
```
|
||||||
@ -148,10 +147,6 @@ The chunking method of the dataset to create. Available options:
|
|||||||
- `"one"`: One
|
- `"one"`: One
|
||||||
- `"email"`: Email
|
- `"email"`: Email
|
||||||
|
|
||||||
##### pagerank, `int`
|
|
||||||
|
|
||||||
The pagerank of the dataset to create. Defaults to `0`.
|
|
||||||
|
|
||||||
##### parser_config
|
##### parser_config
|
||||||
|
|
||||||
The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`:
|
The parser configuration of the dataset. A `ParserConfig` object's attributes vary based on the selected `chunk_method`:
|
||||||
|
|||||||
@ -56,7 +56,6 @@ class RAGFlow:
|
|||||||
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
|
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
|
||||||
permission: str = "me",
|
permission: str = "me",
|
||||||
chunk_method: str = "naive",
|
chunk_method: str = "naive",
|
||||||
pagerank: int = 0,
|
|
||||||
parser_config: Optional[DataSet.ParserConfig] = None,
|
parser_config: Optional[DataSet.ParserConfig] = None,
|
||||||
) -> DataSet:
|
) -> DataSet:
|
||||||
payload = {
|
payload = {
|
||||||
@ -66,7 +65,6 @@ class RAGFlow:
|
|||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"permission": permission,
|
"permission": permission,
|
||||||
"chunk_method": chunk_method,
|
"chunk_method": chunk_method,
|
||||||
"pagerank": pagerank,
|
|
||||||
}
|
}
|
||||||
if parser_config is not None:
|
if parser_config is not None:
|
||||||
payload["parser_config"] = parser_config.to_json()
|
payload["parser_config"] = parser_config.to_json()
|
||||||
|
|||||||
@ -394,51 +394,6 @@ class TestDatasetCreate:
|
|||||||
assert res["code"] == 101, res
|
assert res["code"] == 101, res
|
||||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
||||||
|
|
||||||
@pytest.mark.p2
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"name, pagerank",
|
|
||||||
[
|
|
||||||
("pagerank_min", 0),
|
|
||||||
("pagerank_mid", 50),
|
|
||||||
("pagerank_max", 100),
|
|
||||||
],
|
|
||||||
ids=["min", "mid", "max"],
|
|
||||||
)
|
|
||||||
def test_pagerank(self, HttpApiAuth, name, pagerank):
|
|
||||||
payload = {"name": name, "pagerank": pagerank}
|
|
||||||
res = create_dataset(HttpApiAuth, payload)
|
|
||||||
assert res["code"] == 0, res
|
|
||||||
assert res["data"]["pagerank"] == pagerank, res
|
|
||||||
|
|
||||||
@pytest.mark.p3
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"name, pagerank, expected_message",
|
|
||||||
[
|
|
||||||
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
|
|
||||||
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
|
|
||||||
],
|
|
||||||
ids=["min_limit", "max_limit"],
|
|
||||||
)
|
|
||||||
def test_pagerank_invalid(self, HttpApiAuth, name, pagerank, expected_message):
|
|
||||||
payload = {"name": name, "pagerank": pagerank}
|
|
||||||
res = create_dataset(HttpApiAuth, payload)
|
|
||||||
assert res["code"] == 101, res
|
|
||||||
assert expected_message in res["message"], res
|
|
||||||
|
|
||||||
@pytest.mark.p3
|
|
||||||
def test_pagerank_unset(self, HttpApiAuth):
|
|
||||||
payload = {"name": "pagerank_unset"}
|
|
||||||
res = create_dataset(HttpApiAuth, payload)
|
|
||||||
assert res["code"] == 0, res
|
|
||||||
assert res["data"]["pagerank"] == 0, res
|
|
||||||
|
|
||||||
@pytest.mark.p3
|
|
||||||
def test_pagerank_none(self, HttpApiAuth):
|
|
||||||
payload = {"name": "pagerank_unset", "pagerank": None}
|
|
||||||
res = create_dataset(HttpApiAuth, payload)
|
|
||||||
assert res["code"] == 101, res
|
|
||||||
assert "Input should be a valid integer" in res["message"], res
|
|
||||||
|
|
||||||
@pytest.mark.p1
|
@pytest.mark.p1
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name, parser_config",
|
"name, parser_config",
|
||||||
@ -730,6 +685,7 @@ class TestDatasetCreate:
|
|||||||
{"name": "chunk_count", "chunk_count": 1},
|
{"name": "chunk_count", "chunk_count": 1},
|
||||||
{"name": "token_num", "token_num": 1},
|
{"name": "token_num", "token_num": 1},
|
||||||
{"name": "status", "status": "1"},
|
{"name": "status", "status": "1"},
|
||||||
|
{"name": "pagerank", "pagerank": 50},
|
||||||
{"name": "unknown_field", "unknown_field": "unknown_field"},
|
{"name": "unknown_field", "unknown_field": "unknown_field"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -13,11 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from common import DATASET_NAME_LIMIT, INVALID_API_TOKEN, list_datasets, update_dataset
|
from common import list_datasets, update_dataset
|
||||||
|
from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN
|
||||||
from hypothesis import HealthCheck, example, given, settings
|
from hypothesis import HealthCheck, example, given, settings
|
||||||
from libs.auth import RAGFlowHttpApiAuth
|
from libs.auth import RAGFlowHttpApiAuth
|
||||||
from utils import encode_avatar
|
from utils import encode_avatar
|
||||||
@ -155,10 +157,10 @@ class TestDatasetUpdate:
|
|||||||
|
|
||||||
@pytest.mark.p3
|
@pytest.mark.p3
|
||||||
def test_name_duplicated(self, HttpApiAuth, add_datasets_func):
|
def test_name_duplicated(self, HttpApiAuth, add_datasets_func):
|
||||||
dataset_ids = add_datasets_func[0]
|
dataset_id = add_datasets_func[0]
|
||||||
name = "dataset_1"
|
name = "dataset_1"
|
||||||
payload = {"name": name}
|
payload = {"name": name}
|
||||||
res = update_dataset(HttpApiAuth, dataset_ids, payload)
|
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||||
assert res["code"] == 102, res
|
assert res["code"] == 102, res
|
||||||
assert res["message"] == f"Dataset name '{name}' already exists", res
|
assert res["message"] == f"Dataset name '{name}' already exists", res
|
||||||
|
|
||||||
@ -425,6 +427,7 @@ class TestDatasetUpdate:
|
|||||||
assert res["code"] == 101, res
|
assert res["code"] == 101, res
|
||||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in res["message"], res
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
|
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
|
||||||
def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank):
|
def test_pagerank(self, HttpApiAuth, add_dataset_func, pagerank):
|
||||||
@ -437,6 +440,35 @@ class TestDatasetUpdate:
|
|||||||
assert res["code"] == 0, res
|
assert res["code"] == 0, res
|
||||||
assert res["data"][0]["pagerank"] == pagerank
|
assert res["data"][0]["pagerank"] == pagerank
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_pagerank_set_to_0(self, HttpApiAuth, add_dataset_func):
|
||||||
|
dataset_id = add_dataset_func
|
||||||
|
payload = {"pagerank": 50}
|
||||||
|
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||||
|
assert res["code"] == 0, res
|
||||||
|
|
||||||
|
res = list_datasets(HttpApiAuth, {"id": dataset_id})
|
||||||
|
assert res["code"] == 0, res
|
||||||
|
assert res["data"][0]["pagerank"] == 50, res
|
||||||
|
|
||||||
|
payload = {"pagerank": 0}
|
||||||
|
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||||
|
assert res["code"] == 0
|
||||||
|
|
||||||
|
res = list_datasets(HttpApiAuth, {"id": dataset_id})
|
||||||
|
assert res["code"] == 0, res
|
||||||
|
assert res["data"][0]["pagerank"] == 0, res
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_pagerank_infinity(self, HttpApiAuth, add_dataset_func):
|
||||||
|
dataset_id = add_dataset_func
|
||||||
|
payload = {"pagerank": 50}
|
||||||
|
res = update_dataset(HttpApiAuth, dataset_id, payload)
|
||||||
|
assert res["code"] == 101, res
|
||||||
|
assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res
|
||||||
|
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"pagerank, expected_message",
|
"pagerank, expected_message",
|
||||||
|
|||||||
@ -344,49 +344,6 @@ class TestDatasetCreate:
|
|||||||
client.create_dataset(**payload)
|
client.create_dataset(**payload)
|
||||||
assert "not instance of" in str(excinfo.value), str(excinfo.value)
|
assert "not instance of" in str(excinfo.value), str(excinfo.value)
|
||||||
|
|
||||||
@pytest.mark.p2
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"name, pagerank",
|
|
||||||
[
|
|
||||||
("pagerank_min", 0),
|
|
||||||
("pagerank_mid", 50),
|
|
||||||
("pagerank_max", 100),
|
|
||||||
],
|
|
||||||
ids=["min", "mid", "max"],
|
|
||||||
)
|
|
||||||
def test_pagerank(self, client, name, pagerank):
|
|
||||||
payload = {"name": name, "pagerank": pagerank}
|
|
||||||
dataset = client.create_dataset(**payload)
|
|
||||||
assert dataset.pagerank == pagerank, str(dataset)
|
|
||||||
|
|
||||||
@pytest.mark.p3
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"name, pagerank, expected_message",
|
|
||||||
[
|
|
||||||
("pagerank_min_limit", -1, "Input should be greater than or equal to 0"),
|
|
||||||
("pagerank_max_limit", 101, "Input should be less than or equal to 100"),
|
|
||||||
],
|
|
||||||
ids=["min_limit", "max_limit"],
|
|
||||||
)
|
|
||||||
def test_pagerank_invalid(self, client, name, pagerank, expected_message):
|
|
||||||
payload = {"name": name, "pagerank": pagerank}
|
|
||||||
with pytest.raises(Exception) as excinfo:
|
|
||||||
client.create_dataset(**payload)
|
|
||||||
assert expected_message in str(excinfo.value), str(excinfo.value)
|
|
||||||
|
|
||||||
@pytest.mark.p3
|
|
||||||
def test_pagerank_unset(self, client):
|
|
||||||
payload = {"name": "pagerank_unset"}
|
|
||||||
dataset = client.create_dataset(**payload)
|
|
||||||
assert dataset.pagerank == 0, str(dataset)
|
|
||||||
|
|
||||||
@pytest.mark.p3
|
|
||||||
def test_pagerank_none(self, client):
|
|
||||||
payload = {"name": "pagerank_unset", "pagerank": None}
|
|
||||||
with pytest.raises(Exception) as excinfo:
|
|
||||||
client.create_dataset(**payload)
|
|
||||||
assert "not instance of" in str(excinfo.value), str(excinfo.value)
|
|
||||||
|
|
||||||
@pytest.mark.p1
|
@pytest.mark.p1
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name, parser_config",
|
"name, parser_config",
|
||||||
@ -689,6 +646,7 @@ class TestDatasetCreate:
|
|||||||
{"name": "chunk_count", "chunk_count": 1},
|
{"name": "chunk_count", "chunk_count": 1},
|
||||||
{"name": "token_num", "token_num": 1},
|
{"name": "token_num", "token_num": 1},
|
||||||
{"name": "status", "status": "1"},
|
{"name": "status", "status": "1"},
|
||||||
|
{"name": "pagerank", "pagerank": 50},
|
||||||
{"name": "unknown_field", "unknown_field": "unknown_field"},
|
{"name": "unknown_field", "unknown_field": "unknown_field"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
|
|
||||||
@ -324,6 +325,7 @@ class TestDatasetUpdate:
|
|||||||
dataset.update({"chunk_method": None})
|
dataset.update({"chunk_method": None})
|
||||||
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value)
|
assert "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" in str(excinfo.value), str(excinfo.value)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
|
@pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"])
|
||||||
def test_pagerank(self, client, add_dataset_func, pagerank):
|
def test_pagerank(self, client, add_dataset_func, pagerank):
|
||||||
@ -334,6 +336,30 @@ class TestDatasetUpdate:
|
|||||||
retrieved_dataset = client.get_dataset(name=dataset.name)
|
retrieved_dataset = client.get_dataset(name=dataset.name)
|
||||||
assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset)
|
assert retrieved_dataset.pagerank == pagerank, str(retrieved_dataset)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="#8208")
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_pagerank_set_to_0(self, client, add_dataset_func):
|
||||||
|
dataset = add_dataset_func
|
||||||
|
dataset.update({"pagerank": 50})
|
||||||
|
assert dataset.pagerank == 50, str(dataset)
|
||||||
|
|
||||||
|
retrieved_dataset = client.get_dataset(name=dataset.name)
|
||||||
|
assert retrieved_dataset.pagerank == 50, str(retrieved_dataset)
|
||||||
|
|
||||||
|
dataset.update({"pagerank": 0})
|
||||||
|
assert dataset.pagerank == 0, str(dataset)
|
||||||
|
|
||||||
|
retrieved_dataset = client.get_dataset(name=dataset.name)
|
||||||
|
assert retrieved_dataset.pagerank == 0, str(retrieved_dataset)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="#8208")
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_pagerank_infinity(self, client, add_dataset_func):
|
||||||
|
dataset = add_dataset_func
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
dataset.update({"pagerank": 50})
|
||||||
|
assert "'pagerank' can only be set when doc_engine is elasticsearch" in str(excinfo.value), str(excinfo.value)
|
||||||
|
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"pagerank, expected_message",
|
"pagerank, expected_message",
|
||||||
|
|||||||
Reference in New Issue
Block a user