Feat: init dataflow. (#9791)

### What problem does this PR solve?

#9790

Close #9782

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-08-28 18:40:32 +08:00
committed by GitHub
parent a246949b77
commit c27172b3bc
19 changed files with 1020 additions and 166 deletions

View File

@ -29,8 +29,7 @@ from api.utils import get_uuid, hash_str2int
from rag.prompts.prompts import chunks_format
from rag.utils.redis_conn import REDIS_CONN
class Canvas:
class Graph:
"""
dsl = {
"components": {
@ -73,39 +72,9 @@ class Canvas:
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.path = []
self.history = []
self.components = {}
self.error = ""
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
self.dsl = json.loads(dsl) if dsl else {
"components": {
"begin": {
"obj": {
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": [],
"upstream": [],
"parent_id": ""
}
},
"history": [],
"path": [],
"retrieval": [],
"globals": {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
}
self.dsl = json.loads(dsl)
self._tenant_id = tenant_id
self.task_id = task_id if task_id else get_uuid()
self.load()
@ -116,8 +85,6 @@ class Canvas:
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
@ -130,27 +97,10 @@ class Canvas:
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
self.path = self.dsl["path"]
self.history = self.dsl["history"]
if "globals" in self.dsl:
self.globals = self.dsl["globals"]
else:
self.globals = {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])
def __str__(self):
self.dsl["path"] = self.path
self.dsl["history"] = self.history
self.dsl["globals"] = self.globals
self.dsl["task_id"] = self.task_id
self.dsl["retrieval"] = self.retrieval
self.dsl["memory"] = self.memory
dsl = {
"components": {}
}
@ -169,14 +119,79 @@ class Canvas:
dsl["components"][k][c] = deepcopy(cpn[c])
return json.dumps(dsl, ensure_ascii=False)
def reset(self, mem=False):
def reset(self):
self.path = []
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
try:
REDIS_CONN.delete(f"{self.task_id}-logs")
except Exception as e:
logging.exception(e)
def get_component_name(self, cid):
for n in self.dsl.get("graph", {}).get("nodes", []):
if cid == n["id"]:
return n["data"]["name"]
return ""
def run(self, **kwargs):
raise NotImplementedError()
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
return self.components.get(cpn_id)
def get_component_obj(self, cpn_id) -> ComponentBase:
return self.components.get(cpn_id)["obj"]
def get_component_type(self, cpn_id) -> str:
return self.components.get(cpn_id)["obj"].component_name
def get_component_input_form(self, cpn_id) -> dict:
return self.components.get(cpn_id)["obj"].get_input_form()
def get_tenant_id(self):
return self._tenant_id
class Canvas(Graph):
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
super().__init__(dsl, tenant_id, task_id)
def load(self):
super().load()
self.history = self.dsl["history"]
if "globals" in self.dsl:
self.globals = self.dsl["globals"]
else:
self.globals = {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])
def __str__(self):
self.dsl["history"] = self.history
self.dsl["retrieval"] = self.retrieval
self.dsl["memory"] = self.memory
return super().__str__()
def reset(self, mem=False):
super().reset()
if not mem:
self.history = []
self.retrieval = []
self.memory = []
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
for k in self.globals.keys():
if isinstance(self.globals[k], str):
@ -192,17 +207,6 @@ class Canvas:
else:
self.globals[k] = None
try:
REDIS_CONN.delete(f"{self.task_id}-logs")
except Exception as e:
logging.exception(e)
def get_component_name(self, cid):
for n in self.dsl.get("graph", {}).get("nodes", []):
if cid == n["id"]:
return n["data"]["name"]
return ""
def run(self, **kwargs):
st = time.perf_counter()
self.message_id = get_uuid()
@ -388,18 +392,6 @@ class Canvas:
})
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
return self.components.get(cpn_id)
def get_component_obj(self, cpn_id) -> ComponentBase:
return self.components.get(cpn_id)["obj"]
def get_component_type(self, cpn_id) -> str:
return self.components.get(cpn_id)["obj"].component_name
def get_component_input_form(self, cpn_id) -> dict:
return self.components.get(cpn_id)["obj"].get_input_form()
def is_reff(self, exp: str) -> bool:
exp = exp.strip("{").strip("}")
if exp.find("@") < 0:
@ -421,9 +413,6 @@ class Canvas:
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
return cpn["obj"].output(var_nm)
def get_tenant_id(self):
return self._tenant_id
def get_history(self, window_size):
convs = []
if window_size <= 0:
@ -438,36 +427,6 @@ class Canvas:
def add_user_input(self, question):
self.history.append(("user", question))
def _find_loop(self, max_loops=6):
path = self.path[-1][::-1]
if len(path) < 2:
return False
for i in range(len(path)):
if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0:
path = path[:i]
break
if len(path) < 2:
return False
for loc in range(2, len(path) // 2):
pat = ",".join(path[0:loc])
path_str = ",".join(path)
if len(pat) >= len(path_str):
return False
loop = max_loops
while path_str.find(pat) == 0 and loop >= 0:
loop -= 1
if len(pat)+1 >= len(path_str):
return False
path_str = path_str[len(pat)+1:]
if loop < 0:
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
return pat + " => " + pat
return False
def get_prologue(self):
return self.components["begin"]["obj"]._param.prologue

View File

@ -50,8 +50,9 @@ del _package_path, _import_submodules, _extract_classes_from_module
def component_class(class_name):
m = importlib.import_module("agent.component")
for mdl in ["agent.component", "agent.tools", "rag.flow"]:
try:
return getattr(m, class_name)
return getattr(importlib.import_module(mdl), class_name)
except Exception:
return getattr(importlib.import_module("agent.tools"), class_name)
pass
assert False, f"Can't import {class_name}"

View File

@ -16,7 +16,7 @@
import re
import time
from abc import ABC, abstractmethod
from abc import ABC
import builtins
import json
import os
@ -410,8 +410,8 @@ class ComponentBase(ABC):
)
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
from agent.canvas import Graph # Local import to avoid cyclic dependency
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
self._param = param
@ -528,6 +528,10 @@ class ComponentBase(ABC):
cpn_nms = self._canvas.get_component(self._id)['upstream']
return cpn_nms
def get_downstream(self) -> List[str]:
cpn_nms = self._canvas.get_component(self._id)['downstream']
return cpn_nms
@staticmethod
def string_format(content: str, kv: dict[str, str]) -> str:
for n, v in kv.items():
@ -556,6 +560,5 @@ class ComponentBase(ABC):
def set_exception_default_value(self):
self.set_output("result", self.get_exception_default_value())
@abstractmethod
def thoughts(self) -> str:
...
raise NotImplementedError()

View File

@ -16,9 +16,8 @@
from abc import ABC
import asyncio
from crawl4ai import AsyncWebCrawler
from agent.tools.base import ToolParamBase, ToolBase
from api.utils.web_utils import is_valid_url
class CrawlerParam(ToolParamBase):
@ -39,6 +38,7 @@ class Crawler(ToolBase, ABC):
component_name = "Crawler"
def _run(self, history, **kwargs):
from api.utils.web_utils import is_valid_url
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not is_valid_url(ans):

View File

@ -74,7 +74,6 @@ def retrieval(tenant_id):
[tenant_id],
[kb_id],
embd_mdl,
doc_ids,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

View File

@ -133,6 +133,13 @@ class UserService(CommonService):
cls.model.update(user_dict).where(
cls.model.id == user_id).execute()
@classmethod
@DB.connection_context()
def is_admin(cls, user_id):
return cls.model.select().where(
cls.model.id == user_id,
cls.model.is_superuser == 1).count() > 0
class TenantService(CommonService):
"""Service class for managing tenant-related database operations.

View File

@ -131,6 +131,12 @@ class RAGFlowExcelParser:
return tb_chunks
def markdown(self, fnm):
import pandas as pd
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
df = pd.read_excel(file_like_object)
return df.to_markdown(index=False)
def __call__(self, fnm):
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
wb = RAGFlowExcelParser._load_excel_to_workbook(file_like_object)

View File

@ -93,6 +93,7 @@ class RAGFlowPdfParser:
model_dir, "updown_concat_xgb.model"))
self.page_from = 0
self.column_num = 1
def __char_width(self, c):
return (c["x1"] - c["x0"]) // max(len(c["text"]), 1)
@ -427,10 +428,18 @@ class RAGFlowPdfParser:
i += 1
self.boxes = bxs
def _naive_vertical_merge(self):
def _naive_vertical_merge(self, zoomin=3):
bxs = Recognizer.sort_Y_firstly(
self.boxes, np.median(
self.mean_height) / 3)
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
self.column_num = int(self.page_images[0].size[0] / zoomin / column_width)
if column_width < self.page_images[0].size[0] / zoomin / self.column_num:
logging.info("Multi-column................... {} {}".format(column_width,
self.page_images[0].size[0] / zoomin / self.column_num))
self.boxes = self.sort_X_by_page(self.boxes, column_width / self.column_num)
i = 0
while i + 1 < len(bxs):
b = bxs[i]
@ -1139,20 +1148,94 @@ class RAGFlowPdfParser:
need_image, zoomin, return_html, False)
return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
def parse_into_bboxes(self, fnm, callback=None, zoomin=3):
start = timer()
self.__images__(fnm, zoomin)
if callback:
callback(0.40, "OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
if callback:
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
if callback:
callback(0.83, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
self._concat_downward()
self._naive_vertical_merge(zoomin)
if callback:
callback(0.92, "Text merged ({:.2f}s)".format(timer() - start))
start = timer()
tbls, figs = self._extract_table_figure(True, zoomin, True, True, True)
def insert_table_figures(tbls_or_figs, layout_type):
def min_rectangle_distance(rect1, rect2):
import math
pn1, left1, right1, top1, bottom1 = rect1
pn2, left2, right2, top2, bottom2 = rect2
if (right1 >= left2 and right2 >= left1 and
bottom1 >= top2 and bottom2 >= top1):
return 0 + (pn1-pn2)*10000
if right1 < left2:
dx = left2 - right1
elif right2 < left1:
dx = left1 - right2
else:
dx = 0
if bottom1 < top2:
dy = top2 - bottom1
elif bottom2 < top1:
dy = top1 - bottom2
else:
dy = 0
return math.sqrt(dx*dx + dy*dy) + (pn1-pn2)*10000
for (img, txt), poss in tbls_or_figs:
bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)]
dists = [(min_rectangle_distance((pn, left, right, top, bott), rect),i) for i, rect in bboxes for pn, left, right, top, bott in poss]
min_i = np.argmin(dists, axis=0)[0]
min_i, rect = bboxes[dists[min_i][-1]]
if isinstance(txt, list):
txt = "\n".join(txt)
self.boxes.insert(min_i, {
"page_number": rect[0], "x0": rect[1], "x1": rect[2], "top": rect[3], "bottom": rect[4], "layout_type": layout_type, "text": txt, "image": img
})
for b in self.boxes:
b["position_tag"] = self._line_tag(b, zoomin)
b["image"] = self.crop(b["position_tag"], zoomin)
insert_table_figures(tbls, "table")
insert_table_figures(figs, "figure")
if callback:
callback(1, "Structured ({:.2f}s)".format(timer() - start))
return deepcopy(self.boxes)
@staticmethod
def remove_tag(txt):
return re.sub(r"@@[\t0-9.-]+?##", "", txt)
def crop(self, text, ZM=3, need_position=False):
imgs = []
@staticmethod
def extract_positions(txt):
poss = []
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", text):
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
pn, left, right, top, bottom = tag.strip(
"#").strip("@").split("\t")
left, right, top, bottom = float(left), float(
right), float(top), float(bottom)
poss.append(([int(p) - 1 for p in pn.split("-")],
left, right, top, bottom))
return poss
def crop(self, text, ZM=3, need_position=False):
imgs = []
poss = self.extract_positions(text)
if not poss:
if need_position:
return None, None
@ -1296,8 +1379,8 @@ class VisionParser(RAGFlowPdfParser):
def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
callback = kwargs.get("callback", lambda prog, msg: None)
self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs)
zoomin = kwargs.get("zoomin", 3)
self.__images__(fnm=filename, zoomin=zoomin, page_from=from_page, page_to=to_page, callback=callback)
total_pdf_pages = self.total_page
@ -1311,16 +1394,19 @@ class VisionParser(RAGFlowPdfParser):
if pdf_page_num < start_page or pdf_page_num >= end_page:
continue
docs = picture_vision_llm_chunk(
text = picture_vision_llm_chunk(
binary=img_binary,
vision_model=self.vision_model,
prompt=vision_llm_describe_prompt(page=pdf_page_num+1),
callback=callback,
)
if kwargs.get("callback"):
kwargs["callback"](idx*1./len(self.page_images), f"Processed: {idx+1}/{len(self.page_images)}")
if docs:
all_docs.append(docs)
return [(doc, "") for doc in all_docs], []
if text:
width, height = self.page_images[idx].size
all_docs.append((text, f"{pdf_page_num+1} 0 {width/zoomin} 0 {height/zoomin}"))
return all_docs, []
if __name__ == "__main__":

49
rag/flow/__init__.py Normal file
View File

@ -0,0 +1,49 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import importlib
import inspect
from types import ModuleType
from typing import Dict, Type
_package_path = os.path.dirname(__file__)
__all_classes: Dict[str, Type] = {}
def _import_submodules() -> None:
for filename in os.listdir(_package_path): # noqa: F821
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
continue
module_name = filename[:-3]
try:
module = importlib.import_module(f".{module_name}", package=__name__)
_extract_classes_from_module(module) # noqa: F821
except ImportError as e:
print(f"Warning: Failed to import module {module_name}: {str(e)}")
def _extract_classes_from_module(module: ModuleType) -> None:
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
obj.__module__ == module.__name__ and not name.startswith("_")):
__all_classes[name] = obj
globals()[name] = obj
_import_submodules()
__all__ = list(__all_classes.keys()) + ["__all_classes"]
del _package_path, _import_submodules, _extract_classes_from_module

59
rag/flow/base.py Normal file
View File

@ -0,0 +1,59 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import os
import logging
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentParamBase, ComponentBase
from api.utils.api_utils import timeout
class ProcessParamBase(ComponentParamBase):
def __init__(self):
super().__init__()
self.timeout = 100000000
self.persist_logs = True
class ProcessBase(ComponentBase):
def __init__(self, pipeline, id, param: ProcessParamBase):
super().__init__(pipeline, id, param)
self.callback = partial(self._canvas.callback, self.component_name)
async def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter())
for k,v in kwargs.items():
self.set_output(k, v)
try:
with trio.fail_after(self._param.timeout):
await self._invoke(**kwargs)
self.callback(1, "Done")
except Exception as e:
if self.get_exception_default_value():
self.set_exception_default_value()
else:
self.set_output("_ERROR", str(e))
logging.exception(e)
self.callback(-1, str(e))
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
async def _invoke(self, **kwargs):
raise NotImplementedError()

47
rag/flow/begin.py Normal file
View File

@ -0,0 +1,47 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.utils.storage_factory import STORAGE_IMPL
class FileParam(ProcessParamBase):
def __init__(self):
super().__init__()
def check(self):
pass
class File(ProcessBase):
component_name = "File"
async def _invoke(self, **kwargs):
if self._canvas._doc_id:
e, doc = DocumentService.get_by_id(self._canvas._doc_id)
if not e:
self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!")
return
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
self.set_output("blob", STORAGE_IMPL.get(b, n))
self.set_output("name", doc.name)
else:
file = kwargs.get("file")
self.set_output("name", file["name"])
self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))

160
rag/flow/chunker.py Normal file
View File

@ -0,0 +1,160 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import trio
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from graphrag.utils import get_llm_cache, chat_limiter, set_llm_cache
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.nlp import naive_merge, naive_merge_with_images
from rag.prompts.prompts import keyword_extraction, question_proposal
class ChunkerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.method_options = ["general", "q&a", "resume", "manual", "table", "paper", "book", "laws", "presentation", "one"]
self.method = "general"
self.chunk_token_size = 512
self.delimiter = "\n"
self.overlapped_percent = 0
self.page_rank = 0
self.auto_keywords = 0
self.auto_questions = 0
self.tag_sets = []
self.llm_setting = {
"llm_name": "",
"lang": "Chinese"
}
def check(self):
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
self.check_nonnegative_number(self.page_rank, "Page rank value: (0, 10]")
self.check_nonnegative_number(self.auto_keywords, "Auto-keyword value: (0, 10]")
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
class Chunker(ProcessBase):
component_name = "Chunker"
def _general(self, **kwargs):
self.callback(random.randint(1,5)/100., "Start to chunk via `General`.")
if kwargs.get("output_format") in ["markdown", "text"]:
cks = naive_merge(kwargs.get(kwargs["output_format"]), self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
return [{"text": c} for c in cks]
sections, section_images = [], []
for o in kwargs["json"]:
sections.append((o["text"], o.get("position_tag","")))
section_images.append(o.get("image"))
chunks, images = naive_merge_with_images(sections, section_images,self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
return [{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c)
} for c,img in zip(chunks,images)]
def _q_and_a(self, **kwargs):
pass
def _resume(self, **kwargs):
pass
def _manual(self, **kwargs):
pass
def _table(self, **kwargs):
pass
def _paper(self, **kwargs):
pass
def _book(self, **kwargs):
pass
def _laws(self, **kwargs):
pass
def _presentation(self, **kwargs):
pass
def _one(self, **kwargs):
pass
async def _invoke(self, **kwargs):
function_map = {
"general": self._general,
"q&a": self._q_and_a,
"resume": self._resume,
"manual": self._manual,
"table": self._table,
"paper": self._paper,
"book": self._book,
"laws": self._laws,
"presentation": self._presentation,
"one": self._one,
}
chunks = function_map[self._param.method](**kwargs)
llm_setting = self._param.llm_setting
async def auto_keywords():
nonlocal chunks, llm_setting
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
async def doc_keyword_extraction(chat_mdl, ck, topn):
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "keywords", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, ck["text"], topn))
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "keywords", {"topn": topn})
if cached:
ck["keywords"] = cached.split(",")
async with trio.open_nursery() as nursery:
for ck in chunks:
nursery.start_soon(doc_keyword_extraction, chat_mdl, ck, self._param.auto_keywords)
async def auto_questions():
nonlocal chunks, llm_setting
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "question", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, ck["text"], topn))
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "question", {"topn": topn})
if cached:
d["questions"] = cached.split("\n")
async with trio.open_nursery() as nursery:
for ck in chunks:
nursery.start_soon(doc_question_proposal, chat_mdl, ck, self._param.auto_questions)
async with trio.open_nursery() as nursery:
if self._param.auto_questions:
nursery.start_soon(auto_questions)
if self._param.auto_keywords:
nursery.start_soon(auto_keywords)
if self._param.page_rank:
for ck in chunks:
ck["page_rank"] = self._param.page_rank
self.set_output("chunks", chunks)

107
rag/flow/parser.py Normal file
View File

@ -0,0 +1,107 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import trio
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import RAGFlowPdfParser, PlainParser, VisionParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.llm.cv_model import Base as VLM
from deepdoc.parser import ExcelParser
class ParserParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.setups = {
"pdf": {
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
"vlm_name": "",
"lang": "Chinese",
"suffix": ["pdf"],
"output_format": "json"
},
"excel": {
"output_format": "html"
},
"ppt": {},
"image": {
"parse_method": "ocr"
},
"email": {},
"text": {},
"audio": {},
"video": {},
}
def check(self):
if self.setups["pdf"].get("parse_method") not in ["deepdoc", "plain_text"]:
assert self.setups["pdf"].get("vlm_name"), "No VLM specified."
assert self.setups["pdf"].get("lang"), "No language specified."
class Parser(ProcessBase):
component_name = "Parser"
def _pdf(self, blob):
self.callback(random.randint(1,5)/100., "Start to work on a PDF.")
conf = self._param.setups["pdf"]
self.set_output("output_format", conf["output_format"])
if conf.get("parse_method") == "deepdoc":
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
elif conf.get("parse_method") == "plain_text":
lines,_ = PlainParser()(blob)
bboxes = [{"text": t} for t,_ in lines]
else:
assert conf.get("vlm_name")
vision_model = LLMBundle(self._canvas.tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(bin, callback=self.callback)
bboxes = []
for t, poss in lines:
pn, x0, x1, top, bott = poss.split(" ")
bboxes.append({"page_number": int(pn), "x0": int(x0), "x1": int(x1), "top": int(top), "bottom": int(bott), "text": t})
self.set_output("json", bboxes)
mkdn = ""
for b in bboxes:
if b.get("layout_type", "") == "title":
mkdn += "\n## "
if b.get("layout_type", "") == "figure":
mkdn += "\n![Image]({})".format(VLM.image2base64(b["image"]))
continue
mkdn += b.get("text", "") + "\n"
self.set_output("markdown", mkdn)
def _excel(self, blob):
self.callback(random.randint(1,5)/100., "Start to work on a Excel.")
conf = self._param.setups["excel"]
excel_parser = ExcelParser()
if conf.get("output_format") == "html":
html = excel_parser.html(blob,1000000000)
self.set_output("html", html)
elif conf.get("output_format") == "json":
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
elif conf.get("output_format") == "markdown":
self.set_output("markdown", excel_parser.markdown(blob))
async def _invoke(self, **kwargs):
function_map = {
"pdf": self._pdf,
}
for p_type, conf in self._param.setups.items():
if kwargs.get("name", "").split(".")[-1].lower() not in conf.get("suffix", []):
continue
await trio.to_thread.run_sync(function_map[p_type], kwargs["blob"])
break

121
rag/flow/pipeline.py Normal file
View File

@ -0,0 +1,121 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import json
import logging
import random
import time
import trio
from agent.canvas import Graph
from api.db.services.document_service import DocumentService
from rag.utils.redis_conn import REDIS_CONN
class Pipeline(Graph):
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
super().__init__(dsl, tenant_id, task_id)
self._doc_id = doc_id
self._flow_id = flow_id
self._kb_id = None
if doc_id:
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
def callback(self, component_name: str, progress: float|int|None=None, message: str = "") -> None:
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
bin = REDIS_CONN.get(log_key)
obj = json.loads(bin.encode("utf-8"))
if obj:
if obj[-1]["component_name"] == component_name:
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
else:
obj.append({
"component_name": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
})
else:
obj = [{
"component_name": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
}]
REDIS_CONN.set_obj(log_key, obj, 60*10)
except Exception as e:
logging.exception(e)
def fetch_logs(self):
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
bin = REDIS_CONN.get(log_key)
if bin:
return json.loads(bin.encode("utf-8"))
except Exception as e:
logging.exception(e)
return []
def reset(self):
super().reset()
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
REDIS_CONN.set_obj(log_key, [], 60*10)
except Exception as e:
logging.exception(e)
async def run(self, **kwargs):
st = time.perf_counter()
if not self.path:
self.path.append("begin")
if self._doc_id:
DocumentService.update_by_id(self._doc_id, {
"progress": random.randint(0,5)/100.,
"progress_msg": "Start the pipeline...",
"process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
else:
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx-1])
cpn_obj = self.get_component_obj(self.path[idx])
async def invoke():
nonlocal last_cpn, cpn_obj
await cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
break
idx += 1
self.path.extend(cpn_obj.get_downstream())
if self._doc_id:
DocumentService.update_by_id(self._doc_id, {
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st
})

57
rag/flow/tests/client.py Normal file
View File

@ -0,0 +1,57 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
import trio
from api import settings
from rag.flow.pipeline import Pipeline
def print_logs(pipeline):
last_logs = "[]"
while True:
time.sleep(5)
logs = pipeline.fetch_logs()
logs_str = json.dumps(logs)
if logs_str != last_logs:
print(logs_str)
last_logs = logs_str
if __name__ == '__main__':
parser = argparse.ArgumentParser()
dsl_default_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"dsl_examples",
"general_pdf_all.json",
)
parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
args = parser.parse_args()
settings.init_settings()
pipeline = Pipeline(open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx")
pipeline.reset()
exe = ThreadPoolExecutor(max_workers=5)
thr = exe.submit(print_logs, pipeline)
trio.run(pipeline.run)
thr.result()

View File

@ -0,0 +1,54 @@
{
"components": {
"begin": {
"obj":{
"component_name": "File",
"params": {
}
},
"downstream": ["parser:0"],
"upstream": []
},
"parser:0": {
"obj": {
"component_name": "Parser",
"params": {
"setups": {
"pdf": {
"parse_method": "deepdoc",
"vlm_name": "",
"lang": "Chinese",
"suffix": [
"pdf"
],
"output_format": "json"
}
}
}
},
"downstream": ["chunker:0"],
"upstream": ["begin"]
},
"chunker:0": {
"obj": {
"component_name": "Chunker",
"params": {
"method": "general",
"auto_keywords": 5
}
},
"downstream": ["tokenizer:0"],
"upstream": ["chunker:0"]
},
"tokenizer:0": {
"obj": {
"component_name": "Tokenizer",
"params": {
}
},
"downstream": [],
"upstream": ["chunker:0"]
}
},
"path": []
}

134
rag/flow/tokenizer.py Normal file
View File

@ -0,0 +1,134 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import re
import numpy as np
import trio
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.utils.api_utils import timeout
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.nlp import rag_tokenizer
from rag.settings import EMBEDDING_BATCH_SIZE
from rag.svr.task_executor import embed_limiter
from rag.utils import truncate
class TokenizerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.search_method = ["full_text", "embedding"]
self.filename_embd_weight = 0.1
def check(self):
for v in self.search_method:
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
class Tokenizer(ProcessBase):
component_name = "Tokenizer"
async def _embedding(self, name, chunks):
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
token_count = 0
if self._canvas._kb_id:
e, kb = KnowledgebaseService.get_by_id(self._canvas._kb_id)
embedding_id = kb.embd_id
else:
e, ten = TenantService.get_by_id(self._canvas._tenant_id)
embedding_id = ten.embd_id
embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id)
texts = []
for c in chunks:
if c.get("questions"):
texts.append("\n".join(c["questions"]))
else:
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c["text"]))
vts, c = embedding_model.encode([name])
token_count += c
tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length-10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i: i + EMBEDDING_BATCH_SIZE]))
if len(cnts_) == 0:
cnts_ = vts
else:
cnts_ = np.concatenate((cnts_, vts), axis=0)
token_count += c
if i % 33 == 32:
self.callback(i*1./len(texts)/parts/EMBEDDING_BATCH_SIZE + 0.5*(parts-1))
cnts = cnts_
title_w = float(self._param.filename_embd_weight)
vects = (title_w * tts + (1 - title_w) * cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
return chunks, token_count
async def _invoke(self, **kwargs):
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
if "full_text" in self._param.search_method:
self.callback(random.randint(1,5)/100., "Start to tokenize.")
if kwargs.get("chunks"):
chunks = kwargs["chunks"]
for i, ck in enumerate(chunks):
if ck.get("questions"):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i*1./len(chunks)/parts)
elif kwargs.get("output_format") in ["markdown", "text"]:
ck = {
"text": kwargs.get(kwargs["output_format"], "")
}
if "full_text" in self._param.search_method:
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
chunks = [ck]
else:
chunks = kwargs["json"]
for i, ck in enumerate(chunks):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i*1./len(chunks)/parts)
self.callback(1./parts, "Finish tokenizing.")
if "embedding" in self._param.search_method:
self.callback(random.randint(1,5)/100. + 0.5*(parts-1), "Start embedding inference.")
chunks, token_count = await self._embedding(kwargs.get("name", ""), chunks)
self.set_output("embedding_token_consumption", token_count)
self.callback(1., "Finish embedding.")
self.set_output("chunks", chunks)

View File

@ -563,7 +563,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。", overl
return cks
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?"):
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not texts or len(texts) != len(images):
return [], []
cks = [""]
@ -578,7 +579,10 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)

View File

@ -93,7 +93,8 @@ class MCPToolCallSession(ToolCallSession):
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
logging.error(msg)
await self._process_mcp_tasks(None, msg)
except Exception:
except Exception as e:
logging.exception(e)
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
await self._process_mcp_tasks(None, msg)