Compare commits

...

12 Commits

Author SHA1 Message Date
4e76220e25 Feat: Submit clean data operations form data to the backend. #10427 (#11030)
### What problem does this PR solve?

Feat: Submit clean data operations form data to the backend. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-05 17:32:35 +08:00
24335485bf Fix: get_allowed_llm_factories() return type (#11031)
### What problem does this PR solve?

Fix: get_allowed_llm_factories() return type #11003

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

<img width="2880" height="215" alt="截图 2025-11-05 17-02-01"
src="https://github.com/user-attachments/assets/ee892077-21f9-4b1e-a1d2-b921fa7f6121"
/>
2025-11-05 17:32:12 +08:00
121c51661d Fix: Markdown table extractor (#11018)
### What problem does this PR solve?

Now markdown table extractor supports <table ...>. #10966 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 16:10:21 +08:00
02d10f8eda Move var from rag.settings to common.globals (#11022)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-05 15:48:50 +08:00
dddf766470 Feat: start data sync service. (#11026)
### What problem does this PR solve?

#10953 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-05 15:43:15 +08:00
8584d4b642 Fix: numeric string miss transformation. (#11025)
### What problem does this PR solve?

#11024

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 15:14:30 +08:00
b86e07088b Fix: escape multi-steps issues. (#11016)
### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 14:51:00 +08:00
1a9215bc6f Move some vars to globals (#11017)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-05 14:14:38 +08:00
cf9611c96f Feat: Support more chunking methods (#11000)
### What problem does this PR solve?

Feat: Support more chunking methods #10772 

This PR enables multiple chunking methods — including books, laws,
naive, one, and presentation — to be used with all existing PDF parsers
(DeepDOC, MinerU, Docling, TCADP, Plain Text, and Vision modes).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-05 13:00:42 +08:00
f126875ec6 Apply some tweaks on Admin UI (#11011)
### What problem does this PR solve?

- Fix selected radio button text misaligned with radio button dot
- Fix `<ScrollArea>` scrollbar z-index issue
- Add backdrop blur effect on scrollbar thumbs
- Adjust some styles to match the design 


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 12:58:43 +08:00
89410d2381 fix:api /factories wrong return (#11015)
### What problem does this PR solve?

change:
api /factories wrong return

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 12:50:11 +08:00
96c015fb85 Fix and refactor imports (#11010)
### What problem does this PR solve?

1. Move EMBEDDING_CFG to common.globals
2. Fix error imports
3. Move signal handles to common/signal_utils.py

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-05 11:07:54 +08:00
90 changed files with 1041 additions and 716 deletions

View File

@ -281,12 +281,21 @@ class Canvas(Graph):
def _run_batch(f, t): def _run_batch(f, t):
with ThreadPoolExecutor(max_workers=5) as executor: with ThreadPoolExecutor(max_workers=5) as executor:
thr = [] thr = []
for i in range(f, t): i = f
while i < t:
cpn = self.get_component_obj(self.path[i]) cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]: if cpn.component_name.lower() in ["begin", "userfillup"]:
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
i += 1
else: else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input())) for _, ele in cpn.get_input_elements().items():
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]:
self.path.pop(i)
t -= 1
break
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
i += 1
for t in thr: for t in thr:
t.result() t.result()
@ -316,6 +325,7 @@ class Canvas(Graph):
"thoughts": self.get_component_thoughts(self.path[i]) "thoughts": self.get_component_thoughts(self.path[i])
}) })
_run_batch(idx, to) _run_batch(idx, to)
to = len(self.path)
# post processing of components invocation # post processing of components invocation
for i in range(idx, to): for i in range(idx, to):
cpn = self.get_component(self.path[i]) cpn = self.get_component(self.path[i])

View File

@ -16,6 +16,13 @@
from abc import ABC from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
"""
class VariableModel(BaseModel):
data_type: Annotated[Literal["string", "number", "Object", "Boolean", "Array<string>", "Array<number>", "Array<object>", "Array<boolean>"], Field(default="Array<string>")]
input_mode: Annotated[Literal["constant", "variable"], Field(default="constant")]
value: Annotated[Any, Field(default=None)]
model_config = ConfigDict(extra="forbid")
"""
class IterationParam(ComponentParamBase): class IterationParam(ComponentParamBase):
""" """

View File

@ -216,7 +216,7 @@ class LLM(ComponentBase):
error: str = "" error: str = ""
output_structure=None output_structure=None
try: try:
output_structure=self._param.outputs['structured'] output_structure = None#self._param.outputs['structured']
except Exception: except Exception:
pass pass
if output_structure: if output_structure:

View File

@ -49,6 +49,9 @@ class MessageParam(ComponentParamBase):
class Message(ComponentBase): class Message(ComponentBase):
component_name = "Message" component_name = "Message"
def get_input_elements(self) -> dict[str, Any]:
return self.get_input_elements_from_text("".join(self._param.content))
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]: def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
for k,v in self.get_input_elements_from_text(script).items(): for k,v in self.get_input_elements_from_text(script).items():
if k in kwargs: if k in kwargs:

View File

@ -16,6 +16,8 @@
import os import os
import re import re
from abc import ABC from abc import ABC
from typing import Any
from jinja2 import Template as Jinja2Template from jinja2 import Template as Jinja2Template
from agent.component.base import ComponentParamBase from agent.component.base import ComponentParamBase
from common.connection_utils import timeout from common.connection_utils import timeout
@ -43,6 +45,9 @@ class StringTransformParam(ComponentParamBase):
class StringTransform(Message, ABC): class StringTransform(Message, ABC):
component_name = "StringTransform" component_name = "StringTransform"
def get_input_elements(self) -> dict[str, Any]:
return self.get_input_elements_from_text(self._param.script)
def get_input_form(self) -> dict[str, dict]: def get_input_form(self) -> dict[str, dict]:
if self._param.method == "split": if self._param.method == "split":
return { return {

View File

@ -25,6 +25,7 @@ from api.db.services.dialog_service import meta_filter
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api import settings from api import settings
from common import globals
from common.connection_utils import timeout from common.connection_utils import timeout
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
@ -170,7 +171,7 @@ class Retrieval(ToolBase, ABC):
if kbs: if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE) query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = settings.retriever.retrieval( kbinfos = globals.retriever.retrieval(
query, query,
embd_mdl, embd_mdl,
[kb.tenant_id for kb in kbs], [kb.tenant_id for kb in kbs],
@ -186,7 +187,7 @@ class Retrieval(ToolBase, ABC):
) )
if self._param.toc_enhance: if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
if cks: if cks:
kbinfos["chunks"] = cks kbinfos["chunks"] = cks
if self._param.use_kg: if self._param.use_kg:

View File

@ -32,7 +32,6 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api import settings
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
@ -48,6 +47,7 @@ from api.db.services.canvas_service import UserCanvasService
from agent.canvas import Canvas from agent.canvas import Canvas
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from common import globals
@manager.route('/new_token', methods=['POST']) # noqa: F821 @manager.route('/new_token', methods=['POST']) # noqa: F821
@ -538,7 +538,7 @@ def list_chunks():
) )
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids) res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = [ res = [
{ {
"content": res_item["content_with_weight"], "content": res_item["content_with_weight"],
@ -564,7 +564,7 @@ def get_chunk(chunk_id):
try: try:
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None: if chunk is None:
return server_error_response(Exception("Chunk not found")) return server_error_response(Exception("Chunk not found"))
k = [] k = []
@ -886,7 +886,7 @@ def retrieval():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight, doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
rank_feature=label_question(question, kbs)) rank_feature=label_question(question, kbs))

View File

@ -25,7 +25,6 @@ from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from agent.component import LLM from agent.component import LLM
from api import settings
from api.db import CanvasCategory, FileType from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
@ -46,6 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
from rag.nlp import search from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from common import globals
@manager.route('/templates', methods=['GET']) # noqa: F821 @manager.route('/templates', methods=['GET']) # noqa: F821
@ -192,8 +192,8 @@ def rerun():
if 0 < doc["progress"] < 1: if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...") return get_data_error_result(message=f"`{doc['name']}` is processing...")
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]): if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = "" doc["progress_msg"] = ""
doc["chunk_num"] = 0 doc["chunk_num"] = 0
doc["token_num"] = 0 doc["token_num"] = 0

View File

@ -36,6 +36,7 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from common.string_utils import remove_redundant_spaces from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType from common.constants import RetCode, LLMType, ParserType
from common import globals
@manager.route('/list', methods=['POST']) # noqa: F821 @manager.route('/list', methods=['POST']) # noqa: F821
@ -60,7 +61,7 @@ def list_chunk():
} }
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"]) sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -98,7 +99,7 @@ def get():
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
for tenant in tenants: for tenant in tenants:
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
if chunk: if chunk:
break break
if chunk is None: if chunk is None:
@ -170,7 +171,7 @@ def set():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -186,7 +187,7 @@ def switch():
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]: for cid in req["chunk_ids"]:
if not settings.docStoreConn.update({"id": cid}, if not globals.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])}, {"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id): doc.kb_id):
@ -206,7 +207,7 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, if not globals.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])), search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id): doc.kb_id):
return get_data_error_result(message="Chunk deleting failure") return get_data_error_result(message="Chunk deleting failure")
@ -270,7 +271,7 @@ def create():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
@ -346,7 +347,7 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb]) labels = label_question(question, [kb])
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)), float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)), float(req.get("vector_similarity_weight", 0.3)),
top, top,
@ -385,7 +386,7 @@ def knowledge_graph():
"doc_ids": [doc_id], "doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"] "knowledge_graph_kwd": ["graph", "mind_map"]
} }
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids) sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]: for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"] ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -23,7 +23,6 @@ import flask
from flask import request from flask import request
from flask_login import current_user, login_required from flask_login import current_user, login_required
from api import settings
from api.common.check_team_permission import check_kb_team_permission from api.common.check_team_permission import check_kb_team_permission
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
from api.db import VALID_FILE_TYPES, FileType from api.db import VALID_FILE_TYPES, FileType
@ -49,6 +48,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
from deepdoc.parser.html_parser import RAGFlowHtmlParser from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from common import globals
@manager.route("/upload", methods=["POST"]) # noqa: F821 @manager.route("/upload", methods=["POST"]) # noqa: F821
@ -367,7 +367,7 @@ def change_status():
continue continue
status_int = int(status) status_int = int(status)
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id): if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
result[doc_id] = {"error": "Database error (docStore update)!"} result[doc_id] = {"error": "Database error (docStore update)!"}
result[doc_id] = {"status": status} result[doc_id] = {"status": status}
except Exception as e: except Exception as e:
@ -432,8 +432,8 @@ def run():
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
if req.get("delete", False): if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value: if str(req["run"]) == TaskStatus.RUNNING.value:
doc = doc.to_dict() doc = doc.to_dict()
@ -479,8 +479,8 @@ def rename():
"title_tks": title_tks, "title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
} }
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update( globals.docStoreConn.update(
{"doc_id": req["doc_id"]}, {"doc_id": req["doc_id"]},
es_body, es_body,
search.index_name(tenant_id), search.index_name(tenant_id),
@ -541,8 +541,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
try: try:
if "pipeline_id" in req and req["pipeline_id"] != "": if "pipeline_id" in req and req["pipeline_id"] != "":

View File

@ -35,7 +35,6 @@ from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api import settings
from rag.nlp import search from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
@ -43,7 +42,7 @@ from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType
from common import globals
@manager.route('/create', methods=['post']) # noqa: F821 @manager.route('/create', methods=['post']) # noqa: F821
@login_required @login_required
@ -110,11 +109,11 @@ def update():
if kb.pagerank != req.get("pagerank", 0): if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0: if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
else: else:
# Elasticsearch requires PAGERANK_FLD be non-zero! # Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
e, kb = KnowledgebaseService.get_by_id(kb.id) e, kb = KnowledgebaseService.get_by_id(kb.id)
@ -226,8 +225,8 @@ def rm():
return get_data_error_result( return get_data_error_result(
message="Database error (Knowledgebase removal)!") message="Database error (Knowledgebase removal)!")
for kb in kbs: for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(STORAGE_IMPL, 'remove_bucket'): if hasattr(STORAGE_IMPL, 'remove_bucket'):
STORAGE_IMPL.remove_bucket(kb.id) STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True) return get_json_result(data=True)
@ -248,7 +247,7 @@ def list_tags(kb_id):
tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = [] tags = []
for tenant in tenants: for tenant in tenants:
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id]) tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id])
return get_json_result(data=tags) return get_json_result(data=tags)
@ -267,7 +266,7 @@ def list_tags_from_kbs():
tenants = UserTenantService.get_tenants_by_user_id(current_user.id) tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = [] tags = []
for tenant in tenants: for tenant in tenants:
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids) tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids)
return get_json_result(data=tags) return get_json_result(data=tags)
@ -284,7 +283,7 @@ def rm_tags(kb_id):
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
for t in req["tags"]: for t in req["tags"]:
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]}, globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
{"remove": {"tag_kwd": t}}, {"remove": {"tag_kwd": t}},
search.index_name(kb.tenant_id), search.index_name(kb.tenant_id),
kb_id) kb_id)
@ -303,7 +302,7 @@ def rename_tags(kb_id):
) )
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]}, globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}}, {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id), search.index_name(kb.tenant_id),
kb_id) kb_id)
@ -326,9 +325,9 @@ def knowledge_graph(kb_id):
} }
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj) return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids): if not len(sres.ids):
return get_json_result(data=obj) return get_json_result(data=obj)
@ -360,7 +359,7 @@ def delete_knowledge_graph(kb_id):
code=RetCode.AUTHENTICATION_ERROR code=RetCode.AUTHENTICATION_ERROR
) )
_, kb = KnowledgebaseService.get_by_id(kb_id) _, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
return get_json_result(data=True) return get_json_result(data=True)
@ -732,13 +731,13 @@ def delete_kb_task():
task_id = kb.graphrag_task_id task_id = kb.graphrag_task_id
kb_task_finish_at = "graphrag_task_finish_at" kb_task_finish_at = "graphrag_task_finish_at"
cancel_task(task_id) cancel_task(task_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
case PipelineTaskType.RAPTOR: case PipelineTaskType.RAPTOR:
kb_task_id_field = "raptor_task_id" kb_task_id_field = "raptor_task_id"
task_id = kb.raptor_task_id task_id = kb.raptor_task_id
kb_task_finish_at = "raptor_task_finish_at" kb_task_finish_at = "raptor_task_finish_at"
cancel_task(task_id) cancel_task(task_id)
settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id) globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
case PipelineTaskType.MINDMAP: case PipelineTaskType.MINDMAP:
kb_task_id_field = "mindmap_task_id" kb_task_id_field = "mindmap_task_id"
task_id = kb.mindmap_task_id task_id = kb.mindmap_task_id
@ -850,7 +849,7 @@ def check_embedding():
tenant_id = kb.tenant_id tenant_id = kb.tenant_id
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n) samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
results, eff_sims = [], [] results, eff_sims = [], []
for ck in samples: for ck in samples:

View File

@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from common.constants import StatusEnum, LLMType from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result, get_allowed_llm_factories from api.utils.api_utils import get_json_result, get_allowed_llm_factories
from common.base64_image import test_image from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel

View File

@ -20,7 +20,6 @@ import os
import json import json
from flask import request from flask import request
from peewee import OperationalError from peewee import OperationalError
from api import settings
from api.db.db_models import File from api.db.db_models import File
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
@ -49,6 +48,7 @@ from api.utils.validation_utils import (
) )
from rag.nlp import search from rag.nlp import search
from rag.settings import PAGERANK_FLD from rag.settings import PAGERANK_FLD
from common import globals
@manager.route("/datasets", methods=["POST"]) # noqa: F821 @manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -360,11 +360,11 @@ def update(tenant_id, dataset_id):
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
if req["pagerank"] > 0: if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
else: else:
# Elasticsearch requires PAGERANK_FLD be non-zero! # Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id) search.index_name(kb.tenant_id), kb.id)
if not KnowledgebaseService.update_by_id(kb.id, req): if not KnowledgebaseService.update_by_id(kb.id, req):
@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id):
} }
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj) return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids): if not len(sres.ids):
return get_result(data=obj) return get_result(data=obj)
@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id):
code=RetCode.AUTHENTICATION_ERROR code=RetCode.AUTHENTICATION_ERROR
) )
_, kb = KnowledgebaseService.get_by_id(dataset_id) _, kb = KnowledgebaseService.get_by_id(dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
search.index_name(kb.tenant_id), dataset_id) search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True) return get_result(data=True)

View File

@ -25,6 +25,7 @@ from api.utils.api_utils import validate_request, build_error_result, apikey_req
from rag.app.tag import label_question from rag.app.tag import label_question
from api.db.services.dialog_service import meta_filter, convert_conditions from api.db.services.dialog_service import meta_filter, convert_conditions
from common.constants import RetCode, LLMType from common.constants import RetCode, LLMType
from common import globals
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 @manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
@apikey_required @apikey_required
@ -137,7 +138,7 @@ def retrieval(tenant_id):
# print("doc_ids", doc_ids) # print("doc_ids", doc_ids)
if not doc_ids and metadata_condition is not None: if not doc_ids and metadata_condition is not None:
doc_ids = ['-999'] doc_ids = ['-999']
ranks = settings.retriever.retrieval( ranks = globals.retriever.retrieval(
question, question,
embd_mdl, embd_mdl,
kb.tenant_id, kb.tenant_id,

View File

@ -44,6 +44,7 @@ from rag.prompts.generator import cross_languages, keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from common.string_utils import remove_redundant_spaces from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource
from common import globals
MAXIMUM_OF_UPLOADING_FILES = 256 MAXIMUM_OF_UPLOADING_FILES = 256
@ -307,7 +308,7 @@ def update_doc(tenant_id, dataset_id, document_id):
) )
if not e: if not e:
return get_error_data_result(message="Document not found!") return get_error_data_result(message="Document not found!")
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
if "enabled" in req: if "enabled" in req:
status = int(req["enabled"]) status = int(req["enabled"])
@ -316,7 +317,7 @@ def update_doc(tenant_id, dataset_id, document_id):
if not DocumentService.update_by_id(doc.id, {"status": str(status)}): if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
return get_error_data_result(message="Database error (Document update)!") return get_error_data_result(message="Database error (Document update)!")
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_result(data=True) return get_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -755,7 +756,7 @@ def parse(tenant_id, dataset_id):
return get_error_data_result("Can't parse document that is currently being processed") return get_error_data_result("Can't parse document that is currently being processed")
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id) e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict() doc = doc.to_dict()
@ -835,7 +836,7 @@ def stop_parsing(tenant_id, dataset_id):
return get_error_data_result("Can't stop parsing document with progress at 0 or 1") return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
info = {"run": "2", "progress": 0, "chunk_num": 0} info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
success_count += 1 success_count += 1
if duplicate_messages: if duplicate_messages:
if success_count > 0: if success_count > 0:
@ -968,7 +969,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
res = {"total": 0, "chunks": [], "doc": renamed_doc} res = {"total": 0, "chunks": [], "doc": renamed_doc}
if req.get("id"): if req.get("id"):
chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
if not chunk: if not chunk:
return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND) return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
k = [] k = []
@ -995,8 +996,8 @@ def list_chunks(tenant_id, dataset_id, document_id):
res["chunks"].append(final_chunk) res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk) _ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total res["total"] = sres.total
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -1120,7 +1121,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
# rename keys # rename keys
@ -1201,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
if "chunk_ids" in req: if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
condition["id"] = unique_chunk_ids condition["id"] = unique_chunk_ids
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0: if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(unique_chunk_ids): if "chunk_ids" in req and chunk_number != len(unique_chunk_ids):
@ -1273,7 +1274,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema: schema:
type: object type: object
""" """
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None: if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}") return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
@ -1318,7 +1319,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result() return get_result()
@ -1464,7 +1465,7 @@ def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval( ranks = globals.retriever.retrieval(
question, question,
embd_mdl, embd_mdl,
tenant_ids, tenant_ids,

View File

@ -41,6 +41,7 @@ from rag.app.tag import label_question
from rag.prompts.template import load_prompt from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
from common.constants import RetCode, LLMType, StatusEnum from common.constants import RetCode, LLMType, StatusEnum
from common import globals
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
@ -1015,7 +1016,7 @@ def retrieval_test_embedded():
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb]) labels = label_question(question, [kb])
ranks = settings.retriever.retrieval( ranks = globals.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
) )

View File

@ -38,6 +38,7 @@ from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify from flask import jsonify
from api.utils.health_utils import run_health_checks from api.utils.health_utils import run_health_checks
from common import globals
@manager.route("/version", methods=["GET"]) # noqa: F821 @manager.route("/version", methods=["GET"]) # noqa: F821
@ -100,7 +101,7 @@ def status():
res = {} res = {}
st = timer() st = timer()
try: try:
res["doc_engine"] = settings.docStoreConn.health() res["doc_engine"] = globals.docStoreConn.health()
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e: except Exception as e:
res["doc_engine"] = { res["doc_engine"] = {

View File

@ -58,6 +58,7 @@ from api.utils.web_utils import (
hash_code, hash_code,
captcha_key, captcha_key,
) )
from common import globals
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821 @manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -623,7 +624,7 @@ def user_register(user_id, user):
"id": user_id, "id": user_id,
"name": user["nickname"] + "s Kingdom", "name": user["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL, "llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL, "embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL, "asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS, "parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL, "img2txt_id": settings.IMAGE2TEXT_MDL,

View File

@ -32,6 +32,7 @@ from api.db.services.user_service import TenantService, UserTenantService
from api import settings from api import settings
from common.constants import LLMType from common.constants import LLMType
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common import globals
from api.common.base64 import encode_to_base64 from api.common.base64 import encode_to_base64
@ -49,7 +50,7 @@ def init_superuser():
"id": user_info["id"], "id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom", "name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL, "llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL, "embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL, "asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS, "parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL "img2txt_id": settings.IMAGE2TEXT_MDL

View File

@ -38,6 +38,7 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search from rag.nlp import search
from common.constants import ActiveEnum from common.constants import ActiveEnum
from common import globals
def create_new_user(user_info: dict) -> dict: def create_new_user(user_info: dict) -> dict:
""" """
@ -63,7 +64,7 @@ def create_new_user(user_info: dict) -> dict:
"id": user_id, "id": user_id,
"name": user_info["nickname"] + "s Kingdom", "name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL, "llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL, "embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL, "asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS, "parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL, "img2txt_id": settings.IMAGE2TEXT_MDL,
@ -179,7 +180,7 @@ def delete_user_data(user_id: str) -> dict:
) )
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n" done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
# step1.1.3 delete chunk in es # step1.1.3 delete chunk in es
r = settings.docStoreConn.delete({"kb_id": kb_ids}, r = globals.docStoreConn.delete({"kb_id": kb_ids},
search.index_name(tenant_id), kb_ids) search.index_name(tenant_id), kb_ids)
done_msg += f"- Deleted {r} chunk records.\n" done_msg += f"- Deleted {r} chunk records.\n"
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids) kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
@ -237,7 +238,7 @@ def delete_user_data(user_id: str) -> dict:
kb_doc_info = {} kb_doc_info = {}
for _tenant_id, kb_doc in kb_grouped_doc.items(): for _tenant_id, kb_doc in kb_grouped_doc.items():
for _kb_id, docs in kb_doc.items(): for _kb_id, docs in kb_doc.items():
chunk_delete_res += settings.docStoreConn.delete( chunk_delete_res += globals.docStoreConn.delete(
{"doc_id": [d["id"] for d in docs]}, {"doc_id": [d["id"] for d in docs]},
search.index_name(_tenant_id), _kb_id search.index_name(_tenant_id), _kb_id
) )

View File

@ -111,12 +111,14 @@ class SyncLogsService(CommonService):
return list(query.dicts()) return list(query.dicts())
@classmethod @classmethod
def start(cls, id): def start(cls, id, connector_id):
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') }) cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING})
@classmethod @classmethod
def done(cls, id): def done(cls, id, connector_id):
cls.update_by_id(id, {"status": TaskStatus.DONE}) cls.update_by_id(id, {"status": TaskStatus.DONE})
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE})
@classmethod @classmethod
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0): def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
@ -126,6 +128,7 @@ class SyncLogsService(CommonService):
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.") logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
return None return None
reindex = "1" if reindex else "0" reindex = "1" if reindex else "0"
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
return cls.save(**{ return cls.save(**{
"id": get_uuid(), "id": get_uuid(),
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id, "kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
@ -142,6 +145,7 @@ class SyncLogsService(CommonService):
full_exception_trace=cls.model.full_exception_trace + str(e) full_exception_trace=cls.model.full_exception_trace + str(e)
) \ ) \
.where(cls.model.id == task.id).execute() .where(cls.model.id == task.id).execute()
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
@classmethod @classmethod
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0): def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):

View File

@ -44,6 +44,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily
from common.string_utils import remove_redundant_spaces from common.string_utils import remove_redundant_spaces
from common import globals
class DialogService(CommonService): class DialogService(CommonService):
@ -293,12 +294,13 @@ def meta_filter(metas: dict, filters: list[dict]):
def filter_out(v2docs, operator, value): def filter_out(v2docs, operator, value):
ids = [] ids = []
for input, docids in v2docs.items(): for input, docids in v2docs.items():
try: if operator in ["=", "", ">", "<", "", ""]:
input = float(input) try:
value = float(value) input = float(input)
except Exception: value = float(value)
input = str(input) except Exception:
value = str(value) input = str(input)
value = str(value)
for conds in [ for conds in [
(operator == "contains", str(value).lower() in str(input).lower()), (operator == "contains", str(value).lower() in str(input).lower()),
@ -371,7 +373,7 @@ def chat(dialog, messages, stream=True, **kwargs):
chat_mdl.bind_tools(toolcall_session, tools) chat_mdl.bind_tools(toolcall_session, tools)
bind_models_ts = timer() bind_models_ts = timer()
retriever = settings.retriever retriever = globals.retriever
questions = [m["content"] for m in messages if m["role"] == "user"][-3:] questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
if "doc_ids" in messages[-1]: if "doc_ids" in messages[-1]:
@ -663,7 +665,7 @@ Please write the SQL, only SQL, without any other explanations or text.
logging.debug(f"{question} get SQL(refined): {sql}") logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1 tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql return globals.retriever.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table() tbl, sql = get_table()
if tbl is None: if tbl is None:
@ -757,7 +759,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
embedding_list = list(set([kb.embd_id for kb in kbs])) embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
@ -853,7 +855,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
if not doc_ids: if not doc_ids:
doc_ids = None doc_ids = None
ranks = settings.retriever.retrieval( ranks = globals.retriever.retrieval(
question=question, question=question,
embd_mdl=embd_mdl, embd_mdl=embd_mdl,
tenant_ids=tenant_ids, tenant_ids=tenant_ids,

View File

@ -26,7 +26,6 @@ import trio
import xxhash import xxhash
from peewee import fn, Case, JOIN from peewee import fn, Case, JOIN
from api import settings
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import FileType, UserTenantRole, CanvasCategory from api.db import FileType, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
@ -42,7 +41,7 @@ from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr from rag.utils.doc_store_conn import OrderByExpr
from common import globals
class DocumentService(CommonService): class DocumentService(CommonService):
model = Document model = Document
@ -309,10 +308,10 @@ class DocumentService(CommonService):
page_size = 1000 page_size = 1000
all_chunk_ids = [] all_chunk_ids = []
while True: while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id), page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id]) [doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks) chunk_ids = globals.docStoreConn.getChunkIds(chunks)
if not chunk_ids: if not chunk_ids:
break break
all_chunk_ids.extend(chunk_ids) all_chunk_ids.extend(chunk_ids)
@ -323,19 +322,19 @@ class DocumentService(CommonService):
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX): if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields( graph_source = globals.docStoreConn.getFields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
) )
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id}, globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}}, {"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id) search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"}, {"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id) search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}}, globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id) search.index_name(tenant_id), doc.kb_id)
except Exception: except Exception:
pass pass
@ -996,10 +995,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
if try_create_idx: if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id): if not globals.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

View File

@ -30,13 +30,14 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id): def get_init_tenant_llm(user_id):
from api import settings from api import settings
from common import globals
tenant_llm = [] tenant_llm = []
seen = set() seen = set()
factory_configs = [] factory_configs = []
for factory_config in [ for factory_config in [
settings.CHAT_CFG, settings.CHAT_CFG,
settings.EMBEDDING_CFG, globals.EMBEDDING_CFG,
settings.ASR_CFG, settings.ASR_CFG,
settings.IMAGE2TEXT_CFG, settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG, settings.RERANK_CFG,

View File

@ -34,7 +34,7 @@ from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.settings import get_svr_queue_name from rag.settings import get_svr_queue_name
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from api import settings from common import globals
from rag.nlp import search from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x" CANVAS_DEBUG_DOC_ID = "dataflow_x"
@ -418,7 +418,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
if pre_task["chunk_ids"]: if pre_task["chunk_ids"]:
pre_chunk_ids.extend(pre_task["chunk_ids"].split()) pre_chunk_ids.extend(pre_task["chunk_ids"].split())
if pre_chunk_ids: if pre_chunk_ids:
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]), globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
chunking_config["kb_id"]) chunking_config["kb_id"])
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num}) DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})

View File

@ -17,6 +17,7 @@ import os
import logging import logging
from langfuse import Langfuse from langfuse import Langfuse
from api import settings from api import settings
from common import globals
from common.constants import LLMType from common.constants import LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
@ -114,7 +115,7 @@ class TenantLLMService(CommonService):
if model_config: if model_config:
model_config = model_config.to_dict() model_config = model_config.to_dict()
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''): elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
embedding_cfg = settings.EMBEDDING_CFG embedding_cfg = globals.EMBEDDING_CFG
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]} model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
else: else:
raise LookupError(f"Model({mdlnm}@{fid}) not authorized") raise LookupError(f"Model({mdlnm}@{fid}) not authorized")

View File

@ -27,7 +27,7 @@ from api.db.services.common_service import CommonService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, datetime_format from common.time_utils import current_timestamp, datetime_format
from common.constants import StatusEnum from common.constants import StatusEnum
from rag.settings import MINIO from common import globals
class UserService(CommonService): class UserService(CommonService):
@ -221,7 +221,7 @@ class TenantService(CommonService):
@DB.connection_context() @DB.connection_context()
def user_gateway(cls, tenant_id): def user_gateway(cls, tenant_id):
hash_obj = hashlib.sha256(tenant_id.encode("utf-8")) hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
return int(hash_obj.hexdigest(), 16)%len(MINIO) return int(hash_obj.hexdigest(), 16)%len(globals.MINIO)
class UserTenantService(CommonService): class UserTenantService(CommonService):

View File

@ -25,18 +25,19 @@ import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME from api.constants import RAG_FLOW_SERVICE_NAME
from common.config_utils import decrypt_database_config, get_base_config from common.config_utils import decrypt_database_config, get_base_config
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common import globals
from rag.nlp import search from rag.nlp import search
LLM = None LLM = None
LLM_FACTORY = None LLM_FACTORY = None
LLM_BASE_URL = None LLM_BASE_URL = None
CHAT_MDL = "" CHAT_MDL = ""
EMBEDDING_MDL = "" # EMBEDDING_MDL = "" has been moved to common/globals.py
RERANK_MDL = "" RERANK_MDL = ""
ASR_MDL = "" ASR_MDL = ""
IMAGE2TEXT_MDL = "" IMAGE2TEXT_MDL = ""
CHAT_CFG = "" CHAT_CFG = ""
EMBEDDING_CFG = "" # EMBEDDING_CFG = "" has been moved to common/globals.py
RERANK_CFG = "" RERANK_CFG = ""
ASR_CFG = "" ASR_CFG = ""
IMAGE2TEXT_CFG = "" IMAGE2TEXT_CFG = ""
@ -60,10 +61,10 @@ HTTP_APP_KEY = None
GITHUB_OAUTH = None GITHUB_OAUTH = None
FEISHU_OAUTH = None FEISHU_OAUTH = None
OAUTH_CONFIG = None OAUTH_CONFIG = None
DOC_ENGINE = None # DOC_ENGINE = None has been moved to common/globals.py
docStoreConn = None # docStoreConn = None has been moved to common/globals.py
retriever = None #retriever = None has been moved to common/globals.py
kg_retriever = None kg_retriever = None
# user registration switch # user registration switch
@ -124,8 +125,8 @@ def init_settings():
except Exception: except Exception:
FACTORY_LLM_INFOS = [] FACTORY_LLM_INFOS = []
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key") API_KEY = LLM.get("api_key")
@ -134,19 +135,19 @@ def init_settings():
) )
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)) chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL)) embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL))
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL)) rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL)) asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)) image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL) IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
CHAT_MDL = CHAT_CFG.get("model", "") or "" CHAT_MDL = CHAT_CFG.get("model", "") or ""
EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else "" globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
RERANK_MDL = RERANK_CFG.get("model", "") or "" RERANK_MDL = RERANK_CFG.get("model", "") or ""
ASR_MDL = ASR_CFG.get("model", "") or "" ASR_MDL = ASR_CFG.get("model", "") or ""
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or "" IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
@ -168,23 +169,23 @@ def init_settings():
OAUTH_CONFIG = get_base_config("oauth", {}) OAUTH_CONFIG = get_base_config("oauth", {})
global DOC_ENGINE, docStoreConn, retriever, kg_retriever global kg_retriever
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch") globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch") # globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
lower_case_doc_engine = DOC_ENGINE.lower() lower_case_doc_engine = globals.DOC_ENGINE.lower()
if lower_case_doc_engine == "elasticsearch": if lower_case_doc_engine == "elasticsearch":
docStoreConn = rag.utils.es_conn.ESConnection() globals.docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity": elif lower_case_doc_engine == "infinity":
docStoreConn = rag.utils.infinity_conn.InfinityConnection() globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch": elif lower_case_doc_engine == "opensearch":
docStoreConn = rag.utils.opensearch_conn.OSConnection() globals.docStoreConn = rag.utils.opensearch_conn.OSConnection()
else: else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}") raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}")
retriever = search.Dealer(docStoreConn) globals.retriever = search.Dealer(globals.docStoreConn)
from graphrag import search as kg_search from graphrag import search as kg_search
kg_retriever = kg_search.KGSearch(docStoreConn) kg_retriever = kg_search.KGSearch(globals.docStoreConn)
if int(os.environ.get("SANDBOX_ENABLED", "0")): if int(os.environ.get("SANDBOX_ENABLED", "0")):
global SANDBOX_HOST global SANDBOX_HOST

View File

@ -626,7 +626,7 @@ async def is_strong_enough(chat_model, embedding_model):
def get_allowed_llm_factories() -> list: def get_allowed_llm_factories() -> list:
factories = LLMFactoriesService.get_all() factories = list(LLMFactoriesService.get_all())
if settings.ALLOWED_LLM_FACTORIES is None: if settings.ALLOWED_LLM_FACTORIES is None:
return factories return factories

View File

@ -19,11 +19,11 @@ from timeit import default_timer as timer
from api import settings from api import settings
from api.db.db_models import DB from api.db.db_models import DB
from rag import settings as rag_settings
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.es_conn import ESConnection from rag.utils.es_conn import ESConnection
from rag.utils.infinity_conn import InfinityConnection from rag.utils.infinity_conn import InfinityConnection
from common import globals
def _ok_nok(ok: bool) -> str: def _ok_nok(ok: bool) -> str:
@ -52,7 +52,7 @@ def check_redis() -> tuple[bool, dict]:
def check_doc_engine() -> tuple[bool, dict]: def check_doc_engine() -> tuple[bool, dict]:
st = timer() st = timer()
try: try:
meta = settings.docStoreConn.health() meta = globals.docStoreConn.health()
# treat any successful call as ok # treat any successful call as ok
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})} return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
except Exception as e: except Exception as e:
@ -120,7 +120,7 @@ def get_mysql_status():
def check_minio_alive(): def check_minio_alive():
start_time = timer() start_time = timer()
try: try:
response = requests.get(f'http://{rag_settings.MINIO["host"]}/minio/health/live') response = requests.get(f'http://{globals.MINIO["host"]}/minio/health/live')
if response.status_code == 200: if response.status_code == 200:
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."} return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
else: else:

64
common/globals.py Normal file
View File

@ -0,0 +1,64 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from common.config_utils import get_base_config, decrypt_database_config
EMBEDDING_MDL = ""
EMBEDDING_CFG = ""
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
docStoreConn = None
retriever = None
# move from rag.settings
ES = {}
INFINITY = {}
AZURE = {}
S3 = {}
MINIO = {}
OSS = {}
OS = {}
REDIS = {}
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
if DOC_ENGINE == 'elasticsearch':
ES = get_base_config("es", {})
elif DOC_ENGINE == 'opensearch':
OS = get_base_config("os", {})
elif DOC_ENGINE == 'infinity':
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
S3 = get_base_config("s3", {})
elif STORAGE_IMPL_TYPE == 'MINIO':
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
try:
REDIS = decrypt_database_config(name="redis")
except Exception:
try:
REDIS = get_base_config("redis", {})
except Exception:
REDIS = {}

55
common/signal_utils.py Normal file
View File

@ -0,0 +1,55 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
from datetime import datetime
import logging
import tracemalloc
from common.log_utils import get_project_base_directory
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")

View File

@ -70,6 +70,17 @@ class RAGFlowMarkdownParser:
) )
working_text = replace_tables_with_rendered_html(no_border_table_pattern, tables) working_text = replace_tables_with_rendered_html(no_border_table_pattern, tables)
# Replace any TAGS e.g. <table ...> to <table>
TAGS = ["table", "td", "tr", "th", "tbody", "thead", "div"]
table_with_attributes_pattern = re.compile(
rf"<(?:{'|'.join(TAGS)})[^>]*>", re.IGNORECASE
)
def replace_tag(m):
tag_name = re.match(r"<(\w+)", m.group()).group(1)
return "<{}>".format(tag_name)
working_text = re.sub(table_with_attributes_pattern, replace_tag, working_text)
if "<table>" in working_text.lower(): # for optimize performance if "<table>" in working_text.lower(): # for optimize performance
# HTML table extraction - handle possible html/body wrapper tags # HTML table extraction - handle possible html/body wrapper tags
html_table_pattern = re.compile( html_table_pattern = re.compile(

View File

@ -6,10 +6,11 @@ set -e
# Usage and command-line argument parsing # Usage and command-line argument parsing
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
function usage() { function usage() {
echo "Usage: $0 [--disable-webserver] [--disable-taskexecutor] [--consumer-no-beg=<num>] [--consumer-no-end=<num>] [--workers=<num>] [--host-id=<string>]" echo "Usage: $0 [--disable-webserver] [--disable-taskexecutor] [--disable-datasync] [--consumer-no-beg=<num>] [--consumer-no-end=<num>] [--workers=<num>] [--host-id=<string>]"
echo echo
echo " --disable-webserver Disables the web server (nginx + ragflow_server)." echo " --disable-webserver Disables the web server (nginx + ragflow_server)."
echo " --disable-taskexecutor Disables task executor workers." echo " --disable-taskexecutor Disables task executor workers."
echo " --disable-datasync Disables synchronization of datasource workers."
echo " --enable-mcpserver Enables the MCP server." echo " --enable-mcpserver Enables the MCP server."
echo " --enable-adminserver Enables the Admin server." echo " --enable-adminserver Enables the Admin server."
echo " --consumer-no-beg=<num> Start range for consumers (if using range-based)." echo " --consumer-no-beg=<num> Start range for consumers (if using range-based)."
@ -28,6 +29,7 @@ function usage() {
ENABLE_WEBSERVER=1 # Default to enable web server ENABLE_WEBSERVER=1 # Default to enable web server
ENABLE_TASKEXECUTOR=1 # Default to enable task executor ENABLE_TASKEXECUTOR=1 # Default to enable task executor
ENABLE_DATASYNC=1
ENABLE_MCP_SERVER=0 ENABLE_MCP_SERVER=0
ENABLE_ADMIN_SERVER=0 # Default close admin server ENABLE_ADMIN_SERVER=0 # Default close admin server
CONSUMER_NO_BEG=0 CONSUMER_NO_BEG=0
@ -69,6 +71,10 @@ for arg in "$@"; do
ENABLE_TASKEXECUTOR=0 ENABLE_TASKEXECUTOR=0
shift shift
;; ;;
--disable-datasyn)
ENABLE_DATASYNC=0
shift
;;
--enable-mcpserver) --enable-mcpserver)
ENABLE_MCP_SERVER=1 ENABLE_MCP_SERVER=1
shift shift
@ -236,6 +242,13 @@ if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then
done & done &
fi fi
if [[ "${ENABLE_DATASYNC}" -eq 1 ]]; then
echo "Starting data sync..."
while true; do
"$PY" rag/svr/sync_data_source.py
done &
fi
if [[ "${ENABLE_ADMIN_SERVER}" -eq 1 ]]; then if [[ "${ENABLE_ADMIN_SERVER}" -eq 1 ]]; then
echo "Starting admin_server..." echo "Starting admin_server..."
while true; do while true; do

View File

@ -20,7 +20,6 @@ import os
import networkx as nx import networkx as nx
import trio import trio
from api import settings
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.connection_utils import timeout from common.connection_utils import timeout
@ -40,6 +39,7 @@ from graphrag.utils import (
) )
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock from rag.utils.redis_conn import RedisDistributedLock
from common import globals
async def run_graphrag( async def run_graphrag(
@ -55,7 +55,7 @@ async def run_graphrag(
start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True): for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -170,7 +170,7 @@ async def run_graphrag_for_kb(
chunks = [] chunks = []
current_chunk = "" current_chunk = ""
for d in settings.retriever.chunk_list( for d in globals.retriever.chunk_list(
doc_id, doc_id,
tenant_id, tenant_id,
[kb_id], [kb_id],
@ -387,8 +387,8 @@ async def generate_subgraph(
"removed_kwd": "N", "removed_kwd": "N",
} }
cid = chunk_id(chunk) cid = chunk_id(chunk)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
now = trio.current_time() now = trio.current_time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph return subgraph
@ -496,7 +496,7 @@ async def extract_community(
chunks.append(chunk) chunks.append(chunk)
await trio.to_thread.run_sync( await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete( lambda: globals.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, {"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id), search.index_name(tenant_id),
kb_id, kb_id,
@ -504,7 +504,7 @@ async def extract_community(
) )
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if doc_store_result: if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message) raise Exception(error_message)

View File

@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from graphrag.general.graph_extractor import GraphExtractor from graphrag.general.graph_extractor import GraphExtractor
from graphrag.general.index import update_graph, with_resolution, with_community from graphrag.general.index import update_graph, with_resolution, with_community
from common import globals
settings.init_settings() settings.init_settings()
@ -62,7 +63,7 @@ async def main():
chunks = [ chunks = [
d["content_with_weight"] d["content_with_weight"]
for d in settings.retriever.chunk_list( for d in globals.retriever.chunk_list(
args.doc_id, args.doc_id,
args.tenant_id, args.tenant_id,
[kb_id], [kb_id],

View File

@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from graphrag.general.index import update_graph from graphrag.general.index import update_graph
from graphrag.light.graph_extractor import GraphExtractor from graphrag.light.graph_extractor import GraphExtractor
from common import globals
settings.init_settings() settings.init_settings()
@ -63,7 +64,7 @@ async def main():
chunks = [ chunks = [
d["content_with_weight"] d["content_with_weight"]
for d in settings.retriever.chunk_list( for d in globals.retriever.chunk_list(
args.doc_id, args.doc_id,
args.tenant_id, args.tenant_id,
[kb_id], [kb_id],

View File

@ -29,6 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name from rag.nlp.search import Dealer, index_name
from common.float_utils import get_float from common.float_utils import get_float
from common import globals
class KGSearch(Dealer): class KGSearch(Dealer):
@ -334,6 +335,6 @@ if __name__ == "__main__":
_, kb = KnowledgebaseService.get_by_id(kb_id) _, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
kg = KGSearch(settings.docStoreConn) kg = KGSearch(globals.docStoreConn)
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)) search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))

View File

@ -23,12 +23,12 @@ import trio
import xxhash import xxhash
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
from api import settings
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.connection_utils import timeout from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from common import globals
GRAPH_FIELD_SEP = "<SEP>" GRAPH_FIELD_SEP = "<SEP>"
@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = list(set(ents)) ents = list(set(ents))
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]} conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
res = [] res = []
es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids: for id in es_res.ids:
try: try:
if size == 1: if size == 1:
@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"], "knowledge_graph_kwd": ["graph"],
"removed_kwd": "N", "removed_kwd": "N",
} }
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields) fields2 = globals.docStoreConn.getFields(res, fields)
graph_doc_ids = set() graph_doc_ids = set()
for chunk_id in fields2.keys(): for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"]) graph_doc_ids = set(fields2[chunk_id]["source_id"])
@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])) res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
doc_ids = [] doc_ids = []
if res.total == 0: if res.total == 0:
return doc_ids return doc_ids
@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None): async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id]) res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id])
if not res.total == 0: if not res.total == 0:
for id in res.ids: for id in res.ids:
try: try:
@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
global chat_limiter global chat_limiter
start = trio.current_time() start = trio.current_time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
if change.removed_nodes: if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id) await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
if change.removed_edges: if change.removed_edges:
async def del_edges(from_node, to_node): async def del_edges(from_node, to_node):
async with chat_limiter: async with chat_limiter:
await trio.to_thread.run_sync( await trio.to_thread.run_sync(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
) )
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000): with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id)) doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if b % 100 == es_bulk_size and callback: if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}") callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result: if doc_store_result:
@ -555,7 +555,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list): async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids)) es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
res = defaultdict(list) res = defaultdict(list)
for id in es_res.ids: for id in es_res.ids:
@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
bs = 256 bs = 256
for i in range(0, 1024 * bs, bs): for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync( es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
) )
# tot = settings.docStoreConn.getTotal(es_res) # tot = globals.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds) es_res = globals.docStoreConn.getFields(es_res, flds)
if len(es_res) == 0: if len(es_res) == 0:
break break

View File

@ -15,18 +15,18 @@
# #
import logging import logging
from tika import parser
import re import re
from io import BytesIO from io import BytesIO
from deepdoc.parser.utils import get_text from deepdoc.parser.utils import get_text
from rag.app import naive from rag.app import naive
from rag.app.naive import plaintext_parser, PARSERS
from rag.nlp import bullets_category, is_english,remove_contents_table, \ from rag.nlp import bullets_category, is_english,remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \ hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks tokenize_chunks
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, PlainParser, HtmlParser from deepdoc.parser import PdfParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from PIL import Image from PIL import Image
@ -96,13 +96,33 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
sections, tables, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tables:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = get_text(filename, binary) txt = get_text(filename, binary)

View File

@ -15,7 +15,6 @@
# #
import logging import logging
from tika import parser
import re import re
from io import BytesIO from io import BytesIO
from docx import Document from docx import Document
@ -25,8 +24,8 @@ from deepdoc.parser.utils import get_text
from rag.nlp import bullets_category, remove_contents_table, \ from rag.nlp import bullets_category, remove_contents_table, \
make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge
from rag.nlp import rag_tokenizer, Node from rag.nlp import rag_tokenizer, Node
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser from deepdoc.parser import PdfParser, DocxParser, HtmlParser
from rag.app.naive import plaintext_parser, PARSERS
@ -156,13 +155,36 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return tokenize_chunks(chunks, doc, eng, None) return tokenize_chunks(chunks, doc, eng, None)
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser() if isinstance(layout_recognizer, bool):
for txt, poss in pdf_parser(filename if not binary else binary, layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
from_page=from_page, to_page=to_page, callback=callback)[0]:
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
raw_sections, tables, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not raw_sections and not tables:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
for txt, poss in raw_sections:
sections.append(txt + poss) sections.append(txt + poss)
callback(0.8, "Finish parsing.")
elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE): elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = get_text(filename, binary) txt = get_text(filename, binary)

View File

@ -22,11 +22,11 @@ from common.constants import ParserType
from io import BytesIO from io import BytesIO
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
from common.token_utils import num_tokens_from_string from common.token_utils import num_tokens_from_string
from deepdoc.parser import PdfParser, PlainParser, DocxParser from deepdoc.parser import PdfParser, DocxParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from docx import Document from docx import Document
from PIL import Image from PIL import Image
from rag.app.naive import plaintext_parser, PARSERS
class Pdf(PdfParser): class Pdf(PdfParser):
def __init__(self): def __init__(self):
@ -196,15 +196,34 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
# is it English # is it English
eng = lang.lower() == "english" # pdf_parser.is_english eng = lang.lower() == "english" # pdf_parser.is_english
if re.search(r"\.pdf$", filename, re.IGNORECASE): if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser() if isinstance(layout_recognizer, bool):
sections, tbls = pdf_parser(filename if not binary else binary, layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0]) < 3: name = layout_recognizer.strip().lower()
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections] pdf_parser = PARSERS.get(name, plaintext_parser)
# set pivot using the most frequent type of title, callback(0.1, "Start to parse.")
# then merge between 2 pivot
sections, tbls, pdf_parser = pdf_parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tbls:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03: if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03:
max_lvl = max([lvl for _, lvl in pdf_parser.outlines]) max_lvl = max([lvl for _, lvl in pdf_parser.outlines])
most_level = max(0, max_lvl - 1) most_level = max(0, max_lvl - 1)

View File

@ -26,7 +26,6 @@ from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
from docx.opc.oxml import parse_xml from docx.opc.oxml import parse_xml
from markdown import markdown from markdown import markdown
from PIL import Image from PIL import Image
from tika import parser
from common.constants import LLMType from common.constants import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
@ -39,6 +38,100 @@ from deepdoc.parser.docling_parser import DoclingParser
from deepdoc.parser.tcadp_parser import TCADPParser from deepdoc.parser.tcadp_parser import TCADPParser
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
def DeepDOC_parser(filename, binary=None, from_page=0, to_page=100000, callback=None, pdf_cls = None ,**kwargs):
callback = callback
binary = binary
pdf_parser = pdf_cls() if pdf_cls else Pdf()
sections, tables = pdf_parser(
filename if not binary else binary,
from_page=from_page,
to_page=to_page,
callback=callback
)
tables = vision_figure_parser_pdf_wrapper(tbls=tables,
callback=callback,
**kwargs)
return sections, tables, pdf_parser
def MinerU_parser(filename, binary=None, callback=None, **kwargs):
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api)
if not pdf_parser.check_installation():
callback(-1, "MinerU not found.")
return None, None
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
return sections, tables, pdf_parser
def Docling_parser(filename, binary=None, callback=None, **kwargs):
pdf_parser = DoclingParser()
if not pdf_parser.check_installation():
callback(-1, "Docling not found.")
return None, None
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
return sections, tables, pdf_parser
def TCADP_parser(filename, binary=None, callback=None, **kwargs):
tcadp_parser = TCADPParser()
if not tcadp_parser.check_installation():
callback(-1, "TCADP parser not available. Please check Tencent Cloud API configuration.")
return None, None
sections, tables = tcadp_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("TCADP_OUTPUT_DIR", ""),
file_type="PDF"
)
return sections, tables, tcadp_parser
def plaintext_parser(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
if kwargs.get("layout_recognizer", "") == "Plain Text":
pdf_parser = PlainParser()
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=kwargs.get("layout_recognizer", ""), lang=kwargs.get("lang", "Chinese"))
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, tables = pdf_parser(
filename if not binary else binary,
from_page=from_page,
to_page=to_page,
callback=callback
)
return sections, tables, pdf_parser
PARSERS = {
"deepdoc": DeepDOC_parser,
"mineru": MinerU_parser,
"docling": Docling_parser,
"tcadp": TCADP_parser,
"plaintext": plaintext_parser, # default
}
class Docx(DocxParser): class Docx(DocxParser):
def __init__(self): def __init__(self):
@ -365,7 +458,7 @@ class Markdown(MarkdownParser):
html_content = markdown(text) html_content = markdown(text)
soup = BeautifulSoup(html_content, 'html.parser') soup = BeautifulSoup(html_content, 'html.parser')
return soup return soup
def get_picture_urls(self, soup): def get_picture_urls(self, soup):
if soup: if soup:
return [img.get('src') for img in soup.find_all('img') if img.get('src')] return [img.get('src') for img in soup.find_all('img') if img.get('src')]
@ -375,7 +468,7 @@ class Markdown(MarkdownParser):
if soup: if soup:
return set([a.get('href') for a in soup.find_all('a') if a.get('href')]) return set([a.get('href') for a in soup.find_all('a') if a.get('href')])
return [] return []
def get_pictures(self, text): def get_pictures(self, text):
"""Download and open all images from markdown text.""" """Download and open all images from markdown text."""
import requests import requests
@ -416,11 +509,11 @@ class Markdown(MarkdownParser):
txt = f.read() txt = f.read()
remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables) remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables)
# To eliminate duplicate tables in chunking result, uncomment code below and set separate_tables to True in line 410.
# extractor = MarkdownElementExtractor(remainder)
extractor = MarkdownElementExtractor(txt) extractor = MarkdownElementExtractor(txt)
element_sections = extractor.extract_elements(delimiter) element_sections = extractor.extract_elements(delimiter)
sections = [(element, "") for element in element_sections] sections = [(element, "") for element in element_sections]
tbls = [] tbls = []
for table in tables: for table in tables:
tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), "")) tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), ""))
@ -535,82 +628,29 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if isinstance(layout_recognizer, bool): if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
if layout_recognizer == "DeepDOC": sections, tables, _ = parser(
pdf_parser = Pdf() filename = filename,
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback) binary = binary,
tables=vision_figure_parser_pdf_wrapper(tbls=tables,callback=callback,**kwargs) from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
**kwargs
)
res = tokenize_table(tables, doc, is_english) if not sections and not tables:
callback(0.8, "Finish parsing.") return []
elif layout_recognizer == "MinerU": if name in ["tcadp", "docling", "mineru"]:
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
mineru_server_url = os.environ.get("MINERU_SERVER_URL", "")
mineru_backend = os.environ.get("MINERU_BACKEND", "pipeline")
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api, mineru_server_url=mineru_server_url)
ok, reason = pdf_parser.check_installation(backend=mineru_backend)
if not ok:
callback(-1, f"MinerU not found or server not accessible: {reason}")
return res
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
backend=mineru_backend,
server_url=mineru_server_url,
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
parser_config["chunk_token_num"] = 0 parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
res = tokenize_table(tables, doc, is_english)
elif layout_recognizer == "Docling": callback(0.8, "Finish parsing.")
pdf_parser = DoclingParser()
if not pdf_parser.check_installation():
callback(-1, "Docling not found.")
return res
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
parser_config["chunk_token_num"] = 0
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif layout_recognizer == "TCADP Parser":
tcadp_parser = TCADPParser()
if not tcadp_parser.check_installation():
callback(-1, "TCADP parser not available. Please check Tencent Cloud API configuration.")
return res
sections, tables = tcadp_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("TCADP_OUTPUT_DIR", ""),
file_type="PDF"
)
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
else:
if layout_recognizer == "Plain Text":
pdf_parser = PlainParser()
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE): elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
@ -735,9 +775,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
logging.info(f"Failed to chunk url in registered file type {url}: {e}") logging.info(f"Failed to chunk url in registered file type {url}: {e}")
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs) sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
url_res.extend(sub_url_res) url_res.extend(sub_url_res)
logging.info("naive_merge({}): {}".format(filename, timer() - st)) logging.info("naive_merge({}): {}".format(filename, timer() - st))
if embed_res: if embed_res:
res.extend(embed_res) res.extend(embed_res)
if url_res: if url_res:

View File

@ -15,16 +15,15 @@
# #
import logging import logging
from tika import parser
from io import BytesIO from io import BytesIO
import re import re
from deepdoc.parser.utils import get_text from deepdoc.parser.utils import get_text
from rag.app import naive from rag.app import naive
from rag.nlp import rag_tokenizer, tokenize from rag.nlp import rag_tokenizer, tokenize
from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser from deepdoc.parser import PdfParser, ExcelParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from rag.app.naive import plaintext_parser, PARSERS
class Pdf(PdfParser): class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0, def __call__(self, filename, binary=None, from_page=0,
@ -83,12 +82,34 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser() if isinstance(layout_recognizer, bool):
sections, tbls = pdf_parser( layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
filename if not binary else binary, to_page=to_page, callback=callback)
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs) name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
sections, tbls, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tbls:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
for (img, rows), poss in tbls: for (img, rows), poss in tbls:
if not rows: if not rows:
continue continue

View File

@ -20,14 +20,11 @@ from io import BytesIO
from PIL import Image from PIL import Image
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import VisionParser
from rag.nlp import tokenize, is_english from rag.nlp import tokenize, is_english
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, PptParser, PlainParser from deepdoc.parser import PdfParser, PptParser, PlainParser
from PyPDF2 import PdfReader as pdf2_read from PyPDF2 import PdfReader as pdf2_read
from rag.app.naive import plaintext_parser, PARSERS
class Ppt(PptParser): class Ppt(PptParser):
def __call__(self, fnm, from_page, to_page, callback=None): def __call__(self, fnm, from_page, to_page, callback=None):
@ -54,7 +51,6 @@ class Ppt(PptParser):
self.is_english = is_english(txts) self.is_english = is_english(txts)
return [(txts[i], imgs[i]) for i in range(len(txts))] return [(txts[i], imgs[i]) for i in range(len(txts))]
class Pdf(PdfParser): class Pdf(PdfParser):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -84,7 +80,7 @@ class Pdf(PdfParser):
res.append((lines, self.page_images[i])) res.append((lines, self.page_images[i]))
callback(0.9, "Page {}~{}: Parsing finished".format( callback(0.9, "Page {}~{}: Parsing finished".format(
from_page, min(to_page, self.total_page))) from_page, min(to_page, self.total_page)))
return res return res, []
class PlainPdf(PlainParser): class PlainPdf(PlainParser):
@ -95,7 +91,7 @@ class PlainPdf(PlainParser):
for page in self.pdf.pages[from_page: to_page]: for page in self.pdf.pages[from_page: to_page]:
page_txt.append(page.extract_text()) page_txt.append(page.extract_text())
callback(0.9, "Parsing finished") callback(0.9, "Parsing finished")
return [(txt, None) for txt in page_txt] return [(txt, None) for txt in page_txt], []
def chunk(filename, binary=None, from_page=0, to_page=100000, def chunk(filename, binary=None, from_page=0, to_page=100000,
@ -130,20 +126,33 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return res return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if layout_recognizer == "DeepDOC":
pdf_parser = Pdf()
sections = pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
elif layout_recognizer == "Plain Text":
pdf_parser = PlainParser()
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
sections, _, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
for pn, (txt, img) in enumerate(sections): for pn, (txt, img) in enumerate(sections):
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
pn += from_page pn += from_page

View File

@ -21,6 +21,7 @@ from copy import deepcopy
from deepdoc.parser.utils import get_text from deepdoc.parser.utils import get_text
from rag.app.qa import Excel from rag.app.qa import Excel
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from common import globals
def beAdoc(d, q, a, eng, row_num=-1): def beAdoc(d, q, a, eng, row_num=-1):
@ -124,7 +125,6 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
def label_question(question, kbs): def label_question(question, kbs):
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from graphrag.utils import get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from api import settings
tags = None tags = None
tag_kb_ids = [] tag_kb_ids = []
for kb in kbs: for kb in kbs:
@ -133,14 +133,14 @@ def label_question(question, kbs):
if tag_kb_ids: if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids) all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags: if not all_tags:
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids) all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids) set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
else: else:
all_tags = json.loads(all_tags) all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids) tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
if not tag_kbs: if not tag_kbs:
return tags return tags
tags = settings.retriever.tag_query(question, tags = globals.retriever.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])), list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids, tag_kb_ids,
all_tags, all_tags,

View File

@ -20,10 +20,10 @@ import time
import argparse import argparse
from collections import defaultdict from collections import defaultdict
from common import globals
from common.constants import LLMType from common.constants import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api import settings
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from rag.nlp import tokenize, search from rag.nlp import tokenize, search
from ranx import evaluate from ranx import evaluate
@ -52,7 +52,7 @@ class Benchmark:
run = defaultdict(dict) run = defaultdict(dict)
query_list = list(qrels.keys()) query_list = list(qrels.keys())
for query in query_list: for query in query_list:
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight) 0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0: if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}") print(f"deleted query: {query}")
@ -77,9 +77,9 @@ class Benchmark:
def init_index(self, vector_size: int): def init_index(self, vector_size: int):
if self.initialized_index: if self.initialized_index:
return return
if settings.docStoreConn.indexExist(self.index_name, self.kb_id): if globals.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id) globals.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True self.initialized_index = True
def ms_marco_index(self, file_path, index_name): def ms_marco_index(self, file_path, index_name):
@ -114,13 +114,13 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id) globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = [] docs = []
if docs: if docs:
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id) globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts return qrels, texts
def trivia_qa_index(self, file_path, index_name): def trivia_qa_index(self, file_path, index_name):
@ -155,12 +155,12 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs,self.index_name) globals.docStoreConn.insert(docs,self.index_name)
docs = [] docs = []
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name) globals.docStoreConn.insert(docs, self.index_name)
return qrels, texts return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name): def miracl_index(self, file_path, corpus_path, index_name):
@ -210,12 +210,12 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name) globals.docStoreConn.insert(docs, self.index_name)
docs = [] docs = []
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name) globals.docStoreConn.insert(docs, self.index_name)
return qrels, texts return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path): def save_results(self, qrels, run, texts, dataset, file_path):

View File

@ -21,7 +21,7 @@ from functools import partial
import trio import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.base64_image import id2image, image2id from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream

View File

@ -27,7 +27,7 @@ from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.base64_image import image2id from rag.utils.base64_image import image2id
from deepdoc.parser import ExcelParser from deepdoc.parser import ExcelParser
from deepdoc.parser.mineru_parser import MinerUParser from deepdoc.parser.mineru_parser import MinerUParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser

View File

@ -18,7 +18,7 @@ from functools import partial
import trio import trio
from common.misc_utils import get_uuid from common.misc_utils import get_uuid
from common.base64_image import id2image, image2id from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream from rag.flow.splitter.schema import SplitterFromUpstream

View File

@ -29,7 +29,7 @@ from zhipuai import ZhipuAI
from common.log_utils import log_exception from common.log_utils import log_exception
from common.token_utils import num_tokens_from_string, truncate from common.token_utils import num_tokens_from_string, truncate
from api import settings from common import globals
import logging import logging
@ -69,13 +69,13 @@ class BuiltinEmbed(Base):
_model_lock = threading.Lock() _model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}") logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
embedding_cfg = settings.EMBEDDING_CFG embedding_cfg = globals.EMBEDDING_CFG
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""): if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
with BuiltinEmbed._model_lock: with BuiltinEmbed._model_lock:
BuiltinEmbed._model_name = settings.EMBEDDING_MDL BuiltinEmbed._model_name = globals.EMBEDDING_MDL
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500) BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500)
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"]) BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
self._model = BuiltinEmbed._model self._model = BuiltinEmbed._model
self._model_name = BuiltinEmbed._model_name self._model_name = BuiltinEmbed._model_name
self._max_tokens = BuiltinEmbed._max_tokens self._max_tokens = BuiltinEmbed._max_tokens

View File

@ -15,49 +15,12 @@
# #
import os import os
import logging import logging
from common.config_utils import get_base_config, decrypt_database_config
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common.misc_utils import pip_install_torch from common.misc_utils import pip_install_torch
# Server # Server
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf") RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
# Get storage type and document engine from system environment variables
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
ES = {}
INFINITY = {}
AZURE = {}
S3 = {}
MINIO = {}
OSS = {}
OS = {}
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
if DOC_ENGINE == 'elasticsearch':
ES = get_base_config("es", {})
elif DOC_ENGINE == 'opensearch':
OS = get_base_config("os", {})
elif DOC_ENGINE == 'infinity':
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
S3 = get_base_config("s3", {})
elif STORAGE_IMPL_TYPE == 'MINIO':
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
try:
REDIS = decrypt_database_config(name="redis")
except Exception:
try:
REDIS = get_base_config("redis", {})
except Exception:
REDIS = {}
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4)) DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16)) EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))

View File

@ -26,13 +26,12 @@ import traceback
from api.db.services.connector_service import SyncLogsService from api.db.services.connector_service import SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.log_utils import init_root_logger, get_project_base_directory from common.log_utils import init_root_logger
from api.utils.configs import show_configs from common.config_utils import show_configs
from common.data_source import BlobStorageConnector from common.data_source import BlobStorageConnector
import logging import logging
import os import os
from datetime import datetime, timezone from datetime import datetime, timezone
import tracemalloc
import signal import signal
import trio import trio
import faulthandler import faulthandler
@ -41,6 +40,7 @@ from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector from common.data_source.utils import load_all_docs_from_checkpoint_connector
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5")) MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS) task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
@ -51,11 +51,39 @@ class SyncBase:
self.conf = conf self.conf = conf
async def __call__(self, task: dict): async def __call__(self, task: dict):
SyncLogsService.start(task["id"]) SyncLogsService.start(task["id"], task["connector_id"])
try: try:
async with task_limiter: async with task_limiter:
with trio.fail_after(task["timeout_secs"]): with trio.fail_after(task["timeout_secs"]):
task["poll_range_start"] = await self._run(task) document_batch_generator = await self._generate(task)
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for document_batch in document_batch_generator:
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.S3,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
} for doc in document_batch]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized till {}".format(doc_num, next_update))
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
except Exception as ex: except Exception as ex:
msg = '\n'.join([ msg = '\n'.join([
''.join(traceback.format_exception_only(None, ex)).strip(), ''.join(traceback.format_exception_only(None, ex)).strip(),
@ -65,12 +93,12 @@ class SyncBase:
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"]) SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
async def _run(self, task: dict): async def _generate(self, task: dict):
raise NotImplementedError raise NotImplementedError
class S3(SyncBase): class S3(SyncBase):
async def _run(self, task: dict): async def _generate(self, task: dict):
self.connector = BlobStorageConnector( self.connector = BlobStorageConnector(
bucket_type=self.conf.get("bucket_type", "s3"), bucket_type=self.conf.get("bucket_type", "s3"),
bucket_name=self.conf["bucket_name"], bucket_name=self.conf["bucket_name"],
@ -85,40 +113,11 @@ class S3(SyncBase):
self.conf["bucket_name"], self.conf["bucket_name"],
begin_info begin_info
)) ))
doc_num = 0 return document_batch_generator
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for document_batch in document_batch_generator:
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.S3,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
} for doc in document_batch]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized from {}: {} {}".format(doc_num, self.conf.get("bucket_type", "s3"),
self.conf["bucket_name"],
begin_info
))
SyncLogsService.done(task["id"])
return next_update
class Confluence(SyncBase): class Confluence(SyncBase):
async def _run(self, task: dict): async def _generate(self, task: dict):
from common.data_source.interfaces import StaticCredentialsProvider from common.data_source.interfaces import StaticCredentialsProvider
from common.data_source.config import DocumentSource from common.data_source.config import DocumentSource
@ -156,85 +155,57 @@ class Confluence(SyncBase):
) )
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info)) logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
return document_generator
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for doc in document_generator:
min_update = doc.doc_updated_at if doc.doc_updated_at else next_update
max_update = doc.doc_updated_at if doc.doc_updated_at else next_update
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.CONFLUENCE,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
}]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENCE}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info))
SyncLogsService.done(task["id"])
return next_update
class Notion(SyncBase): class Notion(SyncBase):
async def __call__(self, task: dict): async def _generate(self, task: dict):
pass pass
class Discord(SyncBase): class Discord(SyncBase):
async def __call__(self, task: dict): async def _generate(self, task: dict):
pass pass
class Gmail(SyncBase): class Gmail(SyncBase):
async def __call__(self, task: dict): async def _generate(self, task: dict):
pass pass
class GoogleDriver(SyncBase): class GoogleDriver(SyncBase):
async def __call__(self, task: dict): async def _generate(self, task: dict):
pass pass
class Jira(SyncBase): class Jira(SyncBase):
async def __call__(self, task: dict): async def _generate(self, task: dict):
pass pass
class SharePoint(SyncBase): class SharePoint(SyncBase):
async def __call__(self, task: dict): async def _generate(self, task: dict):
pass pass
class Slack(SyncBase): class Slack(SyncBase):
async def __call__(self, task: dict): async def _generate(self, task: dict):
pass pass
class Teams(SyncBase): class Teams(SyncBase):
async def __call__(self, task: dict): async def _generate(self, task: dict):
pass pass
func_factory = { func_factory = {
FileSource.S3: S3, FileSource.S3: S3,
FileSource.NOTION: Notion, FileSource.NOTION: Notion,
@ -263,41 +234,6 @@ async def dispatch_tasks():
stop_event = threading.Event() stop_event = threading.Event()
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")
def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")
stop_event.set() stop_event.set()

View File

@ -27,9 +27,8 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from common.connection_utils import timeout from common.connection_utils import timeout
from common.base64_image import image2id from rag.utils.base64_image import image2id
from common.log_utils import init_root_logger from common.log_utils import init_root_logger
from common.file_utils import get_project_base_directory
from common.config_utils import show_configs from common.config_utils import show_configs
from graphrag.general.index import run_graphrag_for_kb from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
@ -45,7 +44,6 @@ import re
from functools import partial from functools import partial
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
from timeit import default_timer as timer from timeit import default_timer as timer
import tracemalloc
import signal import signal
import trio import trio
import exceptiongroup import exceptiongroup
@ -69,6 +67,8 @@ from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter from graphrag.utils import chat_limiter
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common import globals
BATCH_SIZE = 64 BATCH_SIZE = 64
@ -129,40 +129,6 @@ def signal_handler(sig, frame):
sys.exit(0) sys.exit(0)
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")
class TaskCanceledException(Exception): class TaskCanceledException(Exception):
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
@ -384,7 +350,7 @@ async def build_chunks(task, progress_callback):
examples = [] examples = []
all_tags = get_tags_from_cache(kb_ids) all_tags = get_tags_from_cache(kb_ids)
if not all_tags: if not all_tags:
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S) all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags) set_tags_to_cache(kb_ids, all_tags)
else: else:
all_tags = json.loads(all_tags) all_tags = json.loads(all_tags)
@ -397,7 +363,7 @@ async def build_chunks(task, progress_callback):
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else: else:
docs_to_tag.append(d) docs_to_tag.append(d)
@ -458,7 +424,7 @@ def build_TOC(task, docs, progress_callback):
def init_kb(row, vector_size: int): def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"]) idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
async def embedding(docs, mdl, parser_config=None, callback=None): async def embedding(docs, mdl, parser_config=None, callback=None):
@ -682,7 +648,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
chunks = [] chunks = []
vctr_nm = "q_%d_vec"%vector_size vctr_nm = "q_%d_vec"%vector_size
for doc_id in doc_ids: for doc_id in doc_ids:
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm], fields=["content_with_weight", vctr_nm],
sort_by_position=True): sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
@ -733,7 +699,7 @@ async def delete_image(kb_id, chunk_id):
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback): async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE): for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
@ -750,7 +716,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str) TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist: except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.") logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)) doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id) nursery.start_soon(delete_image, task_dataset_id, chunk_id)
@ -786,7 +752,7 @@ async def do_handle_task(task):
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user # FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
lower_case_doc_engine = settings.DOC_ENGINE.lower() lower_case_doc_engine = globals.DOC_ENGINE.lower()
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table': if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine." error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
progress_callback(-1, msg=error_message) progress_callback(-1, msg=error_message)
@ -1067,8 +1033,8 @@ async def main():
logging.info(f'RAGFlow version: {get_ragflow_version()}') logging.info(f'RAGFlow version: {get_ragflow_version()}')
show_configs() show_configs()
settings.init_settings() settings.init_settings()
from api.settings import EMBEDDING_CFG from common import globals
logging.info(f'api.settings.EMBEDDING_CFG: {EMBEDDING_CFG}') logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}')
print_rag_settings() print_rag_settings()
if sys.platform != "win32": if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot) signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)

View File

@ -18,17 +18,17 @@ import logging
import os import os
import time import time
from io import BytesIO from io import BytesIO
from rag import settings
from common.decorator import singleton from common.decorator import singleton
from azure.storage.blob import ContainerClient from azure.storage.blob import ContainerClient
from common import globals
@singleton @singleton
class RAGFlowAzureSasBlob: class RAGFlowAzureSasBlob:
def __init__(self): def __init__(self):
self.conn = None self.conn = None
self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"]) self.container_url = os.getenv('CONTAINER_URL', globals.AZURE["container_url"])
self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"]) self.sas_token = os.getenv('SAS_TOKEN', globals.AZURE["sas_token"])
self.__open__() self.__open__()
def __open__(self): def __open__(self):

View File

@ -17,21 +17,21 @@
import logging import logging
import os import os
import time import time
from rag import settings
from common.decorator import singleton from common.decorator import singleton
from azure.identity import ClientSecretCredential, AzureAuthorityHosts from azure.identity import ClientSecretCredential, AzureAuthorityHosts
from azure.storage.filedatalake import FileSystemClient from azure.storage.filedatalake import FileSystemClient
from common import globals
@singleton @singleton
class RAGFlowAzureSpnBlob: class RAGFlowAzureSpnBlob:
def __init__(self): def __init__(self):
self.conn = None self.conn = None
self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"]) self.account_url = os.getenv('ACCOUNT_URL', globals.AZURE["account_url"])
self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"]) self.client_id = os.getenv('CLIENT_ID', globals.AZURE["client_id"])
self.secret = os.getenv('SECRET', settings.AZURE["secret"]) self.secret = os.getenv('SECRET', globals.AZURE["secret"])
self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"]) self.tenant_id = os.getenv('TENANT_ID', globals.AZURE["tenant_id"])
self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"]) self.container_name = os.getenv('CONTAINER_NAME', globals.AZURE["container_name"])
self.__open__() self.__open__()
def __open__(self): def __open__(self):

View File

@ -24,7 +24,6 @@ import copy
from elasticsearch import Elasticsearch, NotFoundError from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout from elastic_transport import ConnectionTimeout
from rag import settings
from rag.settings import TAG_FLD, PAGERANK_FLD from rag.settings import TAG_FLD, PAGERANK_FLD
from common.decorator import singleton from common.decorator import singleton
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
@ -33,6 +32,7 @@ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr,
FusionExpr FusionExpr
from rag.nlp import is_english, rag_tokenizer from rag.nlp import is_english, rag_tokenizer
from common.float_utils import get_float from common.float_utils import get_float
from common import globals
ATTEMPT_TIME = 2 ATTEMPT_TIME = 2
@ -43,17 +43,17 @@ logger = logging.getLogger('ragflow.es_conn')
class ESConnection(DocStoreConnection): class ESConnection(DocStoreConnection):
def __init__(self): def __init__(self):
self.info = {} self.info = {}
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.") logger.info(f"Use Elasticsearch {globals.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME): for _ in range(ATTEMPT_TIME):
try: try:
if self._connect(): if self._connect():
break break
except Exception as e: except Exception as e:
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.") logger.warning(f"{str(e)}. Waiting Elasticsearch {globals.ES['hosts']} to be healthy.")
time.sleep(5) time.sleep(5)
if not self.es.ping(): if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s." msg = f"Elasticsearch {globals.ES['hosts']} is unhealthy in 120s."
logger.error(msg) logger.error(msg)
raise Exception(msg) raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"}) v = self.info.get("version", {"number": "8.11.3"})
@ -68,14 +68,14 @@ class ESConnection(DocStoreConnection):
logger.error(msg) logger.error(msg)
raise Exception(msg) raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r")) self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.") logger.info(f"Elasticsearch {globals.ES['hosts']} is healthy.")
def _connect(self): def _connect(self):
self.es = Elasticsearch( self.es = Elasticsearch(
settings.ES["hosts"].split(","), globals.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[ basic_auth=(globals.ES["username"], globals.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None, "password"]) if "username" in globals.ES and "password" in globals.ES else None,
verify_certs= settings.ES.get("verify_certs", False), verify_certs= globals.ES.get("verify_certs", False),
timeout=600 ) timeout=600 )
if self.es: if self.es:
self.info = self.es.info() self.info = self.es.info()

View File

@ -25,11 +25,11 @@ from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode from infinity.errors import ErrorCode
from rag import settings
from rag.settings import PAGERANK_FLD, TAG_FLD from rag.settings import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton from common.decorator import singleton
import pandas as pd import pandas as pd
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from common import globals
from rag.nlp import is_english from rag.nlp import is_english
from rag.utils.doc_store_conn import ( from rag.utils.doc_store_conn import (
@ -130,8 +130,8 @@ def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> p
@singleton @singleton
class InfinityConnection(DocStoreConnection): class InfinityConnection(DocStoreConnection):
def __init__(self): def __init__(self):
self.dbName = settings.INFINITY.get("db_name", "default_db") self.dbName = globals.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"] infinity_uri = globals.INFINITY["uri"]
if ":" in infinity_uri: if ":" in infinity_uri:
host, port = infinity_uri.split(":") host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port)) infinity_uri = infinity.common.NetworkAddress(host, int(port))

View File

@ -20,8 +20,8 @@ from minio import Minio
from minio.commonconfig import CopySource from minio.commonconfig import CopySource
from minio.error import S3Error from minio.error import S3Error
from io import BytesIO from io import BytesIO
from rag import settings
from common.decorator import singleton from common.decorator import singleton
from common import globals
@singleton @singleton
@ -38,14 +38,14 @@ class RAGFlowMinio:
pass pass
try: try:
self.conn = Minio(settings.MINIO["host"], self.conn = Minio(globals.MINIO["host"],
access_key=settings.MINIO["user"], access_key=globals.MINIO["user"],
secret_key=settings.MINIO["password"], secret_key=globals.MINIO["password"],
secure=False secure=False
) )
except Exception: except Exception:
logging.exception( logging.exception(
"Fail to connect %s " % settings.MINIO["host"]) "Fail to connect %s " % globals.MINIO["host"])
def __close__(self): def __close__(self):
del self.conn del self.conn

View File

@ -24,13 +24,13 @@ import copy
from opensearchpy import OpenSearch, NotFoundError from opensearchpy import OpenSearch, NotFoundError
from opensearchpy import UpdateByQuery, Q, Search, Index from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout from opensearchpy import ConnectionTimeout
from rag import settings
from rag.settings import TAG_FLD, PAGERANK_FLD from rag.settings import TAG_FLD, PAGERANK_FLD
from common.decorator import singleton from common.decorator import singleton
from common.file_utils import get_project_base_directory from common.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr FusionExpr
from rag.nlp import is_english, rag_tokenizer from rag.nlp import is_english, rag_tokenizer
from common import globals
ATTEMPT_TIME = 2 ATTEMPT_TIME = 2
@ -41,13 +41,13 @@ logger = logging.getLogger('ragflow.opensearch_conn')
class OSConnection(DocStoreConnection): class OSConnection(DocStoreConnection):
def __init__(self): def __init__(self):
self.info = {} self.info = {}
logger.info(f"Use OpenSearch {settings.OS['hosts']} as the doc engine.") logger.info(f"Use OpenSearch {globals.OS['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME): for _ in range(ATTEMPT_TIME):
try: try:
self.os = OpenSearch( self.os = OpenSearch(
settings.OS["hosts"].split(","), globals.OS["hosts"].split(","),
http_auth=(settings.OS["username"], settings.OS[ http_auth=(globals.OS["username"], globals.OS[
"password"]) if "username" in settings.OS and "password" in settings.OS else None, "password"]) if "username" in globals.OS and "password" in globals.OS else None,
verify_certs=False, verify_certs=False,
timeout=600 timeout=600
) )
@ -55,10 +55,10 @@ class OSConnection(DocStoreConnection):
self.info = self.os.info() self.info = self.os.info()
break break
except Exception as e: except Exception as e:
logger.warning(f"{str(e)}. Waiting OpenSearch {settings.OS['hosts']} to be healthy.") logger.warning(f"{str(e)}. Waiting OpenSearch {globals.OS['hosts']} to be healthy.")
time.sleep(5) time.sleep(5)
if not self.os.ping(): if not self.os.ping():
msg = f"OpenSearch {settings.OS['hosts']} is unhealthy in 120s." msg = f"OpenSearch {globals.OS['hosts']} is unhealthy in 120s."
logger.error(msg) logger.error(msg)
raise Exception(msg) raise Exception(msg)
v = self.info.get("version", {"number": "2.18.0"}) v = self.info.get("version", {"number": "2.18.0"})
@ -73,7 +73,7 @@ class OSConnection(DocStoreConnection):
logger.error(msg) logger.error(msg)
raise Exception(msg) raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r")) self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"OpenSearch {settings.OS['hosts']} is healthy.") logger.info(f"OpenSearch {globals.OS['hosts']} is healthy.")
""" """
Database operations Database operations

View File

@ -20,14 +20,14 @@ from botocore.config import Config
import time import time
from io import BytesIO from io import BytesIO
from common.decorator import singleton from common.decorator import singleton
from rag import settings from common import globals
@singleton @singleton
class RAGFlowOSS: class RAGFlowOSS:
def __init__(self): def __init__(self):
self.conn = None self.conn = None
self.oss_config = settings.OSS self.oss_config = globals.OSS
self.access_key = self.oss_config.get('access_key', None) self.access_key = self.oss_config.get('access_key', None)
self.secret_key = self.oss_config.get('secret_key', None) self.secret_key = self.oss_config.get('secret_key', None)
self.endpoint_url = self.oss_config.get('endpoint_url', None) self.endpoint_url = self.oss_config.get('endpoint_url', None)

View File

@ -19,8 +19,8 @@ import json
import uuid import uuid
import valkey as redis import valkey as redis
from rag import settings
from common.decorator import singleton from common.decorator import singleton
from common import globals
from valkey.lock import Lock from valkey.lock import Lock
import trio import trio
@ -61,7 +61,7 @@ class RedisDB:
def __init__(self): def __init__(self):
self.REDIS = None self.REDIS = None
self.config = settings.REDIS self.config = globals.REDIS
self.__open__() self.__open__()
def register_scripts(self) -> None: def register_scripts(self) -> None:

View File

@ -21,13 +21,14 @@ from botocore.config import Config
import time import time
from io import BytesIO from io import BytesIO
from common.decorator import singleton from common.decorator import singleton
from rag import settings from common import globals
@singleton @singleton
class RAGFlowS3: class RAGFlowS3:
def __init__(self): def __init__(self):
self.conn = None self.conn = None
self.s3_config = settings.S3 self.s3_config = globals.S3
self.access_key = self.s3_config.get('access_key', None) self.access_key = self.s3_config.get('access_key', None)
self.secret_key = self.s3_config.get('secret_key', None) self.secret_key = self.s3_config.get('secret_key', None)
self.session_token = self.s3_config.get('session_token', None) self.session_token = self.s3_config.get('session_token', None)

View File

@ -15,10 +15,10 @@ import {
} from '@/components/ui/form'; } from '@/components/ui/form';
import { Input } from '@/components/ui/input'; import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator'; import { Separator } from '@/components/ui/separator';
import { SwitchOperatorOptions } from '@/constants/agent';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request'; import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
import { SwitchOperatorOptions } from '@/pages/agent/constant';
import { PromptEditor } from '@/pages/agent/form/components/prompt-editor'; import { PromptEditor } from '@/pages/agent/form/components/prompt-editor';
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
import { Plus, X } from 'lucide-react'; import { Plus, X } from 'lucide-react';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form'; import { useFieldArray, useFormContext } from 'react-hook-form';

View File

@ -13,6 +13,7 @@ const ScrollBar = React.forwardRef<
ref={ref} ref={ref}
orientation={orientation} orientation={orientation}
className={cn( className={cn(
'z-[100]',
'flex touch-none select-none transition-colors', 'flex touch-none select-none transition-colors',
orientation === 'vertical' && orientation === 'vertical' &&
'h-full w-2.5 border-l border-l-transparent p-[1px]', 'h-full w-2.5 border-l border-l-transparent p-[1px]',
@ -22,7 +23,7 @@ const ScrollBar = React.forwardRef<
)} )}
{...props} {...props}
> >
<ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border" /> <ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border backdrop-blur-md" />
</ScrollAreaPrimitive.ScrollAreaScrollbar> </ScrollAreaPrimitive.ScrollAreaScrollbar>
)); ));
ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName; ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName;

View File

@ -1,4 +1,5 @@
import { setInitialChatVariableEnabledFieldValue } from '@/utils/chat'; import { setInitialChatVariableEnabledFieldValue } from '@/utils/chat';
import { Circle, CircleSlash2 } from 'lucide-react';
import { ChatVariableEnabledField, variableEnabledFieldMap } from './chat'; import { ChatVariableEnabledField, variableEnabledFieldMap } from './chat';
export enum ProgrammingLanguage { export enum ProgrammingLanguage {
@ -117,3 +118,53 @@ export enum Operator {
HierarchicalMerger = 'HierarchicalMerger', HierarchicalMerger = 'HierarchicalMerger',
Extractor = 'Extractor', Extractor = 'Extractor',
} }
export enum ComparisonOperator {
Equal = '=',
NotEqual = '≠',
GreatThan = '>',
GreatEqual = '≥',
LessThan = '<',
LessEqual = '≤',
Contains = 'contains',
NotContains = 'not contains',
StartWith = 'start with',
EndWith = 'end with',
Empty = 'empty',
NotEmpty = 'not empty',
}
export const SwitchOperatorOptions = [
{ value: ComparisonOperator.Equal, label: 'equal', icon: 'equal' },
{ value: ComparisonOperator.NotEqual, label: 'notEqual', icon: 'not-equals' },
{ value: ComparisonOperator.GreatThan, label: 'gt', icon: 'Less' },
{
value: ComparisonOperator.GreatEqual,
label: 'ge',
icon: 'Greater-or-equal',
},
{ value: ComparisonOperator.LessThan, label: 'lt', icon: 'Less' },
{ value: ComparisonOperator.LessEqual, label: 'le', icon: 'less-or-equal' },
{ value: ComparisonOperator.Contains, label: 'contains', icon: 'Contains' },
{
value: ComparisonOperator.NotContains,
label: 'notContains',
icon: 'not-contains',
},
{
value: ComparisonOperator.StartWith,
label: 'startWith',
icon: 'list-start',
},
{ value: ComparisonOperator.EndWith, label: 'endWith', icon: 'list-end' },
{
value: ComparisonOperator.Empty,
label: 'empty',
icon: <Circle className="size-4" />,
},
{
value: ComparisonOperator.NotEmpty,
label: 'notEmpty',
icon: <CircleSlash2 className="size-4" />,
},
];

View File

@ -0,0 +1,45 @@
import { IconFont } from '@/components/icon-font';
import { ComparisonOperator, SwitchOperatorOptions } from '@/constants/agent';
import { cn } from '@/lib/utils';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const LogicalOperatorIcon = function OperatorIcon({
icon,
value,
}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) {
if (typeof icon === 'string') {
return (
<IconFont
name={icon}
className={cn('size-4', {
'rotate-180': value === '>',
})}
></IconFont>
);
}
return icon;
};
export function useBuildSwitchOperatorOptions(
subset: ComparisonOperator[] = [],
) {
const { t } = useTranslation();
const switchOperatorOptions = useMemo(() => {
return SwitchOperatorOptions.filter((x) =>
subset.some((y) => y === x.value),
).map((x) => ({
value: x.value,
icon: (
<LogicalOperatorIcon
icon={x.icon}
value={x.value}
></LogicalOperatorIcon>
),
label: t(`flow.switchOperatorOptions.${x.label}`),
}));
}, [t]);
return switchOperatorOptions;
}

View File

@ -1980,6 +1980,10 @@ Important structured information may include: names, dates, locations, events, k
deleteRole: 'Delete role', deleteRole: 'Delete role',
deleteRoleConfirmation: deleteRoleConfirmation:
'Are you sure you want to delete this role? This action cannot be undone.', 'Are you sure you want to delete this role? This action cannot be undone.',
alive: 'Alive',
timeout: 'Timeout',
fail: 'Fail',
}, },
}, },
}; };

View File

@ -1,8 +1,11 @@
import Spotlight from '@/components/spotlight';
import { Card, CardContent } from '@/components/ui/card'; import { Card, CardContent } from '@/components/ui/card';
function AdminMonitoring() { function AdminMonitoring() {
return ( return (
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto"> <Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Spotlight />
<CardContent className="size-full p-0"> <CardContent className="size-full p-0">
<iframe /> <iframe />
</CardContent> </CardContent>

View File

@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import Spotlight from '@/components/spotlight';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { import {
@ -29,7 +30,6 @@ import {
import { LucideEdit3, LucideTrash2, LucideUserPlus } from 'lucide-react'; import { LucideEdit3, LucideTrash2, LucideUserPlus } from 'lucide-react';
import { import {
AdminService,
assignRolePermissions, assignRolePermissions,
createRole, createRole,
deleteRole, deleteRole,
@ -149,7 +149,9 @@ function AdminRoles() {
return ( return (
<> <>
<Card className="!shadow-none w-full h-full border border-border-button bg-transparent rounded-xl"> <Card className="!shadow-none relative w-full h-full border border-border-button bg-transparent rounded-xl">
<Spotlight />
<ScrollArea className="size-full"> <ScrollArea className="size-full">
<CardHeader className="space-y-0 flex flex-row justify-between items-center"> <CardHeader className="space-y-0 flex flex-row justify-between items-center">
<CardTitle>{t('admin.roles')}</CardTitle> <CardTitle>{t('admin.roles')}</CardTitle>

View File

@ -23,6 +23,7 @@ import { useQuery } from '@tanstack/react-query';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import Spotlight from '@/components/spotlight';
import { TableEmpty } from '@/components/table-skeleton'; import { TableEmpty } from '@/components/table-skeleton';
import { Badge } from '@/components/ui/badge'; import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
@ -151,11 +152,11 @@ function AdminServiceStatus() {
alive: 'bg-state-success-5 text-state-success', alive: 'bg-state-success-5 text-state-success',
timeout: 'bg-state-error-5 text-state-error', timeout: 'bg-state-error-5 text-state-error',
fail: 'bg-gray-500/5 text-text-disable', fail: 'bg-gray-500/5 text-text-disable',
}[cell.getValue<string>()], }[cell.getValue()],
)} )}
> >
<LucideDot className="size-[1em] stroke-[8] mr-1" /> <LucideDot className="size-[1em] stroke-[8] mr-1" />
{cell.getValue()} {t(`admin.${cell.getValue()}`)}
</Badge> </Badge>
), ),
enableSorting: false, enableSorting: false,
@ -215,7 +216,9 @@ function AdminServiceStatus() {
return ( return (
<> <>
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl"> <Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl">
<Spotlight />
<ScrollArea className="size-full"> <ScrollArea className="size-full">
<CardHeader className="space-y-0 flex flex-row justify-between items-center"> <CardHeader className="space-y-0 flex flex-row justify-between items-center">
<CardTitle>{t('admin.serviceStatus')}</CardTitle> <CardTitle>{t('admin.serviceStatus')}</CardTitle>
@ -252,7 +255,7 @@ function AdminServiceStatus() {
table.getColumn('service_type')!?.setFilterValue table.getColumn('service_type')!?.setFilterValue
} }
> >
<Label className="space-x-2"> <Label className="flex items-center space-x-2">
<RadioGroupItem <RadioGroupItem
className="bg-bg-input border-border-button" className="bg-bg-input border-border-button"
value="" value=""
@ -261,7 +264,10 @@ function AdminServiceStatus() {
</Label> </Label>
{SERVICE_TYPE_FILTER_OPTIONS.map(({ label, value }) => ( {SERVICE_TYPE_FILTER_OPTIONS.map(({ label, value }) => (
<Label key={value} className="space-x-2"> <Label
key={value}
className="flex items-center space-x-2"
>
<RadioGroupItem <RadioGroupItem
className="bg-bg-input border-border-button" className="bg-bg-input border-border-button"
value={value} value={value}

View File

@ -16,6 +16,7 @@ import {
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { Routes } from '@/routes'; import { Routes } from '@/routes';
import Spotlight from '@/components/spotlight';
import { Avatar } from '@/components/ui/avatar'; import { Avatar } from '@/components/ui/avatar';
import { Badge } from '@/components/ui/badge'; import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
@ -322,7 +323,9 @@ function AdminUserDetail() {
</Button> </Button>
</nav> </nav>
<Card className="!shadow-none h-0 basis-0 grow flex flex-col bg-transparent border dark:border-border-button overflow-hidden"> <Card className="!shadow-none relative h-0 basis-0 grow flex flex-col bg-transparent border dark:border-border-button overflow-hidden">
<Spotlight />
<CardHeader className="pb-10 border-b dark:border-border-button space-y-8"> <CardHeader className="pb-10 border-b dark:border-border-button space-y-8">
<section className="flex items-center gap-4 text-base"> <section className="flex items-center gap-4 text-base">
<Avatar className="justify-center items-center bg-bg-group uppercase"> <Avatar className="justify-center items-center bg-bg-group uppercase">

View File

@ -25,6 +25,7 @@ import {
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { rsaPsw } from '@/utils'; import { rsaPsw } from '@/utils';
import Spotlight from '@/components/spotlight';
import { TableEmpty } from '@/components/table-skeleton'; import { TableEmpty } from '@/components/table-skeleton';
import { Badge } from '@/components/ui/badge'; import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
@ -357,7 +358,9 @@ function AdminUserManagement() {
return ( return (
<> <>
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto"> <Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Spotlight />
<ScrollArea className="size-full"> <ScrollArea className="size-full">
<CardHeader className="space-y-0 flex flex-row justify-between items-center"> <CardHeader className="space-y-0 flex flex-row justify-between items-center">
<CardTitle>{t('admin.userManagement')}</CardTitle> <CardTitle>{t('admin.userManagement')}</CardTitle>
@ -396,13 +399,16 @@ function AdminUserManagement() {
table.getColumn('role')?.setFilterValue(value) table.getColumn('role')?.setFilterValue(value)
} }
> >
<Label className="space-x-2"> <Label className="flex items-center space-x-2">
<RadioGroupItem value="" /> <RadioGroupItem value="" />
<span>{t('admin.all')}</span> <span>{t('admin.all')}</span>
</Label> </Label>
{roleList?.map(({ id, role_name }) => ( {roleList?.map(({ id, role_name }) => (
<Label key={id} className="space-x-2"> <Label
key={id}
className="flex items-center space-x-2"
>
<RadioGroupItem <RadioGroupItem
className="bg-bg-input border-border-button" className="bg-bg-input border-border-button"
value={role_name} value={role_name}
@ -429,7 +435,10 @@ function AdminUserManagement() {
} }
> >
{STATUS_FILTER_OPTIONS.map(({ label, value }) => ( {STATUS_FILTER_OPTIONS.map(({ label, value }) => (
<Label key={value} className="space-x-2"> <Label
key={value}
className="flex items-center space-x-2"
>
<RadioGroupItem <RadioGroupItem
className="bg-bg-input border-border-button" className="bg-bg-input border-border-button"
value={value} value={value}

View File

@ -22,6 +22,7 @@ import {
LucideUserPen, LucideUserPen,
} from 'lucide-react'; } from 'lucide-react';
import Spotlight from '@/components/spotlight';
import { TableEmpty } from '@/components/table-skeleton'; import { TableEmpty } from '@/components/table-skeleton';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import { import {
@ -58,7 +59,6 @@ import {
importWhitelistFromExcel, importWhitelistFromExcel,
listWhitelist, listWhitelist,
updateWhitelistEntry, updateWhitelistEntry,
type AdminService,
} from '@/services/admin-service'; } from '@/services/admin-service';
import { EMPTY_DATA, createFuzzySearchFn, getSortIcon } from './utils'; import { EMPTY_DATA, createFuzzySearchFn, getSortIcon } from './utils';
@ -68,8 +68,6 @@ import useImportExcelForm, {
ImportExcelFormData, ImportExcelFormData,
} from './forms/import-excel-form'; } from './forms/import-excel-form';
// #endregion
const columnHelper = createColumnHelper<AdminService.ListWhitelistItem>(); const columnHelper = createColumnHelper<AdminService.ListWhitelistItem>();
const globalFilterFn = createFuzzySearchFn<AdminService.ListWhitelistItem>([ const globalFilterFn = createFuzzySearchFn<AdminService.ListWhitelistItem>([
'email', 'email',
@ -233,7 +231,9 @@ function AdminWhitelist() {
return ( return (
<> <>
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto"> <Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Spotlight />
<ScrollArea className="size-full"> <ScrollArea className="size-full">
<CardHeader className="space-y-0 flex flex-row justify-between items-center"> <CardHeader className="space-y-0 flex flex-row justify-between items-center">
<CardTitle>{t('admin.whitelistManagement')}</CardTitle> <CardTitle>{t('admin.whitelistManagement')}</CardTitle>

View File

@ -1,9 +1,9 @@
import { Card, CardContent } from '@/components/ui/card'; import { Card, CardContent } from '@/components/ui/card';
import { SwitchOperatorOptions } from '@/constants/agent';
import { LogicalOperatorIcon } from '@/hooks/logic-hooks/use-build-operator-options';
import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow'; import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow';
import { NodeProps, Position } from '@xyflow/react'; import { NodeProps, Position } from '@xyflow/react';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { SwitchOperatorOptions } from '../../constant';
import { LogicalOperatorIcon } from '../../form/switch-form';
import { useGetVariableLabelByValue } from '../../hooks/use-get-begin-query'; import { useGetVariableLabelByValue } from '../../hooks/use-get-begin-query';
import { CommonHandle, LeftEndHandle } from './handle'; import { CommonHandle, LeftEndHandle } from './handle';
import { RightHandleStyle } from './handle-icon'; import { RightHandleStyle } from './handle-icon';

View File

@ -6,8 +6,10 @@ import {
AgentGlobals, AgentGlobals,
AgentGlobalsSysQueryWithBrace, AgentGlobalsSysQueryWithBrace,
CodeTemplateStrMap, CodeTemplateStrMap,
ComparisonOperator,
Operator, Operator,
ProgrammingLanguage, ProgrammingLanguage,
SwitchOperatorOptions,
initialLlmBaseValues, initialLlmBaseValues,
} from '@/constants/agent'; } from '@/constants/agent';
export { Operator } from '@/constants/agent'; export { Operator } from '@/constants/agent';
@ -35,8 +37,6 @@ export enum PromptRole {
} }
import { import {
Circle,
CircleSlash2,
CloudUpload, CloudUpload,
ListOrdered, ListOrdered,
OptionIcon, OptionIcon,
@ -166,27 +166,12 @@ export const componentMenuList = [
}, },
]; ];
export const SwitchOperatorOptions = [ export const DataOperationsOperatorOptions = [
{ value: '=', label: 'equal', icon: 'equal' }, ComparisonOperator.Equal,
{ value: '≠', label: 'notEqual', icon: 'not-equals' }, ComparisonOperator.NotEqual,
{ value: '>', label: 'gt', icon: 'Less' }, ComparisonOperator.Contains,
{ value: '≥', label: 'ge', icon: 'Greater-or-equal' }, ComparisonOperator.StartWith,
{ value: '<', label: 'lt', icon: 'Less' }, ComparisonOperator.EndWith,
{ value: '≤', label: 'le', icon: 'less-or-equal' },
{ value: 'contains', label: 'contains', icon: 'Contains' },
{ value: 'not contains', label: 'notContains', icon: 'not-contains' },
{ value: 'start with', label: 'startWith', icon: 'list-start' },
{ value: 'end with', label: 'endWith', icon: 'list-end' },
{
value: 'empty',
label: 'empty',
icon: <Circle className="size-4" />,
},
{
value: 'not empty',
label: 'notEmpty',
icon: <CircleSlash2 className="size-4" />,
},
]; ];
export const SwitchElseTo = 'end_cpn_ids'; export const SwitchElseTo = 'end_cpn_ids';
@ -716,16 +701,17 @@ export const initialPlaceholderValues = {
}; };
export enum Operations { export enum Operations {
SelectKeys = 'select keys', SelectKeys = 'select_keys',
LiteralEval = 'literal eval', LiteralEval = 'literal_eval',
Combine = 'combine', Combine = 'combine',
FilterValues = 'filter values', FilterValues = 'filter_values',
AppendOrUpdate = 'append or update', AppendOrUpdate = 'append_or_update',
RemoveKeys = 'remove keys', RemoveKeys = 'remove_keys',
RenameKeys = 'rename keys', RenameKeys = 'rename_keys',
} }
export const initialDataOperationsValues = { export const initialDataOperationsValues = {
query: [],
operations: Operations.SelectKeys, operations: Operations.SelectKeys,
outputs: { outputs: {
result: { result: {

View File

@ -0,0 +1,25 @@
import { Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Plus } from 'lucide-react';
import { ReactNode } from 'react';
export type FormListHeaderProps = {
label: ReactNode;
tooltip?: string;
onClick?: () => void;
};
export function DynamicFormHeader({
label,
tooltip,
onClick,
}: FormListHeaderProps) {
return (
<div className="flex items-center justify-between">
<FormLabel tooltip={tooltip}>{label}</FormLabel>
<Button variant={'ghost'} type="button" onClick={onClick}>
<Plus />
</Button>
</div>
);
}

View File

@ -1,4 +1,7 @@
import { RAGFlowFormItem } from '@/components/ragflow-form';
import { Input } from '@/components/ui/input';
import { t } from 'i18next'; import { t } from 'i18next';
import { z } from 'zod';
export type OutputType = { export type OutputType = {
title: string; title: string;
@ -7,6 +10,7 @@ export type OutputType = {
type OutputProps = { type OutputProps = {
list: Array<OutputType>; list: Array<OutputType>;
isFormRequired?: boolean;
}; };
export function transferOutputs(outputs: Record<string, any>) { export function transferOutputs(outputs: Record<string, any>) {
@ -16,7 +20,11 @@ export function transferOutputs(outputs: Record<string, any>) {
})); }));
} }
export function Output({ list }: OutputProps) { export const OutputSchema = {
outputs: z.record(z.any()),
};
export function Output({ list, isFormRequired = false }: OutputProps) {
return ( return (
<section className="space-y-2"> <section className="space-y-2">
<div className="text-sm">{t('flow.output')}</div> <div className="text-sm">{t('flow.output')}</div>
@ -30,6 +38,11 @@ export function Output({ list }: OutputProps) {
</li> </li>
))} ))}
</ul> </ul>
{isFormRequired && (
<RAGFlowFormItem name="outputs" className="hidden">
<Input></Input>
</RAGFlowFormItem>
)}
</section> </section>
); );
} }

View File

@ -1,17 +1,20 @@
import { BlockButton, Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import { X } from 'lucide-react'; import { X } from 'lucide-react';
import { useFieldArray, useFormContext } from 'react-hook-form'; import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { JsonSchemaDataType } from '../../constant'; import { JsonSchemaDataType } from '../../constant';
import { DynamicFormHeader, FormListHeaderProps } from './dynamic-fom-header';
import { QueryVariable } from './query-variable'; import { QueryVariable } from './query-variable';
type QueryVariableListProps = { type QueryVariableListProps = {
types?: JsonSchemaDataType[]; types?: JsonSchemaDataType[];
}; } & FormListHeaderProps;
export function QueryVariableList({ types }: QueryVariableListProps) { export function QueryVariableList({
const { t } = useTranslation(); types,
label,
tooltip,
}: QueryVariableListProps) {
const form = useFormContext(); const form = useFormContext();
const name = 'inputs'; const name = 'query';
const { fields, remove, append } = useFieldArray({ const { fields, remove, append } = useFieldArray({
name: name, name: name,
@ -19,28 +22,31 @@ export function QueryVariableList({ types }: QueryVariableListProps) {
}); });
return ( return (
<div className="space-y-5"> <section className="space-y-2">
{fields.map((field, index) => { <DynamicFormHeader
const nameField = `${name}.${index}.input`; label={label}
tooltip={tooltip}
onClick={() => append({ input: '' })}
></DynamicFormHeader>
<div className="space-y-5">
{fields.map((field, index) => {
const nameField = `${name}.${index}.input`;
return ( return (
<div key={field.id} className="flex items-center gap-2"> <div key={field.id} className="flex items-center gap-2">
<QueryVariable <QueryVariable
name={nameField} name={nameField}
hideLabel hideLabel
className="flex-1" className="flex-1"
types={types} types={types}
></QueryVariable> ></QueryVariable>
<Button variant={'ghost'} onClick={() => remove(index)}> <Button variant={'ghost'} onClick={() => remove(index)}>
<X className="text-text-sub-title-invert " /> <X className="text-text-sub-title-invert " />
</Button> </Button>
</div> </div>
); );
})} })}
</div>
<BlockButton onClick={() => append({ input: '' })}> </section>
{t('common.add')}
</BlockButton>
</div>
); );
} }

View File

@ -1,13 +1,14 @@
import { SelectWithSearch } from '@/components/originui/select-with-search'; import { SelectWithSearch } from '@/components/originui/select-with-search';
import { RAGFlowFormItem } from '@/components/ragflow-form'; import { RAGFlowFormItem } from '@/components/ragflow-form';
import { BlockButton, Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Input } from '@/components/ui/input'; import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator'; import { Separator } from '@/components/ui/separator';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { X } from 'lucide-react'; import { X } from 'lucide-react';
import { ReactNode } from 'react'; import { ReactNode } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form'; import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { DataOperationsOperatorOptions } from '../../constant';
import { DynamicFormHeader } from '../components/dynamic-fom-header';
type SelectKeysProps = { type SelectKeysProps = {
name: string; name: string;
@ -25,7 +26,6 @@ export function FilterValues({
valueField = 'value', valueField = 'value',
operatorField = 'operator', operatorField = 'operator',
}: SelectKeysProps) { }: SelectKeysProps) {
const { t } = useTranslation();
const form = useFormContext(); const form = useFormContext();
const { fields, remove, append } = useFieldArray({ const { fields, remove, append } = useFieldArray({
@ -33,9 +33,18 @@ export function FilterValues({
control: form.control, control: form.control,
}); });
const operatorOptions = useBuildSwitchOperatorOptions(
DataOperationsOperatorOptions,
);
return ( return (
<section className="space-y-2"> <section className="space-y-2">
<FormLabel tooltip={tooltip}>{label}</FormLabel> <DynamicFormHeader
label={label}
tooltip={tooltip}
onClick={() => append({ [keyField]: '', [valueField]: '' })}
></DynamicFormHeader>
<div className="space-y-5"> <div className="space-y-5">
{fields.map((field, index) => { {fields.map((field, index) => {
const keyFieldAlias = `${name}.${index}.${keyField}`; const keyFieldAlias = `${name}.${index}.${keyField}`;
@ -47,12 +56,15 @@ export function FilterValues({
<RAGFlowFormItem name={keyFieldAlias} className="flex-1"> <RAGFlowFormItem name={keyFieldAlias} className="flex-1">
<Input></Input> <Input></Input>
</RAGFlowFormItem> </RAGFlowFormItem>
<Separator orientation="vertical" className="h-2.5" /> <Separator className="w-2" />
<RAGFlowFormItem name={operatorFieldAlias} className="flex-1"> <RAGFlowFormItem name={operatorFieldAlias} className="flex-1">
<SelectWithSearch {...field} options={[]}></SelectWithSearch> <SelectWithSearch
{...field}
options={operatorOptions}
></SelectWithSearch>
</RAGFlowFormItem> </RAGFlowFormItem>
<Separator orientation="vertical" className="h-2.5" /> <Separator className="w-2" />
<RAGFlowFormItem name={valueFieldAlias} className="flex-1"> <RAGFlowFormItem name={valueFieldAlias} className="flex-1">
<Input></Input> <Input></Input>
@ -64,10 +76,6 @@ export function FilterValues({
); );
})} })}
</div> </div>
<BlockButton onClick={() => append({ [keyField]: '', [valueField]: '' })}>
{t('common.add')}
</BlockButton>
</section> </section>
); );
} }

View File

@ -1,6 +1,7 @@
import { SelectWithSearch } from '@/components/originui/select-with-search'; import { SelectWithSearch } from '@/components/originui/select-with-search';
import { RAGFlowFormItem } from '@/components/ragflow-form'; import { RAGFlowFormItem } from '@/components/ragflow-form';
import { Form, FormLabel } from '@/components/ui/form'; import { Form } from '@/components/ui/form';
import { Separator } from '@/components/ui/separator';
import { buildOptions } from '@/utils/form'; import { buildOptions } from '@/utils/form';
import { zodResolver } from '@hookform/resolvers/zod'; import { zodResolver } from '@hookform/resolvers/zod';
import { memo } from 'react'; import { memo } from 'react';
@ -17,14 +18,14 @@ import { useWatchFormChange } from '../../hooks/use-watch-form-change';
import { INextOperatorForm } from '../../interface'; import { INextOperatorForm } from '../../interface';
import { buildOutputList } from '../../utils/build-output-list'; import { buildOutputList } from '../../utils/build-output-list';
import { FormWrapper } from '../components/form-wrapper'; import { FormWrapper } from '../components/form-wrapper';
import { Output } from '../components/output'; import { Output, OutputSchema } from '../components/output';
import { QueryVariableList } from '../components/query-variable-list'; import { QueryVariableList } from '../components/query-variable-list';
import { FilterValues } from './filter-values'; import { FilterValues } from './filter-values';
import { SelectKeys } from './select-keys'; import { SelectKeys } from './select-keys';
import { Updates } from './updates'; import { Updates } from './updates';
export const RetrievalPartialSchema = { export const RetrievalPartialSchema = {
inputs: z.array(z.object({ input: z.string().optional() })), query: z.array(z.object({ input: z.string().optional() })),
operations: z.string(), operations: z.string(),
select_keys: z.array(z.object({ name: z.string().optional() })).optional(), select_keys: z.array(z.object({ name: z.string().optional() })).optional(),
remove_keys: z.array(z.object({ name: z.string().optional() })).optional(), remove_keys: z.array(z.object({ name: z.string().optional() })).optional(),
@ -50,6 +51,7 @@ export const RetrievalPartialSchema = {
}), }),
) )
.optional(), .optional(),
...OutputSchema,
}; };
export const FormSchema = z.object(RetrievalPartialSchema); export const FormSchema = z.object(RetrievalPartialSchema);
@ -65,6 +67,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
const form = useForm<DataOperationsFormSchemaType>({ const form = useForm<DataOperationsFormSchemaType>({
defaultValues: defaultValues, defaultValues: defaultValues,
mode: 'onChange',
resolver: zodResolver(FormSchema), resolver: zodResolver(FormSchema),
shouldUnregister: true, shouldUnregister: true,
}); });
@ -78,17 +81,17 @@ function DataOperationsForm({ node }: INextOperatorForm) {
true, true,
); );
useWatchFormChange(node?.id, form); useWatchFormChange(node?.id, form, true);
return ( return (
<Form {...form}> <Form {...form}>
<FormWrapper> <FormWrapper>
<div className="space-y-2"> <QueryVariableList
<FormLabel tooltip={t('flow.queryTip')}>{t('flow.query')}</FormLabel> tooltip={t('flow.queryTip')}
<QueryVariableList label={t('flow.query')}
types={[JsonSchemaDataType.Array, JsonSchemaDataType.Object]} types={[JsonSchemaDataType.Array, JsonSchemaDataType.Object]}
></QueryVariableList> ></QueryVariableList>
</div> <Separator />
<RAGFlowFormItem name="operations" label={t('flow.operations')}> <RAGFlowFormItem name="operations" label={t('flow.operations')}>
<SelectWithSearch options={OperationsOptions} allowClear /> <SelectWithSearch options={OperationsOptions} allowClear />
</RAGFlowFormItem> </RAGFlowFormItem>
@ -107,7 +110,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
{operations === Operations.AppendOrUpdate && ( {operations === Operations.AppendOrUpdate && (
<Updates <Updates
name="updates" name="updates"
label={t('flow.operationsOptions.updates')} label={t('flow.operationsOptions.appendOrUpdate')}
keyField="key" keyField="key"
valueField="value" valueField="value"
></Updates> ></Updates>
@ -126,7 +129,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
label={t('flow.operationsOptions.filterValues')} label={t('flow.operationsOptions.filterValues')}
></FilterValues> ></FilterValues>
)} )}
<Output list={outputList}></Output> <Output list={outputList} isFormRequired></Output>
</FormWrapper> </FormWrapper>
</Form> </Form>
); );

View File

@ -1,11 +1,10 @@
import { RAGFlowFormItem } from '@/components/ragflow-form'; import { RAGFlowFormItem } from '@/components/ragflow-form';
import { BlockButton, Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Input } from '@/components/ui/input'; import { Input } from '@/components/ui/input';
import { X } from 'lucide-react'; import { X } from 'lucide-react';
import { ReactNode } from 'react'; import { ReactNode } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form'; import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { DynamicFormHeader } from '../components/dynamic-fom-header';
type SelectKeysProps = { type SelectKeysProps = {
name: string; name: string;
@ -13,7 +12,6 @@ type SelectKeysProps = {
tooltip?: string; tooltip?: string;
}; };
export function SelectKeys({ name, label, tooltip }: SelectKeysProps) { export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
const { t } = useTranslation();
const form = useFormContext(); const form = useFormContext();
const { fields, remove, append } = useFieldArray({ const { fields, remove, append } = useFieldArray({
@ -23,7 +21,11 @@ export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
return ( return (
<section className="space-y-2"> <section className="space-y-2">
<FormLabel tooltip={tooltip}>{label}</FormLabel> <DynamicFormHeader
label={label}
tooltip={tooltip}
onClick={() => append({ name: '' })}
></DynamicFormHeader>
<div className="space-y-5"> <div className="space-y-5">
{fields.map((field, index) => { {fields.map((field, index) => {
const nameField = `${name}.${index}.name`; const nameField = `${name}.${index}.name`;
@ -40,10 +42,6 @@ export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
); );
})} })}
</div> </div>
<BlockButton onClick={() => append({ name: '' })}>
{t('common.add')}
</BlockButton>
</section> </section>
); );
} }

View File

@ -1,11 +1,11 @@
import { RAGFlowFormItem } from '@/components/ragflow-form'; import { RAGFlowFormItem } from '@/components/ragflow-form';
import { BlockButton, Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Input } from '@/components/ui/input'; import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator';
import { X } from 'lucide-react'; import { X } from 'lucide-react';
import { ReactNode } from 'react'; import { ReactNode } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form'; import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { DynamicFormHeader } from '../components/dynamic-fom-header';
type SelectKeysProps = { type SelectKeysProps = {
name: string; name: string;
@ -21,7 +21,6 @@ export function Updates({
keyField, keyField,
valueField, valueField,
}: SelectKeysProps) { }: SelectKeysProps) {
const { t } = useTranslation();
const form = useFormContext(); const form = useFormContext();
const { fields, remove, append } = useFieldArray({ const { fields, remove, append } = useFieldArray({
@ -31,7 +30,11 @@ export function Updates({
return ( return (
<section className="space-y-2"> <section className="space-y-2">
<FormLabel tooltip={tooltip}>{label}</FormLabel> <DynamicFormHeader
label={label}
tooltip={tooltip}
onClick={() => append({ [keyField]: '', [valueField]: '' })}
></DynamicFormHeader>
<div className="space-y-5"> <div className="space-y-5">
{fields.map((field, index) => { {fields.map((field, index) => {
const keyFieldAlias = `${name}.${index}.${keyField}`; const keyFieldAlias = `${name}.${index}.${keyField}`;
@ -42,6 +45,7 @@ export function Updates({
<RAGFlowFormItem name={keyFieldAlias} className="flex-1"> <RAGFlowFormItem name={keyFieldAlias} className="flex-1">
<Input></Input> <Input></Input>
</RAGFlowFormItem> </RAGFlowFormItem>
<Separator className="w-2" />
<RAGFlowFormItem name={valueFieldAlias} className="flex-1"> <RAGFlowFormItem name={valueFieldAlias} className="flex-1">
<Input></Input> <Input></Input>
</RAGFlowFormItem> </RAGFlowFormItem>
@ -52,10 +56,6 @@ export function Updates({
); );
})} })}
</div> </div>
<BlockButton onClick={() => append({ [keyField]: '', [valueField]: '' })}>
{t('common.add')}
</BlockButton>
</section> </section>
); );
} }

View File

@ -1,5 +1,4 @@
import { FormContainer } from '@/components/form-container'; import { FormContainer } from '@/components/form-container';
import { IconFont } from '@/components/icon-font';
import { SelectWithSearch } from '@/components/originui/select-with-search'; import { SelectWithSearch } from '@/components/originui/select-with-search';
import { BlockButton, Button } from '@/components/ui/button'; import { BlockButton, Button } from '@/components/ui/button';
import { Card, CardContent } from '@/components/ui/card'; import { Card, CardContent } from '@/components/ui/card';
@ -13,6 +12,7 @@ import {
import { RAGFlowSelect } from '@/components/ui/select'; import { RAGFlowSelect } from '@/components/ui/select';
import { Separator } from '@/components/ui/separator'; import { Separator } from '@/components/ui/separator';
import { Textarea } from '@/components/ui/textarea'; import { Textarea } from '@/components/ui/textarea';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { zodResolver } from '@hookform/resolvers/zod'; import { zodResolver } from '@hookform/resolvers/zod';
import { t } from 'i18next'; import { t } from 'i18next';
@ -22,11 +22,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useFieldArray, useForm, useFormContext } from 'react-hook-form'; import { useFieldArray, useForm, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { z } from 'zod'; import { z } from 'zod';
import { import { SwitchLogicOperatorOptions, VariableType } from '../../constant';
SwitchLogicOperatorOptions,
SwitchOperatorOptions,
VariableType,
} from '../../constant';
import { useBuildQueryVariableOptions } from '../../hooks/use-get-begin-query'; import { useBuildQueryVariableOptions } from '../../hooks/use-get-begin-query';
import { IOperatorForm } from '../../interface'; import { IOperatorForm } from '../../interface';
import { FormWrapper } from '../components/form-wrapper'; import { FormWrapper } from '../components/form-wrapper';
@ -43,42 +39,6 @@ type ConditionCardsProps = {
parentLength: number; parentLength: number;
} & IOperatorForm; } & IOperatorForm;
export const LogicalOperatorIcon = function OperatorIcon({
icon,
value,
}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) {
if (typeof icon === 'string') {
return (
<IconFont
name={icon}
className={cn('size-4', {
'rotate-180': value === '>',
})}
></IconFont>
);
}
return icon;
};
export function useBuildSwitchOperatorOptions() {
const { t } = useTranslation();
const switchOperatorOptions = useMemo(() => {
return SwitchOperatorOptions.map((x) => ({
value: x.value,
icon: (
<LogicalOperatorIcon
icon={x.icon}
value={x.value}
></LogicalOperatorIcon>
),
label: t(`flow.switchOperatorOptions.${x.label}`),
}));
}, [t]);
return switchOperatorOptions;
}
function ConditionCards({ function ConditionCards({
name: parentName, name: parentName,
parentIndex, parentIndex,

View File

@ -2,9 +2,13 @@ import { useEffect } from 'react';
import { UseFormReturn, useWatch } from 'react-hook-form'; import { UseFormReturn, useWatch } from 'react-hook-form';
import useGraphStore from '../store'; import useGraphStore from '../store';
export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) { export function useWatchFormChange(
id?: string,
form?: UseFormReturn<any>,
enableReplacement = false,
) {
let values = useWatch({ control: form?.control }); let values = useWatch({ control: form?.control });
const updateNodeForm = useGraphStore((state) => state.updateNodeForm); const { updateNodeForm, replaceNodeForm } = useGraphStore((state) => state);
useEffect(() => { useEffect(() => {
// Manually triggered form updates are synchronized to the canvas // Manually triggered form updates are synchronized to the canvas
@ -12,7 +16,7 @@ export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) {
values = form?.getValues() || {}; values = form?.getValues() || {};
let nextValues: any = values; let nextValues: any = values;
updateNodeForm(id, nextValues); (enableReplacement ? replaceNodeForm : updateNodeForm)(id, nextValues);
} }
}, [form?.formState.isDirty, id, updateNodeForm, values]); }, [form?.formState.isDirty, id, updateNodeForm, values]);
} }

View File

@ -14,7 +14,7 @@ import {
applyEdgeChanges, applyEdgeChanges,
applyNodeChanges, applyNodeChanges,
} from '@xyflow/react'; } from '@xyflow/react';
import { omit } from 'lodash'; import { cloneDeep, omit } from 'lodash';
import differenceWith from 'lodash/differenceWith'; import differenceWith from 'lodash/differenceWith';
import intersectionWith from 'lodash/intersectionWith'; import intersectionWith from 'lodash/intersectionWith';
import lodashSet from 'lodash/set'; import lodashSet from 'lodash/set';
@ -53,6 +53,7 @@ export type RFState = {
values: any, values: any,
path?: (string | number)[], path?: (string | number)[],
) => RAGFlowNodeType[]; ) => RAGFlowNodeType[];
replaceNodeForm: (nodeId: string, values: any) => void;
onSelectionChange: OnSelectionChangeFunc; onSelectionChange: OnSelectionChangeFunc;
addNode: (nodes: RAGFlowNodeType) => void; addNode: (nodes: RAGFlowNodeType) => void;
getNode: (id?: string | null) => RAGFlowNodeType | undefined; getNode: (id?: string | null) => RAGFlowNodeType | undefined;
@ -433,6 +434,19 @@ const useGraphStore = create<RFState>()(
return nextNodes; return nextNodes;
}, },
replaceNodeForm(nodeId, values) {
if (nodeId) {
set((state) => {
for (const node of state.nodes) {
if (node.id === nodeId) {
//cloneDeep Solving the issue of react-hook-form errors
node.data.form = cloneDeep(values); // TypeError: Cannot assign to read only property '0' of object '[object Array]'
break;
}
}
});
}
},
updateSwitchFormData: (source, sourceHandle, target, isConnecting) => { updateSwitchFormData: (source, sourceHandle, target, isConnecting) => {
const { updateNodeForm, edges } = get(); const { updateNodeForm, edges } = get();
if (sourceHandle) { if (sourceHandle) {

View File

@ -273,7 +273,7 @@ function transformDataOperationsParams(params: DataOperationsFormSchemaType) {
...params, ...params,
select_keys: params?.select_keys?.map((x) => x.name), select_keys: params?.select_keys?.map((x) => x.name),
remove_keys: params?.remove_keys?.map((x) => x.name), remove_keys: params?.remove_keys?.map((x) => x.name),
inputs: params.inputs.map((x) => x.input), query: params.query.map((x) => x.input),
}; };
} }

View File

@ -1,6 +1,6 @@
import { SwitchOperatorOptions } from '@/constants/agent';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request'; import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
import { SwitchOperatorOptions } from '@/pages/agent/constant';
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons';
import { import {
Button, Button,

View File

@ -15,9 +15,9 @@ import {
} from '@/components/ui/form'; } from '@/components/ui/form';
import { Input } from '@/components/ui/input'; import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator'; import { Separator } from '@/components/ui/separator';
import { SwitchOperatorOptions } from '@/constants/agent';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request'; import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
import { SwitchOperatorOptions } from '@/pages/agent/constant';
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
import { Plus, X } from 'lucide-react'; import { Plus, X } from 'lucide-react';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form'; import { useFieldArray, useFormContext } from 'react-hook-form';