mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Move some vars to globals (#11017)
### What problem does this PR solve? As title. ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -25,6 +25,7 @@ from api.db.services.dialog_service import meta_filter
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api import settings
|
from api import settings
|
||||||
|
from common import globals
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
||||||
@ -170,7 +171,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
if kbs:
|
if kbs:
|
||||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||||
kbinfos = settings.retriever.retrieval(
|
kbinfos = globals.retriever.retrieval(
|
||||||
query,
|
query,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
@ -186,7 +187,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
)
|
)
|
||||||
if self._param.toc_enhance:
|
if self._param.toc_enhance:
|
||||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||||
if cks:
|
if cks:
|
||||||
kbinfos["chunks"] = cks
|
kbinfos["chunks"] = cks
|
||||||
if self._param.use_kg:
|
if self._param.use_kg:
|
||||||
|
|||||||
@ -32,7 +32,6 @@ from api.db.services.file_service import FileService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.task_service import queue_tasks, TaskService
|
from api.db.services.task_service import queue_tasks, TaskService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api import settings
|
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
|
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||||
@ -48,6 +47,7 @@ from api.db.services.canvas_service import UserCanvasService
|
|||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||||
@ -538,7 +538,7 @@ def list_chunks():
|
|||||||
)
|
)
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
|
|
||||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||||
res = [
|
res = [
|
||||||
{
|
{
|
||||||
"content": res_item["content_with_weight"],
|
"content": res_item["content_with_weight"],
|
||||||
@ -564,7 +564,7 @@ def get_chunk(chunk_id):
|
|||||||
try:
|
try:
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
return server_error_response(Exception("Chunk not found"))
|
return server_error_response(Exception("Chunk not found"))
|
||||||
k = []
|
k = []
|
||||||
@ -886,7 +886,7 @@ def retrieval():
|
|||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||||
similarity_threshold, vector_similarity_weight, top,
|
similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||||
rank_feature=label_question(question, kbs))
|
rank_feature=label_question(question, kbs))
|
||||||
|
|||||||
@ -25,7 +25,6 @@ from flask import request, Response
|
|||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from agent.component import LLM
|
from agent.component import LLM
|
||||||
from api import settings
|
|
||||||
from api.db import CanvasCategory, FileType
|
from api.db import CanvasCategory, FileType
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
@ -46,6 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
|||||||
from rag.flow.pipeline import Pipeline
|
from rag.flow.pipeline import Pipeline
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||||
@ -192,8 +192,8 @@ def rerun():
|
|||||||
if 0 < doc["progress"] < 1:
|
if 0 < doc["progress"] < 1:
|
||||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||||
|
|
||||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||||
doc["progress_msg"] = ""
|
doc["progress_msg"] = ""
|
||||||
doc["chunk_num"] = 0
|
doc["chunk_num"] = 0
|
||||||
doc["token_num"] = 0
|
doc["token_num"] = 0
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
|
|||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
from common.constants import RetCode, LLMType, ParserType
|
from common.constants import RetCode, LLMType, ParserType
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||||
@ -60,7 +61,7 @@ def list_chunk():
|
|||||||
}
|
}
|
||||||
if "available_int" in req:
|
if "available_int" in req:
|
||||||
query["available_int"] = int(req["available_int"])
|
query["available_int"] = int(req["available_int"])
|
||||||
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -98,7 +99,7 @@ def get():
|
|||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
|
||||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
|
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
|
||||||
if chunk:
|
if chunk:
|
||||||
break
|
break
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
@ -170,7 +171,7 @@ def set():
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -186,7 +187,7 @@ def switch():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
for cid in req["chunk_ids"]:
|
for cid in req["chunk_ids"]:
|
||||||
if not settings.docStoreConn.update({"id": cid},
|
if not globals.docStoreConn.update({"id": cid},
|
||||||
{"available_int": int(req["available_int"])},
|
{"available_int": int(req["available_int"])},
|
||||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||||
doc.kb_id):
|
doc.kb_id):
|
||||||
@ -206,7 +207,7 @@ def rm():
|
|||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
if not globals.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||||
doc.kb_id):
|
doc.kb_id):
|
||||||
return get_data_error_result(message="Chunk deleting failure")
|
return get_data_error_result(message="Chunk deleting failure")
|
||||||
@ -270,7 +271,7 @@ def create():
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc.id, doc.kb_id, c, 1, 0)
|
doc.id, doc.kb_id, c, 1, 0)
|
||||||
@ -346,7 +347,7 @@ def retrieval_test():
|
|||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(question, [kb])
|
||||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||||
float(req.get("similarity_threshold", 0.0)),
|
float(req.get("similarity_threshold", 0.0)),
|
||||||
float(req.get("vector_similarity_weight", 0.3)),
|
float(req.get("vector_similarity_weight", 0.3)),
|
||||||
top,
|
top,
|
||||||
@ -385,7 +386,7 @@ def knowledge_graph():
|
|||||||
"doc_ids": [doc_id],
|
"doc_ids": [doc_id],
|
||||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||||
}
|
}
|
||||||
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
for id in sres.ids[:2]:
|
for id in sres.ids[:2]:
|
||||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||||
|
|||||||
@ -23,7 +23,6 @@ import flask
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
|
|
||||||
from api import settings
|
|
||||||
from api.common.check_team_permission import check_kb_team_permission
|
from api.common.check_team_permission import check_kb_team_permission
|
||||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||||
from api.db import VALID_FILE_TYPES, FileType
|
from api.db import VALID_FILE_TYPES, FileType
|
||||||
@ -49,6 +48,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
|
|||||||
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/upload", methods=["POST"]) # noqa: F821
|
@manager.route("/upload", methods=["POST"]) # noqa: F821
|
||||||
@ -367,7 +367,7 @@ def change_status():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
status_int = int(status)
|
status_int = int(status)
|
||||||
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
||||||
result[doc_id] = {"error": "Database error (docStore update)!"}
|
result[doc_id] = {"error": "Database error (docStore update)!"}
|
||||||
result[doc_id] = {"status": status}
|
result[doc_id] = {"status": status}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -432,8 +432,8 @@ def run():
|
|||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
if req.get("delete", False):
|
if req.get("delete", False):
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||||
doc = doc.to_dict()
|
doc = doc.to_dict()
|
||||||
@ -479,8 +479,8 @@ def rename():
|
|||||||
"title_tks": title_tks,
|
"title_tks": title_tks,
|
||||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||||
}
|
}
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
settings.docStoreConn.update(
|
globals.docStoreConn.update(
|
||||||
{"doc_id": req["doc_id"]},
|
{"doc_id": req["doc_id"]},
|
||||||
es_body,
|
es_body,
|
||||||
search.index_name(tenant_id),
|
search.index_name(tenant_id),
|
||||||
@ -541,8 +541,8 @@ def change_parser():
|
|||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "pipeline_id" in req and req["pipeline_id"] != "":
|
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||||
|
|||||||
@ -35,7 +35,6 @@ from api.db import VALID_FILE_TYPES
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from api import settings
|
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from api.constants import DATASET_NAME_LIMIT
|
from api.constants import DATASET_NAME_LIMIT
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
@ -43,7 +42,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
|||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from rag.utils.doc_store_conn import OrderByExpr
|
from rag.utils.doc_store_conn import OrderByExpr
|
||||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType
|
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType
|
||||||
|
from common import globals
|
||||||
|
|
||||||
@manager.route('/create', methods=['post']) # noqa: F821
|
@manager.route('/create', methods=['post']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@ -110,11 +109,11 @@ def update():
|
|||||||
|
|
||||||
if kb.pagerank != req.get("pagerank", 0):
|
if kb.pagerank != req.get("pagerank", 0):
|
||||||
if req.get("pagerank", 0) > 0:
|
if req.get("pagerank", 0) > 0:
|
||||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
search.index_name(kb.tenant_id), kb.id)
|
||||||
else:
|
else:
|
||||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
search.index_name(kb.tenant_id), kb.id)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||||
@ -226,8 +225,8 @@ def rm():
|
|||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Database error (Knowledgebase removal)!")
|
message="Database error (Knowledgebase removal)!")
|
||||||
for kb in kbs:
|
for kb in kbs:
|
||||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||||
if hasattr(STORAGE_IMPL, 'remove_bucket'):
|
if hasattr(STORAGE_IMPL, 'remove_bucket'):
|
||||||
STORAGE_IMPL.remove_bucket(kb.id)
|
STORAGE_IMPL.remove_bucket(kb.id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
@ -248,7 +247,7 @@ def list_tags(kb_id):
|
|||||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||||
tags = []
|
tags = []
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
|
tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id])
|
||||||
return get_json_result(data=tags)
|
return get_json_result(data=tags)
|
||||||
|
|
||||||
|
|
||||||
@ -267,7 +266,7 @@ def list_tags_from_kbs():
|
|||||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||||
tags = []
|
tags = []
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids)
|
tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids)
|
||||||
return get_json_result(data=tags)
|
return get_json_result(data=tags)
|
||||||
|
|
||||||
|
|
||||||
@ -284,7 +283,7 @@ def rm_tags(kb_id):
|
|||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
|
||||||
for t in req["tags"]:
|
for t in req["tags"]:
|
||||||
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
||||||
{"remove": {"tag_kwd": t}},
|
{"remove": {"tag_kwd": t}},
|
||||||
search.index_name(kb.tenant_id),
|
search.index_name(kb.tenant_id),
|
||||||
kb_id)
|
kb_id)
|
||||||
@ -303,7 +302,7 @@ def rename_tags(kb_id):
|
|||||||
)
|
)
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
|
||||||
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
|
globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
|
||||||
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
|
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
|
||||||
search.index_name(kb.tenant_id),
|
search.index_name(kb.tenant_id),
|
||||||
kb_id)
|
kb_id)
|
||||||
@ -326,9 +325,9 @@ def knowledge_graph(kb_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||||
if not len(sres.ids):
|
if not len(sres.ids):
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
|
|
||||||
@ -360,7 +359,7 @@ def delete_knowledge_graph(kb_id):
|
|||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
@ -732,13 +731,13 @@ def delete_kb_task():
|
|||||||
task_id = kb.graphrag_task_id
|
task_id = kb.graphrag_task_id
|
||||||
kb_task_finish_at = "graphrag_task_finish_at"
|
kb_task_finish_at = "graphrag_task_finish_at"
|
||||||
cancel_task(task_id)
|
cancel_task(task_id)
|
||||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||||
case PipelineTaskType.RAPTOR:
|
case PipelineTaskType.RAPTOR:
|
||||||
kb_task_id_field = "raptor_task_id"
|
kb_task_id_field = "raptor_task_id"
|
||||||
task_id = kb.raptor_task_id
|
task_id = kb.raptor_task_id
|
||||||
kb_task_finish_at = "raptor_task_finish_at"
|
kb_task_finish_at = "raptor_task_finish_at"
|
||||||
cancel_task(task_id)
|
cancel_task(task_id)
|
||||||
settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
|
globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
|
||||||
case PipelineTaskType.MINDMAP:
|
case PipelineTaskType.MINDMAP:
|
||||||
kb_task_id_field = "mindmap_task_id"
|
kb_task_id_field = "mindmap_task_id"
|
||||||
task_id = kb.mindmap_task_id
|
task_id = kb.mindmap_task_id
|
||||||
@ -850,7 +849,7 @@ def check_embedding():
|
|||||||
tenant_id = kb.tenant_id
|
tenant_id = kb.tenant_id
|
||||||
|
|
||||||
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||||
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
||||||
|
|
||||||
results, eff_sims = [], []
|
results, eff_sims = [], []
|
||||||
for ck in samples:
|
for ck in samples:
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|||||||
from common.constants import StatusEnum, LLMType
|
from common.constants import StatusEnum, LLMType
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
|
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
|
||||||
from common.base64_image import test_image
|
from rag.utils.base64_image import test_image
|
||||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,6 @@ import os
|
|||||||
import json
|
import json
|
||||||
from flask import request
|
from flask import request
|
||||||
from peewee import OperationalError
|
from peewee import OperationalError
|
||||||
from api import settings
|
|
||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
@ -49,6 +48,7 @@ from api.utils.validation_utils import (
|
|||||||
)
|
)
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from rag.settings import PAGERANK_FLD
|
from rag.settings import PAGERANK_FLD
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||||
@ -360,11 +360,11 @@ def update(tenant_id, dataset_id):
|
|||||||
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
||||||
|
|
||||||
if req["pagerank"] > 0:
|
if req["pagerank"] > 0:
|
||||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
search.index_name(kb.tenant_id), kb.id)
|
||||||
else:
|
else:
|
||||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||||
search.index_name(kb.tenant_id), kb.id)
|
search.index_name(kb.tenant_id), kb.id)
|
||||||
|
|
||||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||||
@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||||
if not len(sres.ids):
|
if not len(sres.ids):
|
||||||
return get_result(data=obj)
|
return get_result(data=obj)
|
||||||
|
|
||||||
@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id):
|
|||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
|
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
|
||||||
search.index_name(kb.tenant_id), dataset_id)
|
search.index_name(kb.tenant_id), dataset_id)
|
||||||
|
|
||||||
return get_result(data=True)
|
return get_result(data=True)
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from api.utils.api_utils import validate_request, build_error_result, apikey_req
|
|||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||||
from common.constants import RetCode, LLMType
|
from common.constants import RetCode, LLMType
|
||||||
|
from common import globals
|
||||||
|
|
||||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||||
@apikey_required
|
@apikey_required
|
||||||
@ -137,7 +138,7 @@ def retrieval(tenant_id):
|
|||||||
# print("doc_ids", doc_ids)
|
# print("doc_ids", doc_ids)
|
||||||
if not doc_ids and metadata_condition is not None:
|
if not doc_ids and metadata_condition is not None:
|
||||||
doc_ids = ['-999']
|
doc_ids = ['-999']
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = globals.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
kb.tenant_id,
|
kb.tenant_id,
|
||||||
|
|||||||
@ -44,6 +44,7 @@ from rag.prompts.generator import cross_languages, keyword_extraction
|
|||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource
|
from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource
|
||||||
|
from common import globals
|
||||||
|
|
||||||
MAXIMUM_OF_UPLOADING_FILES = 256
|
MAXIMUM_OF_UPLOADING_FILES = 256
|
||||||
|
|
||||||
@ -307,7 +308,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
)
|
)
|
||||||
if not e:
|
if not e:
|
||||||
return get_error_data_result(message="Document not found!")
|
return get_error_data_result(message="Document not found!")
|
||||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||||
|
|
||||||
if "enabled" in req:
|
if "enabled" in req:
|
||||||
status = int(req["enabled"])
|
status = int(req["enabled"])
|
||||||
@ -316,7 +317,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
||||||
return get_error_data_result(message="Database error (Document update)!")
|
return get_error_data_result(message="Database error (Document update)!")
|
||||||
|
|
||||||
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||||
return get_result(data=True)
|
return get_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -755,7 +756,7 @@ def parse(tenant_id, dataset_id):
|
|||||||
return get_error_data_result("Can't parse document that is currently being processed")
|
return get_error_data_result("Can't parse document that is currently being processed")
|
||||||
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
e, doc = DocumentService.get_by_id(id)
|
e, doc = DocumentService.get_by_id(id)
|
||||||
doc = doc.to_dict()
|
doc = doc.to_dict()
|
||||||
@ -835,7 +836,7 @@ def stop_parsing(tenant_id, dataset_id):
|
|||||||
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
||||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||||
success_count += 1
|
success_count += 1
|
||||||
if duplicate_messages:
|
if duplicate_messages:
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
@ -968,7 +969,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
|
|
||||||
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
||||||
if req.get("id"):
|
if req.get("id"):
|
||||||
chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
|
chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
|
||||||
if not chunk:
|
if not chunk:
|
||||||
return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
|
return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
|
||||||
k = []
|
k = []
|
||||||
@ -995,8 +996,8 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
res["chunks"].append(final_chunk)
|
res["chunks"].append(final_chunk)
|
||||||
_ = Chunk(**final_chunk)
|
_ = Chunk(**final_chunk)
|
||||||
|
|
||||||
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||||
res["total"] = sres.total
|
res["total"] = sres.total
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -1120,7 +1121,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||||
# rename keys
|
# rename keys
|
||||||
@ -1201,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
|||||||
if "chunk_ids" in req:
|
if "chunk_ids" in req:
|
||||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||||
condition["id"] = unique_chunk_ids
|
condition["id"] = unique_chunk_ids
|
||||||
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||||
if chunk_number != 0:
|
if chunk_number != 0:
|
||||||
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
||||||
if "chunk_ids" in req and chunk_number != len(unique_chunk_ids):
|
if "chunk_ids" in req and chunk_number != len(unique_chunk_ids):
|
||||||
@ -1273,7 +1274,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
schema:
|
schema:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
||||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
@ -1318,7 +1319,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
|
|
||||||
@ -1464,7 +1465,7 @@ def retrieval_test(tenant_id):
|
|||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = globals.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from rag.app.tag import label_question
|
|||||||
from rag.prompts.template import load_prompt
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||||
from common.constants import RetCode, LLMType, StatusEnum
|
from common.constants import RetCode, LLMType, StatusEnum
|
||||||
|
from common import globals
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
@ -1015,7 +1016,7 @@ def retrieval_test_embedded():
|
|||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(question, [kb])
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = globals.retriever.retrieval(
|
||||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from timeit import default_timer as timer
|
|||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
from api.utils.health_utils import run_health_checks
|
from api.utils.health_utils import run_health_checks
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||||
@ -100,7 +101,7 @@ def status():
|
|||||||
res = {}
|
res = {}
|
||||||
st = timer()
|
st = timer()
|
||||||
try:
|
try:
|
||||||
res["doc_engine"] = settings.docStoreConn.health()
|
res["doc_engine"] = globals.docStoreConn.health()
|
||||||
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
res["doc_engine"] = {
|
res["doc_engine"] = {
|
||||||
|
|||||||
@ -58,6 +58,7 @@ from api.utils.web_utils import (
|
|||||||
hash_code,
|
hash_code,
|
||||||
captcha_key,
|
captcha_key,
|
||||||
)
|
)
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||||
@ -623,7 +624,7 @@ def user_register(user_id, user):
|
|||||||
"id": user_id,
|
"id": user_id,
|
||||||
"name": user["nickname"] + "‘s Kingdom",
|
"name": user["nickname"] + "‘s Kingdom",
|
||||||
"llm_id": settings.CHAT_MDL,
|
"llm_id": settings.CHAT_MDL,
|
||||||
"embd_id": settings.EMBEDDING_MDL,
|
"embd_id": globals.EMBEDDING_MDL,
|
||||||
"asr_id": settings.ASR_MDL,
|
"asr_id": settings.ASR_MDL,
|
||||||
"parser_ids": settings.PARSERS,
|
"parser_ids": settings.PARSERS,
|
||||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from api.db.services.user_service import TenantService, UserTenantService
|
|||||||
from api import settings
|
from api import settings
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
|
from common import globals
|
||||||
from api.common.base64 import encode_to_base64
|
from api.common.base64 import encode_to_base64
|
||||||
|
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ def init_superuser():
|
|||||||
"id": user_info["id"],
|
"id": user_info["id"],
|
||||||
"name": user_info["nickname"] + "‘s Kingdom",
|
"name": user_info["nickname"] + "‘s Kingdom",
|
||||||
"llm_id": settings.CHAT_MDL,
|
"llm_id": settings.CHAT_MDL,
|
||||||
"embd_id": settings.EMBEDDING_MDL,
|
"embd_id": globals.EMBEDDING_MDL,
|
||||||
"asr_id": settings.ASR_MDL,
|
"asr_id": settings.ASR_MDL,
|
||||||
"parser_ids": settings.PARSERS,
|
"parser_ids": settings.PARSERS,
|
||||||
"img2txt_id": settings.IMAGE2TEXT_MDL
|
"img2txt_id": settings.IMAGE2TEXT_MDL
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
|
|||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from common.constants import ActiveEnum
|
from common.constants import ActiveEnum
|
||||||
|
from common import globals
|
||||||
|
|
||||||
def create_new_user(user_info: dict) -> dict:
|
def create_new_user(user_info: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -63,7 +64,7 @@ def create_new_user(user_info: dict) -> dict:
|
|||||||
"id": user_id,
|
"id": user_id,
|
||||||
"name": user_info["nickname"] + "‘s Kingdom",
|
"name": user_info["nickname"] + "‘s Kingdom",
|
||||||
"llm_id": settings.CHAT_MDL,
|
"llm_id": settings.CHAT_MDL,
|
||||||
"embd_id": settings.EMBEDDING_MDL,
|
"embd_id": globals.EMBEDDING_MDL,
|
||||||
"asr_id": settings.ASR_MDL,
|
"asr_id": settings.ASR_MDL,
|
||||||
"parser_ids": settings.PARSERS,
|
"parser_ids": settings.PARSERS,
|
||||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||||
@ -179,7 +180,7 @@ def delete_user_data(user_id: str) -> dict:
|
|||||||
)
|
)
|
||||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||||
# step1.1.3 delete chunk in es
|
# step1.1.3 delete chunk in es
|
||||||
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
r = globals.docStoreConn.delete({"kb_id": kb_ids},
|
||||||
search.index_name(tenant_id), kb_ids)
|
search.index_name(tenant_id), kb_ids)
|
||||||
done_msg += f"- Deleted {r} chunk records.\n"
|
done_msg += f"- Deleted {r} chunk records.\n"
|
||||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||||
@ -237,7 +238,7 @@ def delete_user_data(user_id: str) -> dict:
|
|||||||
kb_doc_info = {}
|
kb_doc_info = {}
|
||||||
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||||
for _kb_id, docs in kb_doc.items():
|
for _kb_id, docs in kb_doc.items():
|
||||||
chunk_delete_res += settings.docStoreConn.delete(
|
chunk_delete_res += globals.docStoreConn.delete(
|
||||||
{"doc_id": [d["id"] for d in docs]},
|
{"doc_id": [d["id"] for d in docs]},
|
||||||
search.index_name(_tenant_id), _kb_id
|
search.index_name(_tenant_id), _kb_id
|
||||||
)
|
)
|
||||||
|
|||||||
@ -44,6 +44,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language
|
|||||||
from common.token_utils import num_tokens_from_string
|
from common.token_utils import num_tokens_from_string
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
class DialogService(CommonService):
|
class DialogService(CommonService):
|
||||||
@ -371,7 +372,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
chat_mdl.bind_tools(toolcall_session, tools)
|
chat_mdl.bind_tools(toolcall_session, tools)
|
||||||
bind_models_ts = timer()
|
bind_models_ts = timer()
|
||||||
|
|
||||||
retriever = settings.retriever
|
retriever = globals.retriever
|
||||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||||
if "doc_ids" in messages[-1]:
|
if "doc_ids" in messages[-1]:
|
||||||
@ -663,7 +664,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
|
|
||||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||||
tried_times += 1
|
tried_times += 1
|
||||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
return globals.retriever.sql_retrieval(sql, format="json"), sql
|
||||||
|
|
||||||
tbl, sql = get_table()
|
tbl, sql = get_table()
|
||||||
if tbl is None:
|
if tbl is None:
|
||||||
@ -757,7 +758,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||||
|
|
||||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||||
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
|
retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||||
|
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||||
@ -853,7 +854,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = globals.retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
tenant_ids=tenant_ids,
|
tenant_ids=tenant_ids,
|
||||||
|
|||||||
@ -26,7 +26,6 @@ import trio
|
|||||||
import xxhash
|
import xxhash
|
||||||
from peewee import fn, Case, JOIN
|
from peewee import fn, Case, JOIN
|
||||||
|
|
||||||
from api import settings
|
|
||||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||||
from api.db import FileType, UserTenantRole, CanvasCategory
|
from api.db import FileType, UserTenantRole, CanvasCategory
|
||||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
||||||
@ -42,7 +41,7 @@ from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
|
|||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from rag.utils.doc_store_conn import OrderByExpr
|
from rag.utils.doc_store_conn import OrderByExpr
|
||||||
|
from common import globals
|
||||||
|
|
||||||
class DocumentService(CommonService):
|
class DocumentService(CommonService):
|
||||||
model = Document
|
model = Document
|
||||||
@ -309,10 +308,10 @@ class DocumentService(CommonService):
|
|||||||
page_size = 1000
|
page_size = 1000
|
||||||
all_chunk_ids = []
|
all_chunk_ids = []
|
||||||
while True:
|
while True:
|
||||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||||
page * page_size, page_size, search.index_name(tenant_id),
|
page * page_size, page_size, search.index_name(tenant_id),
|
||||||
[doc.kb_id])
|
[doc.kb_id])
|
||||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
chunk_ids = globals.docStoreConn.getChunkIds(chunks)
|
||||||
if not chunk_ids:
|
if not chunk_ids:
|
||||||
break
|
break
|
||||||
all_chunk_ids.extend(chunk_ids)
|
all_chunk_ids.extend(chunk_ids)
|
||||||
@ -323,19 +322,19 @@ class DocumentService(CommonService):
|
|||||||
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
|
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
|
||||||
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
||||||
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
graph_source = settings.docStoreConn.getFields(
|
graph_source = globals.docStoreConn.getFields(
|
||||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||||
)
|
)
|
||||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
||||||
{"remove": {"source_id": doc.id}},
|
{"remove": {"source_id": doc.id}},
|
||||||
search.index_name(tenant_id), doc.kb_id)
|
search.index_name(tenant_id), doc.kb_id)
|
||||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||||
{"removed_kwd": "Y"},
|
{"removed_kwd": "Y"},
|
||||||
search.index_name(tenant_id), doc.kb_id)
|
search.index_name(tenant_id), doc.kb_id)
|
||||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||||
search.index_name(tenant_id), doc.kb_id)
|
search.index_name(tenant_id), doc.kb_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@ -996,10 +995,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
d["q_%d_vec" % len(v)] = v
|
d["q_%d_vec" % len(v)] = v
|
||||||
for b in range(0, len(cks), es_bulk_size):
|
for b in range(0, len(cks), es_bulk_size):
|
||||||
if try_create_idx:
|
if try_create_idx:
|
||||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
if not globals.docStoreConn.indexExist(idxnm, kb_id):
|
||||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||||
try_create_idx = False
|
try_create_idx = False
|
||||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||||
|
|||||||
@ -34,7 +34,7 @@ from deepdoc.parser.excel_parser import RAGFlowExcelParser
|
|||||||
from rag.settings import get_svr_queue_name
|
from rag.settings import get_svr_queue_name
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from api import settings
|
from common import globals
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||||
@ -418,7 +418,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
|||||||
if pre_task["chunk_ids"]:
|
if pre_task["chunk_ids"]:
|
||||||
pre_chunk_ids.extend(pre_task["chunk_ids"].split())
|
pre_chunk_ids.extend(pre_task["chunk_ids"].split())
|
||||||
if pre_chunk_ids:
|
if pre_chunk_ids:
|
||||||
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
|
globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
|
||||||
chunking_config["kb_id"])
|
chunking_config["kb_id"])
|
||||||
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
|
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
|
||||||
|
|
||||||
|
|||||||
@ -32,12 +32,12 @@ LLM = None
|
|||||||
LLM_FACTORY = None
|
LLM_FACTORY = None
|
||||||
LLM_BASE_URL = None
|
LLM_BASE_URL = None
|
||||||
CHAT_MDL = ""
|
CHAT_MDL = ""
|
||||||
EMBEDDING_MDL = ""
|
# EMBEDDING_MDL = "" has been moved to common/globals.py
|
||||||
RERANK_MDL = ""
|
RERANK_MDL = ""
|
||||||
ASR_MDL = ""
|
ASR_MDL = ""
|
||||||
IMAGE2TEXT_MDL = ""
|
IMAGE2TEXT_MDL = ""
|
||||||
CHAT_CFG = ""
|
CHAT_CFG = ""
|
||||||
|
# EMBEDDING_CFG = "" has been moved to common/globals.py
|
||||||
RERANK_CFG = ""
|
RERANK_CFG = ""
|
||||||
ASR_CFG = ""
|
ASR_CFG = ""
|
||||||
IMAGE2TEXT_CFG = ""
|
IMAGE2TEXT_CFG = ""
|
||||||
@ -61,10 +61,10 @@ HTTP_APP_KEY = None
|
|||||||
GITHUB_OAUTH = None
|
GITHUB_OAUTH = None
|
||||||
FEISHU_OAUTH = None
|
FEISHU_OAUTH = None
|
||||||
OAUTH_CONFIG = None
|
OAUTH_CONFIG = None
|
||||||
DOC_ENGINE = None
|
# DOC_ENGINE = None has been moved to common/globals.py
|
||||||
docStoreConn = None
|
# docStoreConn = None has been moved to common/globals.py
|
||||||
|
|
||||||
retriever = None
|
#retriever = None has been moved to common/globals.py
|
||||||
kg_retriever = None
|
kg_retriever = None
|
||||||
|
|
||||||
# user registration switch
|
# user registration switch
|
||||||
@ -125,7 +125,7 @@ def init_settings():
|
|||||||
except Exception:
|
except Exception:
|
||||||
FACTORY_LLM_INFOS = []
|
FACTORY_LLM_INFOS = []
|
||||||
|
|
||||||
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||||
global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
||||||
|
|
||||||
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
||||||
@ -135,7 +135,7 @@ def init_settings():
|
|||||||
)
|
)
|
||||||
|
|
||||||
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
|
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
|
||||||
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL))
|
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL))
|
||||||
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
|
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
|
||||||
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
|
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
|
||||||
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
|
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
|
||||||
@ -147,7 +147,7 @@ def init_settings():
|
|||||||
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||||
|
|
||||||
CHAT_MDL = CHAT_CFG.get("model", "") or ""
|
CHAT_MDL = CHAT_CFG.get("model", "") or ""
|
||||||
EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
|
globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
|
||||||
RERANK_MDL = RERANK_CFG.get("model", "") or ""
|
RERANK_MDL = RERANK_CFG.get("model", "") or ""
|
||||||
ASR_MDL = ASR_CFG.get("model", "") or ""
|
ASR_MDL = ASR_CFG.get("model", "") or ""
|
||||||
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
|
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
|
||||||
@ -169,23 +169,23 @@ def init_settings():
|
|||||||
|
|
||||||
OAUTH_CONFIG = get_base_config("oauth", {})
|
OAUTH_CONFIG = get_base_config("oauth", {})
|
||||||
|
|
||||||
global DOC_ENGINE, docStoreConn, retriever, kg_retriever
|
global kg_retriever
|
||||||
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
|
globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
|
||||||
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
# globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
||||||
lower_case_doc_engine = DOC_ENGINE.lower()
|
lower_case_doc_engine = globals.DOC_ENGINE.lower()
|
||||||
if lower_case_doc_engine == "elasticsearch":
|
if lower_case_doc_engine == "elasticsearch":
|
||||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
globals.docStoreConn = rag.utils.es_conn.ESConnection()
|
||||||
elif lower_case_doc_engine == "infinity":
|
elif lower_case_doc_engine == "infinity":
|
||||||
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||||
elif lower_case_doc_engine == "opensearch":
|
elif lower_case_doc_engine == "opensearch":
|
||||||
docStoreConn = rag.utils.opensearch_conn.OSConnection()
|
globals.docStoreConn = rag.utils.opensearch_conn.OSConnection()
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}")
|
||||||
|
|
||||||
retriever = search.Dealer(docStoreConn)
|
globals.retriever = search.Dealer(globals.docStoreConn)
|
||||||
from graphrag import search as kg_search
|
from graphrag import search as kg_search
|
||||||
|
|
||||||
kg_retriever = kg_search.KGSearch(docStoreConn)
|
kg_retriever = kg_search.KGSearch(globals.docStoreConn)
|
||||||
|
|
||||||
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
||||||
global SANDBOX_HOST
|
global SANDBOX_HOST
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
|||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from rag.utils.es_conn import ESConnection
|
from rag.utils.es_conn import ESConnection
|
||||||
from rag.utils.infinity_conn import InfinityConnection
|
from rag.utils.infinity_conn import InfinityConnection
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
def _ok_nok(ok: bool) -> str:
|
def _ok_nok(ok: bool) -> str:
|
||||||
@ -52,7 +53,7 @@ def check_redis() -> tuple[bool, dict]:
|
|||||||
def check_doc_engine() -> tuple[bool, dict]:
|
def check_doc_engine() -> tuple[bool, dict]:
|
||||||
st = timer()
|
st = timer()
|
||||||
try:
|
try:
|
||||||
meta = settings.docStoreConn.health()
|
meta = globals.docStoreConn.health()
|
||||||
# treat any successful call as ok
|
# treat any successful call as ok
|
||||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -14,4 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
EMBEDDING_MDL = ""
|
||||||
|
|
||||||
EMBEDDING_CFG = ""
|
EMBEDDING_CFG = ""
|
||||||
|
|
||||||
|
DOC_ENGINE = None
|
||||||
|
|
||||||
|
docStoreConn = None
|
||||||
|
|
||||||
|
retriever = None
|
||||||
@ -20,7 +20,6 @@ import os
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from api import settings
|
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
@ -40,6 +39,7 @@ from graphrag.utils import (
|
|||||||
)
|
)
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.utils.redis_conn import RedisDistributedLock
|
from rag.utils.redis_conn import RedisDistributedLock
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
async def run_graphrag(
|
async def run_graphrag(
|
||||||
@ -55,7 +55,7 @@ async def run_graphrag(
|
|||||||
start = trio.current_time()
|
start = trio.current_time()
|
||||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||||
chunks.append(d["content_with_weight"])
|
chunks.append(d["content_with_weight"])
|
||||||
|
|
||||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||||
@ -170,7 +170,7 @@ async def run_graphrag_for_kb(
|
|||||||
chunks = []
|
chunks = []
|
||||||
current_chunk = ""
|
current_chunk = ""
|
||||||
|
|
||||||
for d in settings.retriever.chunk_list(
|
for d in globals.retriever.chunk_list(
|
||||||
doc_id,
|
doc_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
[kb_id],
|
[kb_id],
|
||||||
@ -387,8 +387,8 @@ async def generate_subgraph(
|
|||||||
"removed_kwd": "N",
|
"removed_kwd": "N",
|
||||||
}
|
}
|
||||||
cid = chunk_id(chunk)
|
cid = chunk_id(chunk)
|
||||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
|
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
|
||||||
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
|
await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
|
||||||
now = trio.current_time()
|
now = trio.current_time()
|
||||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||||
return subgraph
|
return subgraph
|
||||||
@ -496,7 +496,7 @@ async def extract_community(
|
|||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
await trio.to_thread.run_sync(
|
await trio.to_thread.run_sync(
|
||||||
lambda: settings.docStoreConn.delete(
|
lambda: globals.docStoreConn.delete(
|
||||||
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
||||||
search.index_name(tenant_id),
|
search.index_name(tenant_id),
|
||||||
kb_id,
|
kb_id,
|
||||||
@ -504,7 +504,7 @@ async def extract_community(
|
|||||||
)
|
)
|
||||||
es_bulk_size = 4
|
es_bulk_size = 4
|
||||||
for b in range(0, len(chunks), es_bulk_size):
|
for b in range(0, len(chunks), es_bulk_size):
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||||
if doc_store_result:
|
if doc_store_result:
|
||||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||||
raise Exception(error_message)
|
raise Exception(error_message)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from graphrag.general.graph_extractor import GraphExtractor
|
from graphrag.general.graph_extractor import GraphExtractor
|
||||||
from graphrag.general.index import update_graph, with_resolution, with_community
|
from graphrag.general.index import update_graph, with_resolution, with_community
|
||||||
|
from common import globals
|
||||||
|
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
|
|
||||||
@ -62,7 +63,7 @@ async def main():
|
|||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
d["content_with_weight"]
|
d["content_with_weight"]
|
||||||
for d in settings.retriever.chunk_list(
|
for d in globals.retriever.chunk_list(
|
||||||
args.doc_id,
|
args.doc_id,
|
||||||
args.tenant_id,
|
args.tenant_id,
|
||||||
[kb_id],
|
[kb_id],
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from graphrag.general.index import update_graph
|
from graphrag.general.index import update_graph
|
||||||
from graphrag.light.graph_extractor import GraphExtractor
|
from graphrag.light.graph_extractor import GraphExtractor
|
||||||
|
from common import globals
|
||||||
|
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ async def main():
|
|||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
d["content_with_weight"]
|
d["content_with_weight"]
|
||||||
for d in settings.retriever.chunk_list(
|
for d in globals.retriever.chunk_list(
|
||||||
args.doc_id,
|
args.doc_id,
|
||||||
args.tenant_id,
|
args.tenant_id,
|
||||||
[kb_id],
|
[kb_id],
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr
|
|||||||
|
|
||||||
from rag.nlp.search import Dealer, index_name
|
from rag.nlp.search import Dealer, index_name
|
||||||
from common.float_utils import get_float
|
from common.float_utils import get_float
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
class KGSearch(Dealer):
|
class KGSearch(Dealer):
|
||||||
@ -334,6 +335,6 @@ if __name__ == "__main__":
|
|||||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||||
|
|
||||||
kg = KGSearch(settings.docStoreConn)
|
kg = KGSearch(globals.docStoreConn)
|
||||||
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
|
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
|
||||||
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))
|
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))
|
||||||
|
|||||||
@ -23,12 +23,12 @@ import trio
|
|||||||
import xxhash
|
import xxhash
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
|
|
||||||
from api import settings
|
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.utils.doc_store_conn import OrderByExpr
|
from rag.utils.doc_store_conn import OrderByExpr
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
from common import globals
|
||||||
|
|
||||||
GRAPH_FIELD_SEP = "<SEP>"
|
GRAPH_FIELD_SEP = "<SEP>"
|
||||||
|
|
||||||
@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
|||||||
ents = list(set(ents))
|
ents = list(set(ents))
|
||||||
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
|
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
|
||||||
res = []
|
res = []
|
||||||
es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
||||||
for id in es_res.ids:
|
for id in es_res.ids:
|
||||||
try:
|
try:
|
||||||
if size == 1:
|
if size == 1:
|
||||||
@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
|||||||
"knowledge_graph_kwd": ["graph"],
|
"knowledge_graph_kwd": ["graph"],
|
||||||
"removed_kwd": "N",
|
"removed_kwd": "N",
|
||||||
}
|
}
|
||||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||||
fields2 = settings.docStoreConn.getFields(res, fields)
|
fields2 = globals.docStoreConn.getFields(res, fields)
|
||||||
graph_doc_ids = set()
|
graph_doc_ids = set()
|
||||||
for chunk_id in fields2.keys():
|
for chunk_id in fields2.keys():
|
||||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||||
@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
|||||||
|
|
||||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||||
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
|
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||||
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
|
res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||||
doc_ids = []
|
doc_ids = []
|
||||||
if res.total == 0:
|
if res.total == 0:
|
||||||
return doc_ids
|
return doc_ids
|
||||||
@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
|||||||
|
|
||||||
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||||
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
|
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||||
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
|
res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id])
|
||||||
if not res.total == 0:
|
if not res.total == 0:
|
||||||
for id in res.ids:
|
for id in res.ids:
|
||||||
try:
|
try:
|
||||||
@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
|||||||
global chat_limiter
|
global chat_limiter
|
||||||
start = trio.current_time()
|
start = trio.current_time()
|
||||||
|
|
||||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
|
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
|
||||||
|
|
||||||
if change.removed_nodes:
|
if change.removed_nodes:
|
||||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
|
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
|
||||||
|
|
||||||
if change.removed_edges:
|
if change.removed_edges:
|
||||||
|
|
||||||
async def del_edges(from_node, to_node):
|
async def del_edges(from_node, to_node):
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
await trio.to_thread.run_sync(
|
await trio.to_thread.run_sync(
|
||||||
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
|
globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
|||||||
es_bulk_size = 4
|
es_bulk_size = 4
|
||||||
for b in range(0, len(chunks), es_bulk_size):
|
for b in range(0, len(chunks), es_bulk_size):
|
||||||
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||||
if b % 100 == es_bulk_size and callback:
|
if b % 100 == es_bulk_size and callback:
|
||||||
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
||||||
if doc_store_result:
|
if doc_store_result:
|
||||||
@ -555,7 +555,7 @@ def merge_tuples(list1, list2):
|
|||||||
|
|
||||||
|
|
||||||
async def get_entity_type2samples(idxnms, kb_ids: list):
|
async def get_entity_type2samples(idxnms, kb_ids: list):
|
||||||
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
|
es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
|
||||||
|
|
||||||
res = defaultdict(list)
|
res = defaultdict(list)
|
||||||
for id in es_res.ids:
|
for id in es_res.ids:
|
||||||
@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
|||||||
bs = 256
|
bs = 256
|
||||||
for i in range(0, 1024 * bs, bs):
|
for i in range(0, 1024 * bs, bs):
|
||||||
es_res = await trio.to_thread.run_sync(
|
es_res = await trio.to_thread.run_sync(
|
||||||
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||||
)
|
)
|
||||||
# tot = settings.docStoreConn.getTotal(es_res)
|
# tot = globals.docStoreConn.getTotal(es_res)
|
||||||
es_res = settings.docStoreConn.getFields(es_res, flds)
|
es_res = globals.docStoreConn.getFields(es_res, flds)
|
||||||
|
|
||||||
if len(es_res) == 0:
|
if len(es_res) == 0:
|
||||||
break
|
break
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from copy import deepcopy
|
|||||||
from deepdoc.parser.utils import get_text
|
from deepdoc.parser.utils import get_text
|
||||||
from rag.app.qa import Excel
|
from rag.app.qa import Excel
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
|
from common import globals
|
||||||
|
|
||||||
|
|
||||||
def beAdoc(d, q, a, eng, row_num=-1):
|
def beAdoc(d, q, a, eng, row_num=-1):
|
||||||
@ -124,7 +125,6 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|||||||
def label_question(question, kbs):
|
def label_question(question, kbs):
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||||
from api import settings
|
|
||||||
tags = None
|
tags = None
|
||||||
tag_kb_ids = []
|
tag_kb_ids = []
|
||||||
for kb in kbs:
|
for kb in kbs:
|
||||||
@ -133,14 +133,14 @@ def label_question(question, kbs):
|
|||||||
if tag_kb_ids:
|
if tag_kb_ids:
|
||||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||||
if not all_tags:
|
if not all_tags:
|
||||||
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||||
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
|
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
|
||||||
else:
|
else:
|
||||||
all_tags = json.loads(all_tags)
|
all_tags = json.loads(all_tags)
|
||||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||||
if not tag_kbs:
|
if not tag_kbs:
|
||||||
return tags
|
return tags
|
||||||
tags = settings.retriever.tag_query(question,
|
tags = globals.retriever.tag_query(question,
|
||||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||||
tag_kb_ids,
|
tag_kb_ids,
|
||||||
all_tags,
|
all_tags,
|
||||||
|
|||||||
@ -20,10 +20,10 @@ import time
|
|||||||
import argparse
|
import argparse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from common import globals
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api import settings
|
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from rag.nlp import tokenize, search
|
from rag.nlp import tokenize, search
|
||||||
from ranx import evaluate
|
from ranx import evaluate
|
||||||
@ -52,7 +52,7 @@ class Benchmark:
|
|||||||
run = defaultdict(dict)
|
run = defaultdict(dict)
|
||||||
query_list = list(qrels.keys())
|
query_list = list(qrels.keys())
|
||||||
for query in query_list:
|
for query in query_list:
|
||||||
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||||
0.0, self.vector_similarity_weight)
|
0.0, self.vector_similarity_weight)
|
||||||
if len(ranks["chunks"]) == 0:
|
if len(ranks["chunks"]) == 0:
|
||||||
print(f"deleted query: {query}")
|
print(f"deleted query: {query}")
|
||||||
@ -77,9 +77,9 @@ class Benchmark:
|
|||||||
def init_index(self, vector_size: int):
|
def init_index(self, vector_size: int):
|
||||||
if self.initialized_index:
|
if self.initialized_index:
|
||||||
return
|
return
|
||||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
if globals.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
globals.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||||
self.initialized_index = True
|
self.initialized_index = True
|
||||||
|
|
||||||
def ms_marco_index(self, file_path, index_name):
|
def ms_marco_index(self, file_path, index_name):
|
||||||
@ -114,13 +114,13 @@ class Benchmark:
|
|||||||
docs_count += len(docs)
|
docs_count += len(docs)
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
if docs:
|
if docs:
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||||
return qrels, texts
|
return qrels, texts
|
||||||
|
|
||||||
def trivia_qa_index(self, file_path, index_name):
|
def trivia_qa_index(self, file_path, index_name):
|
||||||
@ -155,12 +155,12 @@ class Benchmark:
|
|||||||
docs_count += len(docs)
|
docs_count += len(docs)
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
settings.docStoreConn.insert(docs,self.index_name)
|
globals.docStoreConn.insert(docs,self.index_name)
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
settings.docStoreConn.insert(docs, self.index_name)
|
globals.docStoreConn.insert(docs, self.index_name)
|
||||||
return qrels, texts
|
return qrels, texts
|
||||||
|
|
||||||
def miracl_index(self, file_path, corpus_path, index_name):
|
def miracl_index(self, file_path, corpus_path, index_name):
|
||||||
@ -210,12 +210,12 @@ class Benchmark:
|
|||||||
docs_count += len(docs)
|
docs_count += len(docs)
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
settings.docStoreConn.insert(docs, self.index_name)
|
globals.docStoreConn.insert(docs, self.index_name)
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
settings.docStoreConn.insert(docs, self.index_name)
|
globals.docStoreConn.insert(docs, self.index_name)
|
||||||
return qrels, texts
|
return qrels, texts
|
||||||
|
|
||||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from functools import partial
|
|||||||
import trio
|
import trio
|
||||||
|
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.base64_image import id2image, image2id
|
from rag.utils.base64_image import id2image, image2id
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from api.db.services.file2document_service import File2DocumentService
|
|||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.base64_image import image2id
|
from rag.utils.base64_image import image2id
|
||||||
from deepdoc.parser import ExcelParser
|
from deepdoc.parser import ExcelParser
|
||||||
from deepdoc.parser.mineru_parser import MinerUParser
|
from deepdoc.parser.mineru_parser import MinerUParser
|
||||||
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from functools import partial
|
|||||||
import trio
|
import trio
|
||||||
|
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.base64_image import id2image, image2id
|
from rag.utils.base64_image import id2image, image2id
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||||
|
|||||||
@ -30,7 +30,6 @@ from zhipuai import ZhipuAI
|
|||||||
from common.log_utils import log_exception
|
from common.log_utils import log_exception
|
||||||
from common.token_utils import num_tokens_from_string, truncate
|
from common.token_utils import num_tokens_from_string, truncate
|
||||||
from common import globals
|
from common import globals
|
||||||
from api import settings
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@ -74,9 +73,9 @@ class BuiltinEmbed(Base):
|
|||||||
embedding_cfg = globals.EMBEDDING_CFG
|
embedding_cfg = globals.EMBEDDING_CFG
|
||||||
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
||||||
with BuiltinEmbed._model_lock:
|
with BuiltinEmbed._model_lock:
|
||||||
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
BuiltinEmbed._model_name = globals.EMBEDDING_MDL
|
||||||
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
|
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500)
|
||||||
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
||||||
self._model = BuiltinEmbed._model
|
self._model = BuiltinEmbed._model
|
||||||
self._model_name = BuiltinEmbed._model_name
|
self._model_name = BuiltinEmbed._model_name
|
||||||
self._max_tokens = BuiltinEmbed._max_tokens
|
self._max_tokens = BuiltinEmbed._max_tokens
|
||||||
|
|||||||
@ -18,13 +18,14 @@ import logging
|
|||||||
from common.config_utils import get_base_config, decrypt_database_config
|
from common.config_utils import get_base_config, decrypt_database_config
|
||||||
from common.file_utils import get_project_base_directory
|
from common.file_utils import get_project_base_directory
|
||||||
from common.misc_utils import pip_install_torch
|
from common.misc_utils import pip_install_torch
|
||||||
|
from common import globals
|
||||||
|
|
||||||
# Server
|
# Server
|
||||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||||
|
|
||||||
# Get storage type and document engine from system environment variables
|
# Get storage type and document engine from system environment variables
|
||||||
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
||||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
globals.DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||||
|
|
||||||
ES = {}
|
ES = {}
|
||||||
INFINITY = {}
|
INFINITY = {}
|
||||||
@ -35,11 +36,11 @@ OSS = {}
|
|||||||
OS = {}
|
OS = {}
|
||||||
|
|
||||||
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
|
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
|
||||||
if DOC_ENGINE == 'elasticsearch':
|
if globals.DOC_ENGINE == 'elasticsearch':
|
||||||
ES = get_base_config("es", {})
|
ES = get_base_config("es", {})
|
||||||
elif DOC_ENGINE == 'opensearch':
|
elif globals.DOC_ENGINE == 'opensearch':
|
||||||
OS = get_base_config("os", {})
|
OS = get_base_config("os", {})
|
||||||
elif DOC_ENGINE == 'infinity':
|
elif globals.DOC_ENGINE == 'infinity':
|
||||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||||
|
|
||||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from api.db.services.canvas_service import UserCanvasService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from common.base64_image import image2id
|
from rag.utils.base64_image import image2id
|
||||||
from common.log_utils import init_root_logger
|
from common.log_utils import init_root_logger
|
||||||
from common.config_utils import show_configs
|
from common.config_utils import show_configs
|
||||||
from graphrag.general.index import run_graphrag_for_kb
|
from graphrag.general.index import run_graphrag_for_kb
|
||||||
@ -68,6 +68,7 @@ from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
|||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from graphrag.utils import chat_limiter
|
from graphrag.utils import chat_limiter
|
||||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||||
|
from common import globals
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
|
|
||||||
@ -349,7 +350,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
examples = []
|
examples = []
|
||||||
all_tags = get_tags_from_cache(kb_ids)
|
all_tags = get_tags_from_cache(kb_ids)
|
||||||
if not all_tags:
|
if not all_tags:
|
||||||
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||||
set_tags_to_cache(kb_ids, all_tags)
|
set_tags_to_cache(kb_ids, all_tags)
|
||||||
else:
|
else:
|
||||||
all_tags = json.loads(all_tags)
|
all_tags = json.loads(all_tags)
|
||||||
@ -362,7 +363,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
if task_canceled:
|
if task_canceled:
|
||||||
progress_callback(-1, msg="Task has been canceled.")
|
progress_callback(-1, msg="Task has been canceled.")
|
||||||
return
|
return
|
||||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||||
else:
|
else:
|
||||||
docs_to_tag.append(d)
|
docs_to_tag.append(d)
|
||||||
@ -423,7 +424,7 @@ def build_TOC(task, docs, progress_callback):
|
|||||||
|
|
||||||
def init_kb(row, vector_size: int):
|
def init_kb(row, vector_size: int):
|
||||||
idxnm = search.index_name(row["tenant_id"])
|
idxnm = search.index_name(row["tenant_id"])
|
||||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||||
|
|
||||||
|
|
||||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||||
@ -647,7 +648,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
|||||||
chunks = []
|
chunks = []
|
||||||
vctr_nm = "q_%d_vec"%vector_size
|
vctr_nm = "q_%d_vec"%vector_size
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||||
fields=["content_with_weight", vctr_nm],
|
fields=["content_with_weight", vctr_nm],
|
||||||
sort_by_position=True):
|
sort_by_position=True):
|
||||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||||
@ -698,7 +699,7 @@ async def delete_image(kb_id, chunk_id):
|
|||||||
|
|
||||||
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||||
task_canceled = has_canceled(task_id)
|
task_canceled = has_canceled(task_id)
|
||||||
if task_canceled:
|
if task_canceled:
|
||||||
progress_callback(-1, msg="Task has been canceled.")
|
progress_callback(-1, msg="Task has been canceled.")
|
||||||
@ -715,7 +716,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
|||||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||||
except DoesNotExist:
|
except DoesNotExist:
|
||||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
for chunk_id in chunk_ids:
|
for chunk_id in chunk_ids:
|
||||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||||
@ -751,7 +752,7 @@ async def do_handle_task(task):
|
|||||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||||
|
|
||||||
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
|
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
|
||||||
lower_case_doc_engine = settings.DOC_ENGINE.lower()
|
lower_case_doc_engine = globals.DOC_ENGINE.lower()
|
||||||
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
|
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
|
||||||
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
|
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
|
||||||
progress_callback(-1, msg=error_message)
|
progress_callback(-1, msg=error_message)
|
||||||
|
|||||||
Reference in New Issue
Block a user