diff --git a/api/apps/search_app.py b/api/apps/search_app.py new file mode 100644 index 000000000..083e63083 --- /dev/null +++ b/api/apps/search_app.py @@ -0,0 +1,188 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from flask import request +from flask_login import current_user, login_required + +from api import settings +from api.constants import DATASET_NAME_LIMIT +from api.db import StatusEnum +from api.db.db_models import DB +from api.db.services import duplicate_name +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.search_service import SearchService +from api.db.services.user_service import TenantService, UserTenantService +from api.utils import get_uuid +from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request + + +@manager.route("/create", methods=["post"]) # noqa: F821 +@login_required +@validate_request("name") +def create(): + req = request.get_json() + search_name = req["name"] + description = req.get("description", "") + if not isinstance(search_name, str): + return get_data_error_result(message="Search name must be string.") + if search_name.strip() == "": + return get_data_error_result(message="Search name can't be empty.") + if len(search_name.encode("utf-8")) > DATASET_NAME_LIMIT: + return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than {DATASET_NAME_LIMIT}") + e, _ = TenantService.get_by_id(current_user.id) + if not e: + return get_data_error_result(message="Authorizationd identity.") + + search_name = search_name.strip() + search_name = duplicate_name(KnowledgebaseService.query, name=search_name, tenant_id=current_user.id, status=StatusEnum.VALID.value) + + req["id"] = get_uuid() + req["name"] = search_name + req["description"] = description + req["tenant_id"] = current_user.id + req["created_by"] = current_user.id + with DB.atomic(): + try: + if not SearchService.save(**req): + return get_data_error_result() + return get_json_result(data={"search_id": req["id"]}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/update", methods=["post"]) # noqa: F821 +@login_required +@validate_request("search_id", "name", "search_config", "tenant_id") +@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") +def update(): + req = request.get_json() + if not isinstance(req["name"], str): + return get_data_error_result(message="Search name must be string.") + if req["name"].strip() == "": + return get_data_error_result(message="Search name can't be empty.") + if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT: + return get_data_error_result(message=f"Search name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") + req["name"] = req["name"].strip() + tenant_id = req["tenant_id"] + e, _ = TenantService.get_by_id(tenant_id) + if not e: + return get_data_error_result(message="Authorizationd identity.") + + search_id = req["search_id"] + if not SearchService.accessible4deletion(search_id, current_user.id): + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) + + try: + search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0] + if not search_app: + return get_json_result(data=False, message=f"Cannot find search {search_id}", code=settings.RetCode.DATA_ERROR) + + if req["name"].lower() != search_app.name.lower() and len(SearchService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) >= 1: + return get_data_error_result(message="Duplicated search name.") + + if "search_config" in req: + current_config = search_app.search_config or {} + new_config = req["search_config"] + + if not isinstance(new_config, dict): + return get_data_error_result(message="search_config must be a JSON object") + + updated_config = {**current_config, **new_config} + req["search_config"] = updated_config + + req.pop("search_id", None) + req.pop("tenant_id", None) + + updated = SearchService.update_by_id(search_id, req) + if not updated: + return get_data_error_result(message="Failed to update search") + + e, updated_search = SearchService.get_by_id(search_id) + if not e: + return get_data_error_result(message="Failed to fetch updated search") + + return get_json_result(data=updated_search.to_dict()) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/detail", methods=["GET"]) # noqa: F821 +@login_required +def detail(): + search_id = request.args["search_id"] + try: + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: + if SearchService.query(tenant_id=tenant.tenant_id, id=search_id): + break + else: + return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR) + + search = SearchService.get_detail(search_id) + if not search: + return get_data_error_result(message="Can't find this Search App!") + return get_json_result(data=search) + except Exception as e: + return server_error_response(e) + + +@manager.route("/list", methods=["POST"]) # noqa: F821 +@login_required +def list_search_app(): + keywords = request.args.get("keywords", "") + page_number = int(request.args.get("page", 0)) + items_per_page = int(request.args.get("page_size", 0)) + orderby = request.args.get("orderby", "create_time") + if request.args.get("desc", "true").lower() == "false": + desc = False + else: + desc = True + + req = request.get_json() + owner_ids = req.get("owner_ids", []) + try: + if not owner_ids: + tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) + tenants = [m["tenant_id"] for m in tenants] + search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords) + else: + tenants = owner_ids + search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords) + search_apps = [search_app for search_app in search_apps if search_app["tenant_id"] in tenants] + total = len(search_apps) + if page_number and items_per_page: + search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page] + return get_json_result(data={"search_apps": search_apps, "total": total}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/rm", methods=["post"]) # noqa: F821 +@login_required +@validate_request("search_id") +def rm(): + req = request.get_json() + search_id = req["search_id"] + if not SearchService.accessible4deletion(search_id, current_user.id): + return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) + + try: + if not SearchService.delete_by_id(search_id): + return get_data_error_result(message=f"Failed to delete search App {search_id}") + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) diff --git a/api/db/db_models.py b/api/db/db_models.py index ce71f7b6f..3ccfbdba3 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import hashlib import inspect import logging import operator import os import sys -import typing import time +import typing from enum import Enum from functools import wraps -import hashlib from flask_login import UserMixin from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer @@ -264,14 +264,15 @@ class BaseDataBase: def with_retry(max_retries=3, retry_delay=1.0): """Decorator: Add retry mechanism to database operations - + Args: max_retries (int): maximum number of retries retry_delay (float): initial retry delay (seconds), will increase exponentially - + Returns: decorated function """ + def decorator(func): @wraps(func) def wrapper(*args, **kwargs): @@ -284,26 +285,28 @@ def with_retry(max_retries=3, retry_delay=1.0): # get self and method name for logging self_obj = args[0] if args else None func_name = func.__name__ - lock_name = getattr(self_obj, 'lock_name', 'unknown') if self_obj else 'unknown' - + lock_name = getattr(self_obj, "lock_name", "unknown") if self_obj else "unknown" + if retry < max_retries - 1: - current_delay = retry_delay * (2 ** retry) - logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry+1}/{max_retries})") + current_delay = retry_delay * (2**retry) + logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry + 1}/{max_retries})") time.sleep(current_delay) else: logging.error(f"{func_name} {lock_name} failed after all attempts: {str(e)}") - + if last_exception: raise last_exception return False + return wrapper + return decorator class PostgresDatabaseLock: def __init__(self, lock_name, timeout=10, db=None): self.lock_name = lock_name - self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31-1) + self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31 - 1) self.timeout = int(timeout) self.db = db if db else DB @@ -542,7 +545,7 @@ class LLM(DataBaseModel): max_tokens = IntegerField(default=0) tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True) - is_tools = BooleanField(null=False, help_text="support tools", default=False) + is_tools = BooleanField(null=False, help_text="support tools", default=False) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) def __str__(self): @@ -796,6 +799,50 @@ class UserCanvasVersion(DataBaseModel): db_table = "user_canvas_version" +class Search(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + avatar = TextField(null=True, help_text="avatar base64 string") + tenant_id = CharField(max_length=32, null=False, index=True) + name = CharField(max_length=128, null=False, help_text="Search name", index=True) + description = TextField(null=True, help_text="KB description") + created_by = CharField(max_length=32, null=False, index=True) + search_config = JSONField( + null=False, + default={ + "kb_ids": [], + "doc_ids": [], + "similarity_threshold": 0.0, + "vector_similarity_weight": 0.3, + "use_kg": False, + # rerank settings + "rerank_id": "", + "top_k": 1024, + # chat settings + "summary": False, + "chat_id": "", + "llm_setting": { + "temperature": 0.1, + "top_p": 0.3, + "frequency_penalty": 0.7, + "presence_penalty": 0.4, + }, + "chat_settingcross_languages": [], + "highlight": False, + "keyword": False, + "web_search": False, + "related_search": False, + "query_mindmap": False, + }, + ) + status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) + + def __str__(self): + return self.name + + class Meta: + db_table = "search" + + def migrate_db(): migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) try: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 65a83ea23..211178a51 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -159,6 +159,7 @@ BAD_CITATION_PATTERNS = [ re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 ] + def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): max_index = len(kbinfos["chunks"]) @@ -555,7 +556,7 @@ def tts(tts_mdl, text): return binascii.hexlify(bin).decode("utf-8") -def ask(question, kb_ids, tenant_id): +def ask(question, kb_ids, tenant_id, chat_llm_name=None): kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -563,7 +564,7 @@ def ask(question, kb_ids, tenant_id): retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) diff --git a/api/db/services/search_service.py b/api/db/services/search_service.py new file mode 100644 index 000000000..c5c812cc9 --- /dev/null +++ b/api/db/services/search_service.py @@ -0,0 +1,110 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime + +from peewee import fn + +from api.db import StatusEnum +from api.db.db_models import DB, Search, User +from api.db.services.common_service import CommonService +from api.utils import current_timestamp, datetime_format + + +class SearchService(CommonService): + model = Search + + @classmethod + def save(cls, **kwargs): + kwargs["create_time"] = current_timestamp() + kwargs["create_date"] = datetime_format(datetime.now()) + kwargs["update_time"] = current_timestamp() + kwargs["update_date"] = datetime_format(datetime.now()) + obj = cls.model.create(**kwargs) + return obj + + @classmethod + @DB.connection_context() + def accessible4deletion(cls, search_id, user_id) -> bool: + search = ( + cls.model.select(cls.model.id) + .where( + cls.model.id == search_id, + cls.model.created_by == user_id, + cls.model.status == StatusEnum.VALID.value, + ) + .first() + ) + return search is not None + + @classmethod + @DB.connection_context() + def get_detail(cls, search_id): + fields = [ + cls.model.id, + cls.model.avatar, + cls.model.tenant_id, + cls.model.name, + cls.model.description, + cls.model.created_by, + cls.model.search_config, + cls.model.update_time, + User.nickname, + User.avatar.alias("tenant_avatar"), + ] + search = ( + cls.model.select(*fields) + .join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value))) + .where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value)) + .first() + .to_dict() + ) + return search + + @classmethod + @DB.connection_context() + def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords): + fields = [ + cls.model.id, + cls.model.avatar, + cls.model.tenant_id, + cls.model.name, + cls.model.description, + cls.model.created_by, + cls.model.status, + cls.model.update_time, + cls.model.create_time, + User.nickname, + User.avatar.alias("tenant_avatar"), + ] + query = ( + cls.model.select(*fields) + .join(User, on=(cls.model.tenant_id == User.id)) + .where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value)) + ) + + if keywords: + query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower())) + if desc: + query = query.order_by(cls.model.getter_by(orderby).desc()) + else: + query = query.order_by(cls.model.getter_by(orderby).asc()) + + count = query.count() + + if page_number and items_per_page: + query = query.paginate(page_number, items_per_page) + + return list(query.dicts()), count