Feat: wrap search app (#8320)

### What problem does this PR solve?

Wrap search app

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-06-18 16:45:42 +08:00
committed by GitHub
parent 311e20599f
commit 1b022116d5
4 changed files with 359 additions and 13 deletions

188
api/apps/search_app.py Normal file
View File

@ -0,0 +1,188 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask_login import current_user, login_required
from api import settings
from api.constants import DATASET_NAME_LIMIT
from api.db import StatusEnum
from api.db.db_models import DB
from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request
@manager.route("/create", methods=["post"]) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.get_json()
search_name = req["name"]
description = req.get("description", "")
if not isinstance(search_name, str):
return get_data_error_result(message="Search name must be string.")
if search_name.strip() == "":
return get_data_error_result(message="Search name can't be empty.")
if len(search_name.encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than {DATASET_NAME_LIMIT}")
e, _ = TenantService.get_by_id(current_user.id)
if not e:
return get_data_error_result(message="Authorizationd identity.")
search_name = search_name.strip()
search_name = duplicate_name(KnowledgebaseService.query, name=search_name, tenant_id=current_user.id, status=StatusEnum.VALID.value)
req["id"] = get_uuid()
req["name"] = search_name
req["description"] = description
req["tenant_id"] = current_user.id
req["created_by"] = current_user.id
with DB.atomic():
try:
if not SearchService.save(**req):
return get_data_error_result()
return get_json_result(data={"search_id": req["id"]})
except Exception as e:
return server_error_response(e)
@manager.route("/update", methods=["post"]) # noqa: F821
@login_required
@validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
def update():
req = request.get_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "":
return get_data_error_result(message="Search name can't be empty.")
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(message=f"Search name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
tenant_id = req["tenant_id"]
e, _ = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(message="Authorizationd identity.")
search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try:
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
if not search_app:
return get_json_result(data=False, message=f"Cannot find search {search_id}", code=settings.RetCode.DATA_ERROR)
if req["name"].lower() != search_app.name.lower() and len(SearchService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) >= 1:
return get_data_error_result(message="Duplicated search name.")
if "search_config" in req:
current_config = search_app.search_config or {}
new_config = req["search_config"]
if not isinstance(new_config, dict):
return get_data_error_result(message="search_config must be a JSON object")
updated_config = {**current_config, **new_config}
req["search_config"] = updated_config
req.pop("search_id", None)
req.pop("tenant_id", None)
updated = SearchService.update_by_id(search_id, req)
if not updated:
return get_data_error_result(message="Failed to update search")
e, updated_search = SearchService.get_by_id(search_id)
if not e:
return get_data_error_result(message="Failed to fetch updated search")
return get_json_result(data=updated_search.to_dict())
except Exception as e:
return server_error_response(e)
@manager.route("/detail", methods=["GET"]) # noqa: F821
@login_required
def detail():
search_id = request.args["search_id"]
try:
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
break
else:
return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR)
search = SearchService.get_detail(search_id)
if not search:
return get_data_error_result(message="Can't find this Search App!")
return get_json_result(data=search)
except Exception as e:
return server_error_response(e)
@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_search_app():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
req = request.get_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
tenants = [m["tenant_id"] for m in tenants]
search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords)
else:
tenants = owner_ids
search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords)
search_apps = [search_app for search_app in search_apps if search_app["tenant_id"] in tenants]
total = len(search_apps)
if page_number and items_per_page:
search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page]
return get_json_result(data={"search_apps": search_apps, "total": total})
except Exception as e:
return server_error_response(e)
@manager.route("/rm", methods=["post"]) # noqa: F821
@login_required
@validate_request("search_id")
def rm():
req = request.get_json()
search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try:
if not SearchService.delete_by_id(search_id):
return get_data_error_result(message=f"Failed to delete search App {search_id}")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@ -13,16 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import hashlib
import inspect import inspect
import logging import logging
import operator import operator
import os import os
import sys import sys
import typing
import time import time
import typing
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
import hashlib
from flask_login import UserMixin from flask_login import UserMixin
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
@ -272,6 +272,7 @@ def with_retry(max_retries=3, retry_delay=1.0):
Returns: Returns:
decorated function decorated function
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -284,7 +285,7 @@ def with_retry(max_retries=3, retry_delay=1.0):
# get self and method name for logging # get self and method name for logging
self_obj = args[0] if args else None self_obj = args[0] if args else None
func_name = func.__name__ func_name = func.__name__
lock_name = getattr(self_obj, 'lock_name', 'unknown') if self_obj else 'unknown' lock_name = getattr(self_obj, "lock_name", "unknown") if self_obj else "unknown"
if retry < max_retries - 1: if retry < max_retries - 1:
current_delay = retry_delay * (2**retry) current_delay = retry_delay * (2**retry)
@ -296,7 +297,9 @@ def with_retry(max_retries=3, retry_delay=1.0):
if last_exception: if last_exception:
raise last_exception raise last_exception
return False return False
return wrapper return wrapper
return decorator return decorator
@ -796,6 +799,50 @@ class UserCanvasVersion(DataBaseModel):
db_table = "user_canvas_version" db_table = "user_canvas_version"
class Search(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
avatar = TextField(null=True, help_text="avatar base64 string")
tenant_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=128, null=False, help_text="Search name", index=True)
description = TextField(null=True, help_text="KB description")
created_by = CharField(max_length=32, null=False, index=True)
search_config = JSONField(
null=False,
default={
"kb_ids": [],
"doc_ids": [],
"similarity_threshold": 0.0,
"vector_similarity_weight": 0.3,
"use_kg": False,
# rerank settings
"rerank_id": "",
"top_k": 1024,
# chat settings
"summary": False,
"chat_id": "",
"llm_setting": {
"temperature": 0.1,
"top_p": 0.3,
"frequency_penalty": 0.7,
"presence_penalty": 0.4,
},
"chat_settingcross_languages": [],
"highlight": False,
"keyword": False,
"web_search": False,
"related_search": False,
"query_mindmap": False,
},
)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
def __str__(self):
return self.name
class Meta:
db_table = "search"
def migrate_db(): def migrate_db():
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
try: try:

View File

@ -159,6 +159,7 @@ BAD_CITATION_PATTERNS = [
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
] ]
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
max_index = len(kbinfos["chunks"]) max_index = len(kbinfos["chunks"])
@ -555,7 +556,7 @@ def tts(tts_mdl, text):
return binascii.hexlify(bin).decode("utf-8") return binascii.hexlify(bin).decode("utf-8")
def ask(question, kb_ids, tenant_id): def ask(question, kb_ids, tenant_id, chat_llm_name=None):
kbs = KnowledgebaseService.get_by_ids(kb_ids) kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs])) embedding_list = list(set([kb.embd_id for kb in kbs]))
@ -563,7 +564,7 @@ def ask(question, kb_ids, tenant_id):
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
max_tokens = chat_mdl.max_length max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))

View File

@ -0,0 +1,110 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
from peewee import fn
from api.db import StatusEnum
from api.db.db_models import DB, Search, User
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format
class SearchService(CommonService):
model = Search
@classmethod
def save(cls, **kwargs):
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model.create(**kwargs)
return obj
@classmethod
@DB.connection_context()
def accessible4deletion(cls, search_id, user_id) -> bool:
search = (
cls.model.select(cls.model.id)
.where(
cls.model.id == search_id,
cls.model.created_by == user_id,
cls.model.status == StatusEnum.VALID.value,
)
.first()
)
return search is not None
@classmethod
@DB.connection_context()
def get_detail(cls, search_id):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.tenant_id,
cls.model.name,
cls.model.description,
cls.model.created_by,
cls.model.search_config,
cls.model.update_time,
User.nickname,
User.avatar.alias("tenant_avatar"),
]
search = (
cls.model.select(*fields)
.join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value)))
.where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value))
.first()
.to_dict()
)
return search
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.tenant_id,
cls.model.name,
cls.model.description,
cls.model.created_by,
cls.model.status,
cls.model.update_time,
cls.model.create_time,
User.nickname,
User.avatar.alias("tenant_avatar"),
]
query = (
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value))
)
if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if desc:
query = query.order_by(cls.model.getter_by(orderby).desc())
else:
query = query.order_by(cls.model.getter_by(orderby).asc())
count = query.count()
if page_number and items_per_page:
query = query.paginate(page_number, items_per_page)
return list(query.dicts()), count