diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 22909a28a..93a50c0c3 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -25,6 +25,7 @@ from api.db.services.dialog_service import meta_filter from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api import settings +from common import globals from common.connection_utils import timeout from rag.app.tag import label_question from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter @@ -170,7 +171,7 @@ class Retrieval(ToolBase, ABC): if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) - kbinfos = settings.retriever.retrieval( + kbinfos = globals.retriever.retrieval( query, embd_mdl, [kb.tenant_id for kb in kbs], @@ -186,7 +187,7 @@ class Retrieval(ToolBase, ABC): ) if self._param.toc_enhance: chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) - cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) + cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) if cks: kbinfos["chunks"] = cks if self._param.use_kg: diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 306b54268..40341c5f7 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -32,7 +32,6 @@ from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService -from api import settings from common.misc_utils import get_uuid from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ @@ -48,6 +47,7 @@ from api.db.services.canvas_service import UserCanvasService from agent.canvas import Canvas from functools import partial from pathlib import Path +from common import globals @manager.route('/new_token', methods=['POST']) # noqa: F821 @@ -538,7 +538,7 @@ def list_chunks(): ) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids) + res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids) res = [ { "content": res_item["content_with_weight"], @@ -564,7 +564,7 @@ def get_chunk(chunk_id): try: tenant_id = objs[0].tenant_id kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) + chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) if chunk is None: return server_error_response(Exception("Chunk not found")) k = [] @@ -886,7 +886,7 @@ def retrieval(): if req.get("keyword", False): chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, + ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight= highlight, rank_feature=label_question(question, kbs)) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 96147d65c..eb3ee956a 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -25,7 +25,6 @@ from flask import request, Response from flask_login import login_required, current_user from agent.component import LLM -from api import settings from api.db import CanvasCategory, FileType from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.document_service import DocumentService @@ -46,6 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf from rag.flow.pipeline import Pipeline from rag.nlp import search from rag.utils.redis_conn import REDIS_CONN +from common import globals @manager.route('/templates', methods=['GET']) # noqa: F821 @@ -192,8 +192,8 @@ def rerun(): if 0 < doc["progress"] < 1: return get_data_error_result(message=f"`{doc['name']}` is processing...") - if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]): - settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) + if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]): + globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) doc["progress_msg"] = "" doc["chunk_num"] = 0 doc["token_num"] = 0 diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 89e59400b..0f5211240 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -36,6 +36,7 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr from rag.settings import PAGERANK_FLD from common.string_utils import remove_redundant_spaces from common.constants import RetCode, LLMType, ParserType +from common import globals @manager.route('/list', methods=['POST']) # noqa: F821 @@ -60,7 +61,7 @@ def list_chunk(): } if "available_int" in req: query["available_int"] = int(req["available_int"]) - sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) + sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} for id in sres.ids: d = { @@ -98,7 +99,7 @@ def get(): return get_data_error_result(message="Tenant not found!") for tenant in tenants: kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id) - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) + chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) if chunk: break if chunk is None: @@ -170,7 +171,7 @@ def set(): v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) + globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -186,7 +187,7 @@ def switch(): if not e: return get_data_error_result(message="Document not found!") for cid in req["chunk_ids"]: - if not settings.docStoreConn.update({"id": cid}, + if not globals.docStoreConn.update({"id": cid}, {"available_int": int(req["available_int"])}, search.index_name(DocumentService.get_tenant_id(req["doc_id"])), doc.kb_id): @@ -206,7 +207,7 @@ def rm(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") - if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, + if not globals.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(DocumentService.get_tenant_id(req["doc_id"])), doc.kb_id): return get_data_error_result(message="Chunk deleting failure") @@ -270,7 +271,7 @@ def create(): v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) + globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) DocumentService.increment_chunk_num( doc.id, doc.kb_id, c, 1, 0) @@ -346,7 +347,7 @@ def retrieval_test(): question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) - ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, + ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, float(req.get("similarity_threshold", 0.0)), float(req.get("vector_similarity_weight", 0.3)), top, @@ -385,7 +386,7 @@ def knowledge_graph(): "doc_ids": [doc_id], "knowledge_graph_kwd": ["graph", "mind_map"] } - sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids) + sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids) obj = {"graph": {}, "mind_map": {}} for id in sres.ids[:2]: ty = sres.field[id]["knowledge_graph_kwd"] diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 3d6a02710..b31535efd 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -23,7 +23,6 @@ import flask from flask import request from flask_login import current_user, login_required -from api import settings from api.common.check_team_permission import check_kb_team_permission from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.db import VALID_FILE_TYPES, FileType @@ -49,6 +48,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url from deepdoc.parser.html_parser import RAGFlowHtmlParser from rag.nlp import search, rag_tokenizer from rag.utils.storage_factory import STORAGE_IMPL +from common import globals @manager.route("/upload", methods=["POST"]) # noqa: F821 @@ -367,7 +367,7 @@ def change_status(): continue status_int = int(status) - if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id): + if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id): result[doc_id] = {"error": "Database error (docStore update)!"} result[doc_id] = {"status": status} except Exception as e: @@ -432,8 +432,8 @@ def run(): DocumentService.update_by_id(id, info) if req.get("delete", False): TaskService.filter_delete([Task.doc_id == id]) - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) + if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) if str(req["run"]) == TaskStatus.RUNNING.value: doc = doc.to_dict() @@ -479,8 +479,8 @@ def rename(): "title_tks": title_tks, "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), } - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.update( + if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + globals.docStoreConn.update( {"doc_id": req["doc_id"]}, es_body, search.index_name(tenant_id), @@ -541,8 +541,8 @@ def change_parser(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(message="Tenant not found!") - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) try: if "pipeline_id" in req and req["pipeline_id"] != "": diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 19f5fc8fa..28c2fa31e 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -35,7 +35,6 @@ from api.db import VALID_FILE_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.db_models import File from api.utils.api_utils import get_json_result -from api import settings from rag.nlp import search from api.constants import DATASET_NAME_LIMIT from rag.settings import PAGERANK_FLD @@ -43,7 +42,7 @@ from rag.utils.redis_conn import REDIS_CONN from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.doc_store_conn import OrderByExpr from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType - +from common import globals @manager.route('/create', methods=['post']) # noqa: F821 @login_required @@ -110,11 +109,11 @@ def update(): if kb.pagerank != req.get("pagerank", 0): if req.get("pagerank", 0) > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, + globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, + globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id) @@ -226,8 +225,8 @@ def rm(): return get_data_error_result( message="Database error (Knowledgebase removal)!") for kb in kbs: - settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) - settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) + globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) + globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) if hasattr(STORAGE_IMPL, 'remove_bucket'): STORAGE_IMPL.remove_bucket(kb.id) return get_json_result(data=True) @@ -248,7 +247,7 @@ def list_tags(kb_id): tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tags = [] for tenant in tenants: - tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id]) + tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id]) return get_json_result(data=tags) @@ -267,7 +266,7 @@ def list_tags_from_kbs(): tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tags = [] for tenant in tenants: - tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids) + tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids) return get_json_result(data=tags) @@ -284,7 +283,7 @@ def rm_tags(kb_id): e, kb = KnowledgebaseService.get_by_id(kb_id) for t in req["tags"]: - settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, + globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, {"remove": {"tag_kwd": t}}, search.index_name(kb.tenant_id), kb_id) @@ -303,7 +302,7 @@ def rename_tags(kb_id): ) e, kb = KnowledgebaseService.get_by_id(kb_id) - settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, + globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}}, search.index_name(kb.tenant_id), kb_id) @@ -326,9 +325,9 @@ def knowledge_graph(kb_id): } obj = {"graph": {}, "mind_map": {}} - if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): + if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): return get_json_result(data=obj) - sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) + sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) if not len(sres.ids): return get_json_result(data=obj) @@ -360,7 +359,7 @@ def delete_knowledge_graph(kb_id): code=RetCode.AUTHENTICATION_ERROR ) _, kb = KnowledgebaseService.get_by_id(kb_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) + globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) return get_json_result(data=True) @@ -732,13 +731,13 @@ def delete_kb_task(): task_id = kb.graphrag_task_id kb_task_finish_at = "graphrag_task_finish_at" cancel_task(task_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) + globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) case PipelineTaskType.RAPTOR: kb_task_id_field = "raptor_task_id" task_id = kb.raptor_task_id kb_task_finish_at = "raptor_task_finish_at" cancel_task(task_id) - settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id) + globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id) case PipelineTaskType.MINDMAP: kb_task_id_field = "mindmap_task_id" task_id = kb.mindmap_task_id @@ -850,7 +849,7 @@ def check_embedding(): tenant_id = kb.tenant_id emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) - samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) + samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) results, eff_sims = [], [] for ck in samples: diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index c49097fac..3d9be95f7 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va from common.constants import StatusEnum, LLMType from api.db.db_models import TenantLLM from api.utils.api_utils import get_json_result, get_allowed_llm_factories -from common.base64_image import test_image +from rag.utils.base64_image import test_image from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 20a7578c9..19ae9a60e 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -20,7 +20,6 @@ import os import json from flask import request from peewee import OperationalError -from api import settings from api.db.db_models import File from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService @@ -49,6 +48,7 @@ from api.utils.validation_utils import ( ) from rag.nlp import search from rag.settings import PAGERANK_FLD +from common import globals @manager.route("/datasets", methods=["POST"]) # noqa: F821 @@ -360,11 +360,11 @@ def update(tenant_id, dataset_id): return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") if req["pagerank"] > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, + globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, + globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) if not KnowledgebaseService.update_by_id(kb.id, req): @@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id): } obj = {"graph": {}, "mind_map": {}} - if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): + if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): return get_result(data=obj) - sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) + sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) if not len(sres.ids): return get_result(data=obj) @@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id): code=RetCode.AUTHENTICATION_ERROR ) _, kb = KnowledgebaseService.get_by_id(dataset_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, + globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) return get_result(data=True) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 9380cc4f0..b91e5faaa 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -25,6 +25,7 @@ from api.utils.api_utils import validate_request, build_error_result, apikey_req from rag.app.tag import label_question from api.db.services.dialog_service import meta_filter, convert_conditions from common.constants import RetCode, LLMType +from common import globals @manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 @apikey_required @@ -137,7 +138,7 @@ def retrieval(tenant_id): # print("doc_ids", doc_ids) if not doc_ids and metadata_condition is not None: doc_ids = ['-999'] - ranks = settings.retriever.retrieval( + ranks = globals.retriever.retrieval( question, embd_mdl, kb.tenant_id, diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index f415db0cd..30d9f877c 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -44,6 +44,7 @@ from rag.prompts.generator import cross_languages, keyword_extraction from rag.utils.storage_factory import STORAGE_IMPL from common.string_utils import remove_redundant_spaces from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource +from common import globals MAXIMUM_OF_UPLOADING_FILES = 256 @@ -307,7 +308,7 @@ def update_doc(tenant_id, dataset_id, document_id): ) if not e: return get_error_data_result(message="Document not found!") - settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) + globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) if "enabled" in req: status = int(req["enabled"]) @@ -316,7 +317,7 @@ def update_doc(tenant_id, dataset_id, document_id): if not DocumentService.update_by_id(doc.id, {"status": str(status)}): return get_error_data_result(message="Database error (Document update)!") - settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) + globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) return get_result(data=True) except Exception as e: return server_error_response(e) @@ -755,7 +756,7 @@ def parse(tenant_id, dataset_id): return get_error_data_result("Can't parse document that is currently being processed") info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} DocumentService.update_by_id(id, info) - settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) + globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) TaskService.filter_delete([Task.doc_id == id]) e, doc = DocumentService.get_by_id(id) doc = doc.to_dict() @@ -835,7 +836,7 @@ def stop_parsing(tenant_id, dataset_id): return get_error_data_result("Can't stop parsing document with progress at 0 or 1") info = {"run": "2", "progress": 0, "chunk_num": 0} DocumentService.update_by_id(id, info) - settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) + globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) success_count += 1 if duplicate_messages: if success_count > 0: @@ -968,7 +969,7 @@ def list_chunks(tenant_id, dataset_id, document_id): res = {"total": 0, "chunks": [], "doc": renamed_doc} if req.get("id"): - chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) + chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) if not chunk: return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND) k = [] @@ -995,8 +996,8 @@ def list_chunks(tenant_id, dataset_id, document_id): res["chunks"].append(final_chunk) _ = Chunk(**final_chunk) - elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): - sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) + elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): + sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) res["total"] = sres.total for id in sres.ids: d = { @@ -1120,7 +1121,7 @@ def add_chunk(tenant_id, dataset_id, document_id): v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) + globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) # rename keys @@ -1201,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id): if "chunk_ids" in req: unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") condition["id"] = unique_chunk_ids - chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) + chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) if chunk_number != 0: DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) if "chunk_ids" in req and chunk_number != len(unique_chunk_ids): @@ -1273,7 +1274,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): schema: type: object """ - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) + chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) if chunk is None: return get_error_data_result(f"Can't find this chunk {chunk_id}") if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): @@ -1318,7 +1319,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) + globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) return get_result() @@ -1464,7 +1465,7 @@ def retrieval_test(tenant_id): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - ranks = settings.retriever.retrieval( + ranks = globals.retriever.retrieval( question, embd_mdl, tenant_ids, diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 2997e1f00..3d3043bc5 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -41,6 +41,7 @@ from rag.app.tag import label_question from rag.prompts.template import load_prompt from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format from common.constants import RetCode, LLMType, StatusEnum +from common import globals @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required @@ -1015,7 +1016,7 @@ def retrieval_test_embedded(): question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) - ranks = settings.retriever.retrieval( + ranks = globals.retriever.retrieval( question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) diff --git a/api/apps/system_app.py b/api/apps/system_app.py index c1467f3d1..10b94102d 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -38,6 +38,7 @@ from timeit import default_timer as timer from rag.utils.redis_conn import REDIS_CONN from flask import jsonify from api.utils.health_utils import run_health_checks +from common import globals @manager.route("/version", methods=["GET"]) # noqa: F821 @@ -100,7 +101,7 @@ def status(): res = {} st = timer() try: - res["doc_engine"] = settings.docStoreConn.health() + res["doc_engine"] = globals.docStoreConn.health() res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) except Exception as e: res["doc_engine"] = { diff --git a/api/apps/user_app.py b/api/apps/user_app.py index f8151f887..8dd4ff10e 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -58,6 +58,7 @@ from api.utils.web_utils import ( hash_code, captcha_key, ) +from common import globals @manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @@ -623,7 +624,7 @@ def user_register(user_id, user): "id": user_id, "name": user["nickname"] + "‘s Kingdom", "llm_id": settings.CHAT_MDL, - "embd_id": settings.EMBEDDING_MDL, + "embd_id": globals.EMBEDDING_MDL, "asr_id": settings.ASR_MDL, "parser_ids": settings.PARSERS, "img2txt_id": settings.IMAGE2TEXT_MDL, diff --git a/api/db/init_data.py b/api/db/init_data.py index a520f2f41..ffce4fe53 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -32,6 +32,7 @@ from api.db.services.user_service import TenantService, UserTenantService from api import settings from common.constants import LLMType from common.file_utils import get_project_base_directory +from common import globals from api.common.base64 import encode_to_base64 @@ -49,7 +50,7 @@ def init_superuser(): "id": user_info["id"], "name": user_info["nickname"] + "‘s Kingdom", "llm_id": settings.CHAT_MDL, - "embd_id": settings.EMBEDDING_MDL, + "embd_id": globals.EMBEDDING_MDL, "asr_id": settings.ASR_MDL, "parser_ids": settings.PARSERS, "img2txt_id": settings.IMAGE2TEXT_MDL diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py index 733cb518d..79ff70714 100644 --- a/api/db/joint_services/user_account_service.py +++ b/api/db/joint_services/user_account_service.py @@ -38,6 +38,7 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS from rag.utils.storage_factory import STORAGE_IMPL from rag.nlp import search from common.constants import ActiveEnum +from common import globals def create_new_user(user_info: dict) -> dict: """ @@ -63,7 +64,7 @@ def create_new_user(user_info: dict) -> dict: "id": user_id, "name": user_info["nickname"] + "‘s Kingdom", "llm_id": settings.CHAT_MDL, - "embd_id": settings.EMBEDDING_MDL, + "embd_id": globals.EMBEDDING_MDL, "asr_id": settings.ASR_MDL, "parser_ids": settings.PARSERS, "img2txt_id": settings.IMAGE2TEXT_MDL, @@ -179,7 +180,7 @@ def delete_user_data(user_id: str) -> dict: ) done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n" # step1.1.3 delete chunk in es - r = settings.docStoreConn.delete({"kb_id": kb_ids}, + r = globals.docStoreConn.delete({"kb_id": kb_ids}, search.index_name(tenant_id), kb_ids) done_msg += f"- Deleted {r} chunk records.\n" kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids) @@ -237,7 +238,7 @@ def delete_user_data(user_id: str) -> dict: kb_doc_info = {} for _tenant_id, kb_doc in kb_grouped_doc.items(): for _kb_id, docs in kb_doc.items(): - chunk_delete_res += settings.docStoreConn.delete( + chunk_delete_res += globals.docStoreConn.delete( {"doc_id": [d["id"] for d in docs]}, search.index_name(_tenant_id), _kb_id ) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index ec5dc523f..6e5bab0da 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -44,6 +44,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces +from common import globals class DialogService(CommonService): @@ -371,7 +372,7 @@ def chat(dialog, messages, stream=True, **kwargs): chat_mdl.bind_tools(toolcall_session, tools) bind_models_ts = timer() - retriever = settings.retriever + retriever = globals.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] if "doc_ids" in messages[-1]: @@ -663,7 +664,7 @@ Please write the SQL, only SQL, without any other explanations or text. logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 - return settings.retriever.sql_retrieval(sql, format="json"), sql + return globals.retriever.sql_retrieval(sql, format="json"), sql tbl, sql = get_table() if tbl is None: @@ -757,7 +758,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): embedding_list = list(set([kb.embd_id for kb in kbs])) is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) - retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever + retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) @@ -853,7 +854,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): if not doc_ids: doc_ids = None - ranks = settings.retriever.retrieval( + ranks = globals.retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 5351e7d92..b8f3f1580 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -26,7 +26,6 @@ import trio import xxhash from peewee import fn, Case, JOIN -from api import settings from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT from api.db import FileType, UserTenantRole, CanvasCategory from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \ @@ -42,7 +41,7 @@ from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME from rag.utils.redis_conn import REDIS_CONN from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.doc_store_conn import OrderByExpr - +from common import globals class DocumentService(CommonService): model = Document @@ -309,10 +308,10 @@ class DocumentService(CommonService): page_size = 1000 all_chunk_ids = [] while True: - chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), + chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id]) - chunk_ids = settings.docStoreConn.getChunkIds(chunks) + chunk_ids = globals.docStoreConn.getChunkIds(chunks) if not chunk_ids: break all_chunk_ids.extend(chunk_ids) @@ -323,19 +322,19 @@ class DocumentService(CommonService): if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX): if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) - settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) - graph_source = settings.docStoreConn.getFields( - settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] + graph_source = globals.docStoreConn.getFields( + globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] ) if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, + globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, {"remove": {"source_id": doc.id}}, search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, + globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, {"removed_kwd": "Y"}, search.index_name(tenant_id), doc.kb_id) - settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, + globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, search.index_name(tenant_id), doc.kb_id) except Exception: pass @@ -996,10 +995,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): d["q_%d_vec" % len(v)] = v for b in range(0, len(cks), es_bulk_size): if try_create_idx: - if not settings.docStoreConn.indexExist(idxnm, kb_id): - settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) + if not globals.docStoreConn.indexExist(idxnm, kb_id): + globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) try_create_idx = False - settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) + globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) DocumentService.increment_chunk_num( doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 0cc0b0a93..1fcb8e7dc 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -34,7 +34,7 @@ from deepdoc.parser.excel_parser import RAGFlowExcelParser from rag.settings import get_svr_queue_name from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.redis_conn import REDIS_CONN -from api import settings +from common import globals from rag.nlp import search CANVAS_DEBUG_DOC_ID = "dataflow_x" @@ -418,7 +418,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int): if pre_task["chunk_ids"]: pre_chunk_ids.extend(pre_task["chunk_ids"].split()) if pre_chunk_ids: - settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]), + globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"]) DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num}) diff --git a/api/settings.py b/api/settings.py index 12b260ae4..c82a7cfff 100644 --- a/api/settings.py +++ b/api/settings.py @@ -32,12 +32,12 @@ LLM = None LLM_FACTORY = None LLM_BASE_URL = None CHAT_MDL = "" -EMBEDDING_MDL = "" +# EMBEDDING_MDL = "" has been moved to common/globals.py RERANK_MDL = "" ASR_MDL = "" IMAGE2TEXT_MDL = "" CHAT_CFG = "" - +# EMBEDDING_CFG = "" has been moved to common/globals.py RERANK_CFG = "" ASR_CFG = "" IMAGE2TEXT_CFG = "" @@ -61,10 +61,10 @@ HTTP_APP_KEY = None GITHUB_OAUTH = None FEISHU_OAUTH = None OAUTH_CONFIG = None -DOC_ENGINE = None -docStoreConn = None +# DOC_ENGINE = None has been moved to common/globals.py +# docStoreConn = None has been moved to common/globals.py -retriever = None +#retriever = None has been moved to common/globals.py kg_retriever = None # user registration switch @@ -125,7 +125,7 @@ def init_settings(): except Exception: FACTORY_LLM_INFOS = [] - global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL + global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY @@ -135,7 +135,7 @@ def init_settings(): ) chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)) - embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL)) + embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL)) rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL)) asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL)) image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)) @@ -147,7 +147,7 @@ def init_settings(): IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) CHAT_MDL = CHAT_CFG.get("model", "") or "" - EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else "" + globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else "" RERANK_MDL = RERANK_CFG.get("model", "") or "" ASR_MDL = ASR_CFG.get("model", "") or "" IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or "" @@ -169,23 +169,23 @@ def init_settings(): OAUTH_CONFIG = get_base_config("oauth", {}) - global DOC_ENGINE, docStoreConn, retriever, kg_retriever - DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") - # DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") - lower_case_doc_engine = DOC_ENGINE.lower() + global kg_retriever + globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") + # globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") + lower_case_doc_engine = globals.DOC_ENGINE.lower() if lower_case_doc_engine == "elasticsearch": - docStoreConn = rag.utils.es_conn.ESConnection() + globals.docStoreConn = rag.utils.es_conn.ESConnection() elif lower_case_doc_engine == "infinity": - docStoreConn = rag.utils.infinity_conn.InfinityConnection() + globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection() elif lower_case_doc_engine == "opensearch": - docStoreConn = rag.utils.opensearch_conn.OSConnection() + globals.docStoreConn = rag.utils.opensearch_conn.OSConnection() else: - raise Exception(f"Not supported doc engine: {DOC_ENGINE}") + raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}") - retriever = search.Dealer(docStoreConn) + globals.retriever = search.Dealer(globals.docStoreConn) from graphrag import search as kg_search - kg_retriever = kg_search.KGSearch(docStoreConn) + kg_retriever = kg_search.KGSearch(globals.docStoreConn) if int(os.environ.get("SANDBOX_ENABLED", "0")): global SANDBOX_HOST diff --git a/api/utils/health_utils.py b/api/utils/health_utils.py index 3a97fc572..f86a0d728 100644 --- a/api/utils/health_utils.py +++ b/api/utils/health_utils.py @@ -24,6 +24,7 @@ from rag.utils.redis_conn import REDIS_CONN from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.es_conn import ESConnection from rag.utils.infinity_conn import InfinityConnection +from common import globals def _ok_nok(ok: bool) -> str: @@ -52,7 +53,7 @@ def check_redis() -> tuple[bool, dict]: def check_doc_engine() -> tuple[bool, dict]: st = timer() try: - meta = settings.docStoreConn.health() + meta = globals.docStoreConn.health() # treat any successful call as ok return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})} except Exception as e: diff --git a/common/globals.py b/common/globals.py index 7e9879ef1..3f2859c74 100644 --- a/common/globals.py +++ b/common/globals.py @@ -14,4 +14,12 @@ # limitations under the License. # -EMBEDDING_CFG = "" \ No newline at end of file +EMBEDDING_MDL = "" + +EMBEDDING_CFG = "" + +DOC_ENGINE = None + +docStoreConn = None + +retriever = None \ No newline at end of file diff --git a/graphrag/general/index.py b/graphrag/general/index.py index 51f79f57f..9247df687 100644 --- a/graphrag/general/index.py +++ b/graphrag/general/index.py @@ -20,7 +20,6 @@ import os import networkx as nx import trio -from api import settings from api.db.services.document_service import DocumentService from common.misc_utils import get_uuid from common.connection_utils import timeout @@ -40,6 +39,7 @@ from graphrag.utils import ( ) from rag.nlp import rag_tokenizer, search from rag.utils.redis_conn import RedisDistributedLock +from common import globals async def run_graphrag( @@ -55,7 +55,7 @@ async def run_graphrag( start = trio.current_time() tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] chunks = [] - for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): + for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): chunks.append(d["content_with_weight"]) with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): @@ -170,7 +170,7 @@ async def run_graphrag_for_kb( chunks = [] current_chunk = "" - for d in settings.retriever.chunk_list( + for d in globals.retriever.chunk_list( doc_id, tenant_id, [kb_id], @@ -387,8 +387,8 @@ async def generate_subgraph( "removed_kwd": "N", } cid = chunk_id(chunk) - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) - await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) now = trio.current_time() callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") return subgraph @@ -496,7 +496,7 @@ async def extract_community( chunks.append(chunk) await trio.to_thread.run_sync( - lambda: settings.docStoreConn.delete( + lambda: globals.docStoreConn.delete( {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, search.index_name(tenant_id), kb_id, @@ -504,7 +504,7 @@ async def extract_community( ) es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) if doc_store_result: error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" raise Exception(error_message) diff --git a/graphrag/general/smoke.py b/graphrag/general/smoke.py index 7023df412..56b4863f9 100644 --- a/graphrag/general/smoke.py +++ b/graphrag/general/smoke.py @@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService from graphrag.general.graph_extractor import GraphExtractor from graphrag.general.index import update_graph, with_resolution, with_community +from common import globals settings.init_settings() @@ -62,7 +63,7 @@ async def main(): chunks = [ d["content_with_weight"] - for d in settings.retriever.chunk_list( + for d in globals.retriever.chunk_list( args.doc_id, args.tenant_id, [kb_id], diff --git a/graphrag/light/smoke.py b/graphrag/light/smoke.py index ec83ab6cf..c83dc8a91 100644 --- a/graphrag/light/smoke.py +++ b/graphrag/light/smoke.py @@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService from graphrag.general.index import update_graph from graphrag.light.graph_extractor import GraphExtractor +from common import globals settings.init_settings() @@ -63,7 +64,7 @@ async def main(): chunks = [ d["content_with_weight"] - for d in settings.retriever.chunk_list( + for d in globals.retriever.chunk_list( args.doc_id, args.tenant_id, [kb_id], diff --git a/graphrag/search.py b/graphrag/search.py index e54c70167..51ec23013 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -29,6 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr from rag.nlp.search import Dealer, index_name from common.float_utils import get_float +from common import globals class KGSearch(Dealer): @@ -334,6 +335,6 @@ if __name__ == "__main__": _, kb = KnowledgebaseService.get_by_id(kb_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) - kg = KGSearch(settings.docStoreConn) + kg = KGSearch(globals.docStoreConn) print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)) diff --git a/graphrag/utils.py b/graphrag/utils.py index b64a12265..c880eab9c 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -23,12 +23,12 @@ import trio import xxhash from networkx.readwrite import json_graph -from api import settings from common.misc_utils import get_uuid from common.connection_utils import timeout from rag.nlp import rag_tokenizer, search from rag.utils.doc_store_conn import OrderByExpr from rag.utils.redis_conn import REDIS_CONN +from common import globals GRAPH_FIELD_SEP = "" @@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): ents = list(set(ents)) conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]} res = [] - es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) + es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) for id in es_res.ids: try: if size == 1: @@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): "knowledge_graph_kwd": ["graph"], "removed_kwd": "N", } - res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) - fields2 = settings.docStoreConn.getFields(res, fields) + res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) + fields2 = globals.docStoreConn.getFields(res, fields) graph_doc_ids = set() for chunk_id in fields2.keys(): graph_doc_ids = set(fields2[chunk_id]["source_id"]) @@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) + res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id])) doc_ids = [] if res.total == 0: return doc_ids @@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph(tenant_id, kb_id, exclude_rebuild=None): conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} - res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) + res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id]) if not res.total == 0: for id in res.ids: try: @@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang global chat_limiter start = trio.current_time() - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) if change.removed_nodes: - await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) + await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) if change.removed_edges: async def del_edges(from_node, to_node): async with chat_limiter: await trio.to_thread.run_sync( - settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id + globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id ) async with trio.open_nursery() as nursery: @@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang es_bulk_size = 4 for b in range(0, len(chunks), es_bulk_size): with trio.fail_after(3 if enable_timeout_assertion else 30000000): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) if b % 100 == es_bulk_size and callback: callback(msg=f"Insert chunks: {b}/{len(chunks)}") if doc_store_result: @@ -555,7 +555,7 @@ def merge_tuples(list1, list2): async def get_entity_type2samples(idxnms, kb_ids: list): - es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) + es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) res = defaultdict(list) for id in es_res.ids: @@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): bs = 256 for i in range(0, 1024 * bs, bs): es_res = await trio.to_thread.run_sync( - lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) + lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) ) - # tot = settings.docStoreConn.getTotal(es_res) - es_res = settings.docStoreConn.getFields(es_res, flds) + # tot = globals.docStoreConn.getTotal(es_res) + es_res = globals.docStoreConn.getFields(es_res, flds) if len(es_res) == 0: break diff --git a/rag/app/tag.py b/rag/app/tag.py index e1a675652..dbe1aac55 100644 --- a/rag/app/tag.py +++ b/rag/app/tag.py @@ -21,6 +21,7 @@ from copy import deepcopy from deepdoc.parser.utils import get_text from rag.app.qa import Excel from rag.nlp import rag_tokenizer +from common import globals def beAdoc(d, q, a, eng, row_num=-1): @@ -124,7 +125,6 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs): def label_question(question, kbs): from api.db.services.knowledgebase_service import KnowledgebaseService from graphrag.utils import get_tags_from_cache, set_tags_to_cache - from api import settings tags = None tag_kb_ids = [] for kb in kbs: @@ -133,14 +133,14 @@ def label_question(question, kbs): if tag_kb_ids: all_tags = get_tags_from_cache(tag_kb_ids) if not all_tags: - all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids) + all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids) set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids) else: all_tags = json.loads(all_tags) tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids) if not tag_kbs: return tags - tags = settings.retriever.tag_query(question, + tags = globals.retriever.tag_query(question, list(set([kb.tenant_id for kb in tag_kbs])), tag_kb_ids, all_tags, diff --git a/rag/benchmark.py b/rag/benchmark.py index 173fcd140..031dbfb64 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -20,10 +20,10 @@ import time import argparse from collections import defaultdict +from common import globals from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.knowledgebase_service import KnowledgebaseService -from api import settings from common.misc_utils import get_uuid from rag.nlp import tokenize, search from ranx import evaluate @@ -52,7 +52,7 @@ class Benchmark: run = defaultdict(dict) query_list = list(qrels.keys()) for query in query_list: - ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, + ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, 0.0, self.vector_similarity_weight) if len(ranks["chunks"]) == 0: print(f"deleted query: {query}") @@ -77,9 +77,9 @@ class Benchmark: def init_index(self, vector_size: int): if self.initialized_index: return - if settings.docStoreConn.indexExist(self.index_name, self.kb_id): - settings.docStoreConn.deleteIdx(self.index_name, self.kb_id) - settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) + if globals.docStoreConn.indexExist(self.index_name, self.kb_id): + globals.docStoreConn.deleteIdx(self.index_name, self.kb_id) + globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) self.initialized_index = True def ms_marco_index(self, file_path, index_name): @@ -114,13 +114,13 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - settings.docStoreConn.insert(docs, self.index_name, self.kb_id) + globals.docStoreConn.insert(docs, self.index_name, self.kb_id) docs = [] if docs: docs, vector_size = self.embedding(docs) self.init_index(vector_size) - settings.docStoreConn.insert(docs, self.index_name, self.kb_id) + globals.docStoreConn.insert(docs, self.index_name, self.kb_id) return qrels, texts def trivia_qa_index(self, file_path, index_name): @@ -155,12 +155,12 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - settings.docStoreConn.insert(docs,self.index_name) + globals.docStoreConn.insert(docs,self.index_name) docs = [] docs, vector_size = self.embedding(docs) self.init_index(vector_size) - settings.docStoreConn.insert(docs, self.index_name) + globals.docStoreConn.insert(docs, self.index_name) return qrels, texts def miracl_index(self, file_path, corpus_path, index_name): @@ -210,12 +210,12 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - settings.docStoreConn.insert(docs, self.index_name) + globals.docStoreConn.insert(docs, self.index_name) docs = [] docs, vector_size = self.embedding(docs) self.init_index(vector_size) - settings.docStoreConn.insert(docs, self.index_name) + globals.docStoreConn.insert(docs, self.index_name) return qrels, texts def save_results(self, qrels, run, texts, dataset, file_path): diff --git a/rag/flow/hierarchical_merger/hierarchical_merger.py b/rag/flow/hierarchical_merger/hierarchical_merger.py index ded3cbead..2cc794f4b 100644 --- a/rag/flow/hierarchical_merger/hierarchical_merger.py +++ b/rag/flow/hierarchical_merger/hierarchical_merger.py @@ -21,7 +21,7 @@ from functools import partial import trio from common.misc_utils import get_uuid -from common.base64_image import id2image, image2id +from rag.utils.base64_image import id2image, image2id from deepdoc.parser.pdf_parser import RAGFlowPdfParser from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py index 28aaf600c..f147d738a 100644 --- a/rag/flow/parser/parser.py +++ b/rag/flow/parser/parser.py @@ -27,7 +27,7 @@ from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.llm_service import LLMBundle from common.misc_utils import get_uuid -from common.base64_image import image2id +from rag.utils.base64_image import image2id from deepdoc.parser import ExcelParser from deepdoc.parser.mineru_parser import MinerUParser from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser diff --git a/rag/flow/splitter/splitter.py b/rag/flow/splitter/splitter.py index a62c44580..4a944d050 100644 --- a/rag/flow/splitter/splitter.py +++ b/rag/flow/splitter/splitter.py @@ -18,7 +18,7 @@ from functools import partial import trio from common.misc_utils import get_uuid -from common.base64_image import id2image, image2id +from rag.utils.base64_image import id2image, image2id from deepdoc.parser.pdf_parser import RAGFlowPdfParser from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.splitter.schema import SplitterFromUpstream diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index a4f5edf0d..fa5a0d21a 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -30,7 +30,6 @@ from zhipuai import ZhipuAI from common.log_utils import log_exception from common.token_utils import num_tokens_from_string, truncate from common import globals -from api import settings import logging @@ -74,9 +73,9 @@ class BuiltinEmbed(Base): embedding_cfg = globals.EMBEDDING_CFG if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""): with BuiltinEmbed._model_lock: - BuiltinEmbed._model_name = settings.EMBEDDING_MDL - BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500) - BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"]) + BuiltinEmbed._model_name = globals.EMBEDDING_MDL + BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500) + BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"]) self._model = BuiltinEmbed._model self._model_name = BuiltinEmbed._model_name self._max_tokens = BuiltinEmbed._max_tokens diff --git a/rag/settings.py b/rag/settings.py index 57df2ee14..dbfc36880 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -18,13 +18,14 @@ import logging from common.config_utils import get_base_config, decrypt_database_config from common.file_utils import get_project_base_directory from common.misc_utils import pip_install_torch +from common import globals # Server RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf") # Get storage type and document engine from system environment variables STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO') -DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch') +globals.DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch') ES = {} INFINITY = {} @@ -35,11 +36,11 @@ OSS = {} OS = {} # Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration -if DOC_ENGINE == 'elasticsearch': +if globals.DOC_ENGINE == 'elasticsearch': ES = get_base_config("es", {}) -elif DOC_ENGINE == 'opensearch': +elif globals.DOC_ENGINE == 'opensearch': OS = get_base_config("os", {}) -elif DOC_ENGINE == 'infinity': +elif globals.DOC_ENGINE == 'infinity': INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 834ff59ee..e35306d4e 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -27,7 +27,7 @@ from api.db.services.canvas_service import UserCanvasService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from common.connection_utils import timeout -from common.base64_image import image2id +from rag.utils.base64_image import image2id from common.log_utils import init_root_logger from common.config_utils import show_configs from graphrag.general.index import run_graphrag_for_kb @@ -68,6 +68,7 @@ from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.utils.storage_factory import STORAGE_IMPL from graphrag.utils import chat_limiter from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc +from common import globals BATCH_SIZE = 64 @@ -349,7 +350,7 @@ async def build_chunks(task, progress_callback): examples = [] all_tags = get_tags_from_cache(kb_ids) if not all_tags: - all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S) + all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S) set_tags_to_cache(kb_ids, all_tags) else: all_tags = json.loads(all_tags) @@ -362,7 +363,7 @@ async def build_chunks(task, progress_callback): if task_canceled: progress_callback(-1, msg="Task has been canceled.") return - if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: + if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) else: docs_to_tag.append(d) @@ -423,7 +424,7 @@ def build_TOC(task, docs, progress_callback): def init_kb(row, vector_size: int): idxnm = search.index_name(row["tenant_id"]) - return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) + return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) async def embedding(docs, mdl, parser_config=None, callback=None): @@ -647,7 +648,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si chunks = [] vctr_nm = "q_%d_vec"%vector_size for doc_id in doc_ids: - for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], + for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True): chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) @@ -698,7 +699,7 @@ async def delete_image(kb_id, chunk_id): async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback): for b in range(0, len(chunks), DOC_BULK_SIZE): - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") @@ -715,7 +716,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c TaskService.update_chunk_ids(task_id, chunk_ids_str) except DoesNotExist: logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") - doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) + doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) async with trio.open_nursery() as nursery: for chunk_id in chunk_ids: nursery.start_soon(delete_image, task_dataset_id, chunk_id) @@ -751,7 +752,7 @@ async def do_handle_task(task): progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user - lower_case_doc_engine = settings.DOC_ENGINE.lower() + lower_case_doc_engine = globals.DOC_ENGINE.lower() if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table': error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine." progress_callback(-1, msg=error_message) diff --git a/common/base64_image.py b/rag/utils/base64_image.py similarity index 100% rename from common/base64_image.py rename to rag/utils/base64_image.py