mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
12 Commits
ca40b56839
...
4e76220e25
| Author | SHA1 | Date | |
|---|---|---|---|
| 4e76220e25 | |||
| 24335485bf | |||
| 121c51661d | |||
| 02d10f8eda | |||
| dddf766470 | |||
| 8584d4b642 | |||
| b86e07088b | |||
| 1a9215bc6f | |||
| cf9611c96f | |||
| f126875ec6 | |||
| 89410d2381 | |||
| 96c015fb85 |
@ -281,12 +281,21 @@ class Canvas(Graph):
|
||||
def _run_batch(f, t):
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
for i in range(f, t):
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
||||
i += 1
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
i += 1
|
||||
for t in thr:
|
||||
t.result()
|
||||
|
||||
@ -316,6 +325,7 @@ class Canvas(Graph):
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
_run_batch(idx, to)
|
||||
to = len(self.path)
|
||||
# post processing of components invocation
|
||||
for i in range(idx, to):
|
||||
cpn = self.get_component(self.path[i])
|
||||
|
||||
@ -16,6 +16,13 @@
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
"""
|
||||
class VariableModel(BaseModel):
|
||||
data_type: Annotated[Literal["string", "number", "Object", "Boolean", "Array<string>", "Array<number>", "Array<object>", "Array<boolean>"], Field(default="Array<string>")]
|
||||
input_mode: Annotated[Literal["constant", "variable"], Field(default="constant")]
|
||||
value: Annotated[Any, Field(default=None)]
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
"""
|
||||
|
||||
class IterationParam(ComponentParamBase):
|
||||
"""
|
||||
|
||||
@ -216,7 +216,7 @@ class LLM(ComponentBase):
|
||||
error: str = ""
|
||||
output_structure=None
|
||||
try:
|
||||
output_structure=self._param.outputs['structured']
|
||||
output_structure = None#self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure:
|
||||
|
||||
@ -49,6 +49,9 @@ class MessageParam(ComponentParamBase):
|
||||
class Message(ComponentBase):
|
||||
component_name = "Message"
|
||||
|
||||
def get_input_elements(self) -> dict[str, Any]:
|
||||
return self.get_input_elements_from_text("".join(self._param.content))
|
||||
|
||||
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
||||
for k,v in self.get_input_elements_from_text(script).items():
|
||||
if k in kwargs:
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Template as Jinja2Template
|
||||
from agent.component.base import ComponentParamBase
|
||||
from common.connection_utils import timeout
|
||||
@ -43,6 +45,9 @@ class StringTransformParam(ComponentParamBase):
|
||||
class StringTransform(Message, ABC):
|
||||
component_name = "StringTransform"
|
||||
|
||||
def get_input_elements(self) -> dict[str, Any]:
|
||||
return self.get_input_elements_from_text(self._param.script)
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
if self._param.method == "split":
|
||||
return {
|
||||
|
||||
@ -25,6 +25,7 @@ from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from common import globals
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
||||
@ -170,7 +171,7 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
if kbs:
|
||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||
kbinfos = settings.retriever.retrieval(
|
||||
kbinfos = globals.retriever.retrieval(
|
||||
query,
|
||||
embd_mdl,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
@ -186,7 +187,7 @@ class Retrieval(ToolBase, ABC):
|
||||
)
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
if self._param.use_kg:
|
||||
|
||||
@ -32,7 +32,6 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import queue_tasks, TaskService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api import settings
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
@ -48,6 +47,7 @@ from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from common import globals
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@ -538,7 +538,7 @@ def list_chunks():
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
@ -564,7 +564,7 @@ def get_chunk(chunk_id):
|
||||
try:
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
k = []
|
||||
@ -886,7 +886,7 @@ def retrieval():
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
|
||||
@ -25,7 +25,6 @@ from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from agent.component import LLM
|
||||
from api import settings
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -46,6 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import globals
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@ -192,8 +192,8 @@ def rerun():
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
doc["progress_msg"] = ""
|
||||
doc["chunk_num"] = 0
|
||||
doc["token_num"] = 0
|
||||
|
||||
@ -36,6 +36,7 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType
|
||||
from common import globals
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST']) # noqa: F821
|
||||
@ -60,7 +61,7 @@ def list_chunk():
|
||||
}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||
sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@ -98,7 +99,7 @@ def get():
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
for tenant in tenants:
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
|
||||
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
|
||||
if chunk:
|
||||
break
|
||||
if chunk is None:
|
||||
@ -170,7 +171,7 @@ def set():
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||
globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -186,7 +187,7 @@ def switch():
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
for cid in req["chunk_ids"]:
|
||||
if not settings.docStoreConn.update({"id": cid},
|
||||
if not globals.docStoreConn.update({"id": cid},
|
||||
{"available_int": int(req["available_int"])},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
@ -206,7 +207,7 @@ def rm():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
if not globals.docStoreConn.delete({"id": req["chunk_ids"]},
|
||||
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
|
||||
doc.kb_id):
|
||||
return get_data_error_result(message="Chunk deleting failure")
|
||||
@ -270,7 +271,7 @@ def create():
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
@ -346,7 +347,7 @@ def retrieval_test():
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
float(req.get("similarity_threshold", 0.0)),
|
||||
float(req.get("vector_similarity_weight", 0.3)),
|
||||
top,
|
||||
@ -385,7 +386,7 @@ def knowledge_graph():
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
|
||||
@ -23,7 +23,6 @@ import flask
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from api import settings
|
||||
from api.common.check_team_permission import check_kb_team_permission
|
||||
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
@ -49,6 +48,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
|
||||
from deepdoc.parser.html_parser import RAGFlowHtmlParser
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common import globals
|
||||
|
||||
|
||||
@manager.route("/upload", methods=["POST"]) # noqa: F821
|
||||
@ -367,7 +367,7 @@ def change_status():
|
||||
continue
|
||||
|
||||
status_int = int(status)
|
||||
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
||||
if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
|
||||
result[doc_id] = {"error": "Database error (docStore update)!"}
|
||||
result[doc_id] = {"status": status}
|
||||
except Exception as e:
|
||||
@ -432,8 +432,8 @@ def run():
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
doc = doc.to_dict()
|
||||
@ -479,8 +479,8 @@ def rename():
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
globals.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
search.index_name(tenant_id),
|
||||
@ -541,8 +541,8 @@ def change_parser():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||
|
||||
@ -35,7 +35,6 @@ from api.db import VALID_FILE_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.settings import PAGERANK_FLD
|
||||
@ -43,7 +42,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType
|
||||
|
||||
from common import globals
|
||||
|
||||
@manager.route('/create', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@ -110,11 +109,11 @@ def update():
|
||||
|
||||
if kb.pagerank != req.get("pagerank", 0):
|
||||
if req.get("pagerank", 0) > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
||||
@ -226,8 +225,8 @@ def rm():
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(STORAGE_IMPL, 'remove_bucket'):
|
||||
STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
@ -248,7 +247,7 @@ def list_tags(kb_id):
|
||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
tags = []
|
||||
for tenant in tenants:
|
||||
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
|
||||
tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id])
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@ -267,7 +266,7 @@ def list_tags_from_kbs():
|
||||
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
|
||||
tags = []
|
||||
for tenant in tenants:
|
||||
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids)
|
||||
tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids)
|
||||
return get_json_result(data=tags)
|
||||
|
||||
|
||||
@ -284,7 +283,7 @@ def rm_tags(kb_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
for t in req["tags"]:
|
||||
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
||||
globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
||||
{"remove": {"tag_kwd": t}},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb_id)
|
||||
@ -303,7 +302,7 @@ def rename_tags(kb_id):
|
||||
)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
|
||||
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
|
||||
globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
|
||||
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
|
||||
search.index_name(kb.tenant_id),
|
||||
kb_id)
|
||||
@ -326,9 +325,9 @@ def knowledge_graph(kb_id):
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||
return get_json_result(data=obj)
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
if not len(sres.ids):
|
||||
return get_json_result(data=obj)
|
||||
|
||||
@ -360,7 +359,7 @@ def delete_knowledge_graph(kb_id):
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -732,13 +731,13 @@ def delete_kb_task():
|
||||
task_id = kb.graphrag_task_id
|
||||
kb_task_finish_at = "graphrag_task_finish_at"
|
||||
cancel_task(task_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
case PipelineTaskType.RAPTOR:
|
||||
kb_task_id_field = "raptor_task_id"
|
||||
task_id = kb.raptor_task_id
|
||||
kb_task_finish_at = "raptor_task_finish_at"
|
||||
cancel_task(task_id)
|
||||
settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
|
||||
case PipelineTaskType.MINDMAP:
|
||||
kb_task_id_field = "mindmap_task_id"
|
||||
task_id = kb.mindmap_task_id
|
||||
@ -850,7 +849,7 @@ def check_embedding():
|
||||
tenant_id = kb.tenant_id
|
||||
|
||||
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
|
||||
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
||||
samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
|
||||
|
||||
results, eff_sims = [], []
|
||||
for ck in samples:
|
||||
|
||||
@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from common.constants import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
|
||||
from common.base64_image import test_image
|
||||
from rag.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ import os
|
||||
import json
|
||||
from flask import request
|
||||
from peewee import OperationalError
|
||||
from api import settings
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
@ -49,6 +48,7 @@ from api.utils.validation_utils import (
|
||||
)
|
||||
from rag.nlp import search
|
||||
from rag.settings import PAGERANK_FLD
|
||||
from common import globals
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["POST"]) # noqa: F821
|
||||
@ -360,11 +360,11 @@ def update(tenant_id, dataset_id):
|
||||
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
|
||||
|
||||
if req["pagerank"] > 0:
|
||||
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
else:
|
||||
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
||||
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
||||
search.index_name(kb.tenant_id), kb.id)
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id):
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||
return get_result(data=obj)
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
if not len(sres.ids):
|
||||
return get_result(data=obj)
|
||||
|
||||
@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id):
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
_, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
|
||||
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
|
||||
search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
|
||||
@ -25,6 +25,7 @@ from api.utils.api_utils import validate_request, build_error_result, apikey_req
|
||||
from rag.app.tag import label_question
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from common.constants import RetCode, LLMType
|
||||
from common import globals
|
||||
|
||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||
@apikey_required
|
||||
@ -137,7 +138,7 @@ def retrieval(tenant_id):
|
||||
# print("doc_ids", doc_ids)
|
||||
if not doc_ids and metadata_condition is not None:
|
||||
doc_ids = ['-999']
|
||||
ranks = settings.retriever.retrieval(
|
||||
ranks = globals.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
kb.tenant_id,
|
||||
|
||||
@ -44,6 +44,7 @@ from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource
|
||||
from common import globals
|
||||
|
||||
MAXIMUM_OF_UPLOADING_FILES = 256
|
||||
|
||||
@ -307,7 +308,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
)
|
||||
if not e:
|
||||
return get_error_data_result(message="Document not found!")
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
|
||||
if "enabled" in req:
|
||||
status = int(req["enabled"])
|
||||
@ -316,7 +317,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
|
||||
return get_error_data_result(message="Database error (Document update)!")
|
||||
|
||||
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
return get_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -755,7 +756,7 @@ def parse(tenant_id, dataset_id):
|
||||
return get_error_data_result("Can't parse document that is currently being processed")
|
||||
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
doc = doc.to_dict()
|
||||
@ -835,7 +836,7 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
|
||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||
globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
|
||||
success_count += 1
|
||||
if duplicate_messages:
|
||||
if success_count > 0:
|
||||
@ -968,7 +969,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
|
||||
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
||||
if req.get("id"):
|
||||
chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
|
||||
chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
|
||||
if not chunk:
|
||||
return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
|
||||
k = []
|
||||
@ -995,8 +996,8 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
res["chunks"].append(final_chunk)
|
||||
_ = Chunk(**final_chunk)
|
||||
|
||||
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
res["total"] = sres.total
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@ -1120,7 +1121,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||
globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||
# rename keys
|
||||
@ -1201,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
if "chunk_ids" in req:
|
||||
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
|
||||
condition["id"] = unique_chunk_ids
|
||||
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||
chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||
if chunk_number != 0:
|
||||
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
||||
if "chunk_ids" in req and chunk_number != len(unique_chunk_ids):
|
||||
@ -1273,7 +1274,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||
if chunk is None:
|
||||
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
@ -1318,7 +1319,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||
globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@ -1464,7 +1465,7 @@ def retrieval_test(tenant_id):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
ranks = globals.retriever.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
tenant_ids,
|
||||
|
||||
@ -41,6 +41,7 @@ from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
from common.constants import RetCode, LLMType, StatusEnum
|
||||
from common import globals
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
@ -1015,7 +1016,7 @@ def retrieval_test_embedded():
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
labels = label_question(question, [kb])
|
||||
ranks = settings.retriever.retrieval(
|
||||
ranks = globals.retriever.retrieval(
|
||||
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
|
||||
@ -38,6 +38,7 @@ from timeit import default_timer as timer
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from flask import jsonify
|
||||
from api.utils.health_utils import run_health_checks
|
||||
from common import globals
|
||||
|
||||
|
||||
@manager.route("/version", methods=["GET"]) # noqa: F821
|
||||
@ -100,7 +101,7 @@ def status():
|
||||
res = {}
|
||||
st = timer()
|
||||
try:
|
||||
res["doc_engine"] = settings.docStoreConn.health()
|
||||
res["doc_engine"] = globals.docStoreConn.health()
|
||||
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||
except Exception as e:
|
||||
res["doc_engine"] = {
|
||||
|
||||
@ -58,6 +58,7 @@ from api.utils.web_utils import (
|
||||
hash_code,
|
||||
captcha_key,
|
||||
)
|
||||
from common import globals
|
||||
|
||||
|
||||
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
|
||||
@ -623,7 +624,7 @@ def user_register(user_id, user):
|
||||
"id": user_id,
|
||||
"name": user["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"embd_id": globals.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
|
||||
@ -32,6 +32,7 @@ from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
from common.constants import LLMType
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import globals
|
||||
from api.common.base64 import encode_to_base64
|
||||
|
||||
|
||||
@ -49,7 +50,7 @@ def init_superuser():
|
||||
"id": user_info["id"],
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"embd_id": globals.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL
|
||||
|
||||
@ -38,6 +38,7 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.nlp import search
|
||||
from common.constants import ActiveEnum
|
||||
from common import globals
|
||||
|
||||
def create_new_user(user_info: dict) -> dict:
|
||||
"""
|
||||
@ -63,7 +64,7 @@ def create_new_user(user_info: dict) -> dict:
|
||||
"id": user_id,
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"embd_id": globals.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
@ -179,7 +180,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
)
|
||||
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
|
||||
# step1.1.3 delete chunk in es
|
||||
r = settings.docStoreConn.delete({"kb_id": kb_ids},
|
||||
r = globals.docStoreConn.delete({"kb_id": kb_ids},
|
||||
search.index_name(tenant_id), kb_ids)
|
||||
done_msg += f"- Deleted {r} chunk records.\n"
|
||||
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
|
||||
@ -237,7 +238,7 @@ def delete_user_data(user_id: str) -> dict:
|
||||
kb_doc_info = {}
|
||||
for _tenant_id, kb_doc in kb_grouped_doc.items():
|
||||
for _kb_id, docs in kb_doc.items():
|
||||
chunk_delete_res += settings.docStoreConn.delete(
|
||||
chunk_delete_res += globals.docStoreConn.delete(
|
||||
{"doc_id": [d["id"] for d in docs]},
|
||||
search.index_name(_tenant_id), _kb_id
|
||||
)
|
||||
|
||||
@ -111,12 +111,14 @@ class SyncLogsService(CommonService):
|
||||
return list(query.dicts())
|
||||
|
||||
@classmethod
|
||||
def start(cls, id):
|
||||
def start(cls, id, connector_id):
|
||||
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
|
||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING})
|
||||
|
||||
@classmethod
|
||||
def done(cls, id):
|
||||
def done(cls, id, connector_id):
|
||||
cls.update_by_id(id, {"status": TaskStatus.DONE})
|
||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE})
|
||||
|
||||
@classmethod
|
||||
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
|
||||
@ -126,6 +128,7 @@ class SyncLogsService(CommonService):
|
||||
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||
return None
|
||||
reindex = "1" if reindex else "0"
|
||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
|
||||
return cls.save(**{
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
||||
@ -142,6 +145,7 @@ class SyncLogsService(CommonService):
|
||||
full_exception_trace=cls.model.full_exception_trace + str(e)
|
||||
) \
|
||||
.where(cls.model.id == task.id).execute()
|
||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
|
||||
|
||||
@classmethod
|
||||
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
||||
|
||||
@ -44,6 +44,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common import globals
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
@ -293,12 +294,13 @@ def meta_filter(metas: dict, filters: list[dict]):
|
||||
def filter_out(v2docs, operator, value):
|
||||
ids = []
|
||||
for input, docids in v2docs.items():
|
||||
try:
|
||||
input = float(input)
|
||||
value = float(value)
|
||||
except Exception:
|
||||
input = str(input)
|
||||
value = str(value)
|
||||
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
||||
try:
|
||||
input = float(input)
|
||||
value = float(value)
|
||||
except Exception:
|
||||
input = str(input)
|
||||
value = str(value)
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
@ -371,7 +373,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
chat_mdl.bind_tools(toolcall_session, tools)
|
||||
bind_models_ts = timer()
|
||||
|
||||
retriever = settings.retriever
|
||||
retriever = globals.retriever
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
if "doc_ids" in messages[-1]:
|
||||
@ -663,7 +665,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
return globals.retriever.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
@ -757,7 +759,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||
retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
|
||||
@ -853,7 +855,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
ranks = globals.retriever.retrieval(
|
||||
question=question,
|
||||
embd_mdl=embd_mdl,
|
||||
tenant_ids=tenant_ids,
|
||||
|
||||
@ -26,7 +26,6 @@ import trio
|
||||
import xxhash
|
||||
from peewee import fn, Case, JOIN
|
||||
|
||||
from api import settings
|
||||
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
|
||||
from api.db import FileType, UserTenantRole, CanvasCategory
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
|
||||
@ -42,7 +41,7 @@ from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
|
||||
from common import globals
|
||||
|
||||
class DocumentService(CommonService):
|
||||
model = Document
|
||||
@ -309,10 +308,10 @@ class DocumentService(CommonService):
|
||||
page_size = 1000
|
||||
all_chunk_ids = []
|
||||
while True:
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
chunk_ids = globals.docStoreConn.getChunkIds(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@ -323,19 +322,19 @@ class DocumentService(CommonService):
|
||||
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
|
||||
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
|
||||
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
graph_source = globals.docStoreConn.getFields(
|
||||
globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
||||
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
|
||||
{"remove": {"source_id": doc.id}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
|
||||
{"removed_kwd": "Y"},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
|
||||
search.index_name(tenant_id), doc.kb_id)
|
||||
except Exception:
|
||||
pass
|
||||
@ -996,10 +995,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
if not globals.docStoreConn.indexExist(idxnm, kb_id):
|
||||
globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
||||
@ -30,13 +30,14 @@ class LLMService(CommonService):
|
||||
|
||||
def get_init_tenant_llm(user_id):
|
||||
from api import settings
|
||||
from common import globals
|
||||
tenant_llm = []
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG,
|
||||
settings.EMBEDDING_CFG,
|
||||
globals.EMBEDDING_CFG,
|
||||
settings.ASR_CFG,
|
||||
settings.IMAGE2TEXT_CFG,
|
||||
settings.RERANK_CFG,
|
||||
|
||||
@ -34,7 +34,7 @@ from deepdoc.parser.excel_parser import RAGFlowExcelParser
|
||||
from rag.settings import get_svr_queue_name
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from api import settings
|
||||
from common import globals
|
||||
from rag.nlp import search
|
||||
|
||||
CANVAS_DEBUG_DOC_ID = "dataflow_x"
|
||||
@ -418,7 +418,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
if pre_task["chunk_ids"]:
|
||||
pre_chunk_ids.extend(pre_task["chunk_ids"].split())
|
||||
if pre_chunk_ids:
|
||||
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
|
||||
globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
|
||||
chunking_config["kb_id"])
|
||||
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ import os
|
||||
import logging
|
||||
from langfuse import Langfuse
|
||||
from api import settings
|
||||
from common import globals
|
||||
from common.constants import LLMType
|
||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -114,7 +115,7 @@ class TenantLLMService(CommonService):
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
embedding_cfg = globals.EMBEDDING_CFG
|
||||
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
|
||||
else:
|
||||
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")
|
||||
|
||||
@ -27,7 +27,7 @@ from api.db.services.common_service import CommonService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from common.constants import StatusEnum
|
||||
from rag.settings import MINIO
|
||||
from common import globals
|
||||
|
||||
|
||||
class UserService(CommonService):
|
||||
@ -221,7 +221,7 @@ class TenantService(CommonService):
|
||||
@DB.connection_context()
|
||||
def user_gateway(cls, tenant_id):
|
||||
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
|
||||
return int(hash_obj.hexdigest(), 16)%len(MINIO)
|
||||
return int(hash_obj.hexdigest(), 16)%len(globals.MINIO)
|
||||
|
||||
|
||||
class UserTenantService(CommonService):
|
||||
|
||||
@ -25,18 +25,19 @@ import rag.utils.opensearch_conn
|
||||
from api.constants import RAG_FLOW_SERVICE_NAME
|
||||
from common.config_utils import decrypt_database_config, get_base_config
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import globals
|
||||
from rag.nlp import search
|
||||
|
||||
LLM = None
|
||||
LLM_FACTORY = None
|
||||
LLM_BASE_URL = None
|
||||
CHAT_MDL = ""
|
||||
EMBEDDING_MDL = ""
|
||||
# EMBEDDING_MDL = "" has been moved to common/globals.py
|
||||
RERANK_MDL = ""
|
||||
ASR_MDL = ""
|
||||
IMAGE2TEXT_MDL = ""
|
||||
CHAT_CFG = ""
|
||||
EMBEDDING_CFG = ""
|
||||
# EMBEDDING_CFG = "" has been moved to common/globals.py
|
||||
RERANK_CFG = ""
|
||||
ASR_CFG = ""
|
||||
IMAGE2TEXT_CFG = ""
|
||||
@ -60,10 +61,10 @@ HTTP_APP_KEY = None
|
||||
GITHUB_OAUTH = None
|
||||
FEISHU_OAUTH = None
|
||||
OAUTH_CONFIG = None
|
||||
DOC_ENGINE = None
|
||||
docStoreConn = None
|
||||
# DOC_ENGINE = None has been moved to common/globals.py
|
||||
# docStoreConn = None has been moved to common/globals.py
|
||||
|
||||
retriever = None
|
||||
#retriever = None has been moved to common/globals.py
|
||||
kg_retriever = None
|
||||
|
||||
# user registration switch
|
||||
@ -124,8 +125,8 @@ def init_settings():
|
||||
except Exception:
|
||||
FACTORY_LLM_INFOS = []
|
||||
|
||||
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
||||
global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||
global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
|
||||
|
||||
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
||||
API_KEY = LLM.get("api_key")
|
||||
@ -134,19 +135,19 @@ def init_settings():
|
||||
)
|
||||
|
||||
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
|
||||
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL))
|
||||
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL))
|
||||
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
|
||||
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
|
||||
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
|
||||
|
||||
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
|
||||
|
||||
CHAT_MDL = CHAT_CFG.get("model", "") or ""
|
||||
EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
|
||||
globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
|
||||
RERANK_MDL = RERANK_CFG.get("model", "") or ""
|
||||
ASR_MDL = ASR_CFG.get("model", "") or ""
|
||||
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
|
||||
@ -168,23 +169,23 @@ def init_settings():
|
||||
|
||||
OAUTH_CONFIG = get_base_config("oauth", {})
|
||||
|
||||
global DOC_ENGINE, docStoreConn, retriever, kg_retriever
|
||||
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
|
||||
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
||||
lower_case_doc_engine = DOC_ENGINE.lower()
|
||||
global kg_retriever
|
||||
globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
|
||||
# globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
|
||||
lower_case_doc_engine = globals.DOC_ENGINE.lower()
|
||||
if lower_case_doc_engine == "elasticsearch":
|
||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
globals.docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
elif lower_case_doc_engine == "infinity":
|
||||
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
elif lower_case_doc_engine == "opensearch":
|
||||
docStoreConn = rag.utils.opensearch_conn.OSConnection()
|
||||
globals.docStoreConn = rag.utils.opensearch_conn.OSConnection()
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}")
|
||||
|
||||
retriever = search.Dealer(docStoreConn)
|
||||
globals.retriever = search.Dealer(globals.docStoreConn)
|
||||
from graphrag import search as kg_search
|
||||
|
||||
kg_retriever = kg_search.KGSearch(docStoreConn)
|
||||
kg_retriever = kg_search.KGSearch(globals.docStoreConn)
|
||||
|
||||
if int(os.environ.get("SANDBOX_ENABLED", "0")):
|
||||
global SANDBOX_HOST
|
||||
|
||||
@ -626,7 +626,7 @@ async def is_strong_enough(chat_model, embedding_model):
|
||||
|
||||
|
||||
def get_allowed_llm_factories() -> list:
|
||||
factories = LLMFactoriesService.get_all()
|
||||
factories = list(LLMFactoriesService.get_all())
|
||||
if settings.ALLOWED_LLM_FACTORIES is None:
|
||||
return factories
|
||||
|
||||
|
||||
@ -19,11 +19,11 @@ from timeit import default_timer as timer
|
||||
|
||||
from api import settings
|
||||
from api.db.db_models import DB
|
||||
from rag import settings as rag_settings
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.utils.es_conn import ESConnection
|
||||
from rag.utils.infinity_conn import InfinityConnection
|
||||
from common import globals
|
||||
|
||||
|
||||
def _ok_nok(ok: bool) -> str:
|
||||
@ -52,7 +52,7 @@ def check_redis() -> tuple[bool, dict]:
|
||||
def check_doc_engine() -> tuple[bool, dict]:
|
||||
st = timer()
|
||||
try:
|
||||
meta = settings.docStoreConn.health()
|
||||
meta = globals.docStoreConn.health()
|
||||
# treat any successful call as ok
|
||||
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
|
||||
except Exception as e:
|
||||
@ -120,7 +120,7 @@ def get_mysql_status():
|
||||
def check_minio_alive():
|
||||
start_time = timer()
|
||||
try:
|
||||
response = requests.get(f'http://{rag_settings.MINIO["host"]}/minio/health/live')
|
||||
response = requests.get(f'http://{globals.MINIO["host"]}/minio/health/live')
|
||||
if response.status_code == 200:
|
||||
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
|
||||
else:
|
||||
|
||||
64
common/globals.py
Normal file
64
common/globals.py
Normal file
@ -0,0 +1,64 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from common.config_utils import get_base_config, decrypt_database_config
|
||||
|
||||
EMBEDDING_MDL = ""
|
||||
|
||||
EMBEDDING_CFG = ""
|
||||
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
|
||||
docStoreConn = None
|
||||
|
||||
retriever = None
|
||||
|
||||
# move from rag.settings
|
||||
ES = {}
|
||||
INFINITY = {}
|
||||
AZURE = {}
|
||||
S3 = {}
|
||||
MINIO = {}
|
||||
OSS = {}
|
||||
OS = {}
|
||||
REDIS = {}
|
||||
|
||||
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
||||
|
||||
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
|
||||
if DOC_ENGINE == 'elasticsearch':
|
||||
ES = get_base_config("es", {})
|
||||
elif DOC_ENGINE == 'opensearch':
|
||||
OS = get_base_config("os", {})
|
||||
elif DOC_ENGINE == 'infinity':
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
AZURE = get_base_config("azure", {})
|
||||
elif STORAGE_IMPL_TYPE == 'AWS_S3':
|
||||
S3 = get_base_config("s3", {})
|
||||
elif STORAGE_IMPL_TYPE == 'MINIO':
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
elif STORAGE_IMPL_TYPE == 'OSS':
|
||||
OSS = get_base_config("oss", {})
|
||||
|
||||
try:
|
||||
REDIS = decrypt_database_config(name="redis")
|
||||
except Exception:
|
||||
try:
|
||||
REDIS = get_base_config("redis", {})
|
||||
except Exception:
|
||||
REDIS = {}
|
||||
55
common/signal_utils.py
Normal file
55
common/signal_utils.py
Normal file
@ -0,0 +1,55 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import tracemalloc
|
||||
from common.log_utils import get_project_base_directory
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
if not tracemalloc.is_tracing():
|
||||
logging.info("start tracemalloc")
|
||||
tracemalloc.start()
|
||||
else:
|
||||
logging.info("tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
if sys.platform == "win32":
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
max_rss = process.memory_info().rss / 1024
|
||||
else:
|
||||
import resource
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
logging.info("stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
@ -70,6 +70,17 @@ class RAGFlowMarkdownParser:
|
||||
)
|
||||
working_text = replace_tables_with_rendered_html(no_border_table_pattern, tables)
|
||||
|
||||
# Replace any TAGS e.g. <table ...> to <table>
|
||||
TAGS = ["table", "td", "tr", "th", "tbody", "thead", "div"]
|
||||
table_with_attributes_pattern = re.compile(
|
||||
rf"<(?:{'|'.join(TAGS)})[^>]*>", re.IGNORECASE
|
||||
)
|
||||
def replace_tag(m):
|
||||
tag_name = re.match(r"<(\w+)", m.group()).group(1)
|
||||
return "<{}>".format(tag_name)
|
||||
|
||||
working_text = re.sub(table_with_attributes_pattern, replace_tag, working_text)
|
||||
|
||||
if "<table>" in working_text.lower(): # for optimize performance
|
||||
# HTML table extraction - handle possible html/body wrapper tags
|
||||
html_table_pattern = re.compile(
|
||||
|
||||
@ -6,10 +6,11 @@ set -e
|
||||
# Usage and command-line argument parsing
|
||||
# -----------------------------------------------------------------------------
|
||||
function usage() {
|
||||
echo "Usage: $0 [--disable-webserver] [--disable-taskexecutor] [--consumer-no-beg=<num>] [--consumer-no-end=<num>] [--workers=<num>] [--host-id=<string>]"
|
||||
echo "Usage: $0 [--disable-webserver] [--disable-taskexecutor] [--disable-datasync] [--consumer-no-beg=<num>] [--consumer-no-end=<num>] [--workers=<num>] [--host-id=<string>]"
|
||||
echo
|
||||
echo " --disable-webserver Disables the web server (nginx + ragflow_server)."
|
||||
echo " --disable-taskexecutor Disables task executor workers."
|
||||
echo " --disable-datasync Disables synchronization of datasource workers."
|
||||
echo " --enable-mcpserver Enables the MCP server."
|
||||
echo " --enable-adminserver Enables the Admin server."
|
||||
echo " --consumer-no-beg=<num> Start range for consumers (if using range-based)."
|
||||
@ -28,6 +29,7 @@ function usage() {
|
||||
|
||||
ENABLE_WEBSERVER=1 # Default to enable web server
|
||||
ENABLE_TASKEXECUTOR=1 # Default to enable task executor
|
||||
ENABLE_DATASYNC=1
|
||||
ENABLE_MCP_SERVER=0
|
||||
ENABLE_ADMIN_SERVER=0 # Default close admin server
|
||||
CONSUMER_NO_BEG=0
|
||||
@ -69,6 +71,10 @@ for arg in "$@"; do
|
||||
ENABLE_TASKEXECUTOR=0
|
||||
shift
|
||||
;;
|
||||
--disable-datasyn)
|
||||
ENABLE_DATASYNC=0
|
||||
shift
|
||||
;;
|
||||
--enable-mcpserver)
|
||||
ENABLE_MCP_SERVER=1
|
||||
shift
|
||||
@ -236,6 +242,13 @@ if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then
|
||||
done &
|
||||
fi
|
||||
|
||||
if [[ "${ENABLE_DATASYNC}" -eq 1 ]]; then
|
||||
echo "Starting data sync..."
|
||||
while true; do
|
||||
"$PY" rag/svr/sync_data_source.py
|
||||
done &
|
||||
fi
|
||||
|
||||
if [[ "${ENABLE_ADMIN_SERVER}" -eq 1 ]]; then
|
||||
echo "Starting admin_server..."
|
||||
while true; do
|
||||
|
||||
@ -20,7 +20,6 @@ import os
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from api.db.services.document_service import DocumentService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.connection_utils import timeout
|
||||
@ -40,6 +39,7 @@ from graphrag.utils import (
|
||||
)
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
from common import globals
|
||||
|
||||
|
||||
async def run_graphrag(
|
||||
@ -55,7 +55,7 @@ async def run_graphrag(
|
||||
start = trio.current_time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||
@ -170,7 +170,7 @@ async def run_graphrag_for_kb(
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for d in settings.retriever.chunk_list(
|
||||
for d in globals.retriever.chunk_list(
|
||||
doc_id,
|
||||
tenant_id,
|
||||
[kb_id],
|
||||
@ -387,8 +387,8 @@ async def generate_subgraph(
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
|
||||
now = trio.current_time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
return subgraph
|
||||
@ -496,7 +496,7 @@ async def extract_community(
|
||||
chunks.append(chunk)
|
||||
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
lambda: globals.docStoreConn.delete(
|
||||
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
@ -504,7 +504,7 @@ async def extract_community(
|
||||
)
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
if doc_store_result:
|
||||
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
|
||||
raise Exception(error_message)
|
||||
|
||||
@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from graphrag.general.graph_extractor import GraphExtractor
|
||||
from graphrag.general.index import update_graph, with_resolution, with_community
|
||||
from common import globals
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
@ -62,7 +63,7 @@ async def main():
|
||||
|
||||
chunks = [
|
||||
d["content_with_weight"]
|
||||
for d in settings.retriever.chunk_list(
|
||||
for d in globals.retriever.chunk_list(
|
||||
args.doc_id,
|
||||
args.tenant_id,
|
||||
[kb_id],
|
||||
|
||||
@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from graphrag.general.index import update_graph
|
||||
from graphrag.light.graph_extractor import GraphExtractor
|
||||
from common import globals
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
@ -63,7 +64,7 @@ async def main():
|
||||
|
||||
chunks = [
|
||||
d["content_with_weight"]
|
||||
for d in settings.retriever.chunk_list(
|
||||
for d in globals.retriever.chunk_list(
|
||||
args.doc_id,
|
||||
args.tenant_id,
|
||||
[kb_id],
|
||||
|
||||
@ -29,6 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr
|
||||
|
||||
from rag.nlp.search import Dealer, index_name
|
||||
from common.float_utils import get_float
|
||||
from common import globals
|
||||
|
||||
|
||||
class KGSearch(Dealer):
|
||||
@ -334,6 +335,6 @@ if __name__ == "__main__":
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
|
||||
kg = KGSearch(settings.docStoreConn)
|
||||
kg = KGSearch(globals.docStoreConn)
|
||||
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
|
||||
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))
|
||||
|
||||
@ -23,12 +23,12 @@ import trio
|
||||
import xxhash
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
from api import settings
|
||||
from common.misc_utils import get_uuid
|
||||
from common.connection_utils import timeout
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import globals
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
||||
ents = list(set(ents))
|
||||
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
|
||||
res = []
|
||||
es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
||||
es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
||||
for id in es_res.ids:
|
||||
try:
|
||||
if size == 1:
|
||||
@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
"knowledge_graph_kwd": ["graph"],
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
fields2 = settings.docStoreConn.getFields(res, fields)
|
||||
res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
fields2 = globals.docStoreConn.getFields(res, fields)
|
||||
graph_doc_ids = set()
|
||||
for chunk_id in fields2.keys():
|
||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||
@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
|
||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||
res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||
doc_ids = []
|
||||
if res.total == 0:
|
||||
return doc_ids
|
||||
@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
|
||||
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
|
||||
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
|
||||
res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id])
|
||||
if not res.total == 0:
|
||||
for id in res.ids:
|
||||
try:
|
||||
@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
global chat_limiter
|
||||
start = trio.current_time()
|
||||
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
|
||||
|
||||
if change.removed_nodes:
|
||||
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
|
||||
|
||||
if change.removed_edges:
|
||||
|
||||
async def del_edges(from_node, to_node):
|
||||
async with chat_limiter:
|
||||
await trio.to_thread.run_sync(
|
||||
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
|
||||
globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
|
||||
)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(chunks), es_bulk_size):
|
||||
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
|
||||
if b % 100 == es_bulk_size and callback:
|
||||
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
|
||||
if doc_store_result:
|
||||
@ -555,7 +555,7 @@ def merge_tuples(list1, list2):
|
||||
|
||||
|
||||
async def get_entity_type2samples(idxnms, kb_ids: list):
|
||||
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
|
||||
es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
|
||||
|
||||
res = defaultdict(list)
|
||||
for id in es_res.ids:
|
||||
@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
bs = 256
|
||||
for i in range(0, 1024 * bs, bs):
|
||||
es_res = await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||
lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||
)
|
||||
# tot = settings.docStoreConn.getTotal(es_res)
|
||||
es_res = settings.docStoreConn.getFields(es_res, flds)
|
||||
# tot = globals.docStoreConn.getTotal(es_res)
|
||||
es_res = globals.docStoreConn.getFields(es_res, flds)
|
||||
|
||||
if len(es_res) == 0:
|
||||
break
|
||||
|
||||
@ -15,18 +15,18 @@
|
||||
#
|
||||
|
||||
import logging
|
||||
from tika import parser
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.app import naive
|
||||
from rag.app.naive import plaintext_parser, PARSERS
|
||||
from rag.nlp import bullets_category, is_english,remove_contents_table, \
|
||||
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
|
||||
tokenize_chunks
|
||||
from rag.nlp import rag_tokenizer
|
||||
from deepdoc.parser import PdfParser, PlainParser, HtmlParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
from deepdoc.parser import PdfParser, HtmlParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@ -96,13 +96,33 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, plaintext_parser)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tables, _ = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections and not tables:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
#
|
||||
|
||||
import logging
|
||||
from tika import parser
|
||||
import re
|
||||
from io import BytesIO
|
||||
from docx import Document
|
||||
@ -25,8 +24,8 @@ from deepdoc.parser.utils import get_text
|
||||
from rag.nlp import bullets_category, remove_contents_table, \
|
||||
make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge
|
||||
from rag.nlp import rag_tokenizer, Node
|
||||
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
|
||||
|
||||
from deepdoc.parser import PdfParser, DocxParser, HtmlParser
|
||||
from rag.app.naive import plaintext_parser, PARSERS
|
||||
|
||||
|
||||
|
||||
@ -156,13 +155,36 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
return tokenize_chunks(chunks, doc, eng, None)
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
for txt, poss in pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)[0]:
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, plaintext_parser)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
raw_sections, tables, _ = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not raw_sections and not tables:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
for txt, poss in raw_sections:
|
||||
sections.append(txt + poss)
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
txt = get_text(filename, binary)
|
||||
|
||||
@ -22,11 +22,11 @@ from common.constants import ParserType
|
||||
from io import BytesIO
|
||||
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from deepdoc.parser import PdfParser, PlainParser, DocxParser
|
||||
from deepdoc.parser import PdfParser, DocxParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
from docx import Document
|
||||
from PIL import Image
|
||||
|
||||
from rag.app.naive import plaintext_parser, PARSERS
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
@ -196,15 +196,34 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
# is it English
|
||||
eng = lang.lower() == "english" # pdf_parser.is_english
|
||||
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
if sections and len(sections[0]) < 3:
|
||||
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
|
||||
# set pivot using the most frequent type of title,
|
||||
# then merge between 2 pivot
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
pdf_parser = PARSERS.get(name, plaintext_parser)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tbls, pdf_parser = pdf_parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections and not tbls:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03:
|
||||
max_lvl = max([lvl for _, lvl in pdf_parser.outlines])
|
||||
most_level = max(0, max_lvl - 1)
|
||||
|
||||
196
rag/app/naive.py
196
rag/app/naive.py
@ -26,7 +26,6 @@ from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
|
||||
from docx.opc.oxml import parse_xml
|
||||
from markdown import markdown
|
||||
from PIL import Image
|
||||
from tika import parser
|
||||
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
@ -39,6 +38,100 @@ from deepdoc.parser.docling_parser import DoclingParser
|
||||
from deepdoc.parser.tcadp_parser import TCADPParser
|
||||
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
|
||||
|
||||
def DeepDOC_parser(filename, binary=None, from_page=0, to_page=100000, callback=None, pdf_cls = None ,**kwargs):
|
||||
callback = callback
|
||||
binary = binary
|
||||
pdf_parser = pdf_cls() if pdf_cls else Pdf()
|
||||
sections, tables = pdf_parser(
|
||||
filename if not binary else binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
callback=callback
|
||||
)
|
||||
tables = vision_figure_parser_pdf_wrapper(tbls=tables,
|
||||
callback=callback,
|
||||
**kwargs)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
def MinerU_parser(filename, binary=None, callback=None, **kwargs):
|
||||
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
|
||||
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
|
||||
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api)
|
||||
|
||||
if not pdf_parser.check_installation():
|
||||
callback(-1, "MinerU not found.")
|
||||
return None, None
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
def Docling_parser(filename, binary=None, callback=None, **kwargs):
|
||||
pdf_parser = DoclingParser()
|
||||
|
||||
if not pdf_parser.check_installation():
|
||||
callback(-1, "Docling not found.")
|
||||
return None, None
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
def TCADP_parser(filename, binary=None, callback=None, **kwargs):
|
||||
tcadp_parser = TCADPParser()
|
||||
|
||||
if not tcadp_parser.check_installation():
|
||||
callback(-1, "TCADP parser not available. Please check Tencent Cloud API configuration.")
|
||||
return None, None
|
||||
|
||||
sections, tables = tcadp_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("TCADP_OUTPUT_DIR", ""),
|
||||
file_type="PDF"
|
||||
)
|
||||
return sections, tables, tcadp_parser
|
||||
|
||||
|
||||
def plaintext_parser(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
|
||||
if kwargs.get("layout_recognizer", "") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
else:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=kwargs.get("layout_recognizer", ""), lang=kwargs.get("lang", "Chinese"))
|
||||
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
|
||||
|
||||
sections, tables = pdf_parser(
|
||||
filename if not binary else binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
callback=callback
|
||||
)
|
||||
return sections, tables, pdf_parser
|
||||
|
||||
|
||||
PARSERS = {
|
||||
"deepdoc": DeepDOC_parser,
|
||||
"mineru": MinerU_parser,
|
||||
"docling": Docling_parser,
|
||||
"tcadp": TCADP_parser,
|
||||
"plaintext": plaintext_parser, # default
|
||||
}
|
||||
|
||||
|
||||
class Docx(DocxParser):
|
||||
def __init__(self):
|
||||
@ -365,7 +458,7 @@ class Markdown(MarkdownParser):
|
||||
html_content = markdown(text)
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
return soup
|
||||
|
||||
|
||||
def get_picture_urls(self, soup):
|
||||
if soup:
|
||||
return [img.get('src') for img in soup.find_all('img') if img.get('src')]
|
||||
@ -375,7 +468,7 @@ class Markdown(MarkdownParser):
|
||||
if soup:
|
||||
return set([a.get('href') for a in soup.find_all('a') if a.get('href')])
|
||||
return []
|
||||
|
||||
|
||||
def get_pictures(self, text):
|
||||
"""Download and open all images from markdown text."""
|
||||
import requests
|
||||
@ -416,11 +509,11 @@ class Markdown(MarkdownParser):
|
||||
txt = f.read()
|
||||
|
||||
remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables)
|
||||
|
||||
# To eliminate duplicate tables in chunking result, uncomment code below and set separate_tables to True in line 410.
|
||||
# extractor = MarkdownElementExtractor(remainder)
|
||||
extractor = MarkdownElementExtractor(txt)
|
||||
element_sections = extractor.extract_elements(delimiter)
|
||||
sections = [(element, "") for element in element_sections]
|
||||
|
||||
tbls = []
|
||||
for table in tables:
|
||||
tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), ""))
|
||||
@ -535,82 +628,29 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, plaintext_parser)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
if layout_recognizer == "DeepDOC":
|
||||
pdf_parser = Pdf()
|
||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
|
||||
tables=vision_figure_parser_pdf_wrapper(tbls=tables,callback=callback,**kwargs)
|
||||
sections, tables, _ = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
if not sections and not tables:
|
||||
return []
|
||||
|
||||
elif layout_recognizer == "MinerU":
|
||||
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
|
||||
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
|
||||
mineru_server_url = os.environ.get("MINERU_SERVER_URL", "")
|
||||
mineru_backend = os.environ.get("MINERU_BACKEND", "pipeline")
|
||||
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api, mineru_server_url=mineru_server_url)
|
||||
ok, reason = pdf_parser.check_installation(backend=mineru_backend)
|
||||
if not ok:
|
||||
callback(-1, f"MinerU not found or server not accessible: {reason}")
|
||||
return res
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
backend=mineru_backend,
|
||||
server_url=mineru_server_url,
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif layout_recognizer == "Docling":
|
||||
pdf_parser = DoclingParser()
|
||||
if not pdf_parser.check_installation():
|
||||
callback(-1, "Docling not found.")
|
||||
return res
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
|
||||
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
|
||||
)
|
||||
parser_config["chunk_token_num"] = 0
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif layout_recognizer == "TCADP Parser":
|
||||
tcadp_parser = TCADPParser()
|
||||
if not tcadp_parser.check_installation():
|
||||
callback(-1, "TCADP parser not available. Please check Tencent Cloud API configuration.")
|
||||
return res
|
||||
|
||||
sections, tables = tcadp_parser.parse_pdf(
|
||||
filepath=filename,
|
||||
binary=binary,
|
||||
callback=callback,
|
||||
output_dir=os.environ.get("TCADP_OUTPUT_DIR", ""),
|
||||
file_type="PDF"
|
||||
)
|
||||
parser_config["chunk_token_num"] = 0
|
||||
callback(0.8, "Finish parsing.")
|
||||
else:
|
||||
if layout_recognizer == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
else:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
|
||||
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
|
||||
|
||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||
callback=callback)
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
@ -735,9 +775,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
|
||||
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
|
||||
url_res.extend(sub_url_res)
|
||||
|
||||
|
||||
logging.info("naive_merge({}): {}".format(filename, timer() - st))
|
||||
|
||||
|
||||
if embed_res:
|
||||
res.extend(embed_res)
|
||||
if url_res:
|
||||
|
||||
@ -15,16 +15,15 @@
|
||||
#
|
||||
|
||||
import logging
|
||||
from tika import parser
|
||||
from io import BytesIO
|
||||
import re
|
||||
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.app import naive
|
||||
from rag.nlp import rag_tokenizer, tokenize
|
||||
from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
|
||||
|
||||
from deepdoc.parser import PdfParser, ExcelParser, HtmlParser
|
||||
from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
|
||||
from rag.app.naive import plaintext_parser, PARSERS
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
@ -83,12 +82,34 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, tbls = pdf_parser(
|
||||
filename if not binary else binary, to_page=to_page, callback=callback)
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, plaintext_parser)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, tbls, _ = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections and not tbls:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
for (img, rows), poss in tbls:
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
@ -20,14 +20,11 @@ from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser.pdf_parser import VisionParser
|
||||
from rag.nlp import tokenize, is_english
|
||||
from rag.nlp import rag_tokenizer
|
||||
from deepdoc.parser import PdfParser, PptParser, PlainParser
|
||||
from PyPDF2 import PdfReader as pdf2_read
|
||||
|
||||
from rag.app.naive import plaintext_parser, PARSERS
|
||||
|
||||
class Ppt(PptParser):
|
||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
||||
@ -54,7 +51,6 @@ class Ppt(PptParser):
|
||||
self.is_english = is_english(txts)
|
||||
return [(txts[i], imgs[i]) for i in range(len(txts))]
|
||||
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -84,7 +80,7 @@ class Pdf(PdfParser):
|
||||
res.append((lines, self.page_images[i]))
|
||||
callback(0.9, "Page {}~{}: Parsing finished".format(
|
||||
from_page, min(to_page, self.total_page)))
|
||||
return res
|
||||
return res, []
|
||||
|
||||
|
||||
class PlainPdf(PlainParser):
|
||||
@ -95,7 +91,7 @@ class PlainPdf(PlainParser):
|
||||
for page in self.pdf.pages[from_page: to_page]:
|
||||
page_txt.append(page.extract_text())
|
||||
callback(0.9, "Parsing finished")
|
||||
return [(txt, None) for txt in page_txt]
|
||||
return [(txt, None) for txt in page_txt], []
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
@ -130,20 +126,33 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
return res
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
if layout_recognizer == "DeepDOC":
|
||||
pdf_parser = Pdf()
|
||||
sections = pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
|
||||
elif layout_recognizer == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||
callback=callback)
|
||||
else:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
|
||||
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
|
||||
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||
callback=callback)
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, plaintext_parser)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, _, _ = parser(
|
||||
filename = filename,
|
||||
binary = binary,
|
||||
from_page = from_page,
|
||||
to_page = to_page,
|
||||
lang = lang,
|
||||
callback = callback,
|
||||
pdf_cls = Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
for pn, (txt, img) in enumerate(sections):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
|
||||
@ -21,6 +21,7 @@ from copy import deepcopy
|
||||
from deepdoc.parser.utils import get_text
|
||||
from rag.app.qa import Excel
|
||||
from rag.nlp import rag_tokenizer
|
||||
from common import globals
|
||||
|
||||
|
||||
def beAdoc(d, q, a, eng, row_num=-1):
|
||||
@ -124,7 +125,6 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||
def label_question(question, kbs):
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||
from api import settings
|
||||
tags = None
|
||||
tag_kb_ids = []
|
||||
for kb in kbs:
|
||||
@ -133,14 +133,14 @@ def label_question(question, kbs):
|
||||
if tag_kb_ids:
|
||||
all_tags = get_tags_from_cache(tag_kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
||||
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
||||
if not tag_kbs:
|
||||
return tags
|
||||
tags = settings.retriever.tag_query(question,
|
||||
tags = globals.retriever.tag_query(question,
|
||||
list(set([kb.tenant_id for kb in tag_kbs])),
|
||||
tag_kb_ids,
|
||||
all_tags,
|
||||
|
||||
@ -20,10 +20,10 @@ import time
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
from common import globals
|
||||
from common.constants import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api import settings
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.nlp import tokenize, search
|
||||
from ranx import evaluate
|
||||
@ -52,7 +52,7 @@ class Benchmark:
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
@ -77,9 +77,9 @@ class Benchmark:
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
if globals.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
globals.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
@ -114,13 +114,13 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
docs = []
|
||||
|
||||
if docs:
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
return qrels, texts
|
||||
|
||||
def trivia_qa_index(self, file_path, index_name):
|
||||
@ -155,12 +155,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs,self.index_name)
|
||||
globals.docStoreConn.insert(docs,self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def miracl_index(self, file_path, corpus_path, index_name):
|
||||
@ -210,12 +210,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
globals.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||
|
||||
@ -21,7 +21,7 @@ from functools import partial
|
||||
import trio
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from common.base64_image import id2image, image2id
|
||||
from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
|
||||
|
||||
@ -27,7 +27,7 @@ from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.misc_utils import get_uuid
|
||||
from common.base64_image import image2id
|
||||
from rag.utils.base64_image import image2id
|
||||
from deepdoc.parser import ExcelParser
|
||||
from deepdoc.parser.mineru_parser import MinerUParser
|
||||
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
||||
|
||||
@ -18,7 +18,7 @@ from functools import partial
|
||||
import trio
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from common.base64_image import id2image, image2id
|
||||
from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||
|
||||
@ -29,7 +29,7 @@ from zhipuai import ZhipuAI
|
||||
|
||||
from common.log_utils import log_exception
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from api import settings
|
||||
from common import globals
|
||||
import logging
|
||||
|
||||
|
||||
@ -69,13 +69,13 @@ class BuiltinEmbed(Base):
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
|
||||
embedding_cfg = globals.EMBEDDING_CFG
|
||||
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
|
||||
with BuiltinEmbed._model_lock:
|
||||
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
|
||||
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
|
||||
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
||||
BuiltinEmbed._model_name = globals.EMBEDDING_MDL
|
||||
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500)
|
||||
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
|
||||
self._model = BuiltinEmbed._model
|
||||
self._model_name = BuiltinEmbed._model_name
|
||||
self._max_tokens = BuiltinEmbed._max_tokens
|
||||
|
||||
@ -15,49 +15,12 @@
|
||||
#
|
||||
import os
|
||||
import logging
|
||||
from common.config_utils import get_base_config, decrypt_database_config
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import pip_install_torch
|
||||
|
||||
# Server
|
||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
|
||||
# Get storage type and document engine from system environment variables
|
||||
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
|
||||
ES = {}
|
||||
INFINITY = {}
|
||||
AZURE = {}
|
||||
S3 = {}
|
||||
MINIO = {}
|
||||
OSS = {}
|
||||
OS = {}
|
||||
|
||||
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
|
||||
if DOC_ENGINE == 'elasticsearch':
|
||||
ES = get_base_config("es", {})
|
||||
elif DOC_ENGINE == 'opensearch':
|
||||
OS = get_base_config("os", {})
|
||||
elif DOC_ENGINE == 'infinity':
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
AZURE = get_base_config("azure", {})
|
||||
elif STORAGE_IMPL_TYPE == 'AWS_S3':
|
||||
S3 = get_base_config("s3", {})
|
||||
elif STORAGE_IMPL_TYPE == 'MINIO':
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
elif STORAGE_IMPL_TYPE == 'OSS':
|
||||
OSS = get_base_config("oss", {})
|
||||
|
||||
try:
|
||||
REDIS = decrypt_database_config(name="redis")
|
||||
except Exception:
|
||||
try:
|
||||
REDIS = get_base_config("redis", {})
|
||||
except Exception:
|
||||
REDIS = {}
|
||||
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
|
||||
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
|
||||
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))
|
||||
|
||||
@ -26,13 +26,12 @@ import traceback
|
||||
|
||||
from api.db.services.connector_service import SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||
from api.utils.configs import show_configs
|
||||
from common.log_utils import init_root_logger
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import BlobStorageConnector
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
import tracemalloc
|
||||
import signal
|
||||
import trio
|
||||
import faulthandler
|
||||
@ -41,6 +40,7 @@ from api import settings
|
||||
from api.versions import get_ragflow_version
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
@ -51,11 +51,39 @@ class SyncBase:
|
||||
self.conf = conf
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
SyncLogsService.start(task["id"])
|
||||
SyncLogsService.start(task["id"], task["connector_id"])
|
||||
try:
|
||||
async with task_limiter:
|
||||
with trio.fail_after(task["timeout_secs"]):
|
||||
task["poll_range_start"] = await self._run(task)
|
||||
document_batch_generator = await self._generate(task)
|
||||
doc_num = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
for document_batch in document_batch_generator:
|
||||
min_update = min([doc.doc_updated_at for doc in document_batch])
|
||||
max_update = max([doc.doc_updated_at for doc in document_batch])
|
||||
next_update = max([next_update, max_update])
|
||||
docs = [{
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": FileSource.S3,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob
|
||||
} for doc in document_batch]
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
|
||||
logging.info("{} docs synchronized till {}".format(doc_num, next_update))
|
||||
SyncLogsService.done(task["id"], task["connector_id"])
|
||||
task["poll_range_start"] = next_update
|
||||
|
||||
except Exception as ex:
|
||||
msg = '\n'.join([
|
||||
''.join(traceback.format_exception_only(None, ex)).strip(),
|
||||
@ -65,12 +93,12 @@ class SyncBase:
|
||||
|
||||
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
|
||||
|
||||
async def _run(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class S3(SyncBase):
|
||||
async def _run(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
self.connector = BlobStorageConnector(
|
||||
bucket_type=self.conf.get("bucket_type", "s3"),
|
||||
bucket_name=self.conf["bucket_name"],
|
||||
@ -85,40 +113,11 @@ class S3(SyncBase):
|
||||
self.conf["bucket_name"],
|
||||
begin_info
|
||||
))
|
||||
doc_num = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
for document_batch in document_batch_generator:
|
||||
min_update = min([doc.doc_updated_at for doc in document_batch])
|
||||
max_update = max([doc.doc_updated_at for doc in document_batch])
|
||||
next_update = max([next_update, max_update])
|
||||
docs = [{
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": FileSource.S3,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob
|
||||
} for doc in document_batch]
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
|
||||
logging.info("{} docs synchronized from {}: {} {}".format(doc_num, self.conf.get("bucket_type", "s3"),
|
||||
self.conf["bucket_name"],
|
||||
begin_info
|
||||
))
|
||||
SyncLogsService.done(task["id"])
|
||||
return next_update
|
||||
return document_batch_generator
|
||||
|
||||
|
||||
class Confluence(SyncBase):
|
||||
async def _run(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
from common.data_source.interfaces import StaticCredentialsProvider
|
||||
from common.data_source.config import DocumentSource
|
||||
|
||||
@ -156,85 +155,57 @@ class Confluence(SyncBase):
|
||||
)
|
||||
|
||||
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
|
||||
|
||||
doc_num = 0
|
||||
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
if task["poll_range_start"]:
|
||||
next_update = task["poll_range_start"]
|
||||
|
||||
for doc in document_generator:
|
||||
min_update = doc.doc_updated_at if doc.doc_updated_at else next_update
|
||||
max_update = doc.doc_updated_at if doc.doc_updated_at else next_update
|
||||
next_update = max([next_update, max_update])
|
||||
|
||||
docs = [{
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": FileSource.CONFLUENCE,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob
|
||||
}]
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENCE}/{task['connector_id']}")
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
|
||||
logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info))
|
||||
SyncLogsService.done(task["id"])
|
||||
return next_update
|
||||
return document_generator
|
||||
|
||||
|
||||
class Notion(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Discord(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Gmail(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class GoogleDriver(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Jira(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class SharePoint(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Slack(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
class Teams(SyncBase):
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
|
||||
|
||||
func_factory = {
|
||||
FileSource.S3: S3,
|
||||
FileSource.NOTION: Notion,
|
||||
@ -263,41 +234,6 @@ async def dispatch_tasks():
|
||||
stop_event = threading.Event()
|
||||
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
if not tracemalloc.is_tracing():
|
||||
logging.info("start tracemalloc")
|
||||
tracemalloc.start()
|
||||
else:
|
||||
logging.info("tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
if sys.platform == "win32":
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
max_rss = process.memory_info().rss / 1024
|
||||
else:
|
||||
import resource
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
logging.info("stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("Received interrupt signal, shutting down...")
|
||||
stop_event.set()
|
||||
|
||||
@ -27,9 +27,8 @@ from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from common.connection_utils import timeout
|
||||
from common.base64_image import image2id
|
||||
from rag.utils.base64_image import image2id
|
||||
from common.log_utils import init_root_logger
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.config_utils import show_configs
|
||||
from graphrag.general.index import run_graphrag_for_kb
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
@ -45,7 +44,6 @@ import re
|
||||
from functools import partial
|
||||
from multiprocessing.context import TimeoutError
|
||||
from timeit import default_timer as timer
|
||||
import tracemalloc
|
||||
import signal
|
||||
import trio
|
||||
import exceptiongroup
|
||||
@ -69,6 +67,8 @@ from common.token_utils import num_tokens_from_string, truncate
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from graphrag.utils import chat_limiter
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common import globals
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
@ -129,40 +129,6 @@ def signal_handler(sig, frame):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
if not tracemalloc.is_tracing():
|
||||
logging.info("start tracemalloc")
|
||||
tracemalloc.start()
|
||||
else:
|
||||
logging.info("tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
if sys.platform == "win32":
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
max_rss = process.memory_info().rss / 1024
|
||||
else:
|
||||
import resource
|
||||
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
|
||||
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
if tracemalloc.is_tracing():
|
||||
logging.info("stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
else:
|
||||
logging.info("tracemalloc not running")
|
||||
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@ -384,7 +350,7 @@ async def build_chunks(task, progress_callback):
|
||||
examples = []
|
||||
all_tags = get_tags_from_cache(kb_ids)
|
||||
if not all_tags:
|
||||
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
|
||||
set_tags_to_cache(kb_ids, all_tags)
|
||||
else:
|
||||
all_tags = json.loads(all_tags)
|
||||
@ -397,7 +363,7 @@ async def build_chunks(task, progress_callback):
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
docs_to_tag.append(d)
|
||||
@ -458,7 +424,7 @@ def build_TOC(task, docs, progress_callback):
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@ -682,7 +648,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for doc_id in doc_ids:
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
@ -733,7 +699,7 @@ async def delete_image(kb_id, chunk_id):
|
||||
|
||||
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
|
||||
for b in range(0, len(chunks), DOC_BULK_SIZE):
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
@ -750,7 +716,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
TaskService.update_chunk_ids(task_id, chunk_ids_str)
|
||||
except DoesNotExist:
|
||||
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
@ -786,7 +752,7 @@ async def do_handle_task(task):
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
|
||||
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
|
||||
lower_case_doc_engine = settings.DOC_ENGINE.lower()
|
||||
lower_case_doc_engine = globals.DOC_ENGINE.lower()
|
||||
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
|
||||
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
|
||||
progress_callback(-1, msg=error_message)
|
||||
@ -1067,8 +1033,8 @@ async def main():
|
||||
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||
show_configs()
|
||||
settings.init_settings()
|
||||
from api.settings import EMBEDDING_CFG
|
||||
logging.info(f'api.settings.EMBEDDING_CFG: {EMBEDDING_CFG}')
|
||||
from common import globals
|
||||
logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}')
|
||||
print_rag_settings()
|
||||
if sys.platform != "win32":
|
||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||
|
||||
@ -18,17 +18,17 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from io import BytesIO
|
||||
from rag import settings
|
||||
from common.decorator import singleton
|
||||
from azure.storage.blob import ContainerClient
|
||||
from common import globals
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowAzureSasBlob:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"])
|
||||
self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"])
|
||||
self.container_url = os.getenv('CONTAINER_URL', globals.AZURE["container_url"])
|
||||
self.sas_token = os.getenv('SAS_TOKEN', globals.AZURE["sas_token"])
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
|
||||
@ -17,21 +17,21 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from rag import settings
|
||||
from common.decorator import singleton
|
||||
from azure.identity import ClientSecretCredential, AzureAuthorityHosts
|
||||
from azure.storage.filedatalake import FileSystemClient
|
||||
from common import globals
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowAzureSpnBlob:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"])
|
||||
self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"])
|
||||
self.secret = os.getenv('SECRET', settings.AZURE["secret"])
|
||||
self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"])
|
||||
self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"])
|
||||
self.account_url = os.getenv('ACCOUNT_URL', globals.AZURE["account_url"])
|
||||
self.client_id = os.getenv('CLIENT_ID', globals.AZURE["client_id"])
|
||||
self.secret = os.getenv('SECRET', globals.AZURE["secret"])
|
||||
self.tenant_id = os.getenv('TENANT_ID', globals.AZURE["tenant_id"])
|
||||
self.container_name = os.getenv('CONTAINER_NAME', globals.AZURE["container_name"])
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
|
||||
@ -24,7 +24,6 @@ import copy
|
||||
from elasticsearch import Elasticsearch, NotFoundError
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from rag import settings
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
@ -33,6 +32,7 @@ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr,
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common.float_utils import get_float
|
||||
from common import globals
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
@ -43,17 +43,17 @@ logger = logging.getLogger('ragflow.es_conn')
|
||||
class ESConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
|
||||
logger.info(f"Use Elasticsearch {globals.ES['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
if self._connect():
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
|
||||
logger.warning(f"{str(e)}. Waiting Elasticsearch {globals.ES['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
|
||||
if not self.es.ping():
|
||||
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
|
||||
msg = f"Elasticsearch {globals.ES['hosts']} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "8.11.3"})
|
||||
@ -68,14 +68,14 @@ class ESConnection(DocStoreConnection):
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
|
||||
logger.info(f"Elasticsearch {globals.ES['hosts']} is healthy.")
|
||||
|
||||
def _connect(self):
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
basic_auth=(settings.ES["username"], settings.ES[
|
||||
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
verify_certs= settings.ES.get("verify_certs", False),
|
||||
globals.ES["hosts"].split(","),
|
||||
basic_auth=(globals.ES["username"], globals.ES[
|
||||
"password"]) if "username" in globals.ES and "password" in globals.ES else None,
|
||||
verify_certs= globals.ES.get("verify_certs", False),
|
||||
timeout=600 )
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
|
||||
@ -25,11 +25,11 @@ from infinity.common import ConflictType, InfinityException, SortType
|
||||
from infinity.index import IndexInfo, IndexType
|
||||
from infinity.connection_pool import ConnectionPool
|
||||
from infinity.errors import ErrorCode
|
||||
from rag import settings
|
||||
from rag.settings import PAGERANK_FLD, TAG_FLD
|
||||
from common.decorator import singleton
|
||||
import pandas as pd
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import globals
|
||||
from rag.nlp import is_english
|
||||
|
||||
from rag.utils.doc_store_conn import (
|
||||
@ -130,8 +130,8 @@ def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> p
|
||||
@singleton
|
||||
class InfinityConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||
infinity_uri = settings.INFINITY["uri"]
|
||||
self.dbName = globals.INFINITY.get("db_name", "default_db")
|
||||
infinity_uri = globals.INFINITY["uri"]
|
||||
if ":" in infinity_uri:
|
||||
host, port = infinity_uri.split(":")
|
||||
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||
|
||||
@ -20,8 +20,8 @@ from minio import Minio
|
||||
from minio.commonconfig import CopySource
|
||||
from minio.error import S3Error
|
||||
from io import BytesIO
|
||||
from rag import settings
|
||||
from common.decorator import singleton
|
||||
from common import globals
|
||||
|
||||
|
||||
@singleton
|
||||
@ -38,14 +38,14 @@ class RAGFlowMinio:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = Minio(settings.MINIO["host"],
|
||||
access_key=settings.MINIO["user"],
|
||||
secret_key=settings.MINIO["password"],
|
||||
self.conn = Minio(globals.MINIO["host"],
|
||||
access_key=globals.MINIO["user"],
|
||||
secret_key=globals.MINIO["password"],
|
||||
secure=False
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Fail to connect %s " % settings.MINIO["host"])
|
||||
"Fail to connect %s " % globals.MINIO["host"])
|
||||
|
||||
def __close__(self):
|
||||
del self.conn
|
||||
|
||||
@ -24,13 +24,13 @@ import copy
|
||||
from opensearchpy import OpenSearch, NotFoundError
|
||||
from opensearchpy import UpdateByQuery, Q, Search, Index
|
||||
from opensearchpy import ConnectionTimeout
|
||||
from rag import settings
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common import globals
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
@ -41,13 +41,13 @@ logger = logging.getLogger('ragflow.opensearch_conn')
|
||||
class OSConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
logger.info(f"Use OpenSearch {settings.OS['hosts']} as the doc engine.")
|
||||
logger.info(f"Use OpenSearch {globals.OS['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
self.os = OpenSearch(
|
||||
settings.OS["hosts"].split(","),
|
||||
http_auth=(settings.OS["username"], settings.OS[
|
||||
"password"]) if "username" in settings.OS and "password" in settings.OS else None,
|
||||
globals.OS["hosts"].split(","),
|
||||
http_auth=(globals.OS["username"], globals.OS[
|
||||
"password"]) if "username" in globals.OS and "password" in globals.OS else None,
|
||||
verify_certs=False,
|
||||
timeout=600
|
||||
)
|
||||
@ -55,10 +55,10 @@ class OSConnection(DocStoreConnection):
|
||||
self.info = self.os.info()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting OpenSearch {settings.OS['hosts']} to be healthy.")
|
||||
logger.warning(f"{str(e)}. Waiting OpenSearch {globals.OS['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
if not self.os.ping():
|
||||
msg = f"OpenSearch {settings.OS['hosts']} is unhealthy in 120s."
|
||||
msg = f"OpenSearch {globals.OS['hosts']} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "2.18.0"})
|
||||
@ -73,7 +73,7 @@ class OSConnection(DocStoreConnection):
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
logger.info(f"OpenSearch {settings.OS['hosts']} is healthy.")
|
||||
logger.info(f"OpenSearch {globals.OS['hosts']} is healthy.")
|
||||
|
||||
"""
|
||||
Database operations
|
||||
|
||||
@ -20,14 +20,14 @@ from botocore.config import Config
|
||||
import time
|
||||
from io import BytesIO
|
||||
from common.decorator import singleton
|
||||
from rag import settings
|
||||
from common import globals
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowOSS:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.oss_config = settings.OSS
|
||||
self.oss_config = globals.OSS
|
||||
self.access_key = self.oss_config.get('access_key', None)
|
||||
self.secret_key = self.oss_config.get('secret_key', None)
|
||||
self.endpoint_url = self.oss_config.get('endpoint_url', None)
|
||||
|
||||
@ -19,8 +19,8 @@ import json
|
||||
import uuid
|
||||
|
||||
import valkey as redis
|
||||
from rag import settings
|
||||
from common.decorator import singleton
|
||||
from common import globals
|
||||
from valkey.lock import Lock
|
||||
import trio
|
||||
|
||||
@ -61,7 +61,7 @@ class RedisDB:
|
||||
|
||||
def __init__(self):
|
||||
self.REDIS = None
|
||||
self.config = settings.REDIS
|
||||
self.config = globals.REDIS
|
||||
self.__open__()
|
||||
|
||||
def register_scripts(self) -> None:
|
||||
|
||||
@ -21,13 +21,14 @@ from botocore.config import Config
|
||||
import time
|
||||
from io import BytesIO
|
||||
from common.decorator import singleton
|
||||
from rag import settings
|
||||
from common import globals
|
||||
|
||||
|
||||
@singleton
|
||||
class RAGFlowS3:
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.s3_config = settings.S3
|
||||
self.s3_config = globals.S3
|
||||
self.access_key = self.s3_config.get('access_key', None)
|
||||
self.secret_key = self.s3_config.get('secret_key', None)
|
||||
self.session_token = self.s3_config.get('session_token', None)
|
||||
|
||||
@ -15,10 +15,10 @@ import {
|
||||
} from '@/components/ui/form';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { SwitchOperatorOptions } from '@/constants/agent';
|
||||
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
|
||||
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
|
||||
import { SwitchOperatorOptions } from '@/pages/agent/constant';
|
||||
import { PromptEditor } from '@/pages/agent/form/components/prompt-editor';
|
||||
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
|
||||
import { Plus, X } from 'lucide-react';
|
||||
import { useCallback } from 'react';
|
||||
import { useFieldArray, useFormContext } from 'react-hook-form';
|
||||
|
||||
@ -13,6 +13,7 @@ const ScrollBar = React.forwardRef<
|
||||
ref={ref}
|
||||
orientation={orientation}
|
||||
className={cn(
|
||||
'z-[100]',
|
||||
'flex touch-none select-none transition-colors',
|
||||
orientation === 'vertical' &&
|
||||
'h-full w-2.5 border-l border-l-transparent p-[1px]',
|
||||
@ -22,7 +23,7 @@ const ScrollBar = React.forwardRef<
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border" />
|
||||
<ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border backdrop-blur-md" />
|
||||
</ScrollAreaPrimitive.ScrollAreaScrollbar>
|
||||
));
|
||||
ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName;
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { setInitialChatVariableEnabledFieldValue } from '@/utils/chat';
|
||||
import { Circle, CircleSlash2 } from 'lucide-react';
|
||||
import { ChatVariableEnabledField, variableEnabledFieldMap } from './chat';
|
||||
|
||||
export enum ProgrammingLanguage {
|
||||
@ -117,3 +118,53 @@ export enum Operator {
|
||||
HierarchicalMerger = 'HierarchicalMerger',
|
||||
Extractor = 'Extractor',
|
||||
}
|
||||
|
||||
export enum ComparisonOperator {
|
||||
Equal = '=',
|
||||
NotEqual = '≠',
|
||||
GreatThan = '>',
|
||||
GreatEqual = '≥',
|
||||
LessThan = '<',
|
||||
LessEqual = '≤',
|
||||
Contains = 'contains',
|
||||
NotContains = 'not contains',
|
||||
StartWith = 'start with',
|
||||
EndWith = 'end with',
|
||||
Empty = 'empty',
|
||||
NotEmpty = 'not empty',
|
||||
}
|
||||
|
||||
export const SwitchOperatorOptions = [
|
||||
{ value: ComparisonOperator.Equal, label: 'equal', icon: 'equal' },
|
||||
{ value: ComparisonOperator.NotEqual, label: 'notEqual', icon: 'not-equals' },
|
||||
{ value: ComparisonOperator.GreatThan, label: 'gt', icon: 'Less' },
|
||||
{
|
||||
value: ComparisonOperator.GreatEqual,
|
||||
label: 'ge',
|
||||
icon: 'Greater-or-equal',
|
||||
},
|
||||
{ value: ComparisonOperator.LessThan, label: 'lt', icon: 'Less' },
|
||||
{ value: ComparisonOperator.LessEqual, label: 'le', icon: 'less-or-equal' },
|
||||
{ value: ComparisonOperator.Contains, label: 'contains', icon: 'Contains' },
|
||||
{
|
||||
value: ComparisonOperator.NotContains,
|
||||
label: 'notContains',
|
||||
icon: 'not-contains',
|
||||
},
|
||||
{
|
||||
value: ComparisonOperator.StartWith,
|
||||
label: 'startWith',
|
||||
icon: 'list-start',
|
||||
},
|
||||
{ value: ComparisonOperator.EndWith, label: 'endWith', icon: 'list-end' },
|
||||
{
|
||||
value: ComparisonOperator.Empty,
|
||||
label: 'empty',
|
||||
icon: <Circle className="size-4" />,
|
||||
},
|
||||
{
|
||||
value: ComparisonOperator.NotEmpty,
|
||||
label: 'notEmpty',
|
||||
icon: <CircleSlash2 className="size-4" />,
|
||||
},
|
||||
];
|
||||
45
web/src/hooks/logic-hooks/use-build-operator-options.tsx
Normal file
45
web/src/hooks/logic-hooks/use-build-operator-options.tsx
Normal file
@ -0,0 +1,45 @@
|
||||
import { IconFont } from '@/components/icon-font';
|
||||
import { ComparisonOperator, SwitchOperatorOptions } from '@/constants/agent';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const LogicalOperatorIcon = function OperatorIcon({
|
||||
icon,
|
||||
value,
|
||||
}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) {
|
||||
if (typeof icon === 'string') {
|
||||
return (
|
||||
<IconFont
|
||||
name={icon}
|
||||
className={cn('size-4', {
|
||||
'rotate-180': value === '>',
|
||||
})}
|
||||
></IconFont>
|
||||
);
|
||||
}
|
||||
return icon;
|
||||
};
|
||||
|
||||
export function useBuildSwitchOperatorOptions(
|
||||
subset: ComparisonOperator[] = [],
|
||||
) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const switchOperatorOptions = useMemo(() => {
|
||||
return SwitchOperatorOptions.filter((x) =>
|
||||
subset.some((y) => y === x.value),
|
||||
).map((x) => ({
|
||||
value: x.value,
|
||||
icon: (
|
||||
<LogicalOperatorIcon
|
||||
icon={x.icon}
|
||||
value={x.value}
|
||||
></LogicalOperatorIcon>
|
||||
),
|
||||
label: t(`flow.switchOperatorOptions.${x.label}`),
|
||||
}));
|
||||
}, [t]);
|
||||
|
||||
return switchOperatorOptions;
|
||||
}
|
||||
@ -1980,6 +1980,10 @@ Important structured information may include: names, dates, locations, events, k
|
||||
deleteRole: 'Delete role',
|
||||
deleteRoleConfirmation:
|
||||
'Are you sure you want to delete this role? This action cannot be undone.',
|
||||
|
||||
alive: 'Alive',
|
||||
timeout: 'Timeout',
|
||||
fail: 'Fail',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import Spotlight from '@/components/spotlight';
|
||||
import { Card, CardContent } from '@/components/ui/card';
|
||||
|
||||
function AdminMonitoring() {
|
||||
return (
|
||||
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
|
||||
<Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
|
||||
<Spotlight />
|
||||
|
||||
<CardContent className="size-full p-0">
|
||||
<iframe />
|
||||
</CardContent>
|
||||
|
||||
@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||
|
||||
import Spotlight from '@/components/spotlight';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import {
|
||||
@ -29,7 +30,6 @@ import {
|
||||
import { LucideEdit3, LucideTrash2, LucideUserPlus } from 'lucide-react';
|
||||
|
||||
import {
|
||||
AdminService,
|
||||
assignRolePermissions,
|
||||
createRole,
|
||||
deleteRole,
|
||||
@ -149,7 +149,9 @@ function AdminRoles() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Card className="!shadow-none w-full h-full border border-border-button bg-transparent rounded-xl">
|
||||
<Card className="!shadow-none relative w-full h-full border border-border-button bg-transparent rounded-xl">
|
||||
<Spotlight />
|
||||
|
||||
<ScrollArea className="size-full">
|
||||
<CardHeader className="space-y-0 flex flex-row justify-between items-center">
|
||||
<CardTitle>{t('admin.roles')}</CardTitle>
|
||||
|
||||
@ -23,6 +23,7 @@ import { useQuery } from '@tanstack/react-query';
|
||||
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
import Spotlight from '@/components/spotlight';
|
||||
import { TableEmpty } from '@/components/table-skeleton';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Button } from '@/components/ui/button';
|
||||
@ -151,11 +152,11 @@ function AdminServiceStatus() {
|
||||
alive: 'bg-state-success-5 text-state-success',
|
||||
timeout: 'bg-state-error-5 text-state-error',
|
||||
fail: 'bg-gray-500/5 text-text-disable',
|
||||
}[cell.getValue<string>()],
|
||||
}[cell.getValue()],
|
||||
)}
|
||||
>
|
||||
<LucideDot className="size-[1em] stroke-[8] mr-1" />
|
||||
{cell.getValue()}
|
||||
{t(`admin.${cell.getValue()}`)}
|
||||
</Badge>
|
||||
),
|
||||
enableSorting: false,
|
||||
@ -215,7 +216,9 @@ function AdminServiceStatus() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl">
|
||||
<Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl">
|
||||
<Spotlight />
|
||||
|
||||
<ScrollArea className="size-full">
|
||||
<CardHeader className="space-y-0 flex flex-row justify-between items-center">
|
||||
<CardTitle>{t('admin.serviceStatus')}</CardTitle>
|
||||
@ -252,7 +255,7 @@ function AdminServiceStatus() {
|
||||
table.getColumn('service_type')!?.setFilterValue
|
||||
}
|
||||
>
|
||||
<Label className="space-x-2">
|
||||
<Label className="flex items-center space-x-2">
|
||||
<RadioGroupItem
|
||||
className="bg-bg-input border-border-button"
|
||||
value=""
|
||||
@ -261,7 +264,10 @@ function AdminServiceStatus() {
|
||||
</Label>
|
||||
|
||||
{SERVICE_TYPE_FILTER_OPTIONS.map(({ label, value }) => (
|
||||
<Label key={value} className="space-x-2">
|
||||
<Label
|
||||
key={value}
|
||||
className="flex items-center space-x-2"
|
||||
>
|
||||
<RadioGroupItem
|
||||
className="bg-bg-input border-border-button"
|
||||
value={value}
|
||||
|
||||
@ -16,6 +16,7 @@ import {
|
||||
import { cn } from '@/lib/utils';
|
||||
import { Routes } from '@/routes';
|
||||
|
||||
import Spotlight from '@/components/spotlight';
|
||||
import { Avatar } from '@/components/ui/avatar';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Button } from '@/components/ui/button';
|
||||
@ -322,7 +323,9 @@ function AdminUserDetail() {
|
||||
</Button>
|
||||
</nav>
|
||||
|
||||
<Card className="!shadow-none h-0 basis-0 grow flex flex-col bg-transparent border dark:border-border-button overflow-hidden">
|
||||
<Card className="!shadow-none relative h-0 basis-0 grow flex flex-col bg-transparent border dark:border-border-button overflow-hidden">
|
||||
<Spotlight />
|
||||
|
||||
<CardHeader className="pb-10 border-b dark:border-border-button space-y-8">
|
||||
<section className="flex items-center gap-4 text-base">
|
||||
<Avatar className="justify-center items-center bg-bg-group uppercase">
|
||||
|
||||
@ -25,6 +25,7 @@ import {
|
||||
import { cn } from '@/lib/utils';
|
||||
import { rsaPsw } from '@/utils';
|
||||
|
||||
import Spotlight from '@/components/spotlight';
|
||||
import { TableEmpty } from '@/components/table-skeleton';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Button } from '@/components/ui/button';
|
||||
@ -357,7 +358,9 @@ function AdminUserManagement() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
|
||||
<Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
|
||||
<Spotlight />
|
||||
|
||||
<ScrollArea className="size-full">
|
||||
<CardHeader className="space-y-0 flex flex-row justify-between items-center">
|
||||
<CardTitle>{t('admin.userManagement')}</CardTitle>
|
||||
@ -396,13 +399,16 @@ function AdminUserManagement() {
|
||||
table.getColumn('role')?.setFilterValue(value)
|
||||
}
|
||||
>
|
||||
<Label className="space-x-2">
|
||||
<Label className="flex items-center space-x-2">
|
||||
<RadioGroupItem value="" />
|
||||
<span>{t('admin.all')}</span>
|
||||
</Label>
|
||||
|
||||
{roleList?.map(({ id, role_name }) => (
|
||||
<Label key={id} className="space-x-2">
|
||||
<Label
|
||||
key={id}
|
||||
className="flex items-center space-x-2"
|
||||
>
|
||||
<RadioGroupItem
|
||||
className="bg-bg-input border-border-button"
|
||||
value={role_name}
|
||||
@ -429,7 +435,10 @@ function AdminUserManagement() {
|
||||
}
|
||||
>
|
||||
{STATUS_FILTER_OPTIONS.map(({ label, value }) => (
|
||||
<Label key={value} className="space-x-2">
|
||||
<Label
|
||||
key={value}
|
||||
className="flex items-center space-x-2"
|
||||
>
|
||||
<RadioGroupItem
|
||||
className="bg-bg-input border-border-button"
|
||||
value={value}
|
||||
|
||||
@ -22,6 +22,7 @@ import {
|
||||
LucideUserPen,
|
||||
} from 'lucide-react';
|
||||
|
||||
import Spotlight from '@/components/spotlight';
|
||||
import { TableEmpty } from '@/components/table-skeleton';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
@ -58,7 +59,6 @@ import {
|
||||
importWhitelistFromExcel,
|
||||
listWhitelist,
|
||||
updateWhitelistEntry,
|
||||
type AdminService,
|
||||
} from '@/services/admin-service';
|
||||
|
||||
import { EMPTY_DATA, createFuzzySearchFn, getSortIcon } from './utils';
|
||||
@ -68,8 +68,6 @@ import useImportExcelForm, {
|
||||
ImportExcelFormData,
|
||||
} from './forms/import-excel-form';
|
||||
|
||||
// #endregion
|
||||
|
||||
const columnHelper = createColumnHelper<AdminService.ListWhitelistItem>();
|
||||
const globalFilterFn = createFuzzySearchFn<AdminService.ListWhitelistItem>([
|
||||
'email',
|
||||
@ -233,7 +231,9 @@ function AdminWhitelist() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
|
||||
<Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
|
||||
<Spotlight />
|
||||
|
||||
<ScrollArea className="size-full">
|
||||
<CardHeader className="space-y-0 flex flex-row justify-between items-center">
|
||||
<CardTitle>{t('admin.whitelistManagement')}</CardTitle>
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import { Card, CardContent } from '@/components/ui/card';
|
||||
import { SwitchOperatorOptions } from '@/constants/agent';
|
||||
import { LogicalOperatorIcon } from '@/hooks/logic-hooks/use-build-operator-options';
|
||||
import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow';
|
||||
import { NodeProps, Position } from '@xyflow/react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { SwitchOperatorOptions } from '../../constant';
|
||||
import { LogicalOperatorIcon } from '../../form/switch-form';
|
||||
import { useGetVariableLabelByValue } from '../../hooks/use-get-begin-query';
|
||||
import { CommonHandle, LeftEndHandle } from './handle';
|
||||
import { RightHandleStyle } from './handle-icon';
|
||||
|
||||
@ -6,8 +6,10 @@ import {
|
||||
AgentGlobals,
|
||||
AgentGlobalsSysQueryWithBrace,
|
||||
CodeTemplateStrMap,
|
||||
ComparisonOperator,
|
||||
Operator,
|
||||
ProgrammingLanguage,
|
||||
SwitchOperatorOptions,
|
||||
initialLlmBaseValues,
|
||||
} from '@/constants/agent';
|
||||
export { Operator } from '@/constants/agent';
|
||||
@ -35,8 +37,6 @@ export enum PromptRole {
|
||||
}
|
||||
|
||||
import {
|
||||
Circle,
|
||||
CircleSlash2,
|
||||
CloudUpload,
|
||||
ListOrdered,
|
||||
OptionIcon,
|
||||
@ -166,27 +166,12 @@ export const componentMenuList = [
|
||||
},
|
||||
];
|
||||
|
||||
export const SwitchOperatorOptions = [
|
||||
{ value: '=', label: 'equal', icon: 'equal' },
|
||||
{ value: '≠', label: 'notEqual', icon: 'not-equals' },
|
||||
{ value: '>', label: 'gt', icon: 'Less' },
|
||||
{ value: '≥', label: 'ge', icon: 'Greater-or-equal' },
|
||||
{ value: '<', label: 'lt', icon: 'Less' },
|
||||
{ value: '≤', label: 'le', icon: 'less-or-equal' },
|
||||
{ value: 'contains', label: 'contains', icon: 'Contains' },
|
||||
{ value: 'not contains', label: 'notContains', icon: 'not-contains' },
|
||||
{ value: 'start with', label: 'startWith', icon: 'list-start' },
|
||||
{ value: 'end with', label: 'endWith', icon: 'list-end' },
|
||||
{
|
||||
value: 'empty',
|
||||
label: 'empty',
|
||||
icon: <Circle className="size-4" />,
|
||||
},
|
||||
{
|
||||
value: 'not empty',
|
||||
label: 'notEmpty',
|
||||
icon: <CircleSlash2 className="size-4" />,
|
||||
},
|
||||
export const DataOperationsOperatorOptions = [
|
||||
ComparisonOperator.Equal,
|
||||
ComparisonOperator.NotEqual,
|
||||
ComparisonOperator.Contains,
|
||||
ComparisonOperator.StartWith,
|
||||
ComparisonOperator.EndWith,
|
||||
];
|
||||
|
||||
export const SwitchElseTo = 'end_cpn_ids';
|
||||
@ -716,16 +701,17 @@ export const initialPlaceholderValues = {
|
||||
};
|
||||
|
||||
export enum Operations {
|
||||
SelectKeys = 'select keys',
|
||||
LiteralEval = 'literal eval',
|
||||
SelectKeys = 'select_keys',
|
||||
LiteralEval = 'literal_eval',
|
||||
Combine = 'combine',
|
||||
FilterValues = 'filter values',
|
||||
AppendOrUpdate = 'append or update',
|
||||
RemoveKeys = 'remove keys',
|
||||
RenameKeys = 'rename keys',
|
||||
FilterValues = 'filter_values',
|
||||
AppendOrUpdate = 'append_or_update',
|
||||
RemoveKeys = 'remove_keys',
|
||||
RenameKeys = 'rename_keys',
|
||||
}
|
||||
|
||||
export const initialDataOperationsValues = {
|
||||
query: [],
|
||||
operations: Operations.SelectKeys,
|
||||
outputs: {
|
||||
result: {
|
||||
|
||||
25
web/src/pages/agent/form/components/dynamic-fom-header.tsx
Normal file
25
web/src/pages/agent/form/components/dynamic-fom-header.tsx
Normal file
@ -0,0 +1,25 @@
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { FormLabel } from '@/components/ui/form';
|
||||
import { Plus } from 'lucide-react';
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
export type FormListHeaderProps = {
|
||||
label: ReactNode;
|
||||
tooltip?: string;
|
||||
onClick?: () => void;
|
||||
};
|
||||
|
||||
export function DynamicFormHeader({
|
||||
label,
|
||||
tooltip,
|
||||
onClick,
|
||||
}: FormListHeaderProps) {
|
||||
return (
|
||||
<div className="flex items-center justify-between">
|
||||
<FormLabel tooltip={tooltip}>{label}</FormLabel>
|
||||
<Button variant={'ghost'} type="button" onClick={onClick}>
|
||||
<Plus />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -1,4 +1,7 @@
|
||||
import { RAGFlowFormItem } from '@/components/ragflow-form';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { t } from 'i18next';
|
||||
import { z } from 'zod';
|
||||
|
||||
export type OutputType = {
|
||||
title: string;
|
||||
@ -7,6 +10,7 @@ export type OutputType = {
|
||||
|
||||
type OutputProps = {
|
||||
list: Array<OutputType>;
|
||||
isFormRequired?: boolean;
|
||||
};
|
||||
|
||||
export function transferOutputs(outputs: Record<string, any>) {
|
||||
@ -16,7 +20,11 @@ export function transferOutputs(outputs: Record<string, any>) {
|
||||
}));
|
||||
}
|
||||
|
||||
export function Output({ list }: OutputProps) {
|
||||
export const OutputSchema = {
|
||||
outputs: z.record(z.any()),
|
||||
};
|
||||
|
||||
export function Output({ list, isFormRequired = false }: OutputProps) {
|
||||
return (
|
||||
<section className="space-y-2">
|
||||
<div className="text-sm">{t('flow.output')}</div>
|
||||
@ -30,6 +38,11 @@ export function Output({ list }: OutputProps) {
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
{isFormRequired && (
|
||||
<RAGFlowFormItem name="outputs" className="hidden">
|
||||
<Input></Input>
|
||||
</RAGFlowFormItem>
|
||||
)}
|
||||
</section>
|
||||
);
|
||||
}
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
import { BlockButton, Button } from '@/components/ui/button';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { X } from 'lucide-react';
|
||||
import { useFieldArray, useFormContext } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { JsonSchemaDataType } from '../../constant';
|
||||
import { DynamicFormHeader, FormListHeaderProps } from './dynamic-fom-header';
|
||||
import { QueryVariable } from './query-variable';
|
||||
|
||||
type QueryVariableListProps = {
|
||||
types?: JsonSchemaDataType[];
|
||||
};
|
||||
export function QueryVariableList({ types }: QueryVariableListProps) {
|
||||
const { t } = useTranslation();
|
||||
} & FormListHeaderProps;
|
||||
export function QueryVariableList({
|
||||
types,
|
||||
label,
|
||||
tooltip,
|
||||
}: QueryVariableListProps) {
|
||||
const form = useFormContext();
|
||||
const name = 'inputs';
|
||||
const name = 'query';
|
||||
|
||||
const { fields, remove, append } = useFieldArray({
|
||||
name: name,
|
||||
@ -19,28 +22,31 @@ export function QueryVariableList({ types }: QueryVariableListProps) {
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="space-y-5">
|
||||
{fields.map((field, index) => {
|
||||
const nameField = `${name}.${index}.input`;
|
||||
<section className="space-y-2">
|
||||
<DynamicFormHeader
|
||||
label={label}
|
||||
tooltip={tooltip}
|
||||
onClick={() => append({ input: '' })}
|
||||
></DynamicFormHeader>
|
||||
<div className="space-y-5">
|
||||
{fields.map((field, index) => {
|
||||
const nameField = `${name}.${index}.input`;
|
||||
|
||||
return (
|
||||
<div key={field.id} className="flex items-center gap-2">
|
||||
<QueryVariable
|
||||
name={nameField}
|
||||
hideLabel
|
||||
className="flex-1"
|
||||
types={types}
|
||||
></QueryVariable>
|
||||
<Button variant={'ghost'} onClick={() => remove(index)}>
|
||||
<X className="text-text-sub-title-invert " />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
<BlockButton onClick={() => append({ input: '' })}>
|
||||
{t('common.add')}
|
||||
</BlockButton>
|
||||
</div>
|
||||
return (
|
||||
<div key={field.id} className="flex items-center gap-2">
|
||||
<QueryVariable
|
||||
name={nameField}
|
||||
hideLabel
|
||||
className="flex-1"
|
||||
types={types}
|
||||
></QueryVariable>
|
||||
<Button variant={'ghost'} onClick={() => remove(index)}>
|
||||
<X className="text-text-sub-title-invert " />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
import { SelectWithSearch } from '@/components/originui/select-with-search';
|
||||
import { RAGFlowFormItem } from '@/components/ragflow-form';
|
||||
import { BlockButton, Button } from '@/components/ui/button';
|
||||
import { FormLabel } from '@/components/ui/form';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
|
||||
import { X } from 'lucide-react';
|
||||
import { ReactNode } from 'react';
|
||||
import { useFieldArray, useFormContext } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { DataOperationsOperatorOptions } from '../../constant';
|
||||
import { DynamicFormHeader } from '../components/dynamic-fom-header';
|
||||
|
||||
type SelectKeysProps = {
|
||||
name: string;
|
||||
@ -25,7 +26,6 @@ export function FilterValues({
|
||||
valueField = 'value',
|
||||
operatorField = 'operator',
|
||||
}: SelectKeysProps) {
|
||||
const { t } = useTranslation();
|
||||
const form = useFormContext();
|
||||
|
||||
const { fields, remove, append } = useFieldArray({
|
||||
@ -33,9 +33,18 @@ export function FilterValues({
|
||||
control: form.control,
|
||||
});
|
||||
|
||||
const operatorOptions = useBuildSwitchOperatorOptions(
|
||||
DataOperationsOperatorOptions,
|
||||
);
|
||||
|
||||
return (
|
||||
<section className="space-y-2">
|
||||
<FormLabel tooltip={tooltip}>{label}</FormLabel>
|
||||
<DynamicFormHeader
|
||||
label={label}
|
||||
tooltip={tooltip}
|
||||
onClick={() => append({ [keyField]: '', [valueField]: '' })}
|
||||
></DynamicFormHeader>
|
||||
|
||||
<div className="space-y-5">
|
||||
{fields.map((field, index) => {
|
||||
const keyFieldAlias = `${name}.${index}.${keyField}`;
|
||||
@ -47,12 +56,15 @@ export function FilterValues({
|
||||
<RAGFlowFormItem name={keyFieldAlias} className="flex-1">
|
||||
<Input></Input>
|
||||
</RAGFlowFormItem>
|
||||
<Separator orientation="vertical" className="h-2.5" />
|
||||
<Separator className="w-2" />
|
||||
|
||||
<RAGFlowFormItem name={operatorFieldAlias} className="flex-1">
|
||||
<SelectWithSearch {...field} options={[]}></SelectWithSearch>
|
||||
<SelectWithSearch
|
||||
{...field}
|
||||
options={operatorOptions}
|
||||
></SelectWithSearch>
|
||||
</RAGFlowFormItem>
|
||||
<Separator orientation="vertical" className="h-2.5" />
|
||||
<Separator className="w-2" />
|
||||
|
||||
<RAGFlowFormItem name={valueFieldAlias} className="flex-1">
|
||||
<Input></Input>
|
||||
@ -64,10 +76,6 @@ export function FilterValues({
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
<BlockButton onClick={() => append({ [keyField]: '', [valueField]: '' })}>
|
||||
{t('common.add')}
|
||||
</BlockButton>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { SelectWithSearch } from '@/components/originui/select-with-search';
|
||||
import { RAGFlowFormItem } from '@/components/ragflow-form';
|
||||
import { Form, FormLabel } from '@/components/ui/form';
|
||||
import { Form } from '@/components/ui/form';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { buildOptions } from '@/utils/form';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { memo } from 'react';
|
||||
@ -17,14 +18,14 @@ import { useWatchFormChange } from '../../hooks/use-watch-form-change';
|
||||
import { INextOperatorForm } from '../../interface';
|
||||
import { buildOutputList } from '../../utils/build-output-list';
|
||||
import { FormWrapper } from '../components/form-wrapper';
|
||||
import { Output } from '../components/output';
|
||||
import { Output, OutputSchema } from '../components/output';
|
||||
import { QueryVariableList } from '../components/query-variable-list';
|
||||
import { FilterValues } from './filter-values';
|
||||
import { SelectKeys } from './select-keys';
|
||||
import { Updates } from './updates';
|
||||
|
||||
export const RetrievalPartialSchema = {
|
||||
inputs: z.array(z.object({ input: z.string().optional() })),
|
||||
query: z.array(z.object({ input: z.string().optional() })),
|
||||
operations: z.string(),
|
||||
select_keys: z.array(z.object({ name: z.string().optional() })).optional(),
|
||||
remove_keys: z.array(z.object({ name: z.string().optional() })).optional(),
|
||||
@ -50,6 +51,7 @@ export const RetrievalPartialSchema = {
|
||||
}),
|
||||
)
|
||||
.optional(),
|
||||
...OutputSchema,
|
||||
};
|
||||
|
||||
export const FormSchema = z.object(RetrievalPartialSchema);
|
||||
@ -65,6 +67,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
|
||||
|
||||
const form = useForm<DataOperationsFormSchemaType>({
|
||||
defaultValues: defaultValues,
|
||||
mode: 'onChange',
|
||||
resolver: zodResolver(FormSchema),
|
||||
shouldUnregister: true,
|
||||
});
|
||||
@ -78,17 +81,17 @@ function DataOperationsForm({ node }: INextOperatorForm) {
|
||||
true,
|
||||
);
|
||||
|
||||
useWatchFormChange(node?.id, form);
|
||||
useWatchFormChange(node?.id, form, true);
|
||||
|
||||
return (
|
||||
<Form {...form}>
|
||||
<FormWrapper>
|
||||
<div className="space-y-2">
|
||||
<FormLabel tooltip={t('flow.queryTip')}>{t('flow.query')}</FormLabel>
|
||||
<QueryVariableList
|
||||
types={[JsonSchemaDataType.Array, JsonSchemaDataType.Object]}
|
||||
></QueryVariableList>
|
||||
</div>
|
||||
<QueryVariableList
|
||||
tooltip={t('flow.queryTip')}
|
||||
label={t('flow.query')}
|
||||
types={[JsonSchemaDataType.Array, JsonSchemaDataType.Object]}
|
||||
></QueryVariableList>
|
||||
<Separator />
|
||||
<RAGFlowFormItem name="operations" label={t('flow.operations')}>
|
||||
<SelectWithSearch options={OperationsOptions} allowClear />
|
||||
</RAGFlowFormItem>
|
||||
@ -107,7 +110,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
|
||||
{operations === Operations.AppendOrUpdate && (
|
||||
<Updates
|
||||
name="updates"
|
||||
label={t('flow.operationsOptions.updates')}
|
||||
label={t('flow.operationsOptions.appendOrUpdate')}
|
||||
keyField="key"
|
||||
valueField="value"
|
||||
></Updates>
|
||||
@ -126,7 +129,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
|
||||
label={t('flow.operationsOptions.filterValues')}
|
||||
></FilterValues>
|
||||
)}
|
||||
<Output list={outputList}></Output>
|
||||
<Output list={outputList} isFormRequired></Output>
|
||||
</FormWrapper>
|
||||
</Form>
|
||||
);
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import { RAGFlowFormItem } from '@/components/ragflow-form';
|
||||
import { BlockButton, Button } from '@/components/ui/button';
|
||||
import { FormLabel } from '@/components/ui/form';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { X } from 'lucide-react';
|
||||
import { ReactNode } from 'react';
|
||||
import { useFieldArray, useFormContext } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { DynamicFormHeader } from '../components/dynamic-fom-header';
|
||||
|
||||
type SelectKeysProps = {
|
||||
name: string;
|
||||
@ -13,7 +12,6 @@ type SelectKeysProps = {
|
||||
tooltip?: string;
|
||||
};
|
||||
export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
|
||||
const { t } = useTranslation();
|
||||
const form = useFormContext();
|
||||
|
||||
const { fields, remove, append } = useFieldArray({
|
||||
@ -23,7 +21,11 @@ export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
|
||||
|
||||
return (
|
||||
<section className="space-y-2">
|
||||
<FormLabel tooltip={tooltip}>{label}</FormLabel>
|
||||
<DynamicFormHeader
|
||||
label={label}
|
||||
tooltip={tooltip}
|
||||
onClick={() => append({ name: '' })}
|
||||
></DynamicFormHeader>
|
||||
<div className="space-y-5">
|
||||
{fields.map((field, index) => {
|
||||
const nameField = `${name}.${index}.name`;
|
||||
@ -40,10 +42,6 @@ export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
<BlockButton onClick={() => append({ name: '' })}>
|
||||
{t('common.add')}
|
||||
</BlockButton>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import { RAGFlowFormItem } from '@/components/ragflow-form';
|
||||
import { BlockButton, Button } from '@/components/ui/button';
|
||||
import { FormLabel } from '@/components/ui/form';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { X } from 'lucide-react';
|
||||
import { ReactNode } from 'react';
|
||||
import { useFieldArray, useFormContext } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { DynamicFormHeader } from '../components/dynamic-fom-header';
|
||||
|
||||
type SelectKeysProps = {
|
||||
name: string;
|
||||
@ -21,7 +21,6 @@ export function Updates({
|
||||
keyField,
|
||||
valueField,
|
||||
}: SelectKeysProps) {
|
||||
const { t } = useTranslation();
|
||||
const form = useFormContext();
|
||||
|
||||
const { fields, remove, append } = useFieldArray({
|
||||
@ -31,7 +30,11 @@ export function Updates({
|
||||
|
||||
return (
|
||||
<section className="space-y-2">
|
||||
<FormLabel tooltip={tooltip}>{label}</FormLabel>
|
||||
<DynamicFormHeader
|
||||
label={label}
|
||||
tooltip={tooltip}
|
||||
onClick={() => append({ [keyField]: '', [valueField]: '' })}
|
||||
></DynamicFormHeader>
|
||||
<div className="space-y-5">
|
||||
{fields.map((field, index) => {
|
||||
const keyFieldAlias = `${name}.${index}.${keyField}`;
|
||||
@ -42,6 +45,7 @@ export function Updates({
|
||||
<RAGFlowFormItem name={keyFieldAlias} className="flex-1">
|
||||
<Input></Input>
|
||||
</RAGFlowFormItem>
|
||||
<Separator className="w-2" />
|
||||
<RAGFlowFormItem name={valueFieldAlias} className="flex-1">
|
||||
<Input></Input>
|
||||
</RAGFlowFormItem>
|
||||
@ -52,10 +56,6 @@ export function Updates({
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
<BlockButton onClick={() => append({ [keyField]: '', [valueField]: '' })}>
|
||||
{t('common.add')}
|
||||
</BlockButton>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import { FormContainer } from '@/components/form-container';
|
||||
import { IconFont } from '@/components/icon-font';
|
||||
import { SelectWithSearch } from '@/components/originui/select-with-search';
|
||||
import { BlockButton, Button } from '@/components/ui/button';
|
||||
import { Card, CardContent } from '@/components/ui/card';
|
||||
@ -13,6 +12,7 @@ import {
|
||||
import { RAGFlowSelect } from '@/components/ui/select';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { Textarea } from '@/components/ui/textarea';
|
||||
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { t } from 'i18next';
|
||||
@ -22,11 +22,7 @@ import { memo, useCallback, useMemo } from 'react';
|
||||
import { useFieldArray, useForm, useFormContext } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
SwitchLogicOperatorOptions,
|
||||
SwitchOperatorOptions,
|
||||
VariableType,
|
||||
} from '../../constant';
|
||||
import { SwitchLogicOperatorOptions, VariableType } from '../../constant';
|
||||
import { useBuildQueryVariableOptions } from '../../hooks/use-get-begin-query';
|
||||
import { IOperatorForm } from '../../interface';
|
||||
import { FormWrapper } from '../components/form-wrapper';
|
||||
@ -43,42 +39,6 @@ type ConditionCardsProps = {
|
||||
parentLength: number;
|
||||
} & IOperatorForm;
|
||||
|
||||
export const LogicalOperatorIcon = function OperatorIcon({
|
||||
icon,
|
||||
value,
|
||||
}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) {
|
||||
if (typeof icon === 'string') {
|
||||
return (
|
||||
<IconFont
|
||||
name={icon}
|
||||
className={cn('size-4', {
|
||||
'rotate-180': value === '>',
|
||||
})}
|
||||
></IconFont>
|
||||
);
|
||||
}
|
||||
return icon;
|
||||
};
|
||||
|
||||
export function useBuildSwitchOperatorOptions() {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const switchOperatorOptions = useMemo(() => {
|
||||
return SwitchOperatorOptions.map((x) => ({
|
||||
value: x.value,
|
||||
icon: (
|
||||
<LogicalOperatorIcon
|
||||
icon={x.icon}
|
||||
value={x.value}
|
||||
></LogicalOperatorIcon>
|
||||
),
|
||||
label: t(`flow.switchOperatorOptions.${x.label}`),
|
||||
}));
|
||||
}, [t]);
|
||||
|
||||
return switchOperatorOptions;
|
||||
}
|
||||
|
||||
function ConditionCards({
|
||||
name: parentName,
|
||||
parentIndex,
|
||||
|
||||
@ -2,9 +2,13 @@ import { useEffect } from 'react';
|
||||
import { UseFormReturn, useWatch } from 'react-hook-form';
|
||||
import useGraphStore from '../store';
|
||||
|
||||
export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) {
|
||||
export function useWatchFormChange(
|
||||
id?: string,
|
||||
form?: UseFormReturn<any>,
|
||||
enableReplacement = false,
|
||||
) {
|
||||
let values = useWatch({ control: form?.control });
|
||||
const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
|
||||
const { updateNodeForm, replaceNodeForm } = useGraphStore((state) => state);
|
||||
|
||||
useEffect(() => {
|
||||
// Manually triggered form updates are synchronized to the canvas
|
||||
@ -12,7 +16,7 @@ export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) {
|
||||
values = form?.getValues() || {};
|
||||
let nextValues: any = values;
|
||||
|
||||
updateNodeForm(id, nextValues);
|
||||
(enableReplacement ? replaceNodeForm : updateNodeForm)(id, nextValues);
|
||||
}
|
||||
}, [form?.formState.isDirty, id, updateNodeForm, values]);
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@ import {
|
||||
applyEdgeChanges,
|
||||
applyNodeChanges,
|
||||
} from '@xyflow/react';
|
||||
import { omit } from 'lodash';
|
||||
import { cloneDeep, omit } from 'lodash';
|
||||
import differenceWith from 'lodash/differenceWith';
|
||||
import intersectionWith from 'lodash/intersectionWith';
|
||||
import lodashSet from 'lodash/set';
|
||||
@ -53,6 +53,7 @@ export type RFState = {
|
||||
values: any,
|
||||
path?: (string | number)[],
|
||||
) => RAGFlowNodeType[];
|
||||
replaceNodeForm: (nodeId: string, values: any) => void;
|
||||
onSelectionChange: OnSelectionChangeFunc;
|
||||
addNode: (nodes: RAGFlowNodeType) => void;
|
||||
getNode: (id?: string | null) => RAGFlowNodeType | undefined;
|
||||
@ -433,6 +434,19 @@ const useGraphStore = create<RFState>()(
|
||||
|
||||
return nextNodes;
|
||||
},
|
||||
replaceNodeForm(nodeId, values) {
|
||||
if (nodeId) {
|
||||
set((state) => {
|
||||
for (const node of state.nodes) {
|
||||
if (node.id === nodeId) {
|
||||
//cloneDeep Solving the issue of react-hook-form errors
|
||||
node.data.form = cloneDeep(values); // TypeError: Cannot assign to read only property '0' of object '[object Array]'
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
updateSwitchFormData: (source, sourceHandle, target, isConnecting) => {
|
||||
const { updateNodeForm, edges } = get();
|
||||
if (sourceHandle) {
|
||||
|
||||
@ -273,7 +273,7 @@ function transformDataOperationsParams(params: DataOperationsFormSchemaType) {
|
||||
...params,
|
||||
select_keys: params?.select_keys?.map((x) => x.name),
|
||||
remove_keys: params?.remove_keys?.map((x) => x.name),
|
||||
inputs: params.inputs.map((x) => x.input),
|
||||
query: params.query.map((x) => x.input),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { SwitchOperatorOptions } from '@/constants/agent';
|
||||
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
|
||||
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
|
||||
import { SwitchOperatorOptions } from '@/pages/agent/constant';
|
||||
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
|
||||
import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons';
|
||||
import {
|
||||
Button,
|
||||
|
||||
@ -15,9 +15,9 @@ import {
|
||||
} from '@/components/ui/form';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { SwitchOperatorOptions } from '@/constants/agent';
|
||||
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
|
||||
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
|
||||
import { SwitchOperatorOptions } from '@/pages/agent/constant';
|
||||
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
|
||||
import { Plus, X } from 'lucide-react';
|
||||
import { useCallback } from 'react';
|
||||
import { useFieldArray, useFormContext } from 'react-hook-form';
|
||||
|
||||
Reference in New Issue
Block a user