Feat: init dataflow. (#9791)

### What problem does this PR solve?

#9790

Close #9782

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-08-28 18:40:32 +08:00
committed by GitHub
parent a246949b77
commit c27172b3bc
19 changed files with 1020 additions and 166 deletions

View File

@ -29,83 +29,52 @@ from api.utils import get_uuid, hash_str2int
from rag.prompts.prompts import chunks_format from rag.prompts.prompts import chunks_format
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
class Graph:
class Canvas:
""" """
dsl = { dsl = {
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {},
},
"downstream": ["answer_0"],
"upstream": [],
},
"retrieval_0": {
"obj": {
"component_name": "Retrieval",
"params": {}
},
"downstream": ["generate_0"],
"upstream": ["answer_0"],
},
"generate_0": {
"obj": {
"component_name": "Generate",
"params": {}
},
"downstream": ["answer_0"],
"upstream": ["retrieval_0"],
}
},
"history": [],
"path": ["begin"],
"retrieval": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
}
"""
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.path = []
self.history = []
self.components = {}
self.error = ""
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
self.dsl = json.loads(dsl) if dsl else {
"components": { "components": {
"begin": { "begin": {
"obj": { "obj":{
"component_name": "Begin", "component_name": "Begin",
"params": { "params": {},
"prologue": "Hi there!"
}
}, },
"downstream": [], "downstream": ["answer_0"],
"upstream": [], "upstream": [],
"parent_id": "" },
"retrieval_0": {
"obj": {
"component_name": "Retrieval",
"params": {}
},
"downstream": ["generate_0"],
"upstream": ["answer_0"],
},
"generate_0": {
"obj": {
"component_name": "Generate",
"params": {}
},
"downstream": ["answer_0"],
"upstream": ["retrieval_0"],
} }
}, },
"history": [], "history": [],
"path": [], "path": ["begin"],
"retrieval": [], "retrieval": {"chunks": [], "doc_aggs": []},
"globals": { "globals": {
"sys.query": "", "sys.query": "",
"sys.user_id": "", "sys.user_id": tenant_id,
"sys.conversation_turns": 0, "sys.conversation_turns": 0,
"sys.files": [] "sys.files": []
} }
} }
"""
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.path = []
self.components = {}
self.error = ""
self.dsl = json.loads(dsl)
self._tenant_id = tenant_id self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid() self.task_id = task_id if task_id else get_uuid()
self.load() self.load()
@ -116,8 +85,6 @@ class Canvas:
for k, cpn in self.components.items(): for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"]) cpn_nms.add(cpn["obj"]["component_name"])
assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
for k, cpn in self.components.items(): for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"]) cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")() param = component_class(cpn["obj"]["component_name"] + "Param")()
@ -130,27 +97,10 @@ class Canvas:
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param) cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
self.path = self.dsl["path"] self.path = self.dsl["path"]
self.history = self.dsl["history"]
if "globals" in self.dsl:
self.globals = self.dsl["globals"]
else:
self.globals = {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])
def __str__(self): def __str__(self):
self.dsl["path"] = self.path self.dsl["path"] = self.path
self.dsl["history"] = self.history
self.dsl["globals"] = self.globals
self.dsl["task_id"] = self.task_id self.dsl["task_id"] = self.task_id
self.dsl["retrieval"] = self.retrieval
self.dsl["memory"] = self.memory
dsl = { dsl = {
"components": {} "components": {}
} }
@ -169,14 +119,79 @@ class Canvas:
dsl["components"][k][c] = deepcopy(cpn[c]) dsl["components"][k][c] = deepcopy(cpn[c])
return json.dumps(dsl, ensure_ascii=False) return json.dumps(dsl, ensure_ascii=False)
def reset(self, mem=False): def reset(self):
self.path = [] self.path = []
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
try:
REDIS_CONN.delete(f"{self.task_id}-logs")
except Exception as e:
logging.exception(e)
def get_component_name(self, cid):
for n in self.dsl.get("graph", {}).get("nodes", []):
if cid == n["id"]:
return n["data"]["name"]
return ""
def run(self, **kwargs):
raise NotImplementedError()
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
return self.components.get(cpn_id)
def get_component_obj(self, cpn_id) -> ComponentBase:
return self.components.get(cpn_id)["obj"]
def get_component_type(self, cpn_id) -> str:
return self.components.get(cpn_id)["obj"].component_name
def get_component_input_form(self, cpn_id) -> dict:
return self.components.get(cpn_id)["obj"].get_input_form()
def get_tenant_id(self):
return self._tenant_id
class Canvas(Graph):
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
super().__init__(dsl, tenant_id, task_id)
def load(self):
super().load()
self.history = self.dsl["history"]
if "globals" in self.dsl:
self.globals = self.dsl["globals"]
else:
self.globals = {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])
def __str__(self):
self.dsl["history"] = self.history
self.dsl["retrieval"] = self.retrieval
self.dsl["memory"] = self.memory
return super().__str__()
def reset(self, mem=False):
super().reset()
if not mem: if not mem:
self.history = [] self.history = []
self.retrieval = [] self.retrieval = []
self.memory = [] self.memory = []
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
for k in self.globals.keys(): for k in self.globals.keys():
if isinstance(self.globals[k], str): if isinstance(self.globals[k], str):
@ -192,17 +207,6 @@ class Canvas:
else: else:
self.globals[k] = None self.globals[k] = None
try:
REDIS_CONN.delete(f"{self.task_id}-logs")
except Exception as e:
logging.exception(e)
def get_component_name(self, cid):
for n in self.dsl.get("graph", {}).get("nodes", []):
if cid == n["id"]:
return n["data"]["name"]
return ""
def run(self, **kwargs): def run(self, **kwargs):
st = time.perf_counter() st = time.perf_counter()
self.message_id = get_uuid() self.message_id = get_uuid()
@ -388,18 +392,6 @@ class Canvas:
}) })
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output())) self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
return self.components.get(cpn_id)
def get_component_obj(self, cpn_id) -> ComponentBase:
return self.components.get(cpn_id)["obj"]
def get_component_type(self, cpn_id) -> str:
return self.components.get(cpn_id)["obj"].component_name
def get_component_input_form(self, cpn_id) -> dict:
return self.components.get(cpn_id)["obj"].get_input_form()
def is_reff(self, exp: str) -> bool: def is_reff(self, exp: str) -> bool:
exp = exp.strip("{").strip("}") exp = exp.strip("{").strip("}")
if exp.find("@") < 0: if exp.find("@") < 0:
@ -421,9 +413,6 @@ class Canvas:
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'") raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
return cpn["obj"].output(var_nm) return cpn["obj"].output(var_nm)
def get_tenant_id(self):
return self._tenant_id
def get_history(self, window_size): def get_history(self, window_size):
convs = [] convs = []
if window_size <= 0: if window_size <= 0:
@ -438,36 +427,6 @@ class Canvas:
def add_user_input(self, question): def add_user_input(self, question):
self.history.append(("user", question)) self.history.append(("user", question))
def _find_loop(self, max_loops=6):
path = self.path[-1][::-1]
if len(path) < 2:
return False
for i in range(len(path)):
if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0:
path = path[:i]
break
if len(path) < 2:
return False
for loc in range(2, len(path) // 2):
pat = ",".join(path[0:loc])
path_str = ",".join(path)
if len(pat) >= len(path_str):
return False
loop = max_loops
while path_str.find(pat) == 0 and loop >= 0:
loop -= 1
if len(pat)+1 >= len(path_str):
return False
path_str = path_str[len(pat)+1:]
if loop < 0:
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
return pat + " => " + pat
return False
def get_prologue(self): def get_prologue(self):
return self.components["begin"]["obj"]._param.prologue return self.components["begin"]["obj"]._param.prologue

View File

@ -50,8 +50,9 @@ del _package_path, _import_submodules, _extract_classes_from_module
def component_class(class_name): def component_class(class_name):
m = importlib.import_module("agent.component") for mdl in ["agent.component", "agent.tools", "rag.flow"]:
try: try:
return getattr(m, class_name) return getattr(importlib.import_module(mdl), class_name)
except Exception: except Exception:
return getattr(importlib.import_module("agent.tools"), class_name) pass
assert False, f"Can't import {class_name}"

View File

@ -16,7 +16,7 @@
import re import re
import time import time
from abc import ABC, abstractmethod from abc import ABC
import builtins import builtins
import json import json
import os import os
@ -410,8 +410,8 @@ class ComponentBase(ABC):
) )
def __init__(self, canvas, id, param: ComponentParamBase): def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency from agent.canvas import Graph # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas" assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
self._canvas = canvas self._canvas = canvas
self._id = id self._id = id
self._param = param self._param = param
@ -528,6 +528,10 @@ class ComponentBase(ABC):
cpn_nms = self._canvas.get_component(self._id)['upstream'] cpn_nms = self._canvas.get_component(self._id)['upstream']
return cpn_nms return cpn_nms
def get_downstream(self) -> List[str]:
cpn_nms = self._canvas.get_component(self._id)['downstream']
return cpn_nms
@staticmethod @staticmethod
def string_format(content: str, kv: dict[str, str]) -> str: def string_format(content: str, kv: dict[str, str]) -> str:
for n, v in kv.items(): for n, v in kv.items():
@ -556,6 +560,5 @@ class ComponentBase(ABC):
def set_exception_default_value(self): def set_exception_default_value(self):
self.set_output("result", self.get_exception_default_value()) self.set_output("result", self.get_exception_default_value())
@abstractmethod
def thoughts(self) -> str: def thoughts(self) -> str:
... raise NotImplementedError()

View File

@ -16,9 +16,8 @@
from abc import ABC from abc import ABC
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
from agent.tools.base import ToolParamBase, ToolBase from agent.tools.base import ToolParamBase, ToolBase
from api.utils.web_utils import is_valid_url
class CrawlerParam(ToolParamBase): class CrawlerParam(ToolParamBase):
@ -39,6 +38,7 @@ class Crawler(ToolBase, ABC):
component_name = "Crawler" component_name = "Crawler"
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
from api.utils.web_utils import is_valid_url
ans = self.get_input() ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else "" ans = " - ".join(ans["content"]) if "content" in ans else ""
if not is_valid_url(ans): if not is_valid_url(ans):

View File

@ -74,7 +74,6 @@ def retrieval(tenant_id):
[tenant_id], [tenant_id],
[kb_id], [kb_id],
embd_mdl, embd_mdl,
doc_ids,
LLMBundle(kb.tenant_id, LLMType.CHAT)) LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck) ranks["chunks"].insert(0, ck)

View File

@ -133,6 +133,13 @@ class UserService(CommonService):
cls.model.update(user_dict).where( cls.model.update(user_dict).where(
cls.model.id == user_id).execute() cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def is_admin(cls, user_id):
return cls.model.select().where(
cls.model.id == user_id,
cls.model.is_superuser == 1).count() > 0
class TenantService(CommonService): class TenantService(CommonService):
"""Service class for managing tenant-related database operations. """Service class for managing tenant-related database operations.

View File

@ -131,6 +131,12 @@ class RAGFlowExcelParser:
return tb_chunks return tb_chunks
def markdown(self, fnm):
import pandas as pd
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
df = pd.read_excel(file_like_object)
return df.to_markdown(index=False)
def __call__(self, fnm): def __call__(self, fnm):
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
wb = RAGFlowExcelParser._load_excel_to_workbook(file_like_object) wb = RAGFlowExcelParser._load_excel_to_workbook(file_like_object)

View File

@ -93,6 +93,7 @@ class RAGFlowPdfParser:
model_dir, "updown_concat_xgb.model")) model_dir, "updown_concat_xgb.model"))
self.page_from = 0 self.page_from = 0
self.column_num = 1
def __char_width(self, c): def __char_width(self, c):
return (c["x1"] - c["x0"]) // max(len(c["text"]), 1) return (c["x1"] - c["x0"]) // max(len(c["text"]), 1)
@ -427,10 +428,18 @@ class RAGFlowPdfParser:
i += 1 i += 1
self.boxes = bxs self.boxes = bxs
def _naive_vertical_merge(self): def _naive_vertical_merge(self, zoomin=3):
bxs = Recognizer.sort_Y_firstly( bxs = Recognizer.sort_Y_firstly(
self.boxes, np.median( self.boxes, np.median(
self.mean_height) / 3) self.mean_height) / 3)
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
self.column_num = int(self.page_images[0].size[0] / zoomin / column_width)
if column_width < self.page_images[0].size[0] / zoomin / self.column_num:
logging.info("Multi-column................... {} {}".format(column_width,
self.page_images[0].size[0] / zoomin / self.column_num))
self.boxes = self.sort_X_by_page(self.boxes, column_width / self.column_num)
i = 0 i = 0
while i + 1 < len(bxs): while i + 1 < len(bxs):
b = bxs[i] b = bxs[i]
@ -1139,20 +1148,94 @@ class RAGFlowPdfParser:
need_image, zoomin, return_html, False) need_image, zoomin, return_html, False)
return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
def parse_into_bboxes(self, fnm, callback=None, zoomin=3):
start = timer()
self.__images__(fnm, zoomin)
if callback:
callback(0.40, "OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
if callback:
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
if callback:
callback(0.83, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
self._concat_downward()
self._naive_vertical_merge(zoomin)
if callback:
callback(0.92, "Text merged ({:.2f}s)".format(timer() - start))
start = timer()
tbls, figs = self._extract_table_figure(True, zoomin, True, True, True)
def insert_table_figures(tbls_or_figs, layout_type):
def min_rectangle_distance(rect1, rect2):
import math
pn1, left1, right1, top1, bottom1 = rect1
pn2, left2, right2, top2, bottom2 = rect2
if (right1 >= left2 and right2 >= left1 and
bottom1 >= top2 and bottom2 >= top1):
return 0 + (pn1-pn2)*10000
if right1 < left2:
dx = left2 - right1
elif right2 < left1:
dx = left1 - right2
else:
dx = 0
if bottom1 < top2:
dy = top2 - bottom1
elif bottom2 < top1:
dy = top1 - bottom2
else:
dy = 0
return math.sqrt(dx*dx + dy*dy) + (pn1-pn2)*10000
for (img, txt), poss in tbls_or_figs:
bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)]
dists = [(min_rectangle_distance((pn, left, right, top, bott), rect),i) for i, rect in bboxes for pn, left, right, top, bott in poss]
min_i = np.argmin(dists, axis=0)[0]
min_i, rect = bboxes[dists[min_i][-1]]
if isinstance(txt, list):
txt = "\n".join(txt)
self.boxes.insert(min_i, {
"page_number": rect[0], "x0": rect[1], "x1": rect[2], "top": rect[3], "bottom": rect[4], "layout_type": layout_type, "text": txt, "image": img
})
for b in self.boxes:
b["position_tag"] = self._line_tag(b, zoomin)
b["image"] = self.crop(b["position_tag"], zoomin)
insert_table_figures(tbls, "table")
insert_table_figures(figs, "figure")
if callback:
callback(1, "Structured ({:.2f}s)".format(timer() - start))
return deepcopy(self.boxes)
@staticmethod @staticmethod
def remove_tag(txt): def remove_tag(txt):
return re.sub(r"@@[\t0-9.-]+?##", "", txt) return re.sub(r"@@[\t0-9.-]+?##", "", txt)
def crop(self, text, ZM=3, need_position=False): @staticmethod
imgs = [] def extract_positions(txt):
poss = [] poss = []
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", text): for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
pn, left, right, top, bottom = tag.strip( pn, left, right, top, bottom = tag.strip(
"#").strip("@").split("\t") "#").strip("@").split("\t")
left, right, top, bottom = float(left), float( left, right, top, bottom = float(left), float(
right), float(top), float(bottom) right), float(top), float(bottom)
poss.append(([int(p) - 1 for p in pn.split("-")], poss.append(([int(p) - 1 for p in pn.split("-")],
left, right, top, bottom)) left, right, top, bottom))
return poss
def crop(self, text, ZM=3, need_position=False):
imgs = []
poss = self.extract_positions(text)
if not poss: if not poss:
if need_position: if need_position:
return None, None return None, None
@ -1296,8 +1379,8 @@ class VisionParser(RAGFlowPdfParser):
def __call__(self, filename, from_page=0, to_page=100000, **kwargs): def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
callback = kwargs.get("callback", lambda prog, msg: None) callback = kwargs.get("callback", lambda prog, msg: None)
zoomin = kwargs.get("zoomin", 3)
self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs) self.__images__(fnm=filename, zoomin=zoomin, page_from=from_page, page_to=to_page, callback=callback)
total_pdf_pages = self.total_page total_pdf_pages = self.total_page
@ -1311,16 +1394,19 @@ class VisionParser(RAGFlowPdfParser):
if pdf_page_num < start_page or pdf_page_num >= end_page: if pdf_page_num < start_page or pdf_page_num >= end_page:
continue continue
docs = picture_vision_llm_chunk( text = picture_vision_llm_chunk(
binary=img_binary, binary=img_binary,
vision_model=self.vision_model, vision_model=self.vision_model,
prompt=vision_llm_describe_prompt(page=pdf_page_num+1), prompt=vision_llm_describe_prompt(page=pdf_page_num+1),
callback=callback, callback=callback,
) )
if kwargs.get("callback"):
kwargs["callback"](idx*1./len(self.page_images), f"Processed: {idx+1}/{len(self.page_images)}")
if docs: if text:
all_docs.append(docs) width, height = self.page_images[idx].size
return [(doc, "") for doc in all_docs], [] all_docs.append((text, f"{pdf_page_num+1} 0 {width/zoomin} 0 {height/zoomin}"))
return all_docs, []
if __name__ == "__main__": if __name__ == "__main__":

49
rag/flow/__init__.py Normal file
View File

@ -0,0 +1,49 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import importlib
import inspect
from types import ModuleType
from typing import Dict, Type
_package_path = os.path.dirname(__file__)
__all_classes: Dict[str, Type] = {}
def _import_submodules() -> None:
for filename in os.listdir(_package_path): # noqa: F821
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
continue
module_name = filename[:-3]
try:
module = importlib.import_module(f".{module_name}", package=__name__)
_extract_classes_from_module(module) # noqa: F821
except ImportError as e:
print(f"Warning: Failed to import module {module_name}: {str(e)}")
def _extract_classes_from_module(module: ModuleType) -> None:
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
obj.__module__ == module.__name__ and not name.startswith("_")):
__all_classes[name] = obj
globals()[name] = obj
_import_submodules()
__all__ = list(__all_classes.keys()) + ["__all_classes"]
del _package_path, _import_submodules, _extract_classes_from_module

59
rag/flow/base.py Normal file
View File

@ -0,0 +1,59 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import os
import logging
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentParamBase, ComponentBase
from api.utils.api_utils import timeout
class ProcessParamBase(ComponentParamBase):
def __init__(self):
super().__init__()
self.timeout = 100000000
self.persist_logs = True
class ProcessBase(ComponentBase):
def __init__(self, pipeline, id, param: ProcessParamBase):
super().__init__(pipeline, id, param)
self.callback = partial(self._canvas.callback, self.component_name)
async def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter())
for k,v in kwargs.items():
self.set_output(k, v)
try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self.callback(-1, str(e))
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
async def _invoke(self, **kwargs):
raise NotImplementedError()

47
rag/flow/begin.py Normal file
View File

@ -0,0 +1,47 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.utils.storage_factory import STORAGE_IMPL
class FileParam(ProcessParamBase):
def __init__(self):
super().__init__()
def check(self):
pass
class File(ProcessBase):
component_name = "File"
async def _invoke(self, **kwargs):
if self._canvas._doc_id:
e, doc = DocumentService.get_by_id(self._canvas._doc_id)
if not e:
self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!")
return
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
self.set_output("blob", STORAGE_IMPL.get(b, n))
self.set_output("name", doc.name)
else:
file = kwargs.get("file")
self.set_output("name", file["name"])
self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))

160
rag/flow/chunker.py Normal file
View File

@ -0,0 +1,160 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import trio
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from graphrag.utils import get_llm_cache, chat_limiter, set_llm_cache
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.nlp import naive_merge, naive_merge_with_images
from rag.prompts.prompts import keyword_extraction, question_proposal
class ChunkerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.method_options = ["general", "q&a", "resume", "manual", "table", "paper", "book", "laws", "presentation", "one"]
self.method = "general"
self.chunk_token_size = 512
self.delimiter = "\n"
self.overlapped_percent = 0
self.page_rank = 0
self.auto_keywords = 0
self.auto_questions = 0
self.tag_sets = []
self.llm_setting = {
"llm_name": "",
"lang": "Chinese"
}
def check(self):
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
self.check_nonnegative_number(self.page_rank, "Page rank value: (0, 10]")
self.check_nonnegative_number(self.auto_keywords, "Auto-keyword value: (0, 10]")
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
class Chunker(ProcessBase):
component_name = "Chunker"
def _general(self, **kwargs):
self.callback(random.randint(1,5)/100., "Start to chunk via `General`.")
if kwargs.get("output_format") in ["markdown", "text"]:
cks = naive_merge(kwargs.get(kwargs["output_format"]), self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
return [{"text": c} for c in cks]
sections, section_images = [], []
for o in kwargs["json"]:
sections.append((o["text"], o.get("position_tag","")))
section_images.append(o.get("image"))
chunks, images = naive_merge_with_images(sections, section_images,self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
return [{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c)
} for c,img in zip(chunks,images)]
def _q_and_a(self, **kwargs):
pass
def _resume(self, **kwargs):
pass
def _manual(self, **kwargs):
pass
def _table(self, **kwargs):
pass
def _paper(self, **kwargs):
pass
def _book(self, **kwargs):
pass
def _laws(self, **kwargs):
pass
def _presentation(self, **kwargs):
pass
def _one(self, **kwargs):
pass
async def _invoke(self, **kwargs):
function_map = {
"general": self._general,
"q&a": self._q_and_a,
"resume": self._resume,
"manual": self._manual,
"table": self._table,
"paper": self._paper,
"book": self._book,
"laws": self._laws,
"presentation": self._presentation,
"one": self._one,
}
chunks = function_map[self._param.method](**kwargs)
llm_setting = self._param.llm_setting
async def auto_keywords():
nonlocal chunks, llm_setting
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
async def doc_keyword_extraction(chat_mdl, ck, topn):
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "keywords", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, ck["text"], topn))
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "keywords", {"topn": topn})
if cached:
ck["keywords"] = cached.split(",")
async with trio.open_nursery() as nursery:
for ck in chunks:
nursery.start_soon(doc_keyword_extraction, chat_mdl, ck, self._param.auto_keywords)
async def auto_questions():
nonlocal chunks, llm_setting
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "question", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, ck["text"], topn))
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "question", {"topn": topn})
if cached:
d["questions"] = cached.split("\n")
async with trio.open_nursery() as nursery:
for ck in chunks:
nursery.start_soon(doc_question_proposal, chat_mdl, ck, self._param.auto_questions)
async with trio.open_nursery() as nursery:
if self._param.auto_questions:
nursery.start_soon(auto_questions)
if self._param.auto_keywords:
nursery.start_soon(auto_keywords)
if self._param.page_rank:
for ck in chunks:
ck["page_rank"] = self._param.page_rank
self.set_output("chunks", chunks)

107
rag/flow/parser.py Normal file
View File

@ -0,0 +1,107 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import trio
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import RAGFlowPdfParser, PlainParser, VisionParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.llm.cv_model import Base as VLM
from deepdoc.parser import ExcelParser
class ParserParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.setups = {
"pdf": {
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
"vlm_name": "",
"lang": "Chinese",
"suffix": ["pdf"],
"output_format": "json"
},
"excel": {
"output_format": "html"
},
"ppt": {},
"image": {
"parse_method": "ocr"
},
"email": {},
"text": {},
"audio": {},
"video": {},
}
def check(self):
if self.setups["pdf"].get("parse_method") not in ["deepdoc", "plain_text"]:
assert self.setups["pdf"].get("vlm_name"), "No VLM specified."
assert self.setups["pdf"].get("lang"), "No language specified."
class Parser(ProcessBase):
component_name = "Parser"
def _pdf(self, blob):
self.callback(random.randint(1,5)/100., "Start to work on a PDF.")
conf = self._param.setups["pdf"]
self.set_output("output_format", conf["output_format"])
if conf.get("parse_method") == "deepdoc":
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
elif conf.get("parse_method") == "plain_text":
lines,_ = PlainParser()(blob)
bboxes = [{"text": t} for t,_ in lines]
else:
assert conf.get("vlm_name")
vision_model = LLMBundle(self._canvas.tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(bin, callback=self.callback)
bboxes = []
for t, poss in lines:
pn, x0, x1, top, bott = poss.split(" ")
bboxes.append({"page_number": int(pn), "x0": int(x0), "x1": int(x1), "top": int(top), "bottom": int(bott), "text": t})
self.set_output("json", bboxes)
mkdn = ""
for b in bboxes:
if b.get("layout_type", "") == "title":
mkdn += "\n## "
if b.get("layout_type", "") == "figure":
mkdn += "\n![Image]({})".format(VLM.image2base64(b["image"]))
continue
mkdn += b.get("text", "") + "\n"
self.set_output("markdown", mkdn)
def _excel(self, blob):
self.callback(random.randint(1,5)/100., "Start to work on a Excel.")
conf = self._param.setups["excel"]
excel_parser = ExcelParser()
if conf.get("output_format") == "html":
html = excel_parser.html(blob,1000000000)
self.set_output("html", html)
elif conf.get("output_format") == "json":
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
elif conf.get("output_format") == "markdown":
self.set_output("markdown", excel_parser.markdown(blob))
async def _invoke(self, **kwargs):
function_map = {
"pdf": self._pdf,
}
for p_type, conf in self._param.setups.items():
if kwargs.get("name", "").split(".")[-1].lower() not in conf.get("suffix", []):
continue
await trio.to_thread.run_sync(function_map[p_type], kwargs["blob"])
break

121
rag/flow/pipeline.py Normal file
View File

@ -0,0 +1,121 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
import logging
import random
import time
import trio
from agent.canvas import Graph
from api.db.services.document_service import DocumentService
from rag.utils.redis_conn import REDIS_CONN
class Pipeline(Graph):
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
super().__init__(dsl, tenant_id, task_id)
self._doc_id = doc_id
self._flow_id = flow_id
self._kb_id = None
if doc_id:
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
def callback(self, component_name: str, progress: float|int|None=None, message: str = "") -> None:
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
bin = REDIS_CONN.get(log_key)
obj = json.loads(bin.encode("utf-8"))
if obj:
if obj[-1]["component_name"] == component_name:
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
else:
obj.append({
"component_name": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
})
else:
obj = [{
"component_name": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
}]
REDIS_CONN.set_obj(log_key, obj, 60*10)
except Exception as e:
logging.exception(e)
def fetch_logs(self):
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
bin = REDIS_CONN.get(log_key)
if bin:
return json.loads(bin.encode("utf-8"))
except Exception as e:
logging.exception(e)
return []
def reset(self):
super().reset()
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
REDIS_CONN.set_obj(log_key, [], 60*10)
except Exception as e:
logging.exception(e)
async def run(self, **kwargs):
st = time.perf_counter()
if not self.path:
self.path.append("begin")
if self._doc_id:
DocumentService.update_by_id(self._doc_id, {
"progress": random.randint(0,5)/100.,
"progress_msg": "Start the pipeline...",
"process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
else:
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx-1])
cpn_obj = self.get_component_obj(self.path[idx])
async def invoke():
nonlocal last_cpn, cpn_obj
await cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
break
idx += 1
self.path.extend(cpn_obj.get_downstream())
if self._doc_id:
DocumentService.update_by_id(self._doc_id, {
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st
})

57
rag/flow/tests/client.py Normal file
View File

@ -0,0 +1,57 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
import trio
from api import settings
from rag.flow.pipeline import Pipeline
def print_logs(pipeline):
last_logs = "[]"
while True:
time.sleep(5)
logs = pipeline.fetch_logs()
logs_str = json.dumps(logs)
if logs_str != last_logs:
print(logs_str)
last_logs = logs_str
if __name__ == '__main__':
parser = argparse.ArgumentParser()
dsl_default_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"dsl_examples",
"general_pdf_all.json",
)
parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
args = parser.parse_args()
settings.init_settings()
pipeline = Pipeline(open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx")
pipeline.reset()
exe = ThreadPoolExecutor(max_workers=5)
thr = exe.submit(print_logs, pipeline)
trio.run(pipeline.run)
thr.result()

View File

@ -0,0 +1,54 @@
{
"components": {
"begin": {
"obj":{
"component_name": "File",
"params": {
}
},
"downstream": ["parser:0"],
"upstream": []
},
"parser:0": {
"obj": {
"component_name": "Parser",
"params": {
"setups": {
"pdf": {
"parse_method": "deepdoc",
"vlm_name": "",
"lang": "Chinese",
"suffix": [
"pdf"
],
"output_format": "json"
}
}
}
},
"downstream": ["chunker:0"],
"upstream": ["begin"]
},
"chunker:0": {
"obj": {
"component_name": "Chunker",
"params": {
"method": "general",
"auto_keywords": 5
}
},
"downstream": ["tokenizer:0"],
"upstream": ["chunker:0"]
},
"tokenizer:0": {
"obj": {
"component_name": "Tokenizer",
"params": {
}
},
"downstream": [],
"upstream": ["chunker:0"]
}
},
"path": []
}

134
rag/flow/tokenizer.py Normal file
View File

@ -0,0 +1,134 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import re
import numpy as np
import trio
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.utils.api_utils import timeout
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.nlp import rag_tokenizer
from rag.settings import EMBEDDING_BATCH_SIZE
from rag.svr.task_executor import embed_limiter
from rag.utils import truncate
class TokenizerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.search_method = ["full_text", "embedding"]
self.filename_embd_weight = 0.1
def check(self):
for v in self.search_method:
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
class Tokenizer(ProcessBase):
component_name = "Tokenizer"
async def _embedding(self, name, chunks):
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
token_count = 0
if self._canvas._kb_id:
e, kb = KnowledgebaseService.get_by_id(self._canvas._kb_id)
embedding_id = kb.embd_id
else:
e, ten = TenantService.get_by_id(self._canvas._tenant_id)
embedding_id = ten.embd_id
embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id)
texts = []
for c in chunks:
if c.get("questions"):
texts.append("\n".join(c["questions"]))
else:
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c["text"]))
vts, c = embedding_model.encode([name])
token_count += c
tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length-10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i: i + EMBEDDING_BATCH_SIZE]))
if len(cnts_) == 0:
cnts_ = vts
else:
cnts_ = np.concatenate((cnts_, vts), axis=0)
token_count += c
if i % 33 == 32:
self.callback(i*1./len(texts)/parts/EMBEDDING_BATCH_SIZE + 0.5*(parts-1))
cnts = cnts_
title_w = float(self._param.filename_embd_weight)
vects = (title_w * tts + (1 - title_w) * cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
return chunks, token_count
async def _invoke(self, **kwargs):
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
if "full_text" in self._param.search_method:
self.callback(random.randint(1,5)/100., "Start to tokenize.")
if kwargs.get("chunks"):
chunks = kwargs["chunks"]
for i, ck in enumerate(chunks):
if ck.get("questions"):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i*1./len(chunks)/parts)
elif kwargs.get("output_format") in ["markdown", "text"]:
ck = {
"text": kwargs.get(kwargs["output_format"], "")
}
if "full_text" in self._param.search_method:
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
chunks = [ck]
else:
chunks = kwargs["json"]
for i, ck in enumerate(chunks):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i*1./len(chunks)/parts)
self.callback(1./parts, "Finish tokenizing.")
if "embedding" in self._param.search_method:
self.callback(random.randint(1,5)/100. + 0.5*(parts-1), "Start embedding inference.")
chunks, token_count = await self._embedding(kwargs.get("name", ""), chunks)
self.set_output("embedding_token_consumption", token_count)
self.callback(1., "Finish embedding.")
self.set_output("chunks", chunks)

View File

@ -563,7 +563,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。", overl
return cks return cks
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?"): def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not texts or len(texts) != len(images): if not texts or len(texts) != len(images):
return [], [] return [], []
cks = [""] cks = [""]
@ -578,7 +579,10 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
if tnum < 8: if tnum < 8:
pos = "" pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num # Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num: if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0: if t.find(pos) < 0:
t += pos t += pos
cks.append(t) cks.append(t)

View File

@ -93,7 +93,8 @@ class MCPToolCallSession(ToolCallSession):
msg = f"Timeout initializing client_session for server {self._mcp_server.id}" msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
logging.error(msg) logging.error(msg)
await self._process_mcp_tasks(None, msg) await self._process_mcp_tasks(None, msg)
except Exception: except Exception as e:
logging.exception(e)
msg = "Connection failed (possibly due to auth error). Please check authentication settings first" msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
await self._process_mcp_tasks(None, msg) await self._process_mcp_tasks(None, msg)