diff --git a/admin/admin_client.py b/admin/admin_client.py index 007b73e29..399d173ab 100644 --- a/admin/admin_client.py +++ b/admin/admin_client.py @@ -1,3 +1,19 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import argparse import base64 diff --git a/admin/admin_server.py b/admin/admin_server.py index 27ee0c72a..e76b38642 100644 --- a/admin/admin_server.py +++ b/admin/admin_server.py @@ -1,3 +1,18 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import os import signal diff --git a/admin/auth.py b/admin/auth.py index 001ba5940..5160912a4 100644 --- a/admin/auth.py +++ b/admin/auth.py @@ -1,3 +1,20 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + import logging import uuid from functools import wraps diff --git a/admin/config.py b/admin/config.py index 94147de8e..609c321a7 100644 --- a/admin/config.py +++ b/admin/config.py @@ -1,3 +1,20 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + import logging import threading from enum import Enum @@ -35,7 +52,8 @@ class BaseConfig(BaseModel): detail_func_name: str def to_dict(self) -> dict[str, Any]: - return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, 'service_type': self.service_type} + return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, + 'service_type': self.service_type} class MetaConfig(BaseConfig): diff --git a/admin/models.py b/admin/models.py index e69de29bb..177b91dd0 100644 --- a/admin/models.py +++ b/admin/models.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/admin/responses.py b/admin/responses.py index 00cee7038..54f841a83 100644 --- a/admin/responses.py +++ b/admin/responses.py @@ -1,15 +1,34 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + from flask import jsonify -def success_response(data=None, message="Success", code = 0): + +def success_response(data=None, message="Success", code=0): return jsonify({ "code": code, "message": message, "data": data }), 200 + def error_response(message="Error", code=-1, data=None): return jsonify({ "code": code, "message": message, "data": data - }), 400 \ No newline at end of file + }), 400 diff --git a/admin/routes.py b/admin/routes.py index a737305de..afc82bc9d 100644 --- a/admin/routes.py +++ b/admin/routes.py @@ -1,3 +1,20 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + from flask import Blueprint, request from auth import login_verify @@ -42,7 +59,7 @@ def create_user(): res = UserMgr.create_user(username, password, role) if res["success"]: user_info = res["user_info"] - user_info.pop("password") # do not return password + user_info.pop("password") # do not return password return success_response(user_info, "User created successfully") else: return error_response("create user failed") @@ -102,6 +119,7 @@ def alter_user_activate_status(username): except Exception as e: return error_response(str(e), 500) + @admin_bp.route('/users/', methods=['GET']) @login_verify def get_user_details(username): @@ -114,6 +132,7 @@ def get_user_details(username): except Exception as e: return error_response(str(e), 500) + @admin_bp.route('/users//datasets', methods=['GET']) @login_verify def get_user_datasets(username): diff --git a/admin/services.py b/admin/services.py index 2c8eaaf7c..3aa738fd5 100644 --- a/admin/services.py +++ b/admin/services.py @@ -1,3 +1,20 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + import re from werkzeug.security import check_password_hash from api.db import ActiveEnum @@ -12,13 +29,15 @@ from api.utils import health_utils from api.common.exceptions import AdminException, UserAlreadyExistsError, UserNotFoundError from config import SERVICE_CONFIGS + class UserMgr: @staticmethod def get_all_users(): users = UserService.get_all_users() result = [] for user in users: - result.append({'email': user.email, 'nickname': user.nickname, 'create_date': user.create_date, 'is_active': user.is_active}) + result.append({'email': user.email, 'nickname': user.nickname, 'create_date': user.create_date, + 'is_active': user.is_active}) return result @staticmethod @@ -112,6 +131,7 @@ class UserMgr: UserService.update_user(usr.id, {"is_active": target_status}) return f"Turn {_activate_status} user activate status successfully!" + class UserServiceMgr: @staticmethod @@ -150,6 +170,7 @@ class UserServiceMgr: 'canvas_category': r['canvas_category'] } for r in res] + class ServiceMgr: @staticmethod diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 24370f1ca..07c16d97d 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -121,7 +121,7 @@ class Retrieval(ToolBase, ABC): if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) - kbinfos = settings.retrievaler.retrieval( + kbinfos = settings.retriever.retrieval( query, embd_mdl, [kb.tenant_id for kb in kbs], @@ -135,7 +135,7 @@ class Retrieval(ToolBase, ABC): rank_feature=label_question(query, kbs), ) if self._param.use_kg: - ck = settings.kg_retrievaler.retrieval(query, + ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], kb_ids, embd_mdl, @@ -146,7 +146,7 @@ class Retrieval(ToolBase, ABC): kbinfos = {"chunks": [], "doc_aggs": []} if self._param.use_kg and kbs: - ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) + ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ck["content"] = ck["content_with_weight"] del ck["content_with_weight"] diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 1bdb7c2f8..4637009d6 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -536,7 +536,7 @@ def list_chunks(): ) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids) + res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids) res = [ { "content": res_item["content_with_weight"], @@ -884,7 +884,7 @@ def retrieval(): if req.get("keyword", False): chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, + ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight= highlight, rank_feature=label_question(question, kbs)) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index bfd80ea9f..5c3c30d65 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -60,7 +60,7 @@ def list_chunk(): } if "available_int" in req: query["available_int"] = int(req["available_int"]) - sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) + sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=True) res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} for id in sres.ids: d = { @@ -346,7 +346,7 @@ def retrieval_test(): question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) - ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, + ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, float(req.get("similarity_threshold", 0.0)), float(req.get("vector_similarity_weight", 0.3)), top, @@ -354,7 +354,7 @@ def retrieval_test(): rank_feature=labels ) if use_kg: - ck = settings.kg_retrievaler.retrieval(question, + ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl, @@ -384,7 +384,7 @@ def knowledge_graph(): "doc_ids": [doc_id], "knowledge_graph_kwd": ["graph", "mind_map"] } - sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids) + sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids) obj = {"graph": {}, "mind_map": {}} for id in sres.ids[:2]: ty = sres.field[id]["knowledge_graph_kwd"] diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 0c56b15ad..bca28fb6f 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -282,7 +282,7 @@ def list_tags(kb_id): tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tags = [] for tenant in tenants: - tags += settings.retrievaler.all_tags(tenant["tenant_id"], [kb_id]) + tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id]) return get_json_result(data=tags) @@ -301,7 +301,7 @@ def list_tags_from_kbs(): tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tags = [] for tenant in tenants: - tags += settings.retrievaler.all_tags(tenant["tenant_id"], kb_ids) + tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids) return get_json_result(data=tags) @@ -362,7 +362,7 @@ def knowledge_graph(kb_id): obj = {"graph": {}, "mind_map": {}} if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): return get_json_result(data=obj) - sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [kb_id]) + sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) if not len(sres.ids): return get_json_result(data=obj) diff --git a/api/apps/plugin_app.py b/api/apps/plugin_app.py index dcd209daa..9ca04416d 100644 --- a/api/apps/plugin_app.py +++ b/api/apps/plugin_app.py @@ -1,8 +1,26 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + from flask import Response from flask_login import login_required from api.utils.api_utils import get_json_result from plugin import GlobalPluginManager + @manager.route('/llm_tools', methods=['GET']) # noqa: F821 @login_required def llm_tools() -> Response: diff --git a/api/apps/sdk/agent.py b/api/apps/sdk/agent.py index 704a3ffcf..b41328365 100644 --- a/api/apps/sdk/agent.py +++ b/api/apps/sdk/agent.py @@ -25,6 +25,7 @@ from api.utils.api_utils import get_data_error_result, get_error_data_result, ge from api.utils.api_utils import get_result from flask import request + @manager.route('/agents', methods=['GET']) # noqa: F821 @token_required def list_agents(tenant_id): @@ -41,7 +42,7 @@ def list_agents(tenant_id): desc = False else: desc = True - canvas = UserCanvasService.get_list(tenant_id,page_number,items_per_page,orderby,desc,id,title) + canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title) return get_result(data=canvas) @@ -93,7 +94,7 @@ def update_agent(tenant_id: str, agent_id: str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.loads(req["dsl"]) - + if req.get("title") is not None: req["title"] = req["title"].strip() diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 7b25f1d8b..ff446055c 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -215,7 +215,8 @@ def delete(tenant_id): continue kb_id_instance_pairs.append((kb_id, kb)) if len(error_kb_ids) > 0: - return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") + return get_error_permission_result( + message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") errors = [] success_count = 0 @@ -232,7 +233,8 @@ def delete(tenant_id): ] ) File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) + FileService.filter_delete( + [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) if not KnowledgebaseService.delete_by_id(kb_id): errors.append(f"Delete dataset error for {kb_id}") continue @@ -329,7 +331,8 @@ def update(tenant_id, dataset_id): try: kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) if kb is None: - return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") + return get_error_permission_result( + message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") if req.get("parser_config"): req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"]) @@ -341,7 +344,8 @@ def update(tenant_id, dataset_id): del req["parser_config"] if "name" in req and req["name"].lower() != kb.name.lower(): - exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) + exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, + status=StatusEnum.VALID.value) if exists: return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") @@ -349,7 +353,8 @@ def update(tenant_id, dataset_id): if not req["embd_id"]: req["embd_id"] = kb.embd_id if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: - return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") + return get_error_data_result( + message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") ok, err = verify_embedding_availability(req["embd_id"], tenant_id) if not ok: return err @@ -359,10 +364,12 @@ def update(tenant_id, dataset_id): return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") if req["pagerank"] > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, + search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, + search.index_name(kb.tenant_id), kb.id) if not KnowledgebaseService.update_by_id(kb.id, req): return get_error_data_result(message="Update dataset error.(Database error)") @@ -473,9 +480,10 @@ def list_datasets(tenant_id): logging.exception(e) return get_error_data_result(message="Database operation failed") + @manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821 @token_required -def knowledge_graph(tenant_id,dataset_id): +def knowledge_graph(tenant_id, dataset_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id): return get_result( data=False, @@ -491,7 +499,7 @@ def knowledge_graph(tenant_id,dataset_id): obj = {"graph": {}, "mind_map": {}} if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): return get_result(data=obj) - sres = settings.retrievaler.search(req, search.index_name(kb.tenant_id), [dataset_id]) + sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) if not len(sres.ids): return get_result(data=obj) @@ -507,14 +515,16 @@ def knowledge_graph(tenant_id,dataset_id): if "nodes" in obj["graph"]: obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] if "edges" in obj["graph"]: - node_id_set = { o["id"] for o in obj["graph"]["nodes"] } - filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] + node_id_set = {o["id"] for o in obj["graph"]["nodes"]} + filtered_edges = [o for o in obj["graph"]["edges"] if + o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] return get_result(data=obj) + @manager.route('/datasets//knowledge_graph', methods=['DELETE']) # noqa: F821 @token_required -def delete_knowledge_graph(tenant_id,dataset_id): +def delete_knowledge_graph(tenant_id, dataset_id): if not KnowledgebaseService.accessible(dataset_id, tenant_id): return get_result( data=False, @@ -522,6 +532,7 @@ def delete_knowledge_graph(tenant_id,dataset_id): code=settings.RetCode.AUTHENTICATION_ERROR ) _, kb = KnowledgebaseService.get_by_id(dataset_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, + search.index_name(kb.tenant_id), dataset_id) return get_result(data=True) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 446d4d74a..dc5476f34 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -1,4 +1,4 @@ - # +# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,9 +38,9 @@ def retrieval(tenant_id): retrieval_setting = req.get("retrieval_setting", {}) similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) top = int(retrieval_setting.get("top_k", 1024)) - metadata_condition = req.get("metadata_condition",{}) + metadata_condition = req.get("metadata_condition", {}) metas = DocumentService.get_meta_by_kbs([kb_id]) - + doc_ids = [] try: @@ -50,12 +50,12 @@ def retrieval(tenant_id): embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) print(metadata_condition) - print("after",convert_conditions(metadata_condition)) + # print("after", convert_conditions(metadata_condition)) doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition))) - print("doc_ids",doc_ids) + # print("doc_ids", doc_ids) if not doc_ids and metadata_condition is not None: doc_ids = ['-999'] - ranks = settings.retrievaler.retrieval( + ranks = settings.retriever.retrieval( question, embd_mdl, kb.tenant_id, @@ -70,17 +70,17 @@ def retrieval(tenant_id): ) if use_kg: - ck = settings.kg_retrievaler.retrieval(question, - [tenant_id], - [kb_id], - embd_mdl, - LLMBundle(kb.tenant_id, LLMType.CHAT)) + ck = settings.kg_retriever.retrieval(question, + [tenant_id], + [kb_id], + embd_mdl, + LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) records = [] for c in ranks["chunks"]: - e, doc = DocumentService.get_by_id( c["doc_id"]) + e, doc = DocumentService.get_by_id(c["doc_id"]) c.pop("vector", None) meta = getattr(doc, 'meta_fields', {}) meta["doc_id"] = c["doc_id"] @@ -100,5 +100,3 @@ def retrieval(tenant_id): ) logging.exception(e) return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR) - - diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 8d5a413b0..2cc2926df 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -982,7 +982,7 @@ def list_chunks(tenant_id, dataset_id, document_id): _ = Chunk(**final_chunk) elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): - sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) + sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) res["total"] = sres.total for id in sres.ids: d = { @@ -1446,7 +1446,7 @@ def retrieval_test(tenant_id): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - ranks = settings.retrievaler.retrieval( + ranks = settings.retriever.retrieval( question, embd_mdl, tenant_ids, @@ -1462,7 +1462,7 @@ def retrieval_test(tenant_id): rank_feature=label_question(question, kbs), ) if use_kg: - ck = settings.kg_retrievaler.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) + ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 96efe208d..17af87dc0 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -1,3 +1,20 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + import pathlib import re @@ -17,7 +34,8 @@ from api.utils.api_utils import get_json_result from api.utils.file_utils import filename_type from rag.utils.storage_factory import STORAGE_IMPL -@manager.route('/file/upload', methods=['POST']) # noqa: F821 + +@manager.route('/file/upload', methods=['POST']) # noqa: F821 @token_required def upload(tenant_id): """ @@ -97,12 +115,14 @@ def upload(tenant_id): e, file = FileService.get_by_id(file_id_list[len_id_list - 1]) if not e: return get_json_result(data=False, message="Folder not found!", code=404) - last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, len_id_list) + last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, + len_id_list) else: e, file = FileService.get_by_id(file_id_list[len_id_list - 2]) if not e: return get_json_result(data=False, message="Folder not found!", code=404) - last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, len_id_list) + last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, + len_id_list) filetype = filename_type(file_obj_names[file_len - 1]) location = file_obj_names[file_len - 1] @@ -129,7 +149,7 @@ def upload(tenant_id): return server_error_response(e) -@manager.route('/file/create', methods=['POST']) # noqa: F821 +@manager.route('/file/create', methods=['POST']) # noqa: F821 @token_required def create(tenant_id): """ @@ -207,7 +227,7 @@ def create(tenant_id): return server_error_response(e) -@manager.route('/file/list', methods=['GET']) # noqa: F821 +@manager.route('/file/list', methods=['GET']) # noqa: F821 @token_required def list_files(tenant_id): """ @@ -299,7 +319,7 @@ def list_files(tenant_id): return server_error_response(e) -@manager.route('/file/root_folder', methods=['GET']) # noqa: F821 +@manager.route('/file/root_folder', methods=['GET']) # noqa: F821 @token_required def get_root_folder(tenant_id): """ @@ -335,7 +355,7 @@ def get_root_folder(tenant_id): return server_error_response(e) -@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821 +@manager.route('/file/parent_folder', methods=['GET']) # noqa: F821 @token_required def get_parent_folder(): """ @@ -380,7 +400,7 @@ def get_parent_folder(): return server_error_response(e) -@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821 +@manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821 @token_required def get_all_parent_folders(tenant_id): """ @@ -428,7 +448,7 @@ def get_all_parent_folders(tenant_id): return server_error_response(e) -@manager.route('/file/rm', methods=['POST']) # noqa: F821 +@manager.route('/file/rm', methods=['POST']) # noqa: F821 @token_required def rm(tenant_id): """ @@ -502,7 +522,7 @@ def rm(tenant_id): return server_error_response(e) -@manager.route('/file/rename', methods=['POST']) # noqa: F821 +@manager.route('/file/rename', methods=['POST']) # noqa: F821 @token_required def rename(tenant_id): """ @@ -542,7 +562,8 @@ def rename(tenant_id): if not e: return get_json_result(message="File not found!", code=404) - if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(file.name.lower()).suffix: + if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path( + file.name.lower()).suffix: return get_json_result(data=False, message="The extension of file can't be changed", code=400) for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id): @@ -562,9 +583,9 @@ def rename(tenant_id): return server_error_response(e) -@manager.route('/file/get/', methods=['GET']) # noqa: F821 +@manager.route('/file/get/', methods=['GET']) # noqa: F821 @token_required -def get(tenant_id,file_id): +def get(tenant_id, file_id): """ Download a file. --- @@ -610,7 +631,7 @@ def get(tenant_id,file_id): return server_error_response(e) -@manager.route('/file/mv', methods=['POST']) # noqa: F821 +@manager.route('/file/mv', methods=['POST']) # noqa: F821 @token_required def move(tenant_id): """ @@ -669,6 +690,7 @@ def move(tenant_id): except Exception as e: return server_error_response(e) + @manager.route('/file/convert', methods=['POST']) # noqa: F821 @token_required def convert(tenant_id): @@ -735,4 +757,4 @@ def convert(tenant_id): file2documents.append(file2document.to_json()) return get_json_result(data=file2documents) except Exception as e: - return server_error_response(e) \ No newline at end of file + return server_error_response(e) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 10b6e9752..684d00928 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -36,7 +36,8 @@ from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.utils import get_uuid -from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request +from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \ + get_result, server_error_response, token_required, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format @@ -88,7 +89,8 @@ def create_agent_session(tenant_id, agent_id): canvas.reset() cvs.dsl = json.loads(str(canvas)) - conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl} + conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, + "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl} API4ConversationService.save(**conv) conv["agent_id"] = conv.pop("dialog_id") return get_result(data=conv) @@ -279,7 +281,7 @@ def chat_completion_openai_like(tenant_id, chat_id): reasoning_match = re.search(r"(.*?)", answer, flags=re.DOTALL) if reasoning_match: reasoning_part = reasoning_match.group(1) - content_part = answer[reasoning_match.end() :] + content_part = answer[reasoning_match.end():] else: reasoning_part = "" content_part = answer @@ -324,7 +326,8 @@ def chat_completion_openai_like(tenant_id, chat_id): response["choices"][0]["delta"]["content"] = None response["choices"][0]["delta"]["reasoning_content"] = None response["choices"][0]["finish_reason"] = "stop" - response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used} + response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, + "total_tokens": len(prompt) + token_used} if need_reference: response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", [])) response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "") @@ -559,7 +562,8 @@ def list_agent_session(tenant_id, agent_id): desc = True # dsl defaults to True in all cases except for False and false include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" - total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl) + total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, + user_id, include_dsl) if not convs: return get_result(data=[]) for conv in convs: @@ -581,7 +585,8 @@ def list_agent_session(tenant_id, agent_id): if message_num != 0 and messages[message_num]["role"] != "user": chunk_list = [] # Add boundary and type checks to prevent KeyError - if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]: + if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance( + conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]: chunks = conv["reference"][chunk_num]["chunks"] for chunk in chunks: # Ensure chunk is a dictionary before calling get method @@ -639,13 +644,16 @@ def delete(tenant_id, chat_id): if errors: if success_count > 0: - return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors") + return get_result(data={"success_count": success_count, "errors": errors}, + message=f"Partially deleted {success_count} sessions with {len(errors)} errors") else: return get_error_data_result(message="; ".join(errors)) if duplicate_messages: if success_count > 0: - return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages}) + return get_result( + message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", + data={"success_count": success_count, "errors": duplicate_messages}) else: return get_error_data_result(message=";".join(duplicate_messages)) @@ -691,13 +699,16 @@ def delete_agent_session(tenant_id, agent_id): if errors: if success_count > 0: - return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors") + return get_result(data={"success_count": success_count, "errors": errors}, + message=f"Partially deleted {success_count} sessions with {len(errors)} errors") else: return get_error_data_result(message="; ".join(errors)) if duplicate_messages: if success_count > 0: - return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages}) + return get_result( + message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", + data={"success_count": success_count, "errors": duplicate_messages}) else: return get_error_data_result(message=";".join(duplicate_messages)) @@ -730,7 +741,9 @@ def ask_about(tenant_id): for ans in ask(req["question"], req["kb_ids"], uid): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps( + {"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" resp = Response(stream(), mimetype="text/event-stream") @@ -882,7 +895,9 @@ def begin_inputs(agent_id): return get_error_data_result(f"Can't find agent by ID: {agent_id}") canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id) - return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()}) + return get_result( + data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), + "prologue": canvas.get_prologue(), "mode": canvas.get_mode()}) @manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821 @@ -911,7 +926,9 @@ def ask_about_embedded(): for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps( + {"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" resp = Response(stream(), mimetype="text/event-stream") @@ -978,7 +995,8 @@ def retrieval_test_embedded(): tenant_ids.append(tenant.tenant_id) break else: - return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", + code=settings.RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: @@ -998,11 +1016,13 @@ def retrieval_test_embedded(): question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) - ranks = settings.retrievaler.retrieval( - question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels + ranks = settings.retriever.retrieval( + question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, + doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) if use_kg: - ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) + ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl, + LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -1013,7 +1033,8 @@ def retrieval_test_embedded(): return get_json_result(data=ranks) except Exception as e: if str(e).find("not_found") > 0: - return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR) + return get_json_result(data=False, message="No chunk found! Check the chunk status please!", + code=settings.RetCode.DATA_ERROR) return server_error_response(e) @@ -1082,7 +1103,8 @@ def detail_share_embedded(): if SearchService.query(tenant_id=tenant.tenant_id, id=search_id): break else: - return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Has no permission for this operation.", + code=settings.RetCode.OPERATING_ERROR) search = SearchService.get_detail(search_id) if not search: diff --git a/api/apps/system_app.py b/api/apps/system_app.py index fa2b5f116..4302813e8 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -39,6 +39,7 @@ from rag.utils.redis_conn import REDIS_CONN from flask import jsonify from api.utils.health_utils import run_health_checks + @manager.route("/version", methods=["GET"]) # noqa: F821 @login_required def version(): @@ -161,7 +162,7 @@ def status(): task_executors = REDIS_CONN.smembers("TASKEXE") now = datetime.now().timestamp() for task_executor_id in task_executors: - heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60*30, now) + heartbeats = REDIS_CONN.zrangebyscore(task_executor_id, now - 60 * 30, now) heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats] task_executor_heartbeats[task_executor_id] = heartbeats except Exception: @@ -273,7 +274,8 @@ def token_list(): objs = [o.to_dict() for o in objs] for o in objs: if not o["beta"]: - o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace("ragflow-", "")[:32] + o["beta"] = generate_confirmation_token(generate_confirmation_token(tenants[0].tenant_id)).replace( + "ragflow-", "")[:32] APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o) return get_json_result(data=objs) except Exception as e: diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py index 63c7f74b7..10668491e 100644 --- a/api/apps/tenant_app.py +++ b/api/apps/tenant_app.py @@ -70,7 +70,8 @@ def create(tenant_id): return get_data_error_result(message=f"{invite_user_email} is already in the team.") if user_tenant_role == UserTenantRole.OWNER: return get_data_error_result(message=f"{invite_user_email} is the owner of the team.") - return get_data_error_result(message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.") + return get_data_error_result( + message=f"{invite_user_email} is in the team, but the role: {user_tenant_role} is invalid.") UserTenantService.save( id=get_uuid(), @@ -132,7 +133,8 @@ def tenant_list(): @login_required def agree(tenant_id): try: - UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], {"role": UserTenantRole.NORMAL}) + UserTenantService.filter_update([UserTenant.tenant_id == tenant_id, UserTenant.user_id == current_user.id], + {"role": UserTenantRole.NORMAL}) return get_json_result(data=True) except Exception as e: return server_error_response(e) diff --git a/api/common/exceptions.py b/api/common/exceptions.py index 5ce0e0bc2..0790ff4cc 100644 --- a/api/common/exceptions.py +++ b/api/common/exceptions.py @@ -1,3 +1,20 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + class AdminException(Exception): def __init__(self, message, code=400): super().__init__(message) @@ -18,4 +35,4 @@ class UserAlreadyExistsError(AdminException): class CannotDeleteAdminError(AdminException): def __init__(self): - super().__init__("Cannot delete admin account", 403) \ No newline at end of file + super().__init__("Cannot delete admin account", 403) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 673000ff9..a8ddf178d 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -370,7 +370,7 @@ def chat(dialog, messages, stream=True, **kwargs): chat_mdl.bind_tools(toolcall_session, tools) bind_models_ts = timer() - retriever = settings.retrievaler + retriever = settings.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] if "doc_ids" in messages[-1]: @@ -472,7 +472,7 @@ def chat(dialog, messages, stream=True, **kwargs): kbinfos["chunks"].extend(tav_res["chunks"]) kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) if prompt_config.get("use_kg"): - ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, + ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) @@ -658,7 +658,7 @@ Please write the SQL, only SQL, without any other explanations or text. logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 - return settings.retrievaler.sql_retrieval(sql, format="json"), sql + return settings.retriever.sql_retrieval(sql, format="json"), sql tbl, sql = get_table() if tbl is None: @@ -752,7 +752,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): embedding_list = list(set([kb.embd_id for kb in kbs])) is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) - retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler + retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) @@ -848,7 +848,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): if not doc_ids: doc_ids = None - ranks = settings.retrievaler.retrieval( + ranks = settings.retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, diff --git a/api/db/services/mcp_server_service.py b/api/db/services/mcp_server_service.py index 101555f4b..1eae882d6 100644 --- a/api/db/services/mcp_server_service.py +++ b/api/db/services/mcp_server_service.py @@ -33,7 +33,8 @@ class MCPServerService(CommonService): @classmethod @DB.connection_context() - def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, keywords): + def get_servers(cls, tenant_id: str, id_list: list[str] | None, page_number, items_per_page, orderby, desc, + keywords): """Retrieve all MCP servers associated with a tenant. This method fetches all MCP servers for a given tenant, ordered by creation time. diff --git a/api/db/services/search_service.py b/api/db/services/search_service.py index acb07da57..de69f2837 100644 --- a/api/db/services/search_service.py +++ b/api/db/services/search_service.py @@ -94,7 +94,8 @@ class SearchService(CommonService): query = ( cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) - .where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value)) + .where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & ( + cls.model.status == StatusEnum.VALID.value)) ) if keywords: diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index f31494b0e..416417339 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -165,7 +165,7 @@ class TaskService(CommonService): ] tasks = ( cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc()) - .where(cls.model.doc_id == doc_id) + .where(cls.model.doc_id == doc_id) ) tasks = list(tasks.dicts()) if not tasks: @@ -205,18 +205,18 @@ class TaskService(CommonService): cls.model.select( *[Document.id, Document.kb_id, Document.location, File.parent_id] ) - .join(Document, on=(cls.model.doc_id == Document.id)) - .join( + .join(Document, on=(cls.model.doc_id == Document.id)) + .join( File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER, ) - .join( + .join( File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER, ) - .where( + .where( Document.status == StatusEnum.VALID.value, Document.run == TaskStatus.RUNNING.value, ~(Document.type == FileType.VIRTUAL.value), @@ -294,8 +294,8 @@ class TaskService(CommonService): cls.model.update(progress=prog).where( (cls.model.id == id) & ( - (cls.model.progress != -1) & - ((prog == -1) | (prog > cls.model.progress)) + (cls.model.progress != -1) & + ((prog == -1) | (prog > cls.model.progress)) ) ).execute() else: @@ -343,6 +343,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int): - Task digests are calculated for optimization and reuse - Previous task chunks may be reused if available """ + def new_task(): return { "id": get_uuid(), @@ -515,7 +516,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE task["file"] = file if not REDIS_CONN.queue_product( - get_svr_queue_name(priority), message=task + get_svr_queue_name(priority), message=task ): return False, "Can't access Redis. Please check the Redis' status." diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index 4eca970ec..6e826ebae 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -57,8 +57,10 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() def get_my_llms(cls, tenant_id): - fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens] - objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() + fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, + cls.model.used_tokens] + objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where( + cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts() return list(objs) @@ -122,7 +124,8 @@ class TenantLLMService(CommonService): model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""} if not model_config: if mdlnm == "flag-embedding": - model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""} + model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, + "api_base": ""} else: if not mdlnm: raise LookupError(f"Type of {llm_type} model is not set.") @@ -137,27 +140,33 @@ class TenantLLMService(CommonService): if llm_type == LLMType.EMBEDDING.value: if model_config["llm_factory"] not in EmbeddingModel: return - return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], + base_url=model_config["api_base"]) if llm_type == LLMType.RERANK: if model_config["llm_factory"] not in RerankModel: return - return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) + return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], + base_url=model_config["api_base"]) if llm_type == LLMType.IMAGE2TEXT.value: if model_config["llm_factory"] not in CvModel: return - return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) + return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, + base_url=model_config["api_base"], **kwargs) if llm_type == LLMType.CHAT.value: if model_config["llm_factory"] not in ChatModel: return - return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) + return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], + base_url=model_config["api_base"], **kwargs) if llm_type == LLMType.SPEECH2TEXT: if model_config["llm_factory"] not in Seq2txtModel: return - return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) + return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], + model_name=model_config["llm_name"], lang=lang, + base_url=model_config["api_base"]) if llm_type == LLMType.TTS: if model_config["llm_factory"] not in TTSModel: return @@ -194,11 +203,14 @@ class TenantLLMService(CommonService): try: num = ( cls.model.update(used_tokens=cls.model.used_tokens + used_tokens) - .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True) + .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, + cls.model.llm_factory == llm_factory if llm_factory else True) .execute() ) except Exception: - logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name) + logging.exception( + "TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", + tenant_id, llm_name) return 0 return num @@ -206,7 +218,9 @@ class TenantLLMService(CommonService): @classmethod @DB.connection_context() def get_openai_models(cls): - objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts() + objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), + ~(cls.model.llm_name == "text-embedding-3-small"), + ~(cls.model.llm_name == "text-embedding-3-large")).dicts() return list(objs) @classmethod @@ -250,8 +264,9 @@ class LLM4Tenant: langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id) self.langfuse = None if langfuse_keys: - langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) + langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, + host=langfuse_keys.host) if langfuse.auth_check(): self.langfuse = langfuse trace_id = self.langfuse.create_trace_id() - self.trace_context = {"trace_id": trace_id} \ No newline at end of file + self.trace_context = {"trace_id": trace_id} diff --git a/api/db/services/user_canvas_version.py b/api/db/services/user_canvas_version.py index 9696a7834..89f73264f 100644 --- a/api/db/services/user_canvas_version.py +++ b/api/db/services/user_canvas_version.py @@ -2,22 +2,22 @@ from api.db.db_models import UserCanvasVersion, DB from api.db.services.common_service import CommonService from peewee import DoesNotExist + class UserCanvasVersionService(CommonService): model = UserCanvasVersion - - + @classmethod @DB.connection_context() def list_by_canvas_id(cls, user_canvas_id): try: user_canvas_version = cls.model.select( - *[cls.model.id, - cls.model.create_time, - cls.model.title, - cls.model.create_date, - cls.model.update_date, - cls.model.user_canvas_id, - cls.model.update_time] + *[cls.model.id, + cls.model.create_time, + cls.model.title, + cls.model.create_date, + cls.model.update_date, + cls.model.user_canvas_id, + cls.model.update_time] ).where(cls.model.user_canvas_id == user_canvas_id) return user_canvas_version except DoesNotExist: @@ -46,18 +46,16 @@ class UserCanvasVersionService(CommonService): @DB.connection_context() def delete_all_versions(cls, user_canvas_id): try: - user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by(cls.model.create_time.desc()) + user_canvas_version = cls.model.select().where(cls.model.user_canvas_id == user_canvas_id).order_by( + cls.model.create_time.desc()) if user_canvas_version.count() > 20: delete_ids = [] for i in range(20, user_canvas_version.count()): delete_ids.append(user_canvas_version[i].id) - + cls.delete_by_ids(delete_ids) return True except DoesNotExist: return None except Exception: return None - - - diff --git a/api/settings.py b/api/settings.py index e6763d8a2..9c003d27b 100644 --- a/api/settings.py +++ b/api/settings.py @@ -65,8 +65,8 @@ OAUTH_CONFIG = None DOC_ENGINE = None docStoreConn = None -retrievaler = None -kg_retrievaler = None +retriever = None +kg_retriever = None # user registration switch REGISTER_ENABLED = 1 @@ -174,7 +174,7 @@ def init_settings(): OAUTH_CONFIG = get_base_config("oauth", {}) - global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler + global DOC_ENGINE, docStoreConn, retriever, kg_retriever DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") # DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") lower_case_doc_engine = DOC_ENGINE.lower() @@ -187,10 +187,10 @@ def init_settings(): else: raise Exception(f"Not supported doc engine: {DOC_ENGINE}") - retrievaler = search.Dealer(docStoreConn) + retriever = search.Dealer(docStoreConn) from graphrag import search as kg_search - kg_retrievaler = kg_search.KGSearch(docStoreConn) + kg_retriever = kg_search.KGSearch(docStoreConn) if int(os.environ.get("SANDBOX_ENABLED", "0")): global SANDBOX_HOST diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 821ec6d31..c9e7ae455 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -60,6 +60,7 @@ from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_ requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) + def serialize_for_json(obj): """ Recursively serialize objects to make them JSON serializable. @@ -68,8 +69,8 @@ def serialize_for_json(obj): if hasattr(obj, '__dict__'): # For objects with __dict__, try to serialize their attributes try: - return {key: serialize_for_json(value) for key, value in obj.__dict__.items() - if not key.startswith('_')} + return {key: serialize_for_json(value) for key, value in obj.__dict__.items() + if not key.startswith('_')} except (AttributeError, TypeError): return str(obj) elif hasattr(obj, '__name__'): @@ -85,6 +86,7 @@ def serialize_for_json(obj): # Fallback: convert to string representation return str(obj) + def request(**kwargs): sess = requests.Session() stream = kwargs.pop("stream", sess.stream) @@ -105,7 +107,8 @@ def request(**kwargs): settings.HTTP_APP_KEY.encode("ascii"), prepped.path_url.encode("ascii"), prepped.body if kwargs.get("json") else b"", - urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"", + urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode( + "ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"", ] ), "sha1", @@ -127,7 +130,7 @@ def request(**kwargs): def get_exponential_backoff_interval(retries, full_jitter=False): """Calculate the exponential backoff wait time.""" # Will be zero if factor equals 0 - countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries)) + countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) # Full jitter according to # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ if full_jitter: @@ -158,11 +161,12 @@ def server_error_response(e): if len(e.args) > 1: try: serialized_data = serialize_for_json(e.args[1]) - return get_json_result(code= settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data) + return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=serialized_data) except Exception: return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=None) if repr(e).find("index_not_found_exception") >= 0: - return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.") + return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, + message="No chunk found, please upload file and parse it.") return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) @@ -207,7 +211,8 @@ def validate_request(*args, **kwargs): if no_arguments: error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) if error_arguments: - error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) + error_string += "required argument values: {}".format( + ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string) return func(*_args, **_kwargs) @@ -222,7 +227,8 @@ def not_allowed_parameters(*params): input_arguments = flask_request.json or flask_request.form.to_dict() for param in params: if param in input_arguments: - return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") + return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, + message=f"Parameter {param} isn't allowed") return f(*args, **kwargs) return wrapper @@ -239,6 +245,7 @@ def active_required(f): if not usr or not usr.is_active == ActiveEnum.ACTIVE.value: return get_json_result(code=settings.RetCode.FORBIDDEN, message="User isn't active, please activate first.") return f(*args, **kwargs) + return wrapper @@ -259,7 +266,7 @@ def send_file_in_mem(data, filename): return send_file(f, as_attachment=True, attachment_filename=filename) -def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None): +def get_json_result(code: settings.RetCode = settings.RetCode.SUCCESS, message="success", data=None): response = {"code": code, "message": message, "data": data} return jsonify(response) @@ -314,7 +321,7 @@ def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing" return jsonify(response) -def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None): +def construct_json_result(code: settings.RetCode = settings.RetCode.SUCCESS, message="success", data=None): if data is None: return jsonify({"code": code, "message": message}) else: @@ -347,14 +354,15 @@ def token_required(func): token = authorization_list[1] objs = APIToken.query(token=token) if not objs: - return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR) + return get_json_result(data=False, message="Authentication error: API key is invalid!", + code=settings.RetCode.AUTHENTICATION_ERROR) kwargs["tenant_id"] = objs[0].tenant_id return func(*args, **kwargs) return decorated_function -def get_result(code=settings.RetCode.SUCCESS, message="", data=None): +def get_result(code: settings.RetCode = settings.RetCode.SUCCESS, message="", data=None): if code == 0: if data is not None: response = {"code": code, "data": data} @@ -366,8 +374,8 @@ def get_result(code=settings.RetCode.SUCCESS, message="", data=None): def get_error_data_result( - message="Sorry! Data missing!", - code=settings.RetCode.DATA_ERROR, + message="Sorry! Data missing!", + code=settings.RetCode.DATA_ERROR, ): result_dict = {"code": code, "message": message} response = {} @@ -402,7 +410,8 @@ def get_parser_config(chunk_method, parser_config): # Define default configurations for each chunking method key_mapping = { - "naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}}, + "naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC", + "raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}}, "qa": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}}, "tag": None, "resume": None, @@ -441,16 +450,16 @@ def get_parser_config(chunk_method, parser_config): def get_data_openai( - id=None, - created=None, - model=None, - prompt_tokens=0, - completion_tokens=0, - content=None, - finish_reason=None, - object="chat.completion", - param=None, - stream=False + id=None, + created=None, + model=None, + prompt_tokens=0, + completion_tokens=0, + content=None, + finish_reason=None, + object="chat.completion", + param=None, + stream=False ): total_tokens = prompt_tokens + completion_tokens @@ -562,7 +571,9 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding")) tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id) - is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms) + is_tenant_model = any( + llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for + llm in tenant_llms) is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS if not (is_builtin_model or is_tenant_model or in_llm_service): @@ -793,7 +804,9 @@ async def is_strong_enough(chat_model, embedding_model): _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) if chat_model: with trio.fail_after(30): - res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {})) + res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", + "content": "Are you strong enough!?"}], + {})) if res.find("**ERROR**") >= 0: raise Exception(res) diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 6d0df65bb..7cb47de12 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -55,7 +55,7 @@ async def run_graphrag( start = trio.current_time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] - for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): + for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): chunks.append(d["content_with_weight"]) with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): @@ -170,7 +170,7 @@ async def run_graphrag_for_kb( chunks = [] current_chunk = "" - for d in settings.retrievaler.chunk_list( + for d in settings.retriever.chunk_list( doc_id, tenant_id, [kb_id], diff --git a/graphrag/general/smoke.py b/graphrag/general/smoke.py index 3f282fb07..5f9fe1437 100644 --- a/graphrag/general/smoke.py +++ b/graphrag/general/smoke.py @@ -62,7 +62,7 @@ async def main(): chunks = [ d["content_with_weight"] - for d in settings.retrievaler.chunk_list( + for d in settings.retriever.chunk_list( args.doc_id, args.tenant_id, [kb_id], diff --git a/graphrag/light/smoke.py b/graphrag/light/smoke.py index 504f09ce7..f8f505f65 100644 --- a/graphrag/light/smoke.py +++ b/graphrag/light/smoke.py @@ -63,7 +63,7 @@ async def main(): chunks = [ d["content_with_weight"] - for d in settings.retrievaler.chunk_list( + for d in settings.retriever.chunk_list( args.doc_id, args.tenant_id, [kb_id], diff --git a/graphrag/utils.py b/graphrag/utils.py index 6abe5f9a9..5e64cdb11 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -341,7 +341,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): ents = list(set(ents)) conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]} res = [] - es_res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) + es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) for id in es_res.ids: try: if size == 1: @@ -398,7 +398,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) + res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) doc_ids = [] if res.total == 0: return doc_ids @@ -409,7 +409,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph(tenant_id, kb_id, exclude_rebuild=None): conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(settings.retrievaler.search, conds, search.index_name(tenant_id), [kb_id]) + res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) if not res.total == 0: for id in res.ids: try: @@ -562,7 +562,7 @@ def merge_tuples(list1, list2): async def get_entity_type2samples(idxnms, kb_ids: list): - es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) + es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) res = defaultdict(list) for id in es_res.ids: diff --git a/rag/app/tag.py b/rag/app/tag.py index de2ce0fa6..e1a675652 100644 --- a/rag/app/tag.py +++ b/rag/app/tag.py @@ -133,14 +133,14 @@ def label_question(question, kbs): if tag_kb_ids: all_tags = get_tags_from_cache(tag_kb_ids) if not all_tags: - all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids) + all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids) set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids) else: all_tags = json.loads(all_tags) tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids) if not tag_kbs: return tags - tags = settings.retrievaler.tag_query(question, + tags = settings.retriever.tag_query(question, list(set([kb.tenant_id for kb in tag_kbs])), tag_kb_ids, all_tags, diff --git a/rag/benchmark.py b/rag/benchmark.py index 31a6c92ff..b73830073 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -52,7 +52,7 @@ class Benchmark: run = defaultdict(dict) query_list = list(qrels.keys()) for query in query_list: - ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, + ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, 0.0, self.vector_similarity_weight) if len(ranks["chunks"]) == 0: print(f"deleted query: {query}") diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index eb03a56b8..ae1021486 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -380,7 +380,7 @@ async def build_chunks(task, progress_callback): examples = [] all_tags = get_tags_from_cache(kb_ids) if not all_tags: - all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S) + all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S) set_tags_to_cache(kb_ids, all_tags) else: all_tags = json.loads(all_tags) @@ -393,7 +393,7 @@ async def build_chunks(task, progress_callback): if task_canceled: progress_callback(-1, msg="Task has been canceled.") return - if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: + if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) else: docs_to_tag.append(d) @@ -645,7 +645,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si chunks = [] vctr_nm = "q_%d_vec"%vector_size for doc_id in doc_ids: - for d in settings.retrievaler.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], + for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True): chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))