Compare commits

...

12 Commits

Author SHA1 Message Date
4e76220e25 Feat: Submit clean data operations form data to the backend. #10427 (#11030)
### What problem does this PR solve?

Feat: Submit clean data operations form data to the backend. #10427

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-11-05 17:32:35 +08:00
24335485bf Fix: get_allowed_llm_factories() return type (#11031)
### What problem does this PR solve?

Fix: get_allowed_llm_factories() return type #11003

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

<img width="2880" height="215" alt="截图 2025-11-05 17-02-01"
src="https://github.com/user-attachments/assets/ee892077-21f9-4b1e-a1d2-b921fa7f6121"
/>
2025-11-05 17:32:12 +08:00
121c51661d Fix: Markdown table extractor (#11018)
### What problem does this PR solve?

Now markdown table extractor supports <table ...>. #10966 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 16:10:21 +08:00
02d10f8eda Move var from rag.settings to common.globals (#11022)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-05 15:48:50 +08:00
dddf766470 Feat: start data sync service. (#11026)
### What problem does this PR solve?

#10953 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-05 15:43:15 +08:00
8584d4b642 Fix: numeric string miss transformation. (#11025)
### What problem does this PR solve?

#11024

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 15:14:30 +08:00
b86e07088b Fix: escape multi-steps issues. (#11016)
### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 14:51:00 +08:00
1a9215bc6f Move some vars to globals (#11017)
### What problem does this PR solve?

As title.

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-05 14:14:38 +08:00
cf9611c96f Feat: Support more chunking methods (#11000)
### What problem does this PR solve?

Feat: Support more chunking methods #10772 

This PR enables multiple chunking methods — including books, laws,
naive, one, and presentation — to be used with all existing PDF parsers
(DeepDOC, MinerU, Docling, TCADP, Plain Text, and Vision modes).

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-11-05 13:00:42 +08:00
f126875ec6 Apply some tweaks on Admin UI (#11011)
### What problem does this PR solve?

- Fix selected radio button text misaligned with radio button dot
- Fix `<ScrollArea>` scrollbar z-index issue
- Add backdrop blur effect on scrollbar thumbs
- Adjust some styles to match the design 


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 12:58:43 +08:00
89410d2381 fix:api /factories wrong return (#11015)
### What problem does this PR solve?

change:
api /factories wrong return

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-11-05 12:50:11 +08:00
96c015fb85 Fix and refactor imports (#11010)
### What problem does this PR solve?

1. Move EMBEDDING_CFG to common.globals
2. Fix error imports
3. Move signal handles to common/signal_utils.py

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
2025-11-05 11:07:54 +08:00
90 changed files with 1041 additions and 716 deletions

View File

@ -281,12 +281,21 @@ class Canvas(Graph):
def _run_batch(f, t):
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for i in range(f, t):
i = f
while i < t:
cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]:
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
i += 1
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
for _, ele in cpn.get_input_elements().items():
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]:
self.path.pop(i)
t -= 1
break
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
i += 1
for t in thr:
t.result()
@ -316,6 +325,7 @@ class Canvas(Graph):
"thoughts": self.get_component_thoughts(self.path[i])
})
_run_batch(idx, to)
to = len(self.path)
# post processing of components invocation
for i in range(idx, to):
cpn = self.get_component(self.path[i])

View File

@ -16,6 +16,13 @@
from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase
"""
class VariableModel(BaseModel):
data_type: Annotated[Literal["string", "number", "Object", "Boolean", "Array<string>", "Array<number>", "Array<object>", "Array<boolean>"], Field(default="Array<string>")]
input_mode: Annotated[Literal["constant", "variable"], Field(default="constant")]
value: Annotated[Any, Field(default=None)]
model_config = ConfigDict(extra="forbid")
"""
class IterationParam(ComponentParamBase):
"""

View File

@ -216,7 +216,7 @@ class LLM(ComponentBase):
error: str = ""
output_structure=None
try:
output_structure=self._param.outputs['structured']
output_structure = None#self._param.outputs['structured']
except Exception:
pass
if output_structure:

View File

@ -49,6 +49,9 @@ class MessageParam(ComponentParamBase):
class Message(ComponentBase):
component_name = "Message"
def get_input_elements(self) -> dict[str, Any]:
return self.get_input_elements_from_text("".join(self._param.content))
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
for k,v in self.get_input_elements_from_text(script).items():
if k in kwargs:

View File

@ -16,6 +16,8 @@
import os
import re
from abc import ABC
from typing import Any
from jinja2 import Template as Jinja2Template
from agent.component.base import ComponentParamBase
from common.connection_utils import timeout
@ -43,6 +45,9 @@ class StringTransformParam(ComponentParamBase):
class StringTransform(Message, ABC):
component_name = "StringTransform"
def get_input_elements(self) -> dict[str, Any]:
return self.get_input_elements_from_text(self._param.script)
def get_input_form(self) -> dict[str, dict]:
if self._param.method == "split":
return {

View File

@ -25,6 +25,7 @@ from api.db.services.dialog_service import meta_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from common import globals
from common.connection_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
@ -170,7 +171,7 @@ class Retrieval(ToolBase, ABC):
if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = settings.retriever.retrieval(
kbinfos = globals.retriever.retrieval(
query,
embd_mdl,
[kb.tenant_id for kb in kbs],
@ -186,7 +187,7 @@ class Retrieval(ToolBase, ABC):
)
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
cks = globals.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
if cks:
kbinfos["chunks"] = cks
if self._param.use_kg:

View File

@ -32,7 +32,6 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from api import settings
from common.misc_utils import get_uuid
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
@ -48,6 +47,7 @@ from api.db.services.canvas_service import UserCanvasService
from agent.canvas import Canvas
from functools import partial
from pathlib import Path
from common import globals
@manager.route('/new_token', methods=['POST']) # noqa: F821
@ -538,7 +538,7 @@ def list_chunks():
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = globals.retriever.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
@ -564,7 +564,7 @@ def get_chunk(chunk_id):
try:
tenant_id = objs[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None:
return server_error_response(Exception("Chunk not found"))
k = []
@ -886,7 +886,7 @@ def retrieval():
if req.get("keyword", False):
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
ranks = globals.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
rank_feature=label_question(question, kbs))

View File

@ -25,7 +25,6 @@ from flask import request, Response
from flask_login import login_required, current_user
from agent.component import LLM
from api import settings
from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService
@ -46,6 +45,7 @@ from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
from common import globals
@manager.route('/templates', methods=['GET']) # noqa: F821
@ -192,8 +192,8 @@ def rerun():
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
if globals.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
globals.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = ""
doc["chunk_num"] = 0
doc["token_num"] = 0

View File

@ -36,6 +36,7 @@ from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extr
from rag.settings import PAGERANK_FLD
from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType
from common import globals
@manager.route('/list', methods=['POST']) # noqa: F821
@ -60,7 +61,7 @@ def list_chunk():
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = settings.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
sres = globals.retriever.search(query, search.index_name(tenant_id), kb_ids, highlight=["content_ltks"])
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@ -98,7 +99,7 @@ def get():
return get_data_error_result(message="Tenant not found!")
for tenant in tenants:
kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids)
if chunk:
break
if chunk is None:
@ -170,7 +171,7 @@ def set():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
globals.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -186,7 +187,7 @@ def switch():
if not e:
return get_data_error_result(message="Document not found!")
for cid in req["chunk_ids"]:
if not settings.docStoreConn.update({"id": cid},
if not globals.docStoreConn.update({"id": cid},
{"available_int": int(req["available_int"])},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
@ -206,7 +207,7 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not settings.docStoreConn.delete({"id": req["chunk_ids"]},
if not globals.docStoreConn.delete({"id": req["chunk_ids"]},
search.index_name(DocumentService.get_tenant_id(req["doc_id"])),
doc.kb_id):
return get_data_error_result(message="Chunk deleting failure")
@ -270,7 +271,7 @@ def create():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
globals.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
@ -346,7 +347,7 @@ def retrieval_test():
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
ranks = globals.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
float(req.get("similarity_threshold", 0.0)),
float(req.get("vector_similarity_weight", 0.3)),
top,
@ -385,7 +386,7 @@ def knowledge_graph():
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = settings.retriever.search(req, search.index_name(tenant_id), kb_ids)
sres = globals.retriever.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]

View File

@ -23,7 +23,6 @@ import flask
from flask import request
from flask_login import current_user, login_required
from api import settings
from api.common.check_team_permission import check_kb_team_permission
from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX
from api.db import VALID_FILE_TYPES, FileType
@ -49,6 +48,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP, html2pdf, is_valid_url
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search, rag_tokenizer
from rag.utils.storage_factory import STORAGE_IMPL
from common import globals
@manager.route("/upload", methods=["POST"]) # noqa: F821
@ -367,7 +367,7 @@ def change_status():
continue
status_int = int(status)
if not settings.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
if not globals.docStoreConn.update({"doc_id": doc_id}, {"available_int": status_int}, search.index_name(kb.tenant_id), doc.kb_id):
result[doc_id] = {"error": "Database error (docStore update)!"}
result[doc_id] = {"status": status}
except Exception as e:
@ -432,8 +432,8 @@ def run():
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
doc = doc.to_dict()
@ -479,8 +479,8 @@ def rename():
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
globals.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
search.index_name(tenant_id),
@ -541,8 +541,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
if globals.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
try:
if "pipeline_id" in req and req["pipeline_id"] != "":

View File

@ -35,7 +35,6 @@ from api.db import VALID_FILE_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
from api.utils.api_utils import get_json_result
from api import settings
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.settings import PAGERANK_FLD
@ -43,7 +42,7 @@ from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType
from common import globals
@manager.route('/create', methods=['post']) # noqa: F821
@login_required
@ -110,11 +109,11 @@ def update():
if kb.pagerank != req.get("pagerank", 0):
if req.get("pagerank", 0) > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
e, kb = KnowledgebaseService.get_by_id(kb.id)
@ -226,8 +225,8 @@ def rm():
return get_data_error_result(
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
globals.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
globals.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
if hasattr(STORAGE_IMPL, 'remove_bucket'):
STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
@ -248,7 +247,7 @@ def list_tags(kb_id):
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = []
for tenant in tenants:
tags += settings.retriever.all_tags(tenant["tenant_id"], [kb_id])
tags += globals.retriever.all_tags(tenant["tenant_id"], [kb_id])
return get_json_result(data=tags)
@ -267,7 +266,7 @@ def list_tags_from_kbs():
tenants = UserTenantService.get_tenants_by_user_id(current_user.id)
tags = []
for tenant in tenants:
tags += settings.retriever.all_tags(tenant["tenant_id"], kb_ids)
tags += globals.retriever.all_tags(tenant["tenant_id"], kb_ids)
return get_json_result(data=tags)
@ -284,7 +283,7 @@ def rm_tags(kb_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
for t in req["tags"]:
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
globals.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
{"remove": {"tag_kwd": t}},
search.index_name(kb.tenant_id),
kb_id)
@ -303,7 +302,7 @@ def rename_tags(kb_id):
)
e, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
globals.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
search.index_name(kb.tenant_id),
kb_id)
@ -326,9 +325,9 @@ def knowledge_graph(kb_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
return get_json_result(data=obj)
@ -360,7 +359,7 @@ def delete_knowledge_graph(kb_id):
code=RetCode.AUTHENTICATION_ERROR
)
_, kb = KnowledgebaseService.get_by_id(kb_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
return get_json_result(data=True)
@ -732,13 +731,13 @@ def delete_kb_task():
task_id = kb.graphrag_task_id
kb_task_finish_at = "graphrag_task_finish_at"
cancel_task(task_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
case PipelineTaskType.RAPTOR:
kb_task_id_field = "raptor_task_id"
task_id = kb.raptor_task_id
kb_task_finish_at = "raptor_task_finish_at"
cancel_task(task_id)
settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
globals.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), kb_id)
case PipelineTaskType.MINDMAP:
kb_task_id_field = "mindmap_task_id"
task_id = kb.mindmap_task_id
@ -850,7 +849,7 @@ def check_embedding():
tenant_id = kb.tenant_id
emb_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id)
samples = sample_random_chunks_with_vectors(settings.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
samples = sample_random_chunks_with_vectors(globals.docStoreConn, tenant_id=tenant_id, kb_id=kb_id, n=n)
results, eff_sims = [], []
for ck in samples:

View File

@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from common.constants import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result, get_allowed_llm_factories
from common.base64_image import test_image
from rag.utils.base64_image import test_image
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel

View File

@ -20,7 +20,6 @@ import os
import json
from flask import request
from peewee import OperationalError
from api import settings
from api.db.db_models import File
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
@ -49,6 +48,7 @@ from api.utils.validation_utils import (
)
from rag.nlp import search
from rag.settings import PAGERANK_FLD
from common import globals
@manager.route("/datasets", methods=["POST"]) # noqa: F821
@ -360,11 +360,11 @@ def update(tenant_id, dataset_id):
return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch")
if req["pagerank"] > 0:
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
globals.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
search.index_name(kb.tenant_id), kb.id)
else:
# Elasticsearch requires PAGERANK_FLD be non-zero!
settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
globals.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
search.index_name(kb.tenant_id), kb.id)
if not KnowledgebaseService.update_by_id(kb.id, req):
@ -493,9 +493,9 @@ def knowledge_graph(tenant_id, dataset_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
if not globals.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
sres = globals.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):
return get_result(data=obj)
@ -528,7 +528,7 @@ def delete_knowledge_graph(tenant_id, dataset_id):
code=RetCode.AUTHENTICATION_ERROR
)
_, kb = KnowledgebaseService.get_by_id(dataset_id)
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
globals.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]},
search.index_name(kb.tenant_id), dataset_id)
return get_result(data=True)

View File

@ -25,6 +25,7 @@ from api.utils.api_utils import validate_request, build_error_result, apikey_req
from rag.app.tag import label_question
from api.db.services.dialog_service import meta_filter, convert_conditions
from common.constants import RetCode, LLMType
from common import globals
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
@apikey_required
@ -137,7 +138,7 @@ def retrieval(tenant_id):
# print("doc_ids", doc_ids)
if not doc_ids and metadata_condition is not None:
doc_ids = ['-999']
ranks = settings.retriever.retrieval(
ranks = globals.retriever.retrieval(
question,
embd_mdl,
kb.tenant_id,

View File

@ -44,6 +44,7 @@ from rag.prompts.generator import cross_languages, keyword_extraction
from rag.utils.storage_factory import STORAGE_IMPL
from common.string_utils import remove_redundant_spaces
from common.constants import RetCode, LLMType, ParserType, TaskStatus, FileSource
from common import globals
MAXIMUM_OF_UPLOADING_FILES = 256
@ -307,7 +308,7 @@ def update_doc(tenant_id, dataset_id, document_id):
)
if not e:
return get_error_data_result(message="Document not found!")
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
if "enabled" in req:
status = int(req["enabled"])
@ -316,7 +317,7 @@ def update_doc(tenant_id, dataset_id, document_id):
if not DocumentService.update_by_id(doc.id, {"status": str(status)}):
return get_error_data_result(message="Database error (Document update)!")
settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
globals.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
return get_result(data=True)
except Exception as e:
return server_error_response(e)
@ -755,7 +756,7 @@ def parse(tenant_id, dataset_id):
return get_error_data_result("Can't parse document that is currently being processed")
info = {"run": "1", "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
globals.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
@ -835,7 +836,7 @@ def stop_parsing(tenant_id, dataset_id):
return get_error_data_result("Can't stop parsing document with progress at 0 or 1")
info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info)
settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
globals.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id)
success_count += 1
if duplicate_messages:
if success_count > 0:
@ -968,7 +969,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
res = {"total": 0, "chunks": [], "doc": renamed_doc}
if req.get("id"):
chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
chunk = globals.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id])
if not chunk:
return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.NOT_FOUND)
k = []
@ -995,8 +996,8 @@ def list_chunks(tenant_id, dataset_id, document_id):
res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
elif globals.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = globals.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:
d = {
@ -1120,7 +1121,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
globals.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
# rename keys
@ -1201,7 +1202,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
if "chunk_ids" in req:
unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk")
condition["id"] = unique_chunk_ids
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
chunk_number = globals.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(unique_chunk_ids):
@ -1273,7 +1274,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema:
type: object
"""
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
chunk = globals.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
@ -1318,7 +1319,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
globals.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result()
@ -1464,7 +1465,7 @@ def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval(
ranks = globals.retriever.retrieval(
question,
embd_mdl,
tenant_ids,

View File

@ -41,6 +41,7 @@ from rag.app.tag import label_question
from rag.prompts.template import load_prompt
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
from common.constants import RetCode, LLMType, StatusEnum
from common import globals
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
@ -1015,7 +1016,7 @@ def retrieval_test_embedded():
question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb])
ranks = settings.retriever.retrieval(
ranks = globals.retriever.retrieval(
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)

View File

@ -38,6 +38,7 @@ from timeit import default_timer as timer
from rag.utils.redis_conn import REDIS_CONN
from flask import jsonify
from api.utils.health_utils import run_health_checks
from common import globals
@manager.route("/version", methods=["GET"]) # noqa: F821
@ -100,7 +101,7 @@ def status():
res = {}
st = timer()
try:
res["doc_engine"] = settings.docStoreConn.health()
res["doc_engine"] = globals.docStoreConn.health()
res["doc_engine"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e:
res["doc_engine"] = {

View File

@ -58,6 +58,7 @@ from api.utils.web_utils import (
hash_code,
captcha_key,
)
from common import globals
@manager.route("/login", methods=["POST", "GET"]) # noqa: F821
@ -623,7 +624,7 @@ def user_register(user_id, user):
"id": user_id,
"name": user["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,

View File

@ -32,6 +32,7 @@ from api.db.services.user_service import TenantService, UserTenantService
from api import settings
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import globals
from api.common.base64 import encode_to_base64
@ -49,7 +50,7 @@ def init_superuser():
"id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL

View File

@ -38,6 +38,7 @@ from api.db.services.user_service import TenantService, UserService, UserTenantS
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search
from common.constants import ActiveEnum
from common import globals
def create_new_user(user_info: dict) -> dict:
"""
@ -63,7 +64,7 @@ def create_new_user(user_info: dict) -> dict:
"id": user_id,
"name": user_info["nickname"] + "s Kingdom",
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"embd_id": globals.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,
@ -179,7 +180,7 @@ def delete_user_data(user_id: str) -> dict:
)
done_msg += f"- Deleted {file2doc_delete_res} document-file relation records.\n"
# step1.1.3 delete chunk in es
r = settings.docStoreConn.delete({"kb_id": kb_ids},
r = globals.docStoreConn.delete({"kb_id": kb_ids},
search.index_name(tenant_id), kb_ids)
done_msg += f"- Deleted {r} chunk records.\n"
kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids)
@ -237,7 +238,7 @@ def delete_user_data(user_id: str) -> dict:
kb_doc_info = {}
for _tenant_id, kb_doc in kb_grouped_doc.items():
for _kb_id, docs in kb_doc.items():
chunk_delete_res += settings.docStoreConn.delete(
chunk_delete_res += globals.docStoreConn.delete(
{"doc_id": [d["id"] for d in docs]},
search.index_name(_tenant_id), _kb_id
)

View File

@ -111,12 +111,14 @@ class SyncLogsService(CommonService):
return list(query.dicts())
@classmethod
def start(cls, id):
def start(cls, id, connector_id):
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING})
@classmethod
def done(cls, id):
def done(cls, id, connector_id):
cls.update_by_id(id, {"status": TaskStatus.DONE})
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE})
@classmethod
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
@ -126,6 +128,7 @@ class SyncLogsService(CommonService):
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
return None
reindex = "1" if reindex else "0"
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
return cls.save(**{
"id": get_uuid(),
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
@ -142,6 +145,7 @@ class SyncLogsService(CommonService):
full_exception_trace=cls.model.full_exception_trace + str(e)
) \
.where(cls.model.id == task.id).execute()
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
@classmethod
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):

View File

@ -44,6 +44,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_language
from common.token_utils import num_tokens_from_string
from rag.utils.tavily_conn import Tavily
from common.string_utils import remove_redundant_spaces
from common import globals
class DialogService(CommonService):
@ -293,12 +294,13 @@ def meta_filter(metas: dict, filters: list[dict]):
def filter_out(v2docs, operator, value):
ids = []
for input, docids in v2docs.items():
try:
input = float(input)
value = float(value)
except Exception:
input = str(input)
value = str(value)
if operator in ["=", "", ">", "<", "", ""]:
try:
input = float(input)
value = float(value)
except Exception:
input = str(input)
value = str(value)
for conds in [
(operator == "contains", str(value).lower() in str(input).lower()),
@ -371,7 +373,7 @@ def chat(dialog, messages, stream=True, **kwargs):
chat_mdl.bind_tools(toolcall_session, tools)
bind_models_ts = timer()
retriever = settings.retriever
retriever = globals.retriever
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
if "doc_ids" in messages[-1]:
@ -663,7 +665,7 @@ Please write the SQL, only SQL, without any other explanations or text.
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return settings.retriever.sql_retrieval(sql, format="json"), sql
return globals.retriever.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
@ -757,7 +759,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
embedding_list = list(set([kb.embd_id for kb in kbs]))
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever
retriever = globals.retriever if not is_knowledge_graph else settings.kg_retriever
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
@ -853,7 +855,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
if not doc_ids:
doc_ids = None
ranks = settings.retriever.retrieval(
ranks = globals.retriever.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,

View File

@ -26,7 +26,6 @@ import trio
import xxhash
from peewee import fn, Case, JOIN
from api import settings
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import FileType, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
@ -42,7 +41,7 @@ from rag.settings import get_svr_queue_name, SVR_CONSUMER_GROUP_NAME
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr
from common import globals
class DocumentService(CommonService):
model = Document
@ -309,10 +308,10 @@ class DocumentService(CommonService):
page_size = 1000
all_chunk_ids = []
while True:
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
chunks = globals.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
chunk_ids = globals.docStoreConn.getChunkIds(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@ -323,19 +322,19 @@ class DocumentService(CommonService):
if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX):
if STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail):
STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
globals.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
graph_source = globals.docStoreConn.getFields(
globals.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
globals.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
globals.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
except Exception:
pass
@ -996,10 +995,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
if not globals.docStoreConn.indexExist(idxnm, kb_id):
globals.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
globals.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

View File

@ -30,13 +30,14 @@ class LLMService(CommonService):
def get_init_tenant_llm(user_id):
from api import settings
from common import globals
tenant_llm = []
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG,
settings.EMBEDDING_CFG,
globals.EMBEDDING_CFG,
settings.ASR_CFG,
settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG,

View File

@ -34,7 +34,7 @@ from deepdoc.parser.excel_parser import RAGFlowExcelParser
from rag.settings import get_svr_queue_name
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from api import settings
from common import globals
from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x"
@ -418,7 +418,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
if pre_task["chunk_ids"]:
pre_chunk_ids.extend(pre_task["chunk_ids"].split())
if pre_chunk_ids:
settings.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
globals.docStoreConn.delete({"id": pre_chunk_ids}, search.index_name(chunking_config["tenant_id"]),
chunking_config["kb_id"])
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})

View File

@ -17,6 +17,7 @@ import os
import logging
from langfuse import Langfuse
from api import settings
from common import globals
from common.constants import LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService
@ -114,7 +115,7 @@ class TenantLLMService(CommonService):
if model_config:
model_config = model_config.to_dict()
elif llm_type == LLMType.EMBEDDING and fid == 'Builtin' and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv('TEI_MODEL', ''):
embedding_cfg = settings.EMBEDDING_CFG
embedding_cfg = globals.EMBEDDING_CFG
model_config = {"llm_factory": 'Builtin', "api_key": embedding_cfg["api_key"], "llm_name": mdlnm, "api_base": embedding_cfg["base_url"]}
else:
raise LookupError(f"Model({mdlnm}@{fid}) not authorized")

View File

@ -27,7 +27,7 @@ from api.db.services.common_service import CommonService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, datetime_format
from common.constants import StatusEnum
from rag.settings import MINIO
from common import globals
class UserService(CommonService):
@ -221,7 +221,7 @@ class TenantService(CommonService):
@DB.connection_context()
def user_gateway(cls, tenant_id):
hash_obj = hashlib.sha256(tenant_id.encode("utf-8"))
return int(hash_obj.hexdigest(), 16)%len(MINIO)
return int(hash_obj.hexdigest(), 16)%len(globals.MINIO)
class UserTenantService(CommonService):

View File

@ -25,18 +25,19 @@ import rag.utils.opensearch_conn
from api.constants import RAG_FLOW_SERVICE_NAME
from common.config_utils import decrypt_database_config, get_base_config
from common.file_utils import get_project_base_directory
from common import globals
from rag.nlp import search
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
CHAT_MDL = ""
EMBEDDING_MDL = ""
# EMBEDDING_MDL = "" has been moved to common/globals.py
RERANK_MDL = ""
ASR_MDL = ""
IMAGE2TEXT_MDL = ""
CHAT_CFG = ""
EMBEDDING_CFG = ""
# EMBEDDING_CFG = "" has been moved to common/globals.py
RERANK_CFG = ""
ASR_CFG = ""
IMAGE2TEXT_CFG = ""
@ -60,10 +61,10 @@ HTTP_APP_KEY = None
GITHUB_OAUTH = None
FEISHU_OAUTH = None
OAUTH_CONFIG = None
DOC_ENGINE = None
docStoreConn = None
# DOC_ENGINE = None has been moved to common/globals.py
# docStoreConn = None has been moved to common/globals.py
retriever = None
#retriever = None has been moved to common/globals.py
kg_retriever = None
# user registration switch
@ -124,8 +125,8 @@ def init_settings():
except Exception:
FACTORY_LLM_INFOS = []
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
global CHAT_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
global CHAT_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key")
@ -134,19 +135,19 @@ def init_settings():
)
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL))
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", globals.EMBEDDING_MDL))
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
globals.EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
CHAT_MDL = CHAT_CFG.get("model", "") or ""
EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
globals.EMBEDDING_MDL = os.getenv("TEI_MODEL", "BAAI/bge-small-en-v1.5") if "tei-" in os.getenv("COMPOSE_PROFILES", "") else ""
RERANK_MDL = RERANK_CFG.get("model", "") or ""
ASR_MDL = ASR_CFG.get("model", "") or ""
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
@ -168,23 +169,23 @@ def init_settings():
OAUTH_CONFIG = get_base_config("oauth", {})
global DOC_ENGINE, docStoreConn, retriever, kg_retriever
DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
# DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
lower_case_doc_engine = DOC_ENGINE.lower()
global kg_retriever
globals.DOC_ENGINE = os.environ.get("DOC_ENGINE", "elasticsearch")
# globals.DOC_ENGINE = os.environ.get('DOC_ENGINE', "opensearch")
lower_case_doc_engine = globals.DOC_ENGINE.lower()
if lower_case_doc_engine == "elasticsearch":
docStoreConn = rag.utils.es_conn.ESConnection()
globals.docStoreConn = rag.utils.es_conn.ESConnection()
elif lower_case_doc_engine == "infinity":
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
globals.docStoreConn = rag.utils.infinity_conn.InfinityConnection()
elif lower_case_doc_engine == "opensearch":
docStoreConn = rag.utils.opensearch_conn.OSConnection()
globals.docStoreConn = rag.utils.opensearch_conn.OSConnection()
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
raise Exception(f"Not supported doc engine: {globals.DOC_ENGINE}")
retriever = search.Dealer(docStoreConn)
globals.retriever = search.Dealer(globals.docStoreConn)
from graphrag import search as kg_search
kg_retriever = kg_search.KGSearch(docStoreConn)
kg_retriever = kg_search.KGSearch(globals.docStoreConn)
if int(os.environ.get("SANDBOX_ENABLED", "0")):
global SANDBOX_HOST

View File

@ -626,7 +626,7 @@ async def is_strong_enough(chat_model, embedding_model):
def get_allowed_llm_factories() -> list:
factories = LLMFactoriesService.get_all()
factories = list(LLMFactoriesService.get_all())
if settings.ALLOWED_LLM_FACTORIES is None:
return factories

View File

@ -19,11 +19,11 @@ from timeit import default_timer as timer
from api import settings
from api.db.db_models import DB
from rag import settings as rag_settings
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.es_conn import ESConnection
from rag.utils.infinity_conn import InfinityConnection
from common import globals
def _ok_nok(ok: bool) -> str:
@ -52,7 +52,7 @@ def check_redis() -> tuple[bool, dict]:
def check_doc_engine() -> tuple[bool, dict]:
st = timer()
try:
meta = settings.docStoreConn.health()
meta = globals.docStoreConn.health()
# treat any successful call as ok
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
except Exception as e:
@ -120,7 +120,7 @@ def get_mysql_status():
def check_minio_alive():
start_time = timer()
try:
response = requests.get(f'http://{rag_settings.MINIO["host"]}/minio/health/live')
response = requests.get(f'http://{globals.MINIO["host"]}/minio/health/live')
if response.status_code == 200:
return {"status": "alive", "message": f"Confirm elapsed: {(timer() - start_time) * 1000.0:.1f} ms."}
else:

64
common/globals.py Normal file
View File

@ -0,0 +1,64 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from common.config_utils import get_base_config, decrypt_database_config
EMBEDDING_MDL = ""
EMBEDDING_CFG = ""
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
docStoreConn = None
retriever = None
# move from rag.settings
ES = {}
INFINITY = {}
AZURE = {}
S3 = {}
MINIO = {}
OSS = {}
OS = {}
REDIS = {}
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
if DOC_ENGINE == 'elasticsearch':
ES = get_base_config("es", {})
elif DOC_ENGINE == 'opensearch':
OS = get_base_config("os", {})
elif DOC_ENGINE == 'infinity':
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
S3 = get_base_config("s3", {})
elif STORAGE_IMPL_TYPE == 'MINIO':
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
try:
REDIS = decrypt_database_config(name="redis")
except Exception:
try:
REDIS = get_base_config("redis", {})
except Exception:
REDIS = {}

55
common/signal_utils.py Normal file
View File

@ -0,0 +1,55 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
from datetime import datetime
import logging
import tracemalloc
from common.log_utils import get_project_base_directory
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")

View File

@ -70,6 +70,17 @@ class RAGFlowMarkdownParser:
)
working_text = replace_tables_with_rendered_html(no_border_table_pattern, tables)
# Replace any TAGS e.g. <table ...> to <table>
TAGS = ["table", "td", "tr", "th", "tbody", "thead", "div"]
table_with_attributes_pattern = re.compile(
rf"<(?:{'|'.join(TAGS)})[^>]*>", re.IGNORECASE
)
def replace_tag(m):
tag_name = re.match(r"<(\w+)", m.group()).group(1)
return "<{}>".format(tag_name)
working_text = re.sub(table_with_attributes_pattern, replace_tag, working_text)
if "<table>" in working_text.lower(): # for optimize performance
# HTML table extraction - handle possible html/body wrapper tags
html_table_pattern = re.compile(

View File

@ -6,10 +6,11 @@ set -e
# Usage and command-line argument parsing
# -----------------------------------------------------------------------------
function usage() {
echo "Usage: $0 [--disable-webserver] [--disable-taskexecutor] [--consumer-no-beg=<num>] [--consumer-no-end=<num>] [--workers=<num>] [--host-id=<string>]"
echo "Usage: $0 [--disable-webserver] [--disable-taskexecutor] [--disable-datasync] [--consumer-no-beg=<num>] [--consumer-no-end=<num>] [--workers=<num>] [--host-id=<string>]"
echo
echo " --disable-webserver Disables the web server (nginx + ragflow_server)."
echo " --disable-taskexecutor Disables task executor workers."
echo " --disable-datasync Disables synchronization of datasource workers."
echo " --enable-mcpserver Enables the MCP server."
echo " --enable-adminserver Enables the Admin server."
echo " --consumer-no-beg=<num> Start range for consumers (if using range-based)."
@ -28,6 +29,7 @@ function usage() {
ENABLE_WEBSERVER=1 # Default to enable web server
ENABLE_TASKEXECUTOR=1 # Default to enable task executor
ENABLE_DATASYNC=1
ENABLE_MCP_SERVER=0
ENABLE_ADMIN_SERVER=0 # Default close admin server
CONSUMER_NO_BEG=0
@ -69,6 +71,10 @@ for arg in "$@"; do
ENABLE_TASKEXECUTOR=0
shift
;;
--disable-datasyn)
ENABLE_DATASYNC=0
shift
;;
--enable-mcpserver)
ENABLE_MCP_SERVER=1
shift
@ -236,6 +242,13 @@ if [[ "${ENABLE_WEBSERVER}" -eq 1 ]]; then
done &
fi
if [[ "${ENABLE_DATASYNC}" -eq 1 ]]; then
echo "Starting data sync..."
while true; do
"$PY" rag/svr/sync_data_source.py
done &
fi
if [[ "${ENABLE_ADMIN_SERVER}" -eq 1 ]]; then
echo "Starting admin_server..."
while true; do

View File

@ -20,7 +20,6 @@ import os
import networkx as nx
import trio
from api import settings
from api.db.services.document_service import DocumentService
from common.misc_utils import get_uuid
from common.connection_utils import timeout
@ -40,6 +39,7 @@ from graphrag.utils import (
)
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock
from common import globals
async def run_graphrag(
@ -55,7 +55,7 @@ async def run_graphrag(
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
for d in globals.retriever.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -170,7 +170,7 @@ async def run_graphrag_for_kb(
chunks = []
current_chunk = ""
for d in settings.retriever.chunk_list(
for d in globals.retriever.chunk_list(
doc_id,
tenant_id,
[kb_id],
@ -387,8 +387,8 @@ async def generate_subgraph(
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(settings.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(globals.docStoreConn.insert, [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id)
now = trio.current_time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph
@ -496,7 +496,7 @@ async def extract_community(
chunks.append(chunk)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
lambda: globals.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
@ -504,7 +504,7 @@ async def extract_community(
)
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message)

View File

@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from graphrag.general.graph_extractor import GraphExtractor
from graphrag.general.index import update_graph, with_resolution, with_community
from common import globals
settings.init_settings()
@ -62,7 +63,7 @@ async def main():
chunks = [
d["content_with_weight"]
for d in settings.retriever.chunk_list(
for d in globals.retriever.chunk_list(
args.doc_id,
args.tenant_id,
[kb_id],

View File

@ -28,6 +28,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from graphrag.general.index import update_graph
from graphrag.light.graph_extractor import GraphExtractor
from common import globals
settings.init_settings()
@ -63,7 +64,7 @@ async def main():
chunks = [
d["content_with_weight"]
for d in settings.retriever.chunk_list(
for d in globals.retriever.chunk_list(
args.doc_id,
args.tenant_id,
[kb_id],

View File

@ -29,6 +29,7 @@ from rag.utils.doc_store_conn import OrderByExpr
from rag.nlp.search import Dealer, index_name
from common.float_utils import get_float
from common import globals
class KGSearch(Dealer):
@ -334,6 +335,6 @@ if __name__ == "__main__":
_, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
kg = KGSearch(settings.docStoreConn)
kg = KGSearch(globals.docStoreConn)
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))

View File

@ -23,12 +23,12 @@ import trio
import xxhash
from networkx.readwrite import json_graph
from api import settings
from common.misc_utils import get_uuid
from common.connection_utils import timeout
from rag.nlp import rag_tokenizer, search
from rag.utils.doc_store_conn import OrderByExpr
from rag.utils.redis_conn import REDIS_CONN
from common import globals
GRAPH_FIELD_SEP = "<SEP>"
@ -334,7 +334,7 @@ def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
ents = list(set(ents))
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
res = []
es_res = settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
es_res = globals.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
for id in es_res.ids:
try:
if size == 1:
@ -381,8 +381,8 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
res = await trio.to_thread.run_sync(lambda: globals.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = globals.docStoreConn.getFields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
@ -391,7 +391,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(lambda: settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
res = await trio.to_thread.run_sync(lambda: globals.retriever.search(conds, search.index_name(tenant_id), [kb_id]))
doc_ids = []
if res.total == 0:
return doc_ids
@ -402,7 +402,7 @@ async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
res = await trio.to_thread.run_sync(settings.retriever.search, conds, search.index_name(tenant_id), [kb_id])
res = await trio.to_thread.run_sync(globals.retriever.search, conds, search.index_name(tenant_id), [kb_id])
if not res.total == 0:
for id in res.ids:
try:
@ -423,17 +423,17 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
global chat_limiter
start = trio.current_time()
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)
if change.removed_nodes:
await trio.to_thread.run_sync(settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(globals.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)
if change.removed_edges:
async def del_edges(from_node, to_node):
async with chat_limiter:
await trio.to_thread.run_sync(
settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
globals.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id
)
async with trio.open_nursery() as nursery:
@ -501,7 +501,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size):
with trio.fail_after(3 if enable_timeout_assertion else 30000000):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b : b + es_bulk_size], search.index_name(tenant_id), kb_id))
if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result:
@ -555,7 +555,7 @@ def merge_tuples(list1, list2):
async def get_entity_type2samples(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
es_res = await trio.to_thread.run_sync(lambda: globals.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]}, idxnms, kb_ids))
res = defaultdict(list)
for id in es_res.ids:
@ -589,10 +589,10 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
bs = 256
for i in range(0, 1024 * bs, bs):
es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
lambda: globals.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
)
# tot = settings.docStoreConn.getTotal(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds)
# tot = globals.docStoreConn.getTotal(es_res)
es_res = globals.docStoreConn.getFields(es_res, flds)
if len(es_res) == 0:
break

View File

@ -15,18 +15,18 @@
#
import logging
from tika import parser
import re
from io import BytesIO
from deepdoc.parser.utils import get_text
from rag.app import naive
from rag.app.naive import plaintext_parser, PARSERS
from rag.nlp import bullets_category, is_english,remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks
from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, PlainParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from deepdoc.parser import PdfParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from PIL import Image
@ -96,13 +96,33 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
sections, tables, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tables:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)

View File

@ -15,7 +15,6 @@
#
import logging
from tika import parser
import re
from io import BytesIO
from docx import Document
@ -25,8 +24,8 @@ from deepdoc.parser.utils import get_text
from rag.nlp import bullets_category, remove_contents_table, \
make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge
from rag.nlp import rag_tokenizer, Node
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
from deepdoc.parser import PdfParser, DocxParser, HtmlParser
from rag.app.naive import plaintext_parser, PARSERS
@ -156,13 +155,36 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return tokenize_chunks(chunks, doc, eng, None)
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
for txt, poss in pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)[0]:
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
raw_sections, tables, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not raw_sections and not tables:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
for txt, poss in raw_sections:
sections.append(txt + poss)
callback(0.8, "Finish parsing.")
elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)

View File

@ -22,11 +22,11 @@ from common.constants import ParserType
from io import BytesIO
from rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
from common.token_utils import num_tokens_from_string
from deepdoc.parser import PdfParser, PlainParser, DocxParser
from deepdoc.parser import PdfParser, DocxParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from docx import Document
from PIL import Image
from rag.app.naive import plaintext_parser, PARSERS
class Pdf(PdfParser):
def __init__(self):
@ -196,15 +196,34 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
# is it English
eng = lang.lower() == "english" # pdf_parser.is_english
if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0]) < 3:
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
# set pivot using the most frequent type of title,
# then merge between 2 pivot
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
pdf_parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
sections, tbls, pdf_parser = pdf_parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tbls:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03:
max_lvl = max([lvl for _, lvl in pdf_parser.outlines])
most_level = max(0, max_lvl - 1)

View File

@ -26,7 +26,6 @@ from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
from docx.opc.oxml import parse_xml
from markdown import markdown
from PIL import Image
from tika import parser
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
@ -39,6 +38,100 @@ from deepdoc.parser.docling_parser import DoclingParser
from deepdoc.parser.tcadp_parser import TCADPParser
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
def DeepDOC_parser(filename, binary=None, from_page=0, to_page=100000, callback=None, pdf_cls = None ,**kwargs):
callback = callback
binary = binary
pdf_parser = pdf_cls() if pdf_cls else Pdf()
sections, tables = pdf_parser(
filename if not binary else binary,
from_page=from_page,
to_page=to_page,
callback=callback
)
tables = vision_figure_parser_pdf_wrapper(tbls=tables,
callback=callback,
**kwargs)
return sections, tables, pdf_parser
def MinerU_parser(filename, binary=None, callback=None, **kwargs):
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api)
if not pdf_parser.check_installation():
callback(-1, "MinerU not found.")
return None, None
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
return sections, tables, pdf_parser
def Docling_parser(filename, binary=None, callback=None, **kwargs):
pdf_parser = DoclingParser()
if not pdf_parser.check_installation():
callback(-1, "Docling not found.")
return None, None
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
return sections, tables, pdf_parser
def TCADP_parser(filename, binary=None, callback=None, **kwargs):
tcadp_parser = TCADPParser()
if not tcadp_parser.check_installation():
callback(-1, "TCADP parser not available. Please check Tencent Cloud API configuration.")
return None, None
sections, tables = tcadp_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("TCADP_OUTPUT_DIR", ""),
file_type="PDF"
)
return sections, tables, tcadp_parser
def plaintext_parser(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
if kwargs.get("layout_recognizer", "") == "Plain Text":
pdf_parser = PlainParser()
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=kwargs.get("layout_recognizer", ""), lang=kwargs.get("lang", "Chinese"))
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, tables = pdf_parser(
filename if not binary else binary,
from_page=from_page,
to_page=to_page,
callback=callback
)
return sections, tables, pdf_parser
PARSERS = {
"deepdoc": DeepDOC_parser,
"mineru": MinerU_parser,
"docling": Docling_parser,
"tcadp": TCADP_parser,
"plaintext": plaintext_parser, # default
}
class Docx(DocxParser):
def __init__(self):
@ -365,7 +458,7 @@ class Markdown(MarkdownParser):
html_content = markdown(text)
soup = BeautifulSoup(html_content, 'html.parser')
return soup
def get_picture_urls(self, soup):
if soup:
return [img.get('src') for img in soup.find_all('img') if img.get('src')]
@ -375,7 +468,7 @@ class Markdown(MarkdownParser):
if soup:
return set([a.get('href') for a in soup.find_all('a') if a.get('href')])
return []
def get_pictures(self, text):
"""Download and open all images from markdown text."""
import requests
@ -416,11 +509,11 @@ class Markdown(MarkdownParser):
txt = f.read()
remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables)
# To eliminate duplicate tables in chunking result, uncomment code below and set separate_tables to True in line 410.
# extractor = MarkdownElementExtractor(remainder)
extractor = MarkdownElementExtractor(txt)
element_sections = extractor.extract_elements(delimiter)
sections = [(element, "") for element in element_sections]
tbls = []
for table in tables:
tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), ""))
@ -535,82 +628,29 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
if layout_recognizer == "DeepDOC":
pdf_parser = Pdf()
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, callback=callback)
tables=vision_figure_parser_pdf_wrapper(tbls=tables,callback=callback,**kwargs)
sections, tables, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
**kwargs
)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
if not sections and not tables:
return []
elif layout_recognizer == "MinerU":
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
mineru_server_url = os.environ.get("MINERU_SERVER_URL", "")
mineru_backend = os.environ.get("MINERU_BACKEND", "pipeline")
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api, mineru_server_url=mineru_server_url)
ok, reason = pdf_parser.check_installation(backend=mineru_backend)
if not ok:
callback(-1, f"MinerU not found or server not accessible: {reason}")
return res
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
backend=mineru_backend,
server_url=mineru_server_url,
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
elif layout_recognizer == "Docling":
pdf_parser = DoclingParser()
if not pdf_parser.check_installation():
callback(-1, "Docling not found.")
return res
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
parser_config["chunk_token_num"] = 0
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif layout_recognizer == "TCADP Parser":
tcadp_parser = TCADPParser()
if not tcadp_parser.check_installation():
callback(-1, "TCADP parser not available. Please check Tencent Cloud API configuration.")
return res
sections, tables = tcadp_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("TCADP_OUTPUT_DIR", ""),
file_type="PDF"
)
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
else:
if layout_recognizer == "Plain Text":
pdf_parser = PlainParser()
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
@ -735,9 +775,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
sub_url_res = chunk(f"{index}.html", html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
url_res.extend(sub_url_res)
logging.info("naive_merge({}): {}".format(filename, timer() - st))
if embed_res:
res.extend(embed_res)
if url_res:

View File

@ -15,16 +15,15 @@
#
import logging
from tika import parser
from io import BytesIO
import re
from deepdoc.parser.utils import get_text
from rag.app import naive
from rag.nlp import rag_tokenizer, tokenize
from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from deepdoc.parser import PdfParser, ExcelParser, HtmlParser
from deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from rag.app.naive import plaintext_parser, PARSERS
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
@ -83,12 +82,34 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(
filename if not binary else binary, to_page=to_page, callback=callback)
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
sections, tbls, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tbls:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
for (img, rows), poss in tbls:
if not rows:
continue

View File

@ -20,14 +20,11 @@ from io import BytesIO
from PIL import Image
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import VisionParser
from rag.nlp import tokenize, is_english
from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, PptParser, PlainParser
from PyPDF2 import PdfReader as pdf2_read
from rag.app.naive import plaintext_parser, PARSERS
class Ppt(PptParser):
def __call__(self, fnm, from_page, to_page, callback=None):
@ -54,7 +51,6 @@ class Ppt(PptParser):
self.is_english = is_english(txts)
return [(txts[i], imgs[i]) for i in range(len(txts))]
class Pdf(PdfParser):
def __init__(self):
super().__init__()
@ -84,7 +80,7 @@ class Pdf(PdfParser):
res.append((lines, self.page_images[i]))
callback(0.9, "Page {}~{}: Parsing finished".format(
from_page, min(to_page, self.total_page)))
return res
return res, []
class PlainPdf(PlainParser):
@ -95,7 +91,7 @@ class PlainPdf(PlainParser):
for page in self.pdf.pages[from_page: to_page]:
page_txt.append(page.extract_text())
callback(0.9, "Parsing finished")
return [(txt, None) for txt in page_txt]
return [(txt, None) for txt in page_txt], []
def chunk(filename, binary=None, from_page=0, to_page=100000,
@ -130,20 +126,33 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if layout_recognizer == "DeepDOC":
pdf_parser = Pdf()
sections = pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
elif layout_recognizer == "Plain Text":
pdf_parser = PlainParser()
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, _ = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, plaintext_parser)
callback(0.1, "Start to parse.")
sections, _, _ = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
for pn, (txt, img) in enumerate(sections):
d = copy.deepcopy(doc)
pn += from_page

View File

@ -21,6 +21,7 @@ from copy import deepcopy
from deepdoc.parser.utils import get_text
from rag.app.qa import Excel
from rag.nlp import rag_tokenizer
from common import globals
def beAdoc(d, q, a, eng, row_num=-1):
@ -124,7 +125,6 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
def label_question(question, kbs):
from api.db.services.knowledgebase_service import KnowledgebaseService
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from api import settings
tags = None
tag_kb_ids = []
for kb in kbs:
@ -133,14 +133,14 @@ def label_question(question, kbs):
if tag_kb_ids:
all_tags = get_tags_from_cache(tag_kb_ids)
if not all_tags:
all_tags = settings.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
all_tags = globals.retriever.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
set_tags_to_cache(tags=all_tags, kb_ids=tag_kb_ids)
else:
all_tags = json.loads(all_tags)
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
if not tag_kbs:
return tags
tags = settings.retriever.tag_query(question,
tags = globals.retriever.tag_query(question,
list(set([kb.tenant_id for kb in tag_kbs])),
tag_kb_ids,
all_tags,

View File

@ -20,10 +20,10 @@ import time
import argparse
from collections import defaultdict
from common import globals
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api import settings
from common.misc_utils import get_uuid
from rag.nlp import tokenize, search
from ranx import evaluate
@ -52,7 +52,7 @@ class Benchmark:
run = defaultdict(dict)
query_list = list(qrels.keys())
for query in query_list:
ranks = settings.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
ranks = globals.retriever.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}")
@ -77,9 +77,9 @@ class Benchmark:
def init_index(self, vector_size: int):
if self.initialized_index:
return
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
if globals.docStoreConn.indexExist(self.index_name, self.kb_id):
globals.docStoreConn.deleteIdx(self.index_name, self.kb_id)
globals.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True
def ms_marco_index(self, file_path, index_name):
@ -114,13 +114,13 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = []
if docs:
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
globals.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts
def trivia_qa_index(self, file_path, index_name):
@ -155,12 +155,12 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs,self.index_name)
globals.docStoreConn.insert(docs,self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name)
globals.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name):
@ -210,12 +210,12 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name)
globals.docStoreConn.insert(docs, self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
settings.docStoreConn.insert(docs, self.index_name)
globals.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path):

View File

@ -21,7 +21,7 @@ from functools import partial
import trio
from common.misc_utils import get_uuid
from common.base64_image import id2image, image2id
from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream

View File

@ -27,7 +27,7 @@ from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from common.misc_utils import get_uuid
from common.base64_image import image2id
from rag.utils.base64_image import image2id
from deepdoc.parser import ExcelParser
from deepdoc.parser.mineru_parser import MinerUParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser

View File

@ -18,7 +18,7 @@ from functools import partial
import trio
from common.misc_utils import get_uuid
from common.base64_image import id2image, image2id
from rag.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream

View File

@ -29,7 +29,7 @@ from zhipuai import ZhipuAI
from common.log_utils import log_exception
from common.token_utils import num_tokens_from_string, truncate
from api import settings
from common import globals
import logging
@ -69,13 +69,13 @@ class BuiltinEmbed(Base):
_model_lock = threading.Lock()
def __init__(self, key, model_name, **kwargs):
logging.info(f"Initialize BuiltinEmbed according to settings.EMBEDDING_CFG: {settings.EMBEDDING_CFG}")
embedding_cfg = settings.EMBEDDING_CFG
logging.info(f"Initialize BuiltinEmbed according to globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}")
embedding_cfg = globals.EMBEDDING_CFG
if not BuiltinEmbed._model and "tei-" in os.getenv("COMPOSE_PROFILES", ""):
with BuiltinEmbed._model_lock:
BuiltinEmbed._model_name = settings.EMBEDDING_MDL
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(settings.EMBEDDING_MDL, 500)
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], settings.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
BuiltinEmbed._model_name = globals.EMBEDDING_MDL
BuiltinEmbed._max_tokens = BuiltinEmbed.MAX_TOKENS.get(globals.EMBEDDING_MDL, 500)
BuiltinEmbed._model = HuggingFaceEmbed(embedding_cfg["api_key"], globals.EMBEDDING_MDL, base_url=embedding_cfg["base_url"])
self._model = BuiltinEmbed._model
self._model_name = BuiltinEmbed._model_name
self._max_tokens = BuiltinEmbed._max_tokens

View File

@ -15,49 +15,12 @@
#
import os
import logging
from common.config_utils import get_base_config, decrypt_database_config
from common.file_utils import get_project_base_directory
from common.misc_utils import pip_install_torch
# Server
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
# Get storage type and document engine from system environment variables
STORAGE_IMPL_TYPE = os.getenv('STORAGE_IMPL', 'MINIO')
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
ES = {}
INFINITY = {}
AZURE = {}
S3 = {}
MINIO = {}
OSS = {}
OS = {}
# Initialize the selected configuration data based on environment variables to solve the problem of initialization errors due to lack of configuration
if DOC_ENGINE == 'elasticsearch':
ES = get_base_config("es", {})
elif DOC_ENGINE == 'opensearch':
OS = get_base_config("os", {})
elif DOC_ENGINE == 'infinity':
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})
elif STORAGE_IMPL_TYPE == 'AWS_S3':
S3 = get_base_config("s3", {})
elif STORAGE_IMPL_TYPE == 'MINIO':
MINIO = decrypt_database_config(name="minio")
elif STORAGE_IMPL_TYPE == 'OSS':
OSS = get_base_config("oss", {})
try:
REDIS = decrypt_database_config(name="redis")
except Exception:
try:
REDIS = get_base_config("redis", {})
except Exception:
REDIS = {}
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))
DOC_BULK_SIZE = int(os.environ.get("DOC_BULK_SIZE", 4))
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE", 16))

View File

@ -26,13 +26,12 @@ import traceback
from api.db.services.connector_service import SyncLogsService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.log_utils import init_root_logger, get_project_base_directory
from api.utils.configs import show_configs
from common.log_utils import init_root_logger
from common.config_utils import show_configs
from common.data_source import BlobStorageConnector
import logging
import os
from datetime import datetime, timezone
import tracemalloc
import signal
import trio
import faulthandler
@ -41,6 +40,7 @@ from api import settings
from api.versions import get_ragflow_version
from common.data_source.confluence_connector import ConfluenceConnector
from common.data_source.utils import load_all_docs_from_checkpoint_connector
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
@ -51,11 +51,39 @@ class SyncBase:
self.conf = conf
async def __call__(self, task: dict):
SyncLogsService.start(task["id"])
SyncLogsService.start(task["id"], task["connector_id"])
try:
async with task_limiter:
with trio.fail_after(task["timeout_secs"]):
task["poll_range_start"] = await self._run(task)
document_batch_generator = await self._generate(task)
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for document_batch in document_batch_generator:
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.S3,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
} for doc in document_batch]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized till {}".format(doc_num, next_update))
SyncLogsService.done(task["id"], task["connector_id"])
task["poll_range_start"] = next_update
except Exception as ex:
msg = '\n'.join([
''.join(traceback.format_exception_only(None, ex)).strip(),
@ -65,12 +93,12 @@ class SyncBase:
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
async def _run(self, task: dict):
async def _generate(self, task: dict):
raise NotImplementedError
class S3(SyncBase):
async def _run(self, task: dict):
async def _generate(self, task: dict):
self.connector = BlobStorageConnector(
bucket_type=self.conf.get("bucket_type", "s3"),
bucket_name=self.conf["bucket_name"],
@ -85,40 +113,11 @@ class S3(SyncBase):
self.conf["bucket_name"],
begin_info
))
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for document_batch in document_batch_generator:
min_update = min([doc.doc_updated_at for doc in document_batch])
max_update = max([doc.doc_updated_at for doc in document_batch])
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.S3,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
} for doc in document_batch]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.S3}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized from {}: {} {}".format(doc_num, self.conf.get("bucket_type", "s3"),
self.conf["bucket_name"],
begin_info
))
SyncLogsService.done(task["id"])
return next_update
return document_batch_generator
class Confluence(SyncBase):
async def _run(self, task: dict):
async def _generate(self, task: dict):
from common.data_source.interfaces import StaticCredentialsProvider
from common.data_source.config import DocumentSource
@ -156,85 +155,57 @@ class Confluence(SyncBase):
)
logging.info("Connect to Confluence: {} {}".format(self.conf["wiki_base"], begin_info))
doc_num = 0
next_update = datetime(1970, 1, 1, tzinfo=timezone.utc)
if task["poll_range_start"]:
next_update = task["poll_range_start"]
for doc in document_generator:
min_update = doc.doc_updated_at if doc.doc_updated_at else next_update
max_update = doc.doc_updated_at if doc.doc_updated_at else next_update
next_update = max([next_update, max_update])
docs = [{
"id": doc.id,
"connector_id": task["connector_id"],
"source": FileSource.CONFLUENCE,
"semantic_identifier": doc.semantic_identifier,
"extension": doc.extension,
"size_bytes": doc.size_bytes,
"doc_updated_at": doc.doc_updated_at,
"blob": doc.blob
}]
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{FileSource.CONFLUENCE}/{task['connector_id']}")
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
doc_num += len(docs)
logging.info("{} docs synchronized from Confluence: {} {}".format(doc_num, self.conf["wiki_base"], begin_info))
SyncLogsService.done(task["id"])
return next_update
return document_generator
class Notion(SyncBase):
async def __call__(self, task: dict):
async def _generate(self, task: dict):
pass
class Discord(SyncBase):
async def __call__(self, task: dict):
async def _generate(self, task: dict):
pass
class Gmail(SyncBase):
async def __call__(self, task: dict):
async def _generate(self, task: dict):
pass
class GoogleDriver(SyncBase):
async def __call__(self, task: dict):
async def _generate(self, task: dict):
pass
class Jira(SyncBase):
async def __call__(self, task: dict):
async def _generate(self, task: dict):
pass
class SharePoint(SyncBase):
async def __call__(self, task: dict):
async def _generate(self, task: dict):
pass
class Slack(SyncBase):
async def __call__(self, task: dict):
async def _generate(self, task: dict):
pass
class Teams(SyncBase):
async def __call__(self, task: dict):
async def _generate(self, task: dict):
pass
func_factory = {
FileSource.S3: S3,
FileSource.NOTION: Notion,
@ -263,41 +234,6 @@ async def dispatch_tasks():
stop_event = threading.Event()
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")
def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...")
stop_event.set()

View File

@ -27,9 +27,8 @@ from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from common.connection_utils import timeout
from common.base64_image import image2id
from rag.utils.base64_image import image2id
from common.log_utils import init_root_logger
from common.file_utils import get_project_base_directory
from common.config_utils import show_configs
from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
@ -45,7 +44,6 @@ import re
from functools import partial
from multiprocessing.context import TimeoutError
from timeit import default_timer as timer
import tracemalloc
import signal
import trio
import exceptiongroup
@ -69,6 +67,8 @@ from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.utils.storage_factory import STORAGE_IMPL
from graphrag.utils import chat_limiter
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
from common import globals
BATCH_SIZE = 64
@ -129,40 +129,6 @@ def signal_handler(sig, frame):
sys.exit(0)
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
if not tracemalloc.is_tracing():
logging.info("start tracemalloc")
tracemalloc.start()
else:
logging.info("tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
current, peak = tracemalloc.get_traced_memory()
if sys.platform == "win32":
import psutil
process = psutil.Process()
max_rss = process.memory_info().rss / 1024
else:
import resource
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing():
logging.info("stop tracemalloc")
tracemalloc.stop()
else:
logging.info("tracemalloc not running")
class TaskCanceledException(Exception):
def __init__(self, msg):
self.msg = msg
@ -384,7 +350,7 @@ async def build_chunks(task, progress_callback):
examples = []
all_tags = get_tags_from_cache(kb_ids)
if not all_tags:
all_tags = settings.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
all_tags = globals.retriever.all_tags_in_portion(tenant_id, kb_ids, S)
set_tags_to_cache(kb_ids, all_tags)
else:
all_tags = json.loads(all_tags)
@ -397,7 +363,7 @@ async def build_chunks(task, progress_callback):
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
if globals.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else:
docs_to_tag.append(d)
@ -458,7 +424,7 @@ def build_TOC(task, docs, progress_callback):
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
return globals.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
async def embedding(docs, mdl, parser_config=None, callback=None):
@ -682,7 +648,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for doc_id in doc_ids:
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
for d in globals.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
@ -733,7 +699,7 @@ async def delete_image(kb_id, chunk_id):
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
@ -750,7 +716,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
doc_store_result = await trio.to_thread.run_sync(lambda: globals.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
@ -786,7 +752,7 @@ async def do_handle_task(task):
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
# FIXME: workaround, Infinity doesn't support table parsing method, this check is to notify user
lower_case_doc_engine = settings.DOC_ENGINE.lower()
lower_case_doc_engine = globals.DOC_ENGINE.lower()
if lower_case_doc_engine == 'infinity' and task['parser_id'].lower() == 'table':
error_message = "Table parsing method is not supported by Infinity, please use other parsing methods or use Elasticsearch as the document engine."
progress_callback(-1, msg=error_message)
@ -1067,8 +1033,8 @@ async def main():
logging.info(f'RAGFlow version: {get_ragflow_version()}')
show_configs()
settings.init_settings()
from api.settings import EMBEDDING_CFG
logging.info(f'api.settings.EMBEDDING_CFG: {EMBEDDING_CFG}')
from common import globals
logging.info(f'globals.EMBEDDING_CFG: {globals.EMBEDDING_CFG}')
print_rag_settings()
if sys.platform != "win32":
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)

View File

@ -18,17 +18,17 @@ import logging
import os
import time
from io import BytesIO
from rag import settings
from common.decorator import singleton
from azure.storage.blob import ContainerClient
from common import globals
@singleton
class RAGFlowAzureSasBlob:
def __init__(self):
self.conn = None
self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"])
self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"])
self.container_url = os.getenv('CONTAINER_URL', globals.AZURE["container_url"])
self.sas_token = os.getenv('SAS_TOKEN', globals.AZURE["sas_token"])
self.__open__()
def __open__(self):

View File

@ -17,21 +17,21 @@
import logging
import os
import time
from rag import settings
from common.decorator import singleton
from azure.identity import ClientSecretCredential, AzureAuthorityHosts
from azure.storage.filedatalake import FileSystemClient
from common import globals
@singleton
class RAGFlowAzureSpnBlob:
def __init__(self):
self.conn = None
self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"])
self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"])
self.secret = os.getenv('SECRET', settings.AZURE["secret"])
self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"])
self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"])
self.account_url = os.getenv('ACCOUNT_URL', globals.AZURE["account_url"])
self.client_id = os.getenv('CLIENT_ID', globals.AZURE["client_id"])
self.secret = os.getenv('SECRET', globals.AZURE["secret"])
self.tenant_id = os.getenv('TENANT_ID', globals.AZURE["tenant_id"])
self.container_name = os.getenv('CONTAINER_NAME', globals.AZURE["container_name"])
self.__open__()
def __open__(self):

View File

@ -24,7 +24,6 @@ import copy
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout
from rag import settings
from rag.settings import TAG_FLD, PAGERANK_FLD
from common.decorator import singleton
from common.file_utils import get_project_base_directory
@ -33,6 +32,7 @@ from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr,
FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.float_utils import get_float
from common import globals
ATTEMPT_TIME = 2
@ -43,17 +43,17 @@ logger = logging.getLogger('ragflow.es_conn')
class ESConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
logger.info(f"Use Elasticsearch {globals.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
logger.warning(f"{str(e)}. Waiting Elasticsearch {globals.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
msg = f"Elasticsearch {globals.ES['hosts']} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
@ -68,14 +68,14 @@ class ESConnection(DocStoreConnection):
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
logger.info(f"Elasticsearch {globals.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
globals.ES["hosts"].split(","),
basic_auth=(globals.ES["username"], globals.ES[
"password"]) if "username" in globals.ES and "password" in globals.ES else None,
verify_certs= globals.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()

View File

@ -25,11 +25,11 @@ from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode
from rag import settings
from rag.settings import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
import pandas as pd
from common.file_utils import get_project_base_directory
from common import globals
from rag.nlp import is_english
from rag.utils.doc_store_conn import (
@ -130,8 +130,8 @@ def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> p
@singleton
class InfinityConnection(DocStoreConnection):
def __init__(self):
self.dbName = settings.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"]
self.dbName = globals.INFINITY.get("db_name", "default_db")
infinity_uri = globals.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))

View File

@ -20,8 +20,8 @@ from minio import Minio
from minio.commonconfig import CopySource
from minio.error import S3Error
from io import BytesIO
from rag import settings
from common.decorator import singleton
from common import globals
@singleton
@ -38,14 +38,14 @@ class RAGFlowMinio:
pass
try:
self.conn = Minio(settings.MINIO["host"],
access_key=settings.MINIO["user"],
secret_key=settings.MINIO["password"],
self.conn = Minio(globals.MINIO["host"],
access_key=globals.MINIO["user"],
secret_key=globals.MINIO["password"],
secure=False
)
except Exception:
logging.exception(
"Fail to connect %s " % settings.MINIO["host"])
"Fail to connect %s " % globals.MINIO["host"])
def __close__(self):
del self.conn

View File

@ -24,13 +24,13 @@ import copy
from opensearchpy import OpenSearch, NotFoundError
from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout
from rag import settings
from rag.settings import TAG_FLD, PAGERANK_FLD
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common import globals
ATTEMPT_TIME = 2
@ -41,13 +41,13 @@ logger = logging.getLogger('ragflow.opensearch_conn')
class OSConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f"Use OpenSearch {settings.OS['hosts']} as the doc engine.")
logger.info(f"Use OpenSearch {globals.OS['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
self.os = OpenSearch(
settings.OS["hosts"].split(","),
http_auth=(settings.OS["username"], settings.OS[
"password"]) if "username" in settings.OS and "password" in settings.OS else None,
globals.OS["hosts"].split(","),
http_auth=(globals.OS["username"], globals.OS[
"password"]) if "username" in globals.OS and "password" in globals.OS else None,
verify_certs=False,
timeout=600
)
@ -55,10 +55,10 @@ class OSConnection(DocStoreConnection):
self.info = self.os.info()
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting OpenSearch {settings.OS['hosts']} to be healthy.")
logger.warning(f"{str(e)}. Waiting OpenSearch {globals.OS['hosts']} to be healthy.")
time.sleep(5)
if not self.os.ping():
msg = f"OpenSearch {settings.OS['hosts']} is unhealthy in 120s."
msg = f"OpenSearch {globals.OS['hosts']} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "2.18.0"})
@ -73,7 +73,7 @@ class OSConnection(DocStoreConnection):
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"OpenSearch {settings.OS['hosts']} is healthy.")
logger.info(f"OpenSearch {globals.OS['hosts']} is healthy.")
"""
Database operations

View File

@ -20,14 +20,14 @@ from botocore.config import Config
import time
from io import BytesIO
from common.decorator import singleton
from rag import settings
from common import globals
@singleton
class RAGFlowOSS:
def __init__(self):
self.conn = None
self.oss_config = settings.OSS
self.oss_config = globals.OSS
self.access_key = self.oss_config.get('access_key', None)
self.secret_key = self.oss_config.get('secret_key', None)
self.endpoint_url = self.oss_config.get('endpoint_url', None)

View File

@ -19,8 +19,8 @@ import json
import uuid
import valkey as redis
from rag import settings
from common.decorator import singleton
from common import globals
from valkey.lock import Lock
import trio
@ -61,7 +61,7 @@ class RedisDB:
def __init__(self):
self.REDIS = None
self.config = settings.REDIS
self.config = globals.REDIS
self.__open__()
def register_scripts(self) -> None:

View File

@ -21,13 +21,14 @@ from botocore.config import Config
import time
from io import BytesIO
from common.decorator import singleton
from rag import settings
from common import globals
@singleton
class RAGFlowS3:
def __init__(self):
self.conn = None
self.s3_config = settings.S3
self.s3_config = globals.S3
self.access_key = self.s3_config.get('access_key', None)
self.secret_key = self.s3_config.get('secret_key', None)
self.session_token = self.s3_config.get('session_token', None)

View File

@ -15,10 +15,10 @@ import {
} from '@/components/ui/form';
import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator';
import { SwitchOperatorOptions } from '@/constants/agent';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
import { SwitchOperatorOptions } from '@/pages/agent/constant';
import { PromptEditor } from '@/pages/agent/form/components/prompt-editor';
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
import { Plus, X } from 'lucide-react';
import { useCallback } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form';

View File

@ -13,6 +13,7 @@ const ScrollBar = React.forwardRef<
ref={ref}
orientation={orientation}
className={cn(
'z-[100]',
'flex touch-none select-none transition-colors',
orientation === 'vertical' &&
'h-full w-2.5 border-l border-l-transparent p-[1px]',
@ -22,7 +23,7 @@ const ScrollBar = React.forwardRef<
)}
{...props}
>
<ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border" />
<ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border backdrop-blur-md" />
</ScrollAreaPrimitive.ScrollAreaScrollbar>
));
ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName;

View File

@ -1,4 +1,5 @@
import { setInitialChatVariableEnabledFieldValue } from '@/utils/chat';
import { Circle, CircleSlash2 } from 'lucide-react';
import { ChatVariableEnabledField, variableEnabledFieldMap } from './chat';
export enum ProgrammingLanguage {
@ -117,3 +118,53 @@ export enum Operator {
HierarchicalMerger = 'HierarchicalMerger',
Extractor = 'Extractor',
}
export enum ComparisonOperator {
Equal = '=',
NotEqual = '≠',
GreatThan = '>',
GreatEqual = '≥',
LessThan = '<',
LessEqual = '≤',
Contains = 'contains',
NotContains = 'not contains',
StartWith = 'start with',
EndWith = 'end with',
Empty = 'empty',
NotEmpty = 'not empty',
}
export const SwitchOperatorOptions = [
{ value: ComparisonOperator.Equal, label: 'equal', icon: 'equal' },
{ value: ComparisonOperator.NotEqual, label: 'notEqual', icon: 'not-equals' },
{ value: ComparisonOperator.GreatThan, label: 'gt', icon: 'Less' },
{
value: ComparisonOperator.GreatEqual,
label: 'ge',
icon: 'Greater-or-equal',
},
{ value: ComparisonOperator.LessThan, label: 'lt', icon: 'Less' },
{ value: ComparisonOperator.LessEqual, label: 'le', icon: 'less-or-equal' },
{ value: ComparisonOperator.Contains, label: 'contains', icon: 'Contains' },
{
value: ComparisonOperator.NotContains,
label: 'notContains',
icon: 'not-contains',
},
{
value: ComparisonOperator.StartWith,
label: 'startWith',
icon: 'list-start',
},
{ value: ComparisonOperator.EndWith, label: 'endWith', icon: 'list-end' },
{
value: ComparisonOperator.Empty,
label: 'empty',
icon: <Circle className="size-4" />,
},
{
value: ComparisonOperator.NotEmpty,
label: 'notEmpty',
icon: <CircleSlash2 className="size-4" />,
},
];

View File

@ -0,0 +1,45 @@
import { IconFont } from '@/components/icon-font';
import { ComparisonOperator, SwitchOperatorOptions } from '@/constants/agent';
import { cn } from '@/lib/utils';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
export const LogicalOperatorIcon = function OperatorIcon({
icon,
value,
}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) {
if (typeof icon === 'string') {
return (
<IconFont
name={icon}
className={cn('size-4', {
'rotate-180': value === '>',
})}
></IconFont>
);
}
return icon;
};
export function useBuildSwitchOperatorOptions(
subset: ComparisonOperator[] = [],
) {
const { t } = useTranslation();
const switchOperatorOptions = useMemo(() => {
return SwitchOperatorOptions.filter((x) =>
subset.some((y) => y === x.value),
).map((x) => ({
value: x.value,
icon: (
<LogicalOperatorIcon
icon={x.icon}
value={x.value}
></LogicalOperatorIcon>
),
label: t(`flow.switchOperatorOptions.${x.label}`),
}));
}, [t]);
return switchOperatorOptions;
}

View File

@ -1980,6 +1980,10 @@ Important structured information may include: names, dates, locations, events, k
deleteRole: 'Delete role',
deleteRoleConfirmation:
'Are you sure you want to delete this role? This action cannot be undone.',
alive: 'Alive',
timeout: 'Timeout',
fail: 'Fail',
},
},
};

View File

@ -1,8 +1,11 @@
import Spotlight from '@/components/spotlight';
import { Card, CardContent } from '@/components/ui/card';
function AdminMonitoring() {
return (
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Spotlight />
<CardContent className="size-full p-0">
<iframe />
</CardContent>

View File

@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import Spotlight from '@/components/spotlight';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import {
@ -29,7 +30,6 @@ import {
import { LucideEdit3, LucideTrash2, LucideUserPlus } from 'lucide-react';
import {
AdminService,
assignRolePermissions,
createRole,
deleteRole,
@ -149,7 +149,9 @@ function AdminRoles() {
return (
<>
<Card className="!shadow-none w-full h-full border border-border-button bg-transparent rounded-xl">
<Card className="!shadow-none relative w-full h-full border border-border-button bg-transparent rounded-xl">
<Spotlight />
<ScrollArea className="size-full">
<CardHeader className="space-y-0 flex flex-row justify-between items-center">
<CardTitle>{t('admin.roles')}</CardTitle>

View File

@ -23,6 +23,7 @@ import { useQuery } from '@tanstack/react-query';
import { cn } from '@/lib/utils';
import Spotlight from '@/components/spotlight';
import { TableEmpty } from '@/components/table-skeleton';
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
@ -151,11 +152,11 @@ function AdminServiceStatus() {
alive: 'bg-state-success-5 text-state-success',
timeout: 'bg-state-error-5 text-state-error',
fail: 'bg-gray-500/5 text-text-disable',
}[cell.getValue<string>()],
}[cell.getValue()],
)}
>
<LucideDot className="size-[1em] stroke-[8] mr-1" />
{cell.getValue()}
{t(`admin.${cell.getValue()}`)}
</Badge>
),
enableSorting: false,
@ -215,7 +216,9 @@ function AdminServiceStatus() {
return (
<>
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl">
<Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl">
<Spotlight />
<ScrollArea className="size-full">
<CardHeader className="space-y-0 flex flex-row justify-between items-center">
<CardTitle>{t('admin.serviceStatus')}</CardTitle>
@ -252,7 +255,7 @@ function AdminServiceStatus() {
table.getColumn('service_type')!?.setFilterValue
}
>
<Label className="space-x-2">
<Label className="flex items-center space-x-2">
<RadioGroupItem
className="bg-bg-input border-border-button"
value=""
@ -261,7 +264,10 @@ function AdminServiceStatus() {
</Label>
{SERVICE_TYPE_FILTER_OPTIONS.map(({ label, value }) => (
<Label key={value} className="space-x-2">
<Label
key={value}
className="flex items-center space-x-2"
>
<RadioGroupItem
className="bg-bg-input border-border-button"
value={value}

View File

@ -16,6 +16,7 @@ import {
import { cn } from '@/lib/utils';
import { Routes } from '@/routes';
import Spotlight from '@/components/spotlight';
import { Avatar } from '@/components/ui/avatar';
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
@ -322,7 +323,9 @@ function AdminUserDetail() {
</Button>
</nav>
<Card className="!shadow-none h-0 basis-0 grow flex flex-col bg-transparent border dark:border-border-button overflow-hidden">
<Card className="!shadow-none relative h-0 basis-0 grow flex flex-col bg-transparent border dark:border-border-button overflow-hidden">
<Spotlight />
<CardHeader className="pb-10 border-b dark:border-border-button space-y-8">
<section className="flex items-center gap-4 text-base">
<Avatar className="justify-center items-center bg-bg-group uppercase">

View File

@ -25,6 +25,7 @@ import {
import { cn } from '@/lib/utils';
import { rsaPsw } from '@/utils';
import Spotlight from '@/components/spotlight';
import { TableEmpty } from '@/components/table-skeleton';
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
@ -357,7 +358,9 @@ function AdminUserManagement() {
return (
<>
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Spotlight />
<ScrollArea className="size-full">
<CardHeader className="space-y-0 flex flex-row justify-between items-center">
<CardTitle>{t('admin.userManagement')}</CardTitle>
@ -396,13 +399,16 @@ function AdminUserManagement() {
table.getColumn('role')?.setFilterValue(value)
}
>
<Label className="space-x-2">
<Label className="flex items-center space-x-2">
<RadioGroupItem value="" />
<span>{t('admin.all')}</span>
</Label>
{roleList?.map(({ id, role_name }) => (
<Label key={id} className="space-x-2">
<Label
key={id}
className="flex items-center space-x-2"
>
<RadioGroupItem
className="bg-bg-input border-border-button"
value={role_name}
@ -429,7 +435,10 @@ function AdminUserManagement() {
}
>
{STATUS_FILTER_OPTIONS.map(({ label, value }) => (
<Label key={value} className="space-x-2">
<Label
key={value}
className="flex items-center space-x-2"
>
<RadioGroupItem
className="bg-bg-input border-border-button"
value={value}

View File

@ -22,6 +22,7 @@ import {
LucideUserPen,
} from 'lucide-react';
import Spotlight from '@/components/spotlight';
import { TableEmpty } from '@/components/table-skeleton';
import { Button } from '@/components/ui/button';
import {
@ -58,7 +59,6 @@ import {
importWhitelistFromExcel,
listWhitelist,
updateWhitelistEntry,
type AdminService,
} from '@/services/admin-service';
import { EMPTY_DATA, createFuzzySearchFn, getSortIcon } from './utils';
@ -68,8 +68,6 @@ import useImportExcelForm, {
ImportExcelFormData,
} from './forms/import-excel-form';
// #endregion
const columnHelper = createColumnHelper<AdminService.ListWhitelistItem>();
const globalFilterFn = createFuzzySearchFn<AdminService.ListWhitelistItem>([
'email',
@ -233,7 +231,9 @@ function AdminWhitelist() {
return (
<>
<Card className="!shadow-none h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Card className="!shadow-none relative h-full border border-border-button bg-transparent rounded-xl overflow-x-hidden overflow-y-auto">
<Spotlight />
<ScrollArea className="size-full">
<CardHeader className="space-y-0 flex flex-row justify-between items-center">
<CardTitle>{t('admin.whitelistManagement')}</CardTitle>

View File

@ -1,9 +1,9 @@
import { Card, CardContent } from '@/components/ui/card';
import { SwitchOperatorOptions } from '@/constants/agent';
import { LogicalOperatorIcon } from '@/hooks/logic-hooks/use-build-operator-options';
import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow';
import { NodeProps, Position } from '@xyflow/react';
import { memo, useCallback } from 'react';
import { SwitchOperatorOptions } from '../../constant';
import { LogicalOperatorIcon } from '../../form/switch-form';
import { useGetVariableLabelByValue } from '../../hooks/use-get-begin-query';
import { CommonHandle, LeftEndHandle } from './handle';
import { RightHandleStyle } from './handle-icon';

View File

@ -6,8 +6,10 @@ import {
AgentGlobals,
AgentGlobalsSysQueryWithBrace,
CodeTemplateStrMap,
ComparisonOperator,
Operator,
ProgrammingLanguage,
SwitchOperatorOptions,
initialLlmBaseValues,
} from '@/constants/agent';
export { Operator } from '@/constants/agent';
@ -35,8 +37,6 @@ export enum PromptRole {
}
import {
Circle,
CircleSlash2,
CloudUpload,
ListOrdered,
OptionIcon,
@ -166,27 +166,12 @@ export const componentMenuList = [
},
];
export const SwitchOperatorOptions = [
{ value: '=', label: 'equal', icon: 'equal' },
{ value: '≠', label: 'notEqual', icon: 'not-equals' },
{ value: '>', label: 'gt', icon: 'Less' },
{ value: '≥', label: 'ge', icon: 'Greater-or-equal' },
{ value: '<', label: 'lt', icon: 'Less' },
{ value: '≤', label: 'le', icon: 'less-or-equal' },
{ value: 'contains', label: 'contains', icon: 'Contains' },
{ value: 'not contains', label: 'notContains', icon: 'not-contains' },
{ value: 'start with', label: 'startWith', icon: 'list-start' },
{ value: 'end with', label: 'endWith', icon: 'list-end' },
{
value: 'empty',
label: 'empty',
icon: <Circle className="size-4" />,
},
{
value: 'not empty',
label: 'notEmpty',
icon: <CircleSlash2 className="size-4" />,
},
export const DataOperationsOperatorOptions = [
ComparisonOperator.Equal,
ComparisonOperator.NotEqual,
ComparisonOperator.Contains,
ComparisonOperator.StartWith,
ComparisonOperator.EndWith,
];
export const SwitchElseTo = 'end_cpn_ids';
@ -716,16 +701,17 @@ export const initialPlaceholderValues = {
};
export enum Operations {
SelectKeys = 'select keys',
LiteralEval = 'literal eval',
SelectKeys = 'select_keys',
LiteralEval = 'literal_eval',
Combine = 'combine',
FilterValues = 'filter values',
AppendOrUpdate = 'append or update',
RemoveKeys = 'remove keys',
RenameKeys = 'rename keys',
FilterValues = 'filter_values',
AppendOrUpdate = 'append_or_update',
RemoveKeys = 'remove_keys',
RenameKeys = 'rename_keys',
}
export const initialDataOperationsValues = {
query: [],
operations: Operations.SelectKeys,
outputs: {
result: {

View File

@ -0,0 +1,25 @@
import { Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Plus } from 'lucide-react';
import { ReactNode } from 'react';
export type FormListHeaderProps = {
label: ReactNode;
tooltip?: string;
onClick?: () => void;
};
export function DynamicFormHeader({
label,
tooltip,
onClick,
}: FormListHeaderProps) {
return (
<div className="flex items-center justify-between">
<FormLabel tooltip={tooltip}>{label}</FormLabel>
<Button variant={'ghost'} type="button" onClick={onClick}>
<Plus />
</Button>
</div>
);
}

View File

@ -1,4 +1,7 @@
import { RAGFlowFormItem } from '@/components/ragflow-form';
import { Input } from '@/components/ui/input';
import { t } from 'i18next';
import { z } from 'zod';
export type OutputType = {
title: string;
@ -7,6 +10,7 @@ export type OutputType = {
type OutputProps = {
list: Array<OutputType>;
isFormRequired?: boolean;
};
export function transferOutputs(outputs: Record<string, any>) {
@ -16,7 +20,11 @@ export function transferOutputs(outputs: Record<string, any>) {
}));
}
export function Output({ list }: OutputProps) {
export const OutputSchema = {
outputs: z.record(z.any()),
};
export function Output({ list, isFormRequired = false }: OutputProps) {
return (
<section className="space-y-2">
<div className="text-sm">{t('flow.output')}</div>
@ -30,6 +38,11 @@ export function Output({ list }: OutputProps) {
</li>
))}
</ul>
{isFormRequired && (
<RAGFlowFormItem name="outputs" className="hidden">
<Input></Input>
</RAGFlowFormItem>
)}
</section>
);
}

View File

@ -1,17 +1,20 @@
import { BlockButton, Button } from '@/components/ui/button';
import { Button } from '@/components/ui/button';
import { X } from 'lucide-react';
import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { JsonSchemaDataType } from '../../constant';
import { DynamicFormHeader, FormListHeaderProps } from './dynamic-fom-header';
import { QueryVariable } from './query-variable';
type QueryVariableListProps = {
types?: JsonSchemaDataType[];
};
export function QueryVariableList({ types }: QueryVariableListProps) {
const { t } = useTranslation();
} & FormListHeaderProps;
export function QueryVariableList({
types,
label,
tooltip,
}: QueryVariableListProps) {
const form = useFormContext();
const name = 'inputs';
const name = 'query';
const { fields, remove, append } = useFieldArray({
name: name,
@ -19,28 +22,31 @@ export function QueryVariableList({ types }: QueryVariableListProps) {
});
return (
<div className="space-y-5">
{fields.map((field, index) => {
const nameField = `${name}.${index}.input`;
<section className="space-y-2">
<DynamicFormHeader
label={label}
tooltip={tooltip}
onClick={() => append({ input: '' })}
></DynamicFormHeader>
<div className="space-y-5">
{fields.map((field, index) => {
const nameField = `${name}.${index}.input`;
return (
<div key={field.id} className="flex items-center gap-2">
<QueryVariable
name={nameField}
hideLabel
className="flex-1"
types={types}
></QueryVariable>
<Button variant={'ghost'} onClick={() => remove(index)}>
<X className="text-text-sub-title-invert " />
</Button>
</div>
);
})}
<BlockButton onClick={() => append({ input: '' })}>
{t('common.add')}
</BlockButton>
</div>
return (
<div key={field.id} className="flex items-center gap-2">
<QueryVariable
name={nameField}
hideLabel
className="flex-1"
types={types}
></QueryVariable>
<Button variant={'ghost'} onClick={() => remove(index)}>
<X className="text-text-sub-title-invert " />
</Button>
</div>
);
})}
</div>
</section>
);
}

View File

@ -1,13 +1,14 @@
import { SelectWithSearch } from '@/components/originui/select-with-search';
import { RAGFlowFormItem } from '@/components/ragflow-form';
import { BlockButton, Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { X } from 'lucide-react';
import { ReactNode } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { DataOperationsOperatorOptions } from '../../constant';
import { DynamicFormHeader } from '../components/dynamic-fom-header';
type SelectKeysProps = {
name: string;
@ -25,7 +26,6 @@ export function FilterValues({
valueField = 'value',
operatorField = 'operator',
}: SelectKeysProps) {
const { t } = useTranslation();
const form = useFormContext();
const { fields, remove, append } = useFieldArray({
@ -33,9 +33,18 @@ export function FilterValues({
control: form.control,
});
const operatorOptions = useBuildSwitchOperatorOptions(
DataOperationsOperatorOptions,
);
return (
<section className="space-y-2">
<FormLabel tooltip={tooltip}>{label}</FormLabel>
<DynamicFormHeader
label={label}
tooltip={tooltip}
onClick={() => append({ [keyField]: '', [valueField]: '' })}
></DynamicFormHeader>
<div className="space-y-5">
{fields.map((field, index) => {
const keyFieldAlias = `${name}.${index}.${keyField}`;
@ -47,12 +56,15 @@ export function FilterValues({
<RAGFlowFormItem name={keyFieldAlias} className="flex-1">
<Input></Input>
</RAGFlowFormItem>
<Separator orientation="vertical" className="h-2.5" />
<Separator className="w-2" />
<RAGFlowFormItem name={operatorFieldAlias} className="flex-1">
<SelectWithSearch {...field} options={[]}></SelectWithSearch>
<SelectWithSearch
{...field}
options={operatorOptions}
></SelectWithSearch>
</RAGFlowFormItem>
<Separator orientation="vertical" className="h-2.5" />
<Separator className="w-2" />
<RAGFlowFormItem name={valueFieldAlias} className="flex-1">
<Input></Input>
@ -64,10 +76,6 @@ export function FilterValues({
);
})}
</div>
<BlockButton onClick={() => append({ [keyField]: '', [valueField]: '' })}>
{t('common.add')}
</BlockButton>
</section>
);
}

View File

@ -1,6 +1,7 @@
import { SelectWithSearch } from '@/components/originui/select-with-search';
import { RAGFlowFormItem } from '@/components/ragflow-form';
import { Form, FormLabel } from '@/components/ui/form';
import { Form } from '@/components/ui/form';
import { Separator } from '@/components/ui/separator';
import { buildOptions } from '@/utils/form';
import { zodResolver } from '@hookform/resolvers/zod';
import { memo } from 'react';
@ -17,14 +18,14 @@ import { useWatchFormChange } from '../../hooks/use-watch-form-change';
import { INextOperatorForm } from '../../interface';
import { buildOutputList } from '../../utils/build-output-list';
import { FormWrapper } from '../components/form-wrapper';
import { Output } from '../components/output';
import { Output, OutputSchema } from '../components/output';
import { QueryVariableList } from '../components/query-variable-list';
import { FilterValues } from './filter-values';
import { SelectKeys } from './select-keys';
import { Updates } from './updates';
export const RetrievalPartialSchema = {
inputs: z.array(z.object({ input: z.string().optional() })),
query: z.array(z.object({ input: z.string().optional() })),
operations: z.string(),
select_keys: z.array(z.object({ name: z.string().optional() })).optional(),
remove_keys: z.array(z.object({ name: z.string().optional() })).optional(),
@ -50,6 +51,7 @@ export const RetrievalPartialSchema = {
}),
)
.optional(),
...OutputSchema,
};
export const FormSchema = z.object(RetrievalPartialSchema);
@ -65,6 +67,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
const form = useForm<DataOperationsFormSchemaType>({
defaultValues: defaultValues,
mode: 'onChange',
resolver: zodResolver(FormSchema),
shouldUnregister: true,
});
@ -78,17 +81,17 @@ function DataOperationsForm({ node }: INextOperatorForm) {
true,
);
useWatchFormChange(node?.id, form);
useWatchFormChange(node?.id, form, true);
return (
<Form {...form}>
<FormWrapper>
<div className="space-y-2">
<FormLabel tooltip={t('flow.queryTip')}>{t('flow.query')}</FormLabel>
<QueryVariableList
types={[JsonSchemaDataType.Array, JsonSchemaDataType.Object]}
></QueryVariableList>
</div>
<QueryVariableList
tooltip={t('flow.queryTip')}
label={t('flow.query')}
types={[JsonSchemaDataType.Array, JsonSchemaDataType.Object]}
></QueryVariableList>
<Separator />
<RAGFlowFormItem name="operations" label={t('flow.operations')}>
<SelectWithSearch options={OperationsOptions} allowClear />
</RAGFlowFormItem>
@ -107,7 +110,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
{operations === Operations.AppendOrUpdate && (
<Updates
name="updates"
label={t('flow.operationsOptions.updates')}
label={t('flow.operationsOptions.appendOrUpdate')}
keyField="key"
valueField="value"
></Updates>
@ -126,7 +129,7 @@ function DataOperationsForm({ node }: INextOperatorForm) {
label={t('flow.operationsOptions.filterValues')}
></FilterValues>
)}
<Output list={outputList}></Output>
<Output list={outputList} isFormRequired></Output>
</FormWrapper>
</Form>
);

View File

@ -1,11 +1,10 @@
import { RAGFlowFormItem } from '@/components/ragflow-form';
import { BlockButton, Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { X } from 'lucide-react';
import { ReactNode } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { DynamicFormHeader } from '../components/dynamic-fom-header';
type SelectKeysProps = {
name: string;
@ -13,7 +12,6 @@ type SelectKeysProps = {
tooltip?: string;
};
export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
const { t } = useTranslation();
const form = useFormContext();
const { fields, remove, append } = useFieldArray({
@ -23,7 +21,11 @@ export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
return (
<section className="space-y-2">
<FormLabel tooltip={tooltip}>{label}</FormLabel>
<DynamicFormHeader
label={label}
tooltip={tooltip}
onClick={() => append({ name: '' })}
></DynamicFormHeader>
<div className="space-y-5">
{fields.map((field, index) => {
const nameField = `${name}.${index}.name`;
@ -40,10 +42,6 @@ export function SelectKeys({ name, label, tooltip }: SelectKeysProps) {
);
})}
</div>
<BlockButton onClick={() => append({ name: '' })}>
{t('common.add')}
</BlockButton>
</section>
);
}

View File

@ -1,11 +1,11 @@
import { RAGFlowFormItem } from '@/components/ragflow-form';
import { BlockButton, Button } from '@/components/ui/button';
import { FormLabel } from '@/components/ui/form';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator';
import { X } from 'lucide-react';
import { ReactNode } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { DynamicFormHeader } from '../components/dynamic-fom-header';
type SelectKeysProps = {
name: string;
@ -21,7 +21,6 @@ export function Updates({
keyField,
valueField,
}: SelectKeysProps) {
const { t } = useTranslation();
const form = useFormContext();
const { fields, remove, append } = useFieldArray({
@ -31,7 +30,11 @@ export function Updates({
return (
<section className="space-y-2">
<FormLabel tooltip={tooltip}>{label}</FormLabel>
<DynamicFormHeader
label={label}
tooltip={tooltip}
onClick={() => append({ [keyField]: '', [valueField]: '' })}
></DynamicFormHeader>
<div className="space-y-5">
{fields.map((field, index) => {
const keyFieldAlias = `${name}.${index}.${keyField}`;
@ -42,6 +45,7 @@ export function Updates({
<RAGFlowFormItem name={keyFieldAlias} className="flex-1">
<Input></Input>
</RAGFlowFormItem>
<Separator className="w-2" />
<RAGFlowFormItem name={valueFieldAlias} className="flex-1">
<Input></Input>
</RAGFlowFormItem>
@ -52,10 +56,6 @@ export function Updates({
);
})}
</div>
<BlockButton onClick={() => append({ [keyField]: '', [valueField]: '' })}>
{t('common.add')}
</BlockButton>
</section>
);
}

View File

@ -1,5 +1,4 @@
import { FormContainer } from '@/components/form-container';
import { IconFont } from '@/components/icon-font';
import { SelectWithSearch } from '@/components/originui/select-with-search';
import { BlockButton, Button } from '@/components/ui/button';
import { Card, CardContent } from '@/components/ui/card';
@ -13,6 +12,7 @@ import {
import { RAGFlowSelect } from '@/components/ui/select';
import { Separator } from '@/components/ui/separator';
import { Textarea } from '@/components/ui/textarea';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { cn } from '@/lib/utils';
import { zodResolver } from '@hookform/resolvers/zod';
import { t } from 'i18next';
@ -22,11 +22,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useFieldArray, useForm, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { z } from 'zod';
import {
SwitchLogicOperatorOptions,
SwitchOperatorOptions,
VariableType,
} from '../../constant';
import { SwitchLogicOperatorOptions, VariableType } from '../../constant';
import { useBuildQueryVariableOptions } from '../../hooks/use-get-begin-query';
import { IOperatorForm } from '../../interface';
import { FormWrapper } from '../components/form-wrapper';
@ -43,42 +39,6 @@ type ConditionCardsProps = {
parentLength: number;
} & IOperatorForm;
export const LogicalOperatorIcon = function OperatorIcon({
icon,
value,
}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) {
if (typeof icon === 'string') {
return (
<IconFont
name={icon}
className={cn('size-4', {
'rotate-180': value === '>',
})}
></IconFont>
);
}
return icon;
};
export function useBuildSwitchOperatorOptions() {
const { t } = useTranslation();
const switchOperatorOptions = useMemo(() => {
return SwitchOperatorOptions.map((x) => ({
value: x.value,
icon: (
<LogicalOperatorIcon
icon={x.icon}
value={x.value}
></LogicalOperatorIcon>
),
label: t(`flow.switchOperatorOptions.${x.label}`),
}));
}, [t]);
return switchOperatorOptions;
}
function ConditionCards({
name: parentName,
parentIndex,

View File

@ -2,9 +2,13 @@ import { useEffect } from 'react';
import { UseFormReturn, useWatch } from 'react-hook-form';
import useGraphStore from '../store';
export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) {
export function useWatchFormChange(
id?: string,
form?: UseFormReturn<any>,
enableReplacement = false,
) {
let values = useWatch({ control: form?.control });
const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
const { updateNodeForm, replaceNodeForm } = useGraphStore((state) => state);
useEffect(() => {
// Manually triggered form updates are synchronized to the canvas
@ -12,7 +16,7 @@ export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) {
values = form?.getValues() || {};
let nextValues: any = values;
updateNodeForm(id, nextValues);
(enableReplacement ? replaceNodeForm : updateNodeForm)(id, nextValues);
}
}, [form?.formState.isDirty, id, updateNodeForm, values]);
}

View File

@ -14,7 +14,7 @@ import {
applyEdgeChanges,
applyNodeChanges,
} from '@xyflow/react';
import { omit } from 'lodash';
import { cloneDeep, omit } from 'lodash';
import differenceWith from 'lodash/differenceWith';
import intersectionWith from 'lodash/intersectionWith';
import lodashSet from 'lodash/set';
@ -53,6 +53,7 @@ export type RFState = {
values: any,
path?: (string | number)[],
) => RAGFlowNodeType[];
replaceNodeForm: (nodeId: string, values: any) => void;
onSelectionChange: OnSelectionChangeFunc;
addNode: (nodes: RAGFlowNodeType) => void;
getNode: (id?: string | null) => RAGFlowNodeType | undefined;
@ -433,6 +434,19 @@ const useGraphStore = create<RFState>()(
return nextNodes;
},
replaceNodeForm(nodeId, values) {
if (nodeId) {
set((state) => {
for (const node of state.nodes) {
if (node.id === nodeId) {
//cloneDeep Solving the issue of react-hook-form errors
node.data.form = cloneDeep(values); // TypeError: Cannot assign to read only property '0' of object '[object Array]'
break;
}
}
});
}
},
updateSwitchFormData: (source, sourceHandle, target, isConnecting) => {
const { updateNodeForm, edges } = get();
if (sourceHandle) {

View File

@ -273,7 +273,7 @@ function transformDataOperationsParams(params: DataOperationsFormSchemaType) {
...params,
select_keys: params?.select_keys?.map((x) => x.name),
remove_keys: params?.remove_keys?.map((x) => x.name),
inputs: params.inputs.map((x) => x.input),
query: params.query.map((x) => x.input),
};
}

View File

@ -1,6 +1,6 @@
import { SwitchOperatorOptions } from '@/constants/agent';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
import { SwitchOperatorOptions } from '@/pages/agent/constant';
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons';
import {
Button,

View File

@ -15,9 +15,9 @@ import {
} from '@/components/ui/form';
import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator';
import { SwitchOperatorOptions } from '@/constants/agent';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request';
import { SwitchOperatorOptions } from '@/pages/agent/constant';
import { useBuildSwitchOperatorOptions } from '@/pages/agent/form/switch-form';
import { Plus, X } from 'lucide-react';
import { useCallback } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form';