Move some vars to globals (#11017)

### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-05 14:14:38 +08:00
committed by GitHub
parent cf9611c96f
commit 1a9215bc6f
35 changed files with 185 additions and 164 deletions

View File

@ -25,6 +25,7 @@ from api.db.services.dialog_service import meta_filter
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api import settings from api import settings
from common import globals
from common.connection_utils import timeout from common.connection_utils import timeout
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
@ -170,7 +171,7 @@ class Retrieval(ToolBase, ABC):
if kbs: if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE) query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = settings.retriever.retrieval( kbinfos = globals.retriever.retrieval(
query, query,
embd_mdl, embd_mdl,
[kb.tenant_id for kb in kbs], [kb.tenant_id for kb in kbs],
@ -186,7 +187,7 @@ class Retrieval(ToolBase, ABC):
) )
if self._param.toc_enhance: if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
if cks: if cks:
kbinfos["chunks"] = cks kbinfos["chunks"] = cks
if self._param.use_kg: if self._param.use_kg:

View File

@ -32,7 +32,6 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api import settings
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
@ -48,6 +47,7 @@ from api.db.services.canvas_service import UserCanvasService
from agent.canvas import Canvas from agent.canvas import Canvas
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from common import globals
@manager.route('/new_token', methods=['POST']) # noqa: F821 @manager.route('/new_token', methods=['POST']) # noqa: F821
@ -538,7 +538,7 @@ def list_chunks():
) )
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids) res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = [ res = [
{ {
"content": res_item["content_with_weight"], "content": res_item["content_with_weight"],
@ -564,7 +564,7 @@ def get_chunk(chunk_id):
try: try:
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None: if chunk is None:
return server_error_response(Exception("Chunk not found")) return server_error_response(Exception("Chunk not found"))
k = [] k = []
@ -886,7 +886,7 @@ def retrieval():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight, doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
rank_feature=label_question(question, kbs)) rank_feature=label_question(question, kbs))

View File

@ -25,7 +25,6 @@ from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from agent.component import LLM from agent.component import LLM
from api import settings
from api.db import CanvasCategory, FileType from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -46,6 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
from rag.nlp import search from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from common import globals
@manager.route('/templates', methods=['GET']) # noqa: F821 @manager.route('/templates', methods=['GET']) # noqa: F821
@ -192,8 +192,8 @@ def rerun():
if 0 < doc["progress"] < 1: if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...") return get_data_error_result(message=f"`{doc['name']}` is processing...")
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]): if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = "" doc["progress_msg"] = ""
doc["chunk_num"] = 0 doc["chunk_num"] = 0
doc["token_num"] = 0 doc["token_num"] = 0

View File

@ -36,6 +36,7 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from common.string_utils import remove_redundant_spaces from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType from common.constants import RetCode, LLMType, ParserType
from common import globals
@manager.route('/list', methods=['POST']) # noqa: F821 @manager.route('/list', methods=['POST']) # noqa: F821
@ -60,7 +61,7 @@ def list_chunk():
} }
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -98,7 +99,7 @@ def get():
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
for tenant in tenants: for tenant in tenants:
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
if chunk: if chunk:
break break
if chunk is None: if chunk is None:
@ -170,7 +171,7 @@ def set():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -186,7 +187,7 @@ def switch():
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]: for cid in req["chunk_ids"]:
if not settings.docStoreConn.update({"id": cid}, if not globals.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])}, {"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id): doc.kb_id):
@ -206,7 +207,7 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, if not globals.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id): doc.kb_id):
return get_data_error_result(message="Chunk deleting failure") return get_data_error_result(message="Chunk deleting failure")
@ -270,7 +271,7 @@ def create():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
@ -346,7 +347,7 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb]) labels = label_question(question, [kb])
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)), float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)), float(req.get("vector_similarity_weight", 0.3)),
top, top,
@ -385,7 +386,7 @@ def knowledge_graph():
"doc_ids": [doc_id], "doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"] "knowledge_graph_kwd": ["graph", "mind_map"]
} }
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids) sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]: for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"] ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -23,7 +23,6 @@ import flask
from flask import request from flask import request
from flask_login import current_user, login_required from flask_login import current_user, login_required
from api import settings
from api.common.check_team_permission import check_kb_team_permission from api.common.check_team_permission import check_kb_team_permission
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
from api.db import VALID_FILE_TYPES, FileType from api.db import VALID_FILE_TYPES, FileType
@ -49,6 +48,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
from deepdoc.parser.html_parser import RAGFlowHtmlParser from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from common import globals
@manager.route("/upload", methods=["POST"]) # noqa: F821 @manager.route("/upload", methods=["POST"]) # noqa: F821
@ -367,7 +367,7 @@ def change_status():
continue continue
status_int = int(status) status_int = int(status)
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id): if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
result[doc_id] = {"error": "Database error (docStore update)!"} result[doc_id] = {"error": "Database error (docStore update)!"}
result[doc_id] = {"status": status} result[doc_id] = {"status": status}
except Exception as e: except Exception as e:
@ -432,8 +432,8 @@ def run():
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
if req.get("delete", False): if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value: if str(req["run"]) == TaskStatus.RUNNING.value:
doc = doc.to_dict() doc = doc.to_dict()
@ -479,8 +479,8 @@ def rename():
"title_tks": title_tks, "title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
} }
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update( globals.docStoreConn.update(
{"doc_id": req["doc_id"]}, {"doc_id": req["doc_id"]},
es_body, es_body,
search.index_name(tenant_id), search.index_name(tenant_id),
@ -541,8 +541,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
try: try:
if "pipeline_id" in req and req["pipeline_id"] != "": if "pipeline_id" in req and req["pipeline_id"] != "":

View File

@ -35,7 +35,6 @@ from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api import settings
from rag.nlp import search from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
@ -43,7 +42,7 @@ from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType
from common import globals
@manager.route('/create', methods=['post']) # noqa: F821 @manager.route('/create', methods=['post']) # noqa: F821
@login_required @login_required
@ -110,11 +109,11 @@ def update():
if kb.pagerank != req.get("pagerank", 0): if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0: if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
else: else:
# Elasticsearch requires PAGERANK_FLD be non-zero! # Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
e, kb = KnowledgebaseService.get_by_id(kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id)
@ -226,8 +225,8 @@ def rm():
return get_data_error_result( return get_data_error_result(
message="Database error (Knowledgebase removal)!") message="Database error (Knowledgebase removal)!")
for kb in kbs: for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(STORAGE_IMPL, 'remove_bucket'): if hasattr(STORAGE_IMPL, 'remove_bucket'):
STORAGE_IMPL.remove_bucket(kb.id) STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True) return get_json_result(data=True)
@ -248,7 +247,7 @@ def list_tags(kb_id):
tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = [] tags = []
for tenant in tenants: for tenant in tenants:
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id]) tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id])
return get_json_result(data=tags) return get_json_result(data=tags)
@ -267,7 +266,7 @@ def list_tags_from_kbs():
tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = [] tags = []
for tenant in tenants: for tenant in tenants:
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids) tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids)
return get_json_result(data=tags) return get_json_result(data=tags)
@ -284,7 +283,7 @@ def rm_tags(kb_id):
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
for t in req["tags"]: for t in req["tags"]:
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
{"remove": {"tag_kwd": t}}, {"remove": {"tag_kwd": t}},
search.index_name(kb.tenant_id), search.index_name(kb.tenant_id),
kb_id) kb_id)
@ -303,7 +302,7 @@ def rename_tags(kb_id):
) )
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}}, {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id), search.index_name(kb.tenant_id),
kb_id) kb_id)
@ -326,9 +325,9 @@ def knowledge_graph(kb_id):
} }
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj) return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids): if not len(sres.ids):
return get_json_result(data=obj) return get_json_result(data=obj)
@ -360,7 +359,7 @@ def delete_knowledge_graph(kb_id):
code=RetCode.AUTHENTICATION_ERROR code=RetCode.AUTHENTICATION_ERROR
) )
_, kb = KnowledgebaseService.get_by_id(kb_id) _, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
return get_json_result(data=True) return get_json_result(data=True)
@ -732,13 +731,13 @@ def delete_kb_task():
task_id = kb.graphrag_task_id task_id = kb.graphrag_task_id
kb_task_finish_at = "graphrag_task_finish_at" kb_task_finish_at = "graphrag_task_finish_at"
cancel_task(task_id) cancel_task(task_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
case PipelineTaskType.RAPTOR: case PipelineTaskType.RAPTOR:
kb_task_id_field = "raptor_task_id" kb_task_id_field = "raptor_task_id"
task_id = kb.raptor_task_id task_id = kb.raptor_task_id
kb_task_finish_at = "raptor_task_finish_at" kb_task_finish_at = "raptor_task_finish_at"
cancel_task(task_id) cancel_task(task_id)
settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id) globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
case PipelineTaskType.MINDMAP: case PipelineTaskType.MINDMAP:
kb_task_id_field = "mindmap_task_id" kb_task_id_field = "mindmap_task_id"
task_id = kb.mindmap_task_id task_id = kb.mindmap_task_id
@ -850,7 +849,7 @@ def check_embedding():
tenant_id = kb.tenant_id tenant_id = kb.tenant_id
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
results, eff_sims = [], [] results, eff_sims = [], []
for ck in samples: for ck in samples:

View File

@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from common.constants import StatusEnum, LLMType from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result, get_allowed_llm_factories from api.utils.api_utils import get_json_result, get_allowed_llm_factories
from common.base64_image import test_image from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel

View File

@ -20,7 +20,6 @@ import os
import json import json
from flask import request from flask import request
from peewee import OperationalError from peewee import OperationalError
from api import settings
from api.db.db_models import File from api.db.db_models import File
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
@ -49,6 +48,7 @@ from api.utils.validation_utils import (
) )
from rag.nlp import search from rag.nlp import search
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from common import globals
@manager.route("/datasets", methods=["POST"]) # noqa: F821 @manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -360,11 +360,11 @@ def update(tenant_id, dataset_id):
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
if req["pagerank"] > 0: if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
else: else:
# Elasticsearch requires PAGERANK_FLD be non-zero! # Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
if not KnowledgebaseService.update_by_id(kb.id, req): if not KnowledgebaseService.update_by_id(kb.id, req):
@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id):
} }
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj) return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids): if not len(sres.ids):
return get_result(data=obj) return get_result(data=obj)
@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id):
code=RetCode.AUTHENTICATION_ERROR code=RetCode.AUTHENTICATION_ERROR
) )
_, kb = KnowledgebaseService.get_by_id(dataset_id) _, kb = KnowledgebaseService.get_by_id(dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
search.index_name(kb.tenant_id), dataset_id) search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True) return get_result(data=True)

View File

@ -25,6 +25,7 @@ from api.utils.api_utils import validate_request, build_error_result, apikey_req
from rag.app.tag import label_question from rag.app.tag import label_question
from api.db.services.dialog_service import meta_filter, convert_conditions from api.db.services.dialog_service import meta_filter, convert_conditions
from common.constants import RetCode, LLMType from common.constants import RetCode, LLMType
from common import globals
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 @manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
@apikey_required @apikey_required
@ -137,7 +138,7 @@ def retrieval(tenant_id):
# print("doc_ids", doc_ids) # print("doc_ids", doc_ids)
if not doc_ids and metadata_condition is not None: if not doc_ids and metadata_condition is not None:
doc_ids = ['-999'] doc_ids = ['-999']
ranks = settings.retriever.retrieval( ranks = globals.retriever.retrieval(
question, question,
embd_mdl, embd_mdl,
kb.tenant_id, kb.tenant_id,

View File

@ -44,6 +44,7 @@ from rag.prompts.generator import cross_languages, keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from common.string_utils import remove_redundant_spaces from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource
from common import globals
MAXIMUM_OF_UPLOADING_FILES = 256 MAXIMUM_OF_UPLOADING_FILES = 256
@ -307,7 +308,7 @@ def update_doc(tenant_id, dataset_id, document_id):
) )
if not e: if not e:
return get_error_data_result(message="Document not found!") return get_error_data_result(message="Document not found!")
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
if "enabled" in req: if "enabled" in req:
status = int(req["enabled"]) status = int(req["enabled"])
@ -316,7 +317,7 @@ def update_doc(tenant_id, dataset_id, document_id):
if not DocumentService.update_by_id(doc.id, {"status": str(status)}): if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
return get_error_data_result(message="Database error (Document update)!") return get_error_data_result(message="Database error (Document update)!")
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_result(data=True) return get_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -755,7 +756,7 @@ def parse(tenant_id, dataset_id):
return get_error_data_result("Can't parse document that is currently being processed") return get_error_data_result("Can't parse document that is currently being processed")
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id) e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict() doc = doc.to_dict()
@ -835,7 +836,7 @@ def stop_parsing(tenant_id, dataset_id):
return get_error_data_result("Can't stop parsing document with progress at 0 or 1") return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
info = {"run": "2", "progress": 0, "chunk_num": 0} info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
success_count += 1 success_count += 1
if duplicate_messages: if duplicate_messages:
if success_count > 0: if success_count > 0:
@ -968,7 +969,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
res = {"total": 0, "chunks": [], "doc": renamed_doc} res = {"total": 0, "chunks": [], "doc": renamed_doc}
if req.get("id"): if req.get("id"):
chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
if not chunk: if not chunk:
return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND) return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
k = [] k = []
@ -995,8 +996,8 @@ def list_chunks(tenant_id, dataset_id, document_id):
res["chunks"].append(final_chunk) res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk) _ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total res["total"] = sres.total
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -1120,7 +1121,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
# rename keys # rename keys
@ -1201,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
if "chunk_ids" in req: if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
condition["id"] = unique_chunk_ids condition["id"] = unique_chunk_ids
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0: if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(unique_chunk_ids): if "chunk_ids" in req and chunk_number != len(unique_chunk_ids):
@ -1273,7 +1274,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema: schema:
type: object type: object
""" """
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None: if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}") return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
@ -1318,7 +1319,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result() return get_result()
@ -1464,7 +1465,7 @@ def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval( ranks = globals.retriever.retrieval(
question, question,
embd_mdl, embd_mdl,
tenant_ids, tenant_ids,

View File

@ -41,6 +41,7 @@ from rag.app.tag import label_question
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
from common.constants import RetCode, LLMType, StatusEnum from common.constants import RetCode, LLMType, StatusEnum
from common import globals
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
@ -1015,7 +1016,7 @@ def retrieval_test_embedded():
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb]) labels = label_question(question, [kb])
ranks = settings.retriever.retrieval( ranks = globals.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
) )

View File

@ -38,6 +38,7 @@ from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify from flask import jsonify
from api.utils.health_utils import run_health_checks from api.utils.health_utils import run_health_checks
from common import globals
@manager.route("/version", methods=["GET"]) # noqa: F821 @manager.route("/version", methods=["GET"]) # noqa: F821
@ -100,7 +101,7 @@ def status():
res = {} res = {}
st = timer() st = timer()
try: try:
res["doc_engine"] = settings.docStoreConn.health() res["doc_engine"] = globals.docStoreConn.health()
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e: except Exception as e:
res["doc_engine"] = { res["doc_engine"] = {

View File

@ -58,6 +58,7 @@ from api.utils.web_utils import (
hash_code, hash_code,
captcha_key, captcha_key,
) )
from common import globals
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -623,7 +624,7 @@ def user_register(user_id, user):
"id": user_id, "id": user_id,
"name": user["nickname"] + "s Kingdom", "name": user["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL, "llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL, "embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL, "asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS, "parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL, "img2txt_id": settings.IMAGE2TEXT_MDL,

View File

@ -32,6 +32,7 @@ from api.db.services.user_service import TenantService, UserTenantService
from api import settings from api import settings
from common.constants import LLMType from common.constants import LLMType
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common import globals
from api.common.base64 import encode_to_base64 from api.common.base64 import encode_to_base64
@ -49,7 +50,7 @@ def init_superuser():
"id": user_info["id"], "id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom", "name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL, "llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL, "embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL, "asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS, "parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL "img2txt_id": settings.IMAGE2TEXT_MDL

View File

@ -38,6 +38,7 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search from rag.nlp import search
from common.constants import ActiveEnum from common.constants import ActiveEnum
from common import globals
def create_new_user(user_info: dict) -> dict: def create_new_user(user_info: dict) -> dict:
""" """
@ -63,7 +64,7 @@ def create_new_user(user_info: dict) -> dict:
"id": user_id, "id": user_id,
"name": user_info["nickname"] + "s Kingdom", "name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL, "llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL, "embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL, "asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS, "parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL, "img2txt_id": settings.IMAGE2TEXT_MDL,
@ -179,7 +180,7 @@ def delete_user_data(user_id: str) -> dict:
) )
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n" done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
# step1.1.3 delete chunk in es # step1.1.3 delete chunk in es
r = settings.docStoreConn.delete({"kb_id": kb_ids}, r = globals.docStoreConn.delete({"kb_id": kb_ids},
search.index_name(tenant_id), kb_ids) search.index_name(tenant_id), kb_ids)
done_msg += f"- Deleted {r} chunk records.\n" done_msg += f"- Deleted {r} chunk records.\n"
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids) kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
@ -237,7 +238,7 @@ def delete_user_data(user_id: str) -> dict:
kb_doc_info = {} kb_doc_info = {}
for _tenant_id, kb_doc in kb_grouped_doc.items(): for _tenant_id, kb_doc in kb_grouped_doc.items():
for _kb_id, docs in kb_doc.items(): for _kb_id, docs in kb_doc.items():
chunk_delete_res += settings.docStoreConn.delete( chunk_delete_res += globals.docStoreConn.delete(
{"doc_id": [d["id"] for d in docs]}, {"doc_id": [d["id"] for d in docs]},
search.index_name(_tenant_id), _kb_id search.index_name(_tenant_id), _kb_id
) )

View File

@ -44,6 +44,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily
from common.string_utils import remove_redundant_spaces from common.string_utils import remove_redundant_spaces
from common import globals
class DialogService(CommonService): class DialogService(CommonService):
@ -371,7 +372,7 @@ def chat(dialog, messages, stream=True, **kwargs):
chat_mdl.bind_tools(toolcall_session, tools) chat_mdl.bind_tools(toolcall_session, tools)
bind_models_ts = timer() bind_models_ts = timer()
retriever = settings.retriever retriever = globals.retriever
questions = [m["content"] for m in messages if m["role"] == "user"][-3:] questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
if "doc_ids" in messages[-1]: if "doc_ids" in messages[-1]:
@ -663,7 +664,7 @@ Please write the SQL, only SQL, without any other explanations or text.
logging.debug(f"{question} get SQL(refined): {sql}") logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1 tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql return globals.retriever.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table() tbl, sql = get_table()
if tbl is None: if tbl is None:
@ -757,7 +758,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
embedding_list = list(set([kb.embd_id for kb in kbs])) embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
@ -853,7 +854,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
if not doc_ids: if not doc_ids:
doc_ids = None doc_ids = None
ranks = settings.retriever.retrieval( ranks = globals.retriever.retrieval(
question=question, question=question,
embd_mdl=embd_mdl, embd_mdl=embd_mdl,
tenant_ids=tenant_ids, tenant_ids=tenant_ids,

View File

@ -26,7 +26,6 @@ import trio
import xxhash import xxhash
from peewee import fn, Case, JOIN from peewee import fn, Case, JOIN
from api import settings
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import FileType, UserTenantRole, CanvasCategory from api.db import FileType, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
@ -42,7 +41,7 @@ from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr from rag.utils.doc_store_conn import OrderByExpr
from common import globals
class DocumentService(CommonService): class DocumentService(CommonService):
model = Document model = Document
@ -309,10 +308,10 @@ class DocumentService(CommonService):
page_size = 1000 page_size = 1000
all_chunk_ids = [] all_chunk_ids = []
while True: while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id), page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id]) [doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks) chunk_ids = globals.docStoreConn.getChunkIds(chunks)
if not chunk_ids: if not chunk_ids:
break break
all_chunk_ids.extend(chunk_ids) all_chunk_ids.extend(chunk_ids)
@ -323,19 +322,19 @@ class DocumentService(CommonService):
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX): if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields( graph_source = globals.docStoreConn.getFields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
) )
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}}, {"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id) search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"}, {"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id) search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id) search.index_name(tenant_id), doc.kb_id)
except Exception: except Exception:
pass pass
@ -996,10 +995,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
if try_create_idx: if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id): if not globals.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

View File

@ -34,7 +34,7 @@ from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.settings import get_svr_queue_name from rag.settings import get_svr_queue_name
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from api import settings from common import globals
from rag.nlp import search from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x" CANVAS_DEBUG_DOC_ID = "dataflow_x"
@ -418,7 +418,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
if pre_task["chunk_ids"]: if pre_task["chunk_ids"]:
pre_chunk_ids.extend(pre_task["chunk_ids"].split()) pre_chunk_ids.extend(pre_task["chunk_ids"].split())
if pre_chunk_ids: if pre_chunk_ids:
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]), globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
chunking_config["kb_id"]) chunking_config["kb_id"])
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num}) DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})

View File

@ -32,12 +32,12 @@ LLM = None
LLM_FACTORY = None LLM_FACTORY = None
LLM_BASE_URL = None LLM_BASE_URL = None
CHAT_MDL = "" CHAT_MDL = ""
EMBEDDING_MDL = "" # EMBEDDING_MDL = "" has been moved to common/globals.py
RERANK_MDL = "" RERANK_MDL = ""
ASR_MDL = "" ASR_MDL = ""
IMAGE2TEXT_MDL = "" IMAGE2TEXT_MDL = ""
CHAT_CFG = "" CHAT_CFG = ""
# EMBEDDING_CFG = "" has been moved to common/globals.py
RERANK_CFG = "" RERANK_CFG = ""
ASR_CFG = "" ASR_CFG = ""
IMAGE2TEXT_CFG = "" IMAGE2TEXT_CFG = ""
@ -61,10 +61,10 @@ HTTP_APP_KEY = None
GITHUB_OAUTH = None GITHUB_OAUTH = None
FEISHU_OAUTH = None FEISHU_OAUTH = None
OAUTH_CONFIG = None OAUTH_CONFIG = None
DOC_ENGINE = None # DOC_ENGINE = None has been moved to common/globals.py
docStoreConn = None # docStoreConn = None has been moved to common/globals.py
retriever = None #retriever = None has been moved to common/globals.py
kg_retriever = None kg_retriever = None
# user registration switch # user registration switch
@ -125,7 +125,7 @@ def init_settings():
except Exception: except Exception:
FACTORY_LLM_INFOS = [] FACTORY_LLM_INFOS = []
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
@ -135,7 +135,7 @@ def init_settings():
) )
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)) chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL)) embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL))
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL)) rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL)) asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)) image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
@ -147,7 +147,7 @@ def init_settings():
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
CHAT_MDL = CHAT_CFG.get("model", "") or "" CHAT_MDL = CHAT_CFG.get("model", "") or ""
EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else "" globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
RERANK_MDL = RERANK_CFG.get("model", "") or "" RERANK_MDL = RERANK_CFG.get("model", "") or ""
ASR_MDL = ASR_CFG.get("model", "") or "" ASR_MDL = ASR_CFG.get("model", "") or ""
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or "" IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
@ -169,23 +169,23 @@ def init_settings():
OAUTH_CONFIG = get_base_config("oauth", {}) OAUTH_CONFIG = get_base_config("oauth", {})
global DOC_ENGINE, docStoreConn, retriever, kg_retriever global kg_retriever
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") # globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
lower_case_doc_engine = DOC_ENGINE.lower() lower_case_doc_engine = globals.DOC_ENGINE.lower()
if lower_case_doc_engine == "elasticsearch": if lower_case_doc_engine == "elasticsearch":
docStoreConn = rag.utils.es_conn.ESConnection() globals.docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity": elif lower_case_doc_engine == "infinity":
docStoreConn = rag.utils.infinity_conn.InfinityConnection() globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch": elif lower_case_doc_engine == "opensearch":
docStoreConn = rag.utils.opensearch_conn.OSConnection() globals.docStoreConn = rag.utils.opensearch_conn.OSConnection()
else: else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}") raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}")
retriever = search.Dealer(docStoreConn) globals.retriever = search.Dealer(globals.docStoreConn)
from graphrag import search as kg_search from graphrag import search as kg_search
kg_retriever = kg_search.KGSearch(docStoreConn) kg_retriever = kg_search.KGSearch(globals.docStoreConn)
if int(os.environ.get("SANDBOX_ENABLED", "0")): if int(os.environ.get("SANDBOX_ENABLED", "0")):
global SANDBOX_HOST global SANDBOX_HOST

View File

@ -24,6 +24,7 @@ from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.es_conn import ESConnection from rag.utils.es_conn import ESConnection
from rag.utils.infinity_conn import InfinityConnection from rag.utils.infinity_conn import InfinityConnection
from common import globals
def _ok_nok(ok: bool) -> str: def _ok_nok(ok: bool) -> str:
@ -52,7 +53,7 @@ def check_redis() -> tuple[bool, dict]:
def check_doc_engine() -> tuple[bool, dict]: def check_doc_engine() -> tuple[bool, dict]:
st = timer() st = timer()
try: try:
meta = settings.docStoreConn.health() meta = globals.docStoreConn.health()
# treat any successful call as ok # treat any successful call as ok
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})} return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
except Exception as e: except Exception as e:

View File

@ -14,4 +14,12 @@
# limitations under the License. # limitations under the License.
# #
EMBEDDING_MDL = ""
EMBEDDING_CFG = "" EMBEDDING_CFG = ""
DOC_ENGINE = None
docStoreConn = None
retriever = None

View File

@ -20,7 +20,6 @@ import os
import networkx as nx import networkx as nx
import trio import trio
from api import settings
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.connection_utils import timeout from common.connection_utils import timeout
@ -40,6 +39,7 @@ from graphrag.utils import (
) )
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import RedisDistributedLock
from common import globals
async def run_graphrag( async def run_graphrag(
@ -55,7 +55,7 @@ async def run_graphrag(
start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -170,7 +170,7 @@ async def run_graphrag_for_kb(
chunks = [] chunks = []
current_chunk = "" current_chunk = ""
for d in settings.retriever.chunk_list( for d in globals.retriever.chunk_list(
doc_id, doc_id,
tenant_id, tenant_id,
[kb_id], [kb_id],
@ -387,8 +387,8 @@ async def generate_subgraph(
"removed_kwd": "N", "removed_kwd": "N",
} }
cid = chunk_id(chunk) cid = chunk_id(chunk)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
now = trio.current_time() now = trio.current_time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph return subgraph
@ -496,7 +496,7 @@ async def extract_community(
chunks.append(chunk) chunks.append(chunk)
await trio.to_thread.run_sync( await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete( lambda: globals.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, {"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id), search.index_name(tenant_id),
kb_id, kb_id,
@ -504,7 +504,7 @@ async def extract_community(
) )
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if doc_store_result: if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message) raise Exception(error_message)

View File

@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from graphrag.general.graph_extractor import GraphExtractor from graphrag.general.graph_extractor import GraphExtractor
from graphrag.general.index import update_graph, with_resolution, with_community from graphrag.general.index import update_graph, with_resolution, with_community
from common import globals
settings.init_settings() settings.init_settings()
@ -62,7 +63,7 @@ async def main():
chunks = [ chunks = [
d["content_with_weight"] d["content_with_weight"]
for d in settings.retriever.chunk_list( for d in globals.retriever.chunk_list(
args.doc_id, args.doc_id,
args.tenant_id, args.tenant_id,
[kb_id], [kb_id],

View File

@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from graphrag.general.index import update_graph from graphrag.general.index import update_graph
from graphrag.light.graph_extractor import GraphExtractor from graphrag.light.graph_extractor import GraphExtractor
from common import globals
settings.init_settings() settings.init_settings()
@ -63,7 +64,7 @@ async def main():
chunks = [ chunks = [
d["content_with_weight"] d["content_with_weight"]
for d in settings.retriever.chunk_list( for d in globals.retriever.chunk_list(
args.doc_id, args.doc_id,
args.tenant_id, args.tenant_id,
[kb_id], [kb_id],

View File

@ -29,6 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name from rag.nlp.search import Dealer, index_name
from common.float_utils import get_float from common.float_utils import get_float
from common import globals
class KGSearch(Dealer): class KGSearch(Dealer):
@ -334,6 +335,6 @@ if __name__ == "__main__":
_, kb = KnowledgebaseService.get_by_id(kb_id) _, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
kg = KGSearch(settings.docStoreConn) kg = KGSearch(globals.docStoreConn)
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)) search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))

View File

@ -23,12 +23,12 @@ import trio
import xxhash import xxhash
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
from api import settings
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.connection_utils import timeout from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from common import globals
GRAPH_FIELD_SEP = "<SEP>" GRAPH_FIELD_SEP = "<SEP>"
@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = list(set(ents)) ents = list(set(ents))
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]} conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
res = [] res = []
es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids: for id in es_res.ids:
try: try:
if size == 1: if size == 1:
@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"], "knowledge_graph_kwd": ["graph"],
"removed_kwd": "N", "removed_kwd": "N",
} }
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields) fields2 = globals.docStoreConn.getFields(res, fields)
graph_doc_ids = set() graph_doc_ids = set()
for chunk_id in fields2.keys(): for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"]) graph_doc_ids = set(fields2[chunk_id]["source_id"])
@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
doc_ids = [] doc_ids = []
if res.total == 0: if res.total == 0:
return doc_ids return doc_ids
@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None): async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id])
if not res.total == 0: if not res.total == 0:
for id in res.ids: for id in res.ids:
try: try:
@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
global chat_limiter global chat_limiter
start = trio.current_time() start = trio.current_time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
if change.removed_nodes: if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
if change.removed_edges: if change.removed_edges:
async def del_edges(from_node, to_node): async def del_edges(from_node, to_node):
async with chat_limiter: async with chat_limiter:
await trio.to_thread.run_sync( await trio.to_thread.run_sync(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
) )
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000): with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if b % 100 == es_bulk_size and callback: if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}") callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result: if doc_store_result:
@ -555,7 +555,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list): async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
res = defaultdict(list) res = defaultdict(list)
for id in es_res.ids: for id in es_res.ids:
@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
bs = 256 bs = 256
for i in range(0, 1024 * bs, bs): for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync( es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
) )
# tot = settings.docStoreConn.getTotal(es_res) # tot = globals.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds) es_res = globals.docStoreConn.getFields(es_res, flds)
if len(es_res) == 0: if len(es_res) == 0:
break break

View File

@ -21,6 +21,7 @@ from copy import deepcopy
from deepdoc.parser.utils import get_text from deepdoc.parser.utils import get_text
from rag.app.qa import Excel from rag.app.qa import Excel
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from common import globals
def beAdoc(d, q, a, eng, row_num=-1): def beAdoc(d, q, a, eng, row_num=-1):
@ -124,7 +125,6 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
def label_question(question, kbs): def label_question(question, kbs):
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from graphrag.utils import get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from api import settings
tags = None tags = None
tag_kb_ids = [] tag_kb_ids = []
for kb in kbs: for kb in kbs:
@ -133,14 +133,14 @@ def label_question(question, kbs):
if tag_kb_ids: if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids) all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags: if not all_tags:
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids) all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids) set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
else: else:
all_tags = json.loads(all_tags) all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids) tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
if not tag_kbs: if not tag_kbs:
return tags return tags
tags = settings.retriever.tag_query(question, tags = globals.retriever.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])), list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids, tag_kb_ids,
all_tags, all_tags,

View File

@ -20,10 +20,10 @@ import time
import argparse import argparse
from collections import defaultdict from collections import defaultdict
from common import globals
from common.constants import LLMType from common.constants import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api import settings
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from rag.nlp import tokenize, search from rag.nlp import tokenize, search
from ranx import evaluate from ranx import evaluate
@ -52,7 +52,7 @@ class Benchmark:
run = defaultdict(dict) run = defaultdict(dict)
query_list = list(qrels.keys()) query_list = list(qrels.keys())
for query in query_list: for query in query_list:
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight) 0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0: if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}") print(f"deleted query: {query}")
@ -77,9 +77,9 @@ class Benchmark:
def init_index(self, vector_size: int): def init_index(self, vector_size: int):
if self.initialized_index: if self.initialized_index:
return return
if settings.docStoreConn.indexExist(self.index_name, self.kb_id): if globals.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id) globals.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True self.initialized_index = True
def ms_marco_index(self, file_path, index_name): def ms_marco_index(self, file_path, index_name):
@ -114,13 +114,13 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id) globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = [] docs = []
if docs: if docs:
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id) globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts return qrels, texts
def trivia_qa_index(self, file_path, index_name): def trivia_qa_index(self, file_path, index_name):
@ -155,12 +155,12 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs,self.index_name) globals.docStoreConn.insert(docs,self.index_name)
docs = [] docs = []
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name) globals.docStoreConn.insert(docs, self.index_name)
return qrels, texts return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name): def miracl_index(self, file_path, corpus_path, index_name):
@ -210,12 +210,12 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name) globals.docStoreConn.insert(docs, self.index_name)
docs = [] docs = []
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name) globals.docStoreConn.insert(docs, self.index_name)
return qrels, texts return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path): def save_results(self, qrels, run, texts, dataset, file_path):

View File

@ -21,7 +21,7 @@ from functools import partial
import trio import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.base64_image import id2image, image2id from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream

View File

@ -27,7 +27,7 @@ from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.base64_image import image2id from rag.utils.base64_image import image2id
from deepdoc.parser import ExcelParser from deepdoc.parser import ExcelParser
from deepdoc.parser.mineru_parser import MinerUParser from deepdoc.parser.mineru_parser import MinerUParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser

View File

@ -18,7 +18,7 @@ from functools import partial
import trio import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.base64_image import id2image, image2id from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream from rag.flow.splitter.schema import SplitterFromUpstream

View File

@ -30,7 +30,6 @@ from zhipuai import ZhipuAI
from common.log_utils import log_exception from common.log_utils import log_exception
from common.token_utils import num_tokens_from_string, truncate from common.token_utils import num_tokens_from_string, truncate
from common import globals from common import globals
from api import settings
import logging import logging
@ -74,9 +73,9 @@ class BuiltinEmbed(Base):
embedding_cfg = globals.EMBEDDING_CFG embedding_cfg = globals.EMBEDDING_CFG
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""): if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
with BuiltinEmbed._model_lock: with BuiltinEmbed._model_lock:
BuiltinEmbed._model_name = settings.EMBEDDING_MDL BuiltinEmbed._model_name = globals.EMBEDDING_MDL
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500) BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500)
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"]) BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
self._model = BuiltinEmbed._model self._model = BuiltinEmbed._model
self._model_name = BuiltinEmbed._model_name self._model_name = BuiltinEmbed._model_name
self._max_tokens = BuiltinEmbed._max_tokens self._max_tokens = BuiltinEmbed._max_tokens

View File

@ -18,13 +18,14 @@ import logging
from common.config_utils import get_base_config, decrypt_database_config from common.config_utils import get_base_config, decrypt_database_config
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common.misc_utils import pip_install_torch from common.misc_utils import pip_install_torch
from common import globals
# Server # Server
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf") RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
# Get storage type and document engine from system environment variables # Get storage type and document engine from system environment variables
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO') STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch') globals.DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
ES = {} ES = {}
INFINITY = {} INFINITY = {}
@ -35,11 +36,11 @@ OSS = {}
OS = {} OS = {}
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration # Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
if DOC_ENGINE == 'elasticsearch': if globals.DOC_ENGINE == 'elasticsearch':
ES = get_base_config("es", {}) ES = get_base_config("es", {})
elif DOC_ENGINE == 'opensearch': elif globals.DOC_ENGINE == 'opensearch':
OS = get_base_config("os", {}) OS = get_base_config("os", {})
elif DOC_ENGINE == 'infinity': elif globals.DOC_ENGINE == 'infinity':
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:

View File

@ -27,7 +27,7 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from common.connection_utils import timeout from common.connection_utils import timeout
from common.base64_image import image2id from rag.utils.base64_image import image2id
from common.log_utils import init_root_logger from common.log_utils import init_root_logger
from common.config_utils import show_configs from common.config_utils import show_configs
from graphrag.general.index import run_graphrag_for_kb from graphrag.general.index import run_graphrag_for_kb
@ -68,6 +68,7 @@ from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter from graphrag.utils import chat_limiter
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common import globals
BATCH_SIZE = 64 BATCH_SIZE = 64
@ -349,7 +350,7 @@ async def build_chunks(task, progress_callback):
examples = [] examples = []
all_tags = get_tags_from_cache(kb_ids) all_tags = get_tags_from_cache(kb_ids)
if not all_tags: if not all_tags:
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S) all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags) set_tags_to_cache(kb_ids, all_tags)
else: else:
all_tags = json.loads(all_tags) all_tags = json.loads(all_tags)
@ -362,7 +363,7 @@ async def build_chunks(task, progress_callback):
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else: else:
docs_to_tag.append(d) docs_to_tag.append(d)
@ -423,7 +424,7 @@ def build_TOC(task, docs, progress_callback):
def init_kb(row, vector_size: int): def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"]) idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
async def embedding(docs, mdl, parser_config=None, callback=None): async def embedding(docs, mdl, parser_config=None, callback=None):
@ -647,7 +648,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
chunks = [] chunks = []
vctr_nm = "q_%d_vec"%vector_size vctr_nm = "q_%d_vec"%vector_size
for doc_id in doc_ids: for doc_id in doc_ids:
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm], fields=["content_with_weight", vctr_nm],
sort_by_position=True): sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
@ -698,7 +699,7 @@ async def delete_image(kb_id, chunk_id):
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback): async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE): for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
@ -715,7 +716,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str) TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id) nursery.start_soon(delete_image, task_dataset_id, chunk_id)
@ -751,7 +752,7 @@ async def do_handle_task(task):
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
lower_case_doc_engine = settings.DOC_ENGINE.lower() lower_case_doc_engine = globals.DOC_ENGINE.lower()
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table': if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine." error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
progress_callback(-1, msg=error_message) progress_callback(-1, msg=error_message)