mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: wrap search app (#8320)
### What problem does this PR solve? Wrap search app ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
188
api/apps/search_app.py
Normal file
188
api/apps/search_app.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_login import current_user, login_required
|
||||||
|
|
||||||
|
from api import settings
|
||||||
|
from api.constants import DATASET_NAME_LIMIT
|
||||||
|
from api.db import StatusEnum
|
||||||
|
from api.db.db_models import DB
|
||||||
|
from api.db.services import duplicate_name
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.search_service import SearchService
|
||||||
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
|
from api.utils import get_uuid
|
||||||
|
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/create", methods=["post"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("name")
|
||||||
|
def create():
|
||||||
|
req = request.get_json()
|
||||||
|
search_name = req["name"]
|
||||||
|
description = req.get("description", "")
|
||||||
|
if not isinstance(search_name, str):
|
||||||
|
return get_data_error_result(message="Search name must be string.")
|
||||||
|
if search_name.strip() == "":
|
||||||
|
return get_data_error_result(message="Search name can't be empty.")
|
||||||
|
if len(search_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||||
|
return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than {DATASET_NAME_LIMIT}")
|
||||||
|
e, _ = TenantService.get_by_id(current_user.id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Authorizationd identity.")
|
||||||
|
|
||||||
|
search_name = search_name.strip()
|
||||||
|
search_name = duplicate_name(KnowledgebaseService.query, name=search_name, tenant_id=current_user.id, status=StatusEnum.VALID.value)
|
||||||
|
|
||||||
|
req["id"] = get_uuid()
|
||||||
|
req["name"] = search_name
|
||||||
|
req["description"] = description
|
||||||
|
req["tenant_id"] = current_user.id
|
||||||
|
req["created_by"] = current_user.id
|
||||||
|
with DB.atomic():
|
||||||
|
try:
|
||||||
|
if not SearchService.save(**req):
|
||||||
|
return get_data_error_result()
|
||||||
|
return get_json_result(data={"search_id": req["id"]})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/update", methods=["post"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("search_id", "name", "search_config", "tenant_id")
|
||||||
|
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
|
||||||
|
def update():
|
||||||
|
req = request.get_json()
|
||||||
|
if not isinstance(req["name"], str):
|
||||||
|
return get_data_error_result(message="Search name must be string.")
|
||||||
|
if req["name"].strip() == "":
|
||||||
|
return get_data_error_result(message="Search name can't be empty.")
|
||||||
|
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||||
|
return get_data_error_result(message=f"Search name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
|
||||||
|
req["name"] = req["name"].strip()
|
||||||
|
tenant_id = req["tenant_id"]
|
||||||
|
e, _ = TenantService.get_by_id(tenant_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Authorizationd identity.")
|
||||||
|
|
||||||
|
search_id = req["search_id"]
|
||||||
|
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||||
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
|
||||||
|
if not search_app:
|
||||||
|
return get_json_result(data=False, message=f"Cannot find search {search_id}", code=settings.RetCode.DATA_ERROR)
|
||||||
|
|
||||||
|
if req["name"].lower() != search_app.name.lower() and len(SearchService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) >= 1:
|
||||||
|
return get_data_error_result(message="Duplicated search name.")
|
||||||
|
|
||||||
|
if "search_config" in req:
|
||||||
|
current_config = search_app.search_config or {}
|
||||||
|
new_config = req["search_config"]
|
||||||
|
|
||||||
|
if not isinstance(new_config, dict):
|
||||||
|
return get_data_error_result(message="search_config must be a JSON object")
|
||||||
|
|
||||||
|
updated_config = {**current_config, **new_config}
|
||||||
|
req["search_config"] = updated_config
|
||||||
|
|
||||||
|
req.pop("search_id", None)
|
||||||
|
req.pop("tenant_id", None)
|
||||||
|
|
||||||
|
updated = SearchService.update_by_id(search_id, req)
|
||||||
|
if not updated:
|
||||||
|
return get_data_error_result(message="Failed to update search")
|
||||||
|
|
||||||
|
e, updated_search = SearchService.get_by_id(search_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="Failed to fetch updated search")
|
||||||
|
|
||||||
|
return get_json_result(data=updated_search.to_dict())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/detail", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def detail():
|
||||||
|
search_id = request.args["search_id"]
|
||||||
|
try:
|
||||||
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
|
for tenant in tenants:
|
||||||
|
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
search = SearchService.get_detail(search_id)
|
||||||
|
if not search:
|
||||||
|
return get_data_error_result(message="Can't find this Search App!")
|
||||||
|
return get_json_result(data=search)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def list_search_app():
|
||||||
|
keywords = request.args.get("keywords", "")
|
||||||
|
page_number = int(request.args.get("page", 0))
|
||||||
|
items_per_page = int(request.args.get("page_size", 0))
|
||||||
|
orderby = request.args.get("orderby", "create_time")
|
||||||
|
if request.args.get("desc", "true").lower() == "false":
|
||||||
|
desc = False
|
||||||
|
else:
|
||||||
|
desc = True
|
||||||
|
|
||||||
|
req = request.get_json()
|
||||||
|
owner_ids = req.get("owner_ids", [])
|
||||||
|
try:
|
||||||
|
if not owner_ids:
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||||
|
tenants = [m["tenant_id"] for m in tenants]
|
||||||
|
search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords)
|
||||||
|
else:
|
||||||
|
tenants = owner_ids
|
||||||
|
search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords)
|
||||||
|
search_apps = [search_app for search_app in search_apps if search_app["tenant_id"] in tenants]
|
||||||
|
total = len(search_apps)
|
||||||
|
if page_number and items_per_page:
|
||||||
|
search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page]
|
||||||
|
return get_json_result(data={"search_apps": search_apps, "total": total})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/rm", methods=["post"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
@validate_request("search_id")
|
||||||
|
def rm():
|
||||||
|
req = request.get_json()
|
||||||
|
search_id = req["search_id"]
|
||||||
|
if not SearchService.accessible4deletion(search_id, current_user.id):
|
||||||
|
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not SearchService.delete_by_id(search_id):
|
||||||
|
return get_data_error_result(message=f"Failed to delete search App {search_id}")
|
||||||
|
return get_json_result(data=True)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
@ -13,16 +13,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import typing
|
|
||||||
import time
|
import time
|
||||||
|
import typing
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
@ -272,6 +272,7 @@ def with_retry(max_retries=3, retry_delay=1.0):
|
|||||||
Returns:
|
Returns:
|
||||||
decorated function
|
decorated function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
@ -284,7 +285,7 @@ def with_retry(max_retries=3, retry_delay=1.0):
|
|||||||
# get self and method name for logging
|
# get self and method name for logging
|
||||||
self_obj = args[0] if args else None
|
self_obj = args[0] if args else None
|
||||||
func_name = func.__name__
|
func_name = func.__name__
|
||||||
lock_name = getattr(self_obj, 'lock_name', 'unknown') if self_obj else 'unknown'
|
lock_name = getattr(self_obj, "lock_name", "unknown") if self_obj else "unknown"
|
||||||
|
|
||||||
if retry < max_retries - 1:
|
if retry < max_retries - 1:
|
||||||
current_delay = retry_delay * (2**retry)
|
current_delay = retry_delay * (2**retry)
|
||||||
@ -296,7 +297,9 @@ def with_retry(max_retries=3, retry_delay=1.0):
|
|||||||
if last_exception:
|
if last_exception:
|
||||||
raise last_exception
|
raise last_exception
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
@ -796,6 +799,50 @@ class UserCanvasVersion(DataBaseModel):
|
|||||||
db_table = "user_canvas_version"
|
db_table = "user_canvas_version"
|
||||||
|
|
||||||
|
|
||||||
|
class Search(DataBaseModel):
|
||||||
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
avatar = TextField(null=True, help_text="avatar base64 string")
|
||||||
|
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||||
|
name = CharField(max_length=128, null=False, help_text="Search name", index=True)
|
||||||
|
description = TextField(null=True, help_text="KB description")
|
||||||
|
created_by = CharField(max_length=32, null=False, index=True)
|
||||||
|
search_config = JSONField(
|
||||||
|
null=False,
|
||||||
|
default={
|
||||||
|
"kb_ids": [],
|
||||||
|
"doc_ids": [],
|
||||||
|
"similarity_threshold": 0.0,
|
||||||
|
"vector_similarity_weight": 0.3,
|
||||||
|
"use_kg": False,
|
||||||
|
# rerank settings
|
||||||
|
"rerank_id": "",
|
||||||
|
"top_k": 1024,
|
||||||
|
# chat settings
|
||||||
|
"summary": False,
|
||||||
|
"chat_id": "",
|
||||||
|
"llm_setting": {
|
||||||
|
"temperature": 0.1,
|
||||||
|
"top_p": 0.3,
|
||||||
|
"frequency_penalty": 0.7,
|
||||||
|
"presence_penalty": 0.4,
|
||||||
|
},
|
||||||
|
"chat_settingcross_languages": [],
|
||||||
|
"highlight": False,
|
||||||
|
"keyword": False,
|
||||||
|
"web_search": False,
|
||||||
|
"related_search": False,
|
||||||
|
"query_mindmap": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "search"
|
||||||
|
|
||||||
|
|
||||||
def migrate_db():
|
def migrate_db():
|
||||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -159,6 +159,7 @@ BAD_CITATION_PATTERNS = [
|
|||||||
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
|
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||||
max_index = len(kbinfos["chunks"])
|
max_index = len(kbinfos["chunks"])
|
||||||
|
|
||||||
@ -555,7 +556,7 @@ def tts(tts_mdl, text):
|
|||||||
return binascii.hexlify(bin).decode("utf-8")
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def ask(question, kb_ids, tenant_id):
|
def ask(question, kb_ids, tenant_id, chat_llm_name=None):
|
||||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||||
|
|
||||||
@ -563,7 +564,7 @@ def ask(question, kb_ids, tenant_id):
|
|||||||
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
||||||
|
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||||
max_tokens = chat_mdl.max_length
|
max_tokens = chat_mdl.max_length
|
||||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||||
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
|
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
|
||||||
|
|||||||
110
api/db/services/search_service.py
Normal file
110
api/db/services/search_service.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from peewee import fn
|
||||||
|
|
||||||
|
from api.db import StatusEnum
|
||||||
|
from api.db.db_models import DB, Search, User
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.utils import current_timestamp, datetime_format
|
||||||
|
|
||||||
|
|
||||||
|
class SearchService(CommonService):
|
||||||
|
model = Search
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save(cls, **kwargs):
|
||||||
|
kwargs["create_time"] = current_timestamp()
|
||||||
|
kwargs["create_date"] = datetime_format(datetime.now())
|
||||||
|
kwargs["update_time"] = current_timestamp()
|
||||||
|
kwargs["update_date"] = datetime_format(datetime.now())
|
||||||
|
obj = cls.model.create(**kwargs)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def accessible4deletion(cls, search_id, user_id) -> bool:
|
||||||
|
search = (
|
||||||
|
cls.model.select(cls.model.id)
|
||||||
|
.where(
|
||||||
|
cls.model.id == search_id,
|
||||||
|
cls.model.created_by == user_id,
|
||||||
|
cls.model.status == StatusEnum.VALID.value,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return search is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_detail(cls, search_id):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.tenant_id,
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.description,
|
||||||
|
cls.model.created_by,
|
||||||
|
cls.model.search_config,
|
||||||
|
cls.model.update_time,
|
||||||
|
User.nickname,
|
||||||
|
User.avatar.alias("tenant_avatar"),
|
||||||
|
]
|
||||||
|
search = (
|
||||||
|
cls.model.select(*fields)
|
||||||
|
.join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value)))
|
||||||
|
.where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value))
|
||||||
|
.first()
|
||||||
|
.to_dict()
|
||||||
|
)
|
||||||
|
return search
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords):
|
||||||
|
fields = [
|
||||||
|
cls.model.id,
|
||||||
|
cls.model.avatar,
|
||||||
|
cls.model.tenant_id,
|
||||||
|
cls.model.name,
|
||||||
|
cls.model.description,
|
||||||
|
cls.model.created_by,
|
||||||
|
cls.model.status,
|
||||||
|
cls.model.update_time,
|
||||||
|
cls.model.create_time,
|
||||||
|
User.nickname,
|
||||||
|
User.avatar.alias("tenant_avatar"),
|
||||||
|
]
|
||||||
|
query = (
|
||||||
|
cls.model.select(*fields)
|
||||||
|
.join(User, on=(cls.model.tenant_id == User.id))
|
||||||
|
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value))
|
||||||
|
)
|
||||||
|
|
||||||
|
if keywords:
|
||||||
|
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
||||||
|
if desc:
|
||||||
|
query = query.order_by(cls.model.getter_by(orderby).desc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(cls.model.getter_by(orderby).asc())
|
||||||
|
|
||||||
|
count = query.count()
|
||||||
|
|
||||||
|
if page_number and items_per_page:
|
||||||
|
query = query.paginate(page_number, items_per_page)
|
||||||
|
|
||||||
|
return list(query.dicts()), count
|
||||||
Reference in New Issue
Block a user