mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: init dataflow. (#9791)
### What problem does this PR solve? #9790 Close #9782 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
181
agent/canvas.py
181
agent/canvas.py
@ -29,8 +29,7 @@ from api.utils import get_uuid, hash_str2int
|
||||
from rag.prompts.prompts import chunks_format
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
class Canvas:
|
||||
class Graph:
|
||||
"""
|
||||
dsl = {
|
||||
"components": {
|
||||
@ -73,39 +72,9 @@ class Canvas:
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None):
|
||||
self.path = []
|
||||
self.history = []
|
||||
self.components = {}
|
||||
self.error = ""
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": tenant_id,
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
self.dsl = json.loads(dsl) if dsl else {
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj": {
|
||||
"component_name": "Begin",
|
||||
"params": {
|
||||
"prologue": "Hi there!"
|
||||
}
|
||||
},
|
||||
"downstream": [],
|
||||
"upstream": [],
|
||||
"parent_id": ""
|
||||
}
|
||||
},
|
||||
"history": [],
|
||||
"path": [],
|
||||
"retrieval": [],
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
}
|
||||
self.dsl = json.loads(dsl)
|
||||
self._tenant_id = tenant_id
|
||||
self.task_id = task_id if task_id else get_uuid()
|
||||
self.load()
|
||||
@ -116,8 +85,6 @@ class Canvas:
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
|
||||
assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
|
||||
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
||||
@ -130,27 +97,10 @@ class Canvas:
|
||||
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
|
||||
|
||||
self.path = self.dsl["path"]
|
||||
self.history = self.dsl["history"]
|
||||
if "globals" in self.dsl:
|
||||
self.globals = self.dsl["globals"]
|
||||
else:
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
|
||||
self.retrieval = self.dsl["retrieval"]
|
||||
self.memory = self.dsl.get("memory", [])
|
||||
|
||||
def __str__(self):
|
||||
self.dsl["path"] = self.path
|
||||
self.dsl["history"] = self.history
|
||||
self.dsl["globals"] = self.globals
|
||||
self.dsl["task_id"] = self.task_id
|
||||
self.dsl["retrieval"] = self.retrieval
|
||||
self.dsl["memory"] = self.memory
|
||||
dsl = {
|
||||
"components": {}
|
||||
}
|
||||
@ -169,14 +119,79 @@ class Canvas:
|
||||
dsl["components"][k][c] = deepcopy(cpn[c])
|
||||
return json.dumps(dsl, ensure_ascii=False)
|
||||
|
||||
def reset(self, mem=False):
|
||||
def reset(self):
|
||||
self.path = []
|
||||
for k, cpn in self.components.items():
|
||||
self.components[k]["obj"].reset()
|
||||
try:
|
||||
REDIS_CONN.delete(f"{self.task_id}-logs")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
def get_component_name(self, cid):
|
||||
for n in self.dsl.get("graph", {}).get("nodes", []):
|
||||
if cid == n["id"]:
|
||||
return n["data"]["name"]
|
||||
return ""
|
||||
|
||||
def run(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
|
||||
return self.components.get(cpn_id)
|
||||
|
||||
def get_component_obj(self, cpn_id) -> ComponentBase:
|
||||
return self.components.get(cpn_id)["obj"]
|
||||
|
||||
def get_component_type(self, cpn_id) -> str:
|
||||
return self.components.get(cpn_id)["obj"].component_name
|
||||
|
||||
def get_component_input_form(self, cpn_id) -> dict:
|
||||
return self.components.get(cpn_id)["obj"].get_input_form()
|
||||
|
||||
def get_tenant_id(self):
|
||||
return self._tenant_id
|
||||
|
||||
|
||||
class Canvas(Graph):
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None):
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": tenant_id,
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
|
||||
def load(self):
|
||||
super().load()
|
||||
self.history = self.dsl["history"]
|
||||
if "globals" in self.dsl:
|
||||
self.globals = self.dsl["globals"]
|
||||
else:
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
|
||||
self.retrieval = self.dsl["retrieval"]
|
||||
self.memory = self.dsl.get("memory", [])
|
||||
|
||||
def __str__(self):
|
||||
self.dsl["history"] = self.history
|
||||
self.dsl["retrieval"] = self.retrieval
|
||||
self.dsl["memory"] = self.memory
|
||||
return super().__str__()
|
||||
|
||||
def reset(self, mem=False):
|
||||
super().reset()
|
||||
if not mem:
|
||||
self.history = []
|
||||
self.retrieval = []
|
||||
self.memory = []
|
||||
for k, cpn in self.components.items():
|
||||
self.components[k]["obj"].reset()
|
||||
|
||||
for k in self.globals.keys():
|
||||
if isinstance(self.globals[k], str):
|
||||
@ -192,17 +207,6 @@ class Canvas:
|
||||
else:
|
||||
self.globals[k] = None
|
||||
|
||||
try:
|
||||
REDIS_CONN.delete(f"{self.task_id}-logs")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
def get_component_name(self, cid):
|
||||
for n in self.dsl.get("graph", {}).get("nodes", []):
|
||||
if cid == n["id"]:
|
||||
return n["data"]["name"]
|
||||
return ""
|
||||
|
||||
def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self.message_id = get_uuid()
|
||||
@ -388,18 +392,6 @@ class Canvas:
|
||||
})
|
||||
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
|
||||
|
||||
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
|
||||
return self.components.get(cpn_id)
|
||||
|
||||
def get_component_obj(self, cpn_id) -> ComponentBase:
|
||||
return self.components.get(cpn_id)["obj"]
|
||||
|
||||
def get_component_type(self, cpn_id) -> str:
|
||||
return self.components.get(cpn_id)["obj"].component_name
|
||||
|
||||
def get_component_input_form(self, cpn_id) -> dict:
|
||||
return self.components.get(cpn_id)["obj"].get_input_form()
|
||||
|
||||
def is_reff(self, exp: str) -> bool:
|
||||
exp = exp.strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
@ -421,9 +413,6 @@ class Canvas:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
return cpn["obj"].output(var_nm)
|
||||
|
||||
def get_tenant_id(self):
|
||||
return self._tenant_id
|
||||
|
||||
def get_history(self, window_size):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
@ -438,36 +427,6 @@ class Canvas:
|
||||
def add_user_input(self, question):
|
||||
self.history.append(("user", question))
|
||||
|
||||
def _find_loop(self, max_loops=6):
|
||||
path = self.path[-1][::-1]
|
||||
if len(path) < 2:
|
||||
return False
|
||||
|
||||
for i in range(len(path)):
|
||||
if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0:
|
||||
path = path[:i]
|
||||
break
|
||||
|
||||
if len(path) < 2:
|
||||
return False
|
||||
|
||||
for loc in range(2, len(path) // 2):
|
||||
pat = ",".join(path[0:loc])
|
||||
path_str = ",".join(path)
|
||||
if len(pat) >= len(path_str):
|
||||
return False
|
||||
loop = max_loops
|
||||
while path_str.find(pat) == 0 and loop >= 0:
|
||||
loop -= 1
|
||||
if len(pat)+1 >= len(path_str):
|
||||
return False
|
||||
path_str = path_str[len(pat)+1:]
|
||||
if loop < 0:
|
||||
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
|
||||
return pat + " => " + pat
|
||||
|
||||
return False
|
||||
|
||||
def get_prologue(self):
|
||||
return self.components["begin"]["obj"]._param.prologue
|
||||
|
||||
|
||||
@ -50,8 +50,9 @@ del _package_path, _import_submodules, _extract_classes_from_module
|
||||
|
||||
|
||||
def component_class(class_name):
|
||||
m = importlib.import_module("agent.component")
|
||||
for mdl in ["agent.component", "agent.tools", "rag.flow"]:
|
||||
try:
|
||||
return getattr(m, class_name)
|
||||
return getattr(importlib.import_module(mdl), class_name)
|
||||
except Exception:
|
||||
return getattr(importlib.import_module("agent.tools"), class_name)
|
||||
pass
|
||||
assert False, f"Can't import {class_name}"
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
import builtins
|
||||
import json
|
||||
import os
|
||||
@ -410,8 +410,8 @@ class ComponentBase(ABC):
|
||||
)
|
||||
|
||||
def __init__(self, canvas, id, param: ComponentParamBase):
|
||||
from agent.canvas import Canvas # Local import to avoid cyclic dependency
|
||||
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
|
||||
from agent.canvas import Graph # Local import to avoid cyclic dependency
|
||||
assert isinstance(canvas, Graph), "canvas must be an instance of Canvas"
|
||||
self._canvas = canvas
|
||||
self._id = id
|
||||
self._param = param
|
||||
@ -528,6 +528,10 @@ class ComponentBase(ABC):
|
||||
cpn_nms = self._canvas.get_component(self._id)['upstream']
|
||||
return cpn_nms
|
||||
|
||||
def get_downstream(self) -> List[str]:
|
||||
cpn_nms = self._canvas.get_component(self._id)['downstream']
|
||||
return cpn_nms
|
||||
|
||||
@staticmethod
|
||||
def string_format(content: str, kv: dict[str, str]) -> str:
|
||||
for n, v in kv.items():
|
||||
@ -556,6 +560,5 @@ class ComponentBase(ABC):
|
||||
def set_exception_default_value(self):
|
||||
self.set_output("result", self.get_exception_default_value())
|
||||
|
||||
@abstractmethod
|
||||
def thoughts(self) -> str:
|
||||
...
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -16,9 +16,8 @@
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
|
||||
from agent.tools.base import ToolParamBase, ToolBase
|
||||
from api.utils.web_utils import is_valid_url
|
||||
|
||||
|
||||
|
||||
class CrawlerParam(ToolParamBase):
|
||||
@ -39,6 +38,7 @@ class Crawler(ToolBase, ABC):
|
||||
component_name = "Crawler"
|
||||
|
||||
def _run(self, history, **kwargs):
|
||||
from api.utils.web_utils import is_valid_url
|
||||
ans = self.get_input()
|
||||
ans = " - ".join(ans["content"]) if "content" in ans else ""
|
||||
if not is_valid_url(ans):
|
||||
|
||||
@ -74,7 +74,6 @@ def retrieval(tenant_id):
|
||||
[tenant_id],
|
||||
[kb_id],
|
||||
embd_mdl,
|
||||
doc_ids,
|
||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
ranks["chunks"].insert(0, ck)
|
||||
|
||||
@ -133,6 +133,13 @@ class UserService(CommonService):
|
||||
cls.model.update(user_dict).where(
|
||||
cls.model.id == user_id).execute()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_admin(cls, user_id):
|
||||
return cls.model.select().where(
|
||||
cls.model.id == user_id,
|
||||
cls.model.is_superuser == 1).count() > 0
|
||||
|
||||
|
||||
class TenantService(CommonService):
|
||||
"""Service class for managing tenant-related database operations.
|
||||
|
||||
@ -131,6 +131,12 @@ class RAGFlowExcelParser:
|
||||
|
||||
return tb_chunks
|
||||
|
||||
def markdown(self, fnm):
|
||||
import pandas as pd
|
||||
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
||||
df = pd.read_excel(file_like_object)
|
||||
return df.to_markdown(index=False)
|
||||
|
||||
def __call__(self, fnm):
|
||||
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
||||
wb = RAGFlowExcelParser._load_excel_to_workbook(file_like_object)
|
||||
|
||||
@ -93,6 +93,7 @@ class RAGFlowPdfParser:
|
||||
model_dir, "updown_concat_xgb.model"))
|
||||
|
||||
self.page_from = 0
|
||||
self.column_num = 1
|
||||
|
||||
def __char_width(self, c):
|
||||
return (c["x1"] - c["x0"]) // max(len(c["text"]), 1)
|
||||
@ -427,10 +428,18 @@ class RAGFlowPdfParser:
|
||||
i += 1
|
||||
self.boxes = bxs
|
||||
|
||||
def _naive_vertical_merge(self):
|
||||
def _naive_vertical_merge(self, zoomin=3):
|
||||
bxs = Recognizer.sort_Y_firstly(
|
||||
self.boxes, np.median(
|
||||
self.mean_height) / 3)
|
||||
|
||||
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
|
||||
self.column_num = int(self.page_images[0].size[0] / zoomin / column_width)
|
||||
if column_width < self.page_images[0].size[0] / zoomin / self.column_num:
|
||||
logging.info("Multi-column................... {} {}".format(column_width,
|
||||
self.page_images[0].size[0] / zoomin / self.column_num))
|
||||
self.boxes = self.sort_X_by_page(self.boxes, column_width / self.column_num)
|
||||
|
||||
i = 0
|
||||
while i + 1 < len(bxs):
|
||||
b = bxs[i]
|
||||
@ -1139,20 +1148,94 @@ class RAGFlowPdfParser:
|
||||
need_image, zoomin, return_html, False)
|
||||
return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls
|
||||
|
||||
def parse_into_bboxes(self, fnm, callback=None, zoomin=3):
|
||||
start = timer()
|
||||
self.__images__(fnm, zoomin)
|
||||
if callback:
|
||||
callback(0.40, "OCR finished ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._layouts_rec(zoomin)
|
||||
if callback:
|
||||
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._table_transformer_job(zoomin)
|
||||
if callback:
|
||||
callback(0.83, "Table analysis ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
self._text_merge()
|
||||
self._concat_downward()
|
||||
self._naive_vertical_merge(zoomin)
|
||||
if callback:
|
||||
callback(0.92, "Text merged ({:.2f}s)".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
tbls, figs = self._extract_table_figure(True, zoomin, True, True, True)
|
||||
|
||||
def insert_table_figures(tbls_or_figs, layout_type):
|
||||
def min_rectangle_distance(rect1, rect2):
|
||||
import math
|
||||
pn1, left1, right1, top1, bottom1 = rect1
|
||||
pn2, left2, right2, top2, bottom2 = rect2
|
||||
if (right1 >= left2 and right2 >= left1 and
|
||||
bottom1 >= top2 and bottom2 >= top1):
|
||||
return 0 + (pn1-pn2)*10000
|
||||
if right1 < left2:
|
||||
dx = left2 - right1
|
||||
elif right2 < left1:
|
||||
dx = left1 - right2
|
||||
else:
|
||||
dx = 0
|
||||
if bottom1 < top2:
|
||||
dy = top2 - bottom1
|
||||
elif bottom2 < top1:
|
||||
dy = top1 - bottom2
|
||||
else:
|
||||
dy = 0
|
||||
return math.sqrt(dx*dx + dy*dy) + (pn1-pn2)*10000
|
||||
|
||||
for (img, txt), poss in tbls_or_figs:
|
||||
bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)]
|
||||
dists = [(min_rectangle_distance((pn, left, right, top, bott), rect),i) for i, rect in bboxes for pn, left, right, top, bott in poss]
|
||||
min_i = np.argmin(dists, axis=0)[0]
|
||||
min_i, rect = bboxes[dists[min_i][-1]]
|
||||
if isinstance(txt, list):
|
||||
txt = "\n".join(txt)
|
||||
self.boxes.insert(min_i, {
|
||||
"page_number": rect[0], "x0": rect[1], "x1": rect[2], "top": rect[3], "bottom": rect[4], "layout_type": layout_type, "text": txt, "image": img
|
||||
})
|
||||
|
||||
for b in self.boxes:
|
||||
b["position_tag"] = self._line_tag(b, zoomin)
|
||||
b["image"] = self.crop(b["position_tag"], zoomin)
|
||||
|
||||
insert_table_figures(tbls, "table")
|
||||
insert_table_figures(figs, "figure")
|
||||
if callback:
|
||||
callback(1, "Structured ({:.2f}s)".format(timer() - start))
|
||||
return deepcopy(self.boxes)
|
||||
|
||||
@staticmethod
|
||||
def remove_tag(txt):
|
||||
return re.sub(r"@@[\t0-9.-]+?##", "", txt)
|
||||
|
||||
def crop(self, text, ZM=3, need_position=False):
|
||||
imgs = []
|
||||
@staticmethod
|
||||
def extract_positions(txt):
|
||||
poss = []
|
||||
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", text):
|
||||
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
|
||||
pn, left, right, top, bottom = tag.strip(
|
||||
"#").strip("@").split("\t")
|
||||
left, right, top, bottom = float(left), float(
|
||||
right), float(top), float(bottom)
|
||||
poss.append(([int(p) - 1 for p in pn.split("-")],
|
||||
left, right, top, bottom))
|
||||
return poss
|
||||
|
||||
def crop(self, text, ZM=3, need_position=False):
|
||||
imgs = []
|
||||
poss = self.extract_positions(text)
|
||||
if not poss:
|
||||
if need_position:
|
||||
return None, None
|
||||
@ -1296,8 +1379,8 @@ class VisionParser(RAGFlowPdfParser):
|
||||
|
||||
def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
|
||||
callback = kwargs.get("callback", lambda prog, msg: None)
|
||||
|
||||
self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs)
|
||||
zoomin = kwargs.get("zoomin", 3)
|
||||
self.__images__(fnm=filename, zoomin=zoomin, page_from=from_page, page_to=to_page, callback=callback)
|
||||
|
||||
total_pdf_pages = self.total_page
|
||||
|
||||
@ -1311,16 +1394,19 @@ class VisionParser(RAGFlowPdfParser):
|
||||
if pdf_page_num < start_page or pdf_page_num >= end_page:
|
||||
continue
|
||||
|
||||
docs = picture_vision_llm_chunk(
|
||||
text = picture_vision_llm_chunk(
|
||||
binary=img_binary,
|
||||
vision_model=self.vision_model,
|
||||
prompt=vision_llm_describe_prompt(page=pdf_page_num+1),
|
||||
callback=callback,
|
||||
)
|
||||
if kwargs.get("callback"):
|
||||
kwargs["callback"](idx*1./len(self.page_images), f"Processed: {idx+1}/{len(self.page_images)}")
|
||||
|
||||
if docs:
|
||||
all_docs.append(docs)
|
||||
return [(doc, "") for doc in all_docs], []
|
||||
if text:
|
||||
width, height = self.page_images[idx].size
|
||||
all_docs.append((text, f"{pdf_page_num+1} 0 {width/zoomin} 0 {height/zoomin}"))
|
||||
return all_docs, []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
49
rag/flow/__init__.py
Normal file
49
rag/flow/__init__.py
Normal file
@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import importlib
|
||||
import inspect
|
||||
from types import ModuleType
|
||||
from typing import Dict, Type
|
||||
|
||||
_package_path = os.path.dirname(__file__)
|
||||
__all_classes: Dict[str, Type] = {}
|
||||
|
||||
def _import_submodules() -> None:
|
||||
for filename in os.listdir(_package_path): # noqa: F821
|
||||
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
|
||||
continue
|
||||
module_name = filename[:-3]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(f".{module_name}", package=__name__)
|
||||
_extract_classes_from_module(module) # noqa: F821
|
||||
except ImportError as e:
|
||||
print(f"Warning: Failed to import module {module_name}: {str(e)}")
|
||||
|
||||
def _extract_classes_from_module(module: ModuleType) -> None:
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if (inspect.isclass(obj) and
|
||||
obj.__module__ == module.__name__ and not name.startswith("_")):
|
||||
__all_classes[name] = obj
|
||||
globals()[name] = obj
|
||||
|
||||
_import_submodules()
|
||||
|
||||
__all__ = list(__all_classes.keys()) + ["__all_classes"]
|
||||
|
||||
del _package_path, _import_submodules, _extract_classes_from_module
|
||||
59
rag/flow/base.py
Normal file
59
rag/flow/base.py
Normal file
@ -0,0 +1,59 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import time
|
||||
import os
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
import trio
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
|
||||
class ProcessParamBase(ComponentParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.timeout = 100000000
|
||||
self.persist_logs = True
|
||||
|
||||
|
||||
class ProcessBase(ComponentBase):
|
||||
|
||||
def __init__(self, pipeline, id, param: ProcessParamBase):
|
||||
super().__init__(pipeline, id, param)
|
||||
self.callback = partial(self._canvas.callback, self.component_name)
|
||||
|
||||
async def invoke(self, **kwargs) -> dict[str, Any]:
|
||||
self.set_output("_created_time", time.perf_counter())
|
||||
for k,v in kwargs.items():
|
||||
self.set_output(k, v)
|
||||
try:
|
||||
with trio.fail_after(self._param.timeout):
|
||||
await self._invoke(**kwargs)
|
||||
self.callback(1, "Done")
|
||||
except Exception as e:
|
||||
if self.get_exception_default_value():
|
||||
self.set_exception_default_value()
|
||||
else:
|
||||
self.set_output("_ERROR", str(e))
|
||||
logging.exception(e)
|
||||
self.callback(-1, str(e))
|
||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||
return self.output()
|
||||
|
||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
||||
async def _invoke(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
47
rag/flow/begin.py
Normal file
47
rag/flow/begin.py
Normal file
@ -0,0 +1,47 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
|
||||
|
||||
class FileParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def check(self):
|
||||
pass
|
||||
|
||||
|
||||
class File(ProcessBase):
|
||||
component_name = "File"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
if self._canvas._doc_id:
|
||||
e, doc = DocumentService.get_by_id(self._canvas._doc_id)
|
||||
if not e:
|
||||
self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!")
|
||||
return
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
|
||||
self.set_output("blob", STORAGE_IMPL.get(b, n))
|
||||
self.set_output("name", doc.name)
|
||||
else:
|
||||
file = kwargs.get("file")
|
||||
self.set_output("name", file["name"])
|
||||
self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))
|
||||
160
rag/flow/chunker.py
Normal file
160
rag/flow/chunker.py
Normal file
@ -0,0 +1,160 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import trio
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from graphrag.utils import get_llm_cache, chat_limiter, set_llm_cache
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.nlp import naive_merge, naive_merge_with_images
|
||||
from rag.prompts.prompts import keyword_extraction, question_proposal
|
||||
|
||||
|
||||
class ChunkerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.method_options = ["general", "q&a", "resume", "manual", "table", "paper", "book", "laws", "presentation", "one"]
|
||||
self.method = "general"
|
||||
self.chunk_token_size = 512
|
||||
self.delimiter = "\n"
|
||||
self.overlapped_percent = 0
|
||||
self.page_rank = 0
|
||||
self.auto_keywords = 0
|
||||
self.auto_questions = 0
|
||||
self.tag_sets = []
|
||||
self.llm_setting = {
|
||||
"llm_name": "",
|
||||
"lang": "Chinese"
|
||||
}
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
|
||||
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
|
||||
self.check_nonnegative_number(self.page_rank, "Page rank value: (0, 10]")
|
||||
self.check_nonnegative_number(self.auto_keywords, "Auto-keyword value: (0, 10]")
|
||||
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
|
||||
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
||||
|
||||
|
||||
class Chunker(ProcessBase):
|
||||
component_name = "Chunker"
|
||||
|
||||
def _general(self, **kwargs):
|
||||
self.callback(random.randint(1,5)/100., "Start to chunk via `General`.")
|
||||
if kwargs.get("output_format") in ["markdown", "text"]:
|
||||
cks = naive_merge(kwargs.get(kwargs["output_format"]), self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
|
||||
return [{"text": c} for c in cks]
|
||||
|
||||
sections, section_images = [], []
|
||||
for o in kwargs["json"]:
|
||||
sections.append((o["text"], o.get("position_tag","")))
|
||||
section_images.append(o.get("image"))
|
||||
|
||||
chunks, images = naive_merge_with_images(sections, section_images,self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
|
||||
return [{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": RAGFlowPdfParser.extract_positions(c)
|
||||
} for c,img in zip(chunks,images)]
|
||||
|
||||
def _q_and_a(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _resume(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _manual(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _table(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _paper(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _book(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _laws(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _presentation(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _one(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
function_map = {
|
||||
"general": self._general,
|
||||
"q&a": self._q_and_a,
|
||||
"resume": self._resume,
|
||||
"manual": self._manual,
|
||||
"table": self._table,
|
||||
"paper": self._paper,
|
||||
"book": self._book,
|
||||
"laws": self._laws,
|
||||
"presentation": self._presentation,
|
||||
"one": self._one,
|
||||
}
|
||||
chunks = function_map[self._param.method](**kwargs)
|
||||
llm_setting = self._param.llm_setting
|
||||
|
||||
async def auto_keywords():
|
||||
nonlocal chunks, llm_setting
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
|
||||
|
||||
async def doc_keyword_extraction(chat_mdl, ck, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "keywords", {"topn": topn})
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, ck["text"], topn))
|
||||
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "keywords", {"topn": topn})
|
||||
if cached:
|
||||
ck["keywords"] = cached.split(",")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for ck in chunks:
|
||||
nursery.start_soon(doc_keyword_extraction, chat_mdl, ck, self._param.auto_keywords)
|
||||
|
||||
async def auto_questions():
|
||||
nonlocal chunks, llm_setting
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
|
||||
|
||||
async def doc_question_proposal(chat_mdl, d, topn):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "question", {"topn": topn})
|
||||
if not cached:
|
||||
async with chat_limiter:
|
||||
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, ck["text"], topn))
|
||||
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "question", {"topn": topn})
|
||||
if cached:
|
||||
d["questions"] = cached.split("\n")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for ck in chunks:
|
||||
nursery.start_soon(doc_question_proposal, chat_mdl, ck, self._param.auto_questions)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
if self._param.auto_questions:
|
||||
nursery.start_soon(auto_questions)
|
||||
if self._param.auto_keywords:
|
||||
nursery.start_soon(auto_keywords)
|
||||
|
||||
if self._param.page_rank:
|
||||
for ck in chunks:
|
||||
ck["page_rank"] = self._param.page_rank
|
||||
|
||||
self.set_output("chunks", chunks)
|
||||
107
rag/flow/parser.py
Normal file
107
rag/flow/parser.py
Normal file
@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import trio
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser, PlainParser, VisionParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.llm.cv_model import Base as VLM
|
||||
from deepdoc.parser import ExcelParser
|
||||
|
||||
|
||||
class ParserParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setups = {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
||||
"vlm_name": "",
|
||||
"lang": "Chinese",
|
||||
"suffix": ["pdf"],
|
||||
"output_format": "json"
|
||||
},
|
||||
"excel": {
|
||||
"output_format": "html"
|
||||
},
|
||||
"ppt": {},
|
||||
"image": {
|
||||
"parse_method": "ocr"
|
||||
},
|
||||
"email": {},
|
||||
"text": {},
|
||||
"audio": {},
|
||||
"video": {},
|
||||
}
|
||||
|
||||
def check(self):
|
||||
if self.setups["pdf"].get("parse_method") not in ["deepdoc", "plain_text"]:
|
||||
assert self.setups["pdf"].get("vlm_name"), "No VLM specified."
|
||||
assert self.setups["pdf"].get("lang"), "No language specified."
|
||||
|
||||
|
||||
class Parser(ProcessBase):
|
||||
component_name = "Parser"
|
||||
|
||||
def _pdf(self, blob):
|
||||
self.callback(random.randint(1,5)/100., "Start to work on a PDF.")
|
||||
conf = self._param.setups["pdf"]
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
if conf.get("parse_method") == "deepdoc":
|
||||
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
||||
elif conf.get("parse_method") == "plain_text":
|
||||
lines,_ = PlainParser()(blob)
|
||||
bboxes = [{"text": t} for t,_ in lines]
|
||||
else:
|
||||
assert conf.get("vlm_name")
|
||||
vision_model = LLMBundle(self._canvas.tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self.setups["pdf"].get("lang"))
|
||||
lines, _ = VisionParser(vision_model=vision_model)(bin, callback=self.callback)
|
||||
bboxes = []
|
||||
for t, poss in lines:
|
||||
pn, x0, x1, top, bott = poss.split(" ")
|
||||
bboxes.append({"page_number": int(pn), "x0": int(x0), "x1": int(x1), "top": int(top), "bottom": int(bott), "text": t})
|
||||
|
||||
self.set_output("json", bboxes)
|
||||
mkdn = ""
|
||||
for b in bboxes:
|
||||
if b.get("layout_type", "") == "title":
|
||||
mkdn += "\n## "
|
||||
if b.get("layout_type", "") == "figure":
|
||||
mkdn += "\n".format(VLM.image2base64(b["image"]))
|
||||
continue
|
||||
mkdn += b.get("text", "") + "\n"
|
||||
self.set_output("markdown", mkdn)
|
||||
|
||||
def _excel(self, blob):
|
||||
self.callback(random.randint(1,5)/100., "Start to work on a Excel.")
|
||||
conf = self._param.setups["excel"]
|
||||
excel_parser = ExcelParser()
|
||||
if conf.get("output_format") == "html":
|
||||
html = excel_parser.html(blob,1000000000)
|
||||
self.set_output("html", html)
|
||||
elif conf.get("output_format") == "json":
|
||||
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
|
||||
elif conf.get("output_format") == "markdown":
|
||||
self.set_output("markdown", excel_parser.markdown(blob))
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
function_map = {
|
||||
"pdf": self._pdf,
|
||||
}
|
||||
for p_type, conf in self._param.setups.items():
|
||||
if kwargs.get("name", "").split(".")[-1].lower() not in conf.get("suffix", []):
|
||||
continue
|
||||
await trio.to_thread.run_sync(function_map[p_type], kwargs["blob"])
|
||||
break
|
||||
121
rag/flow/pipeline.py
Normal file
121
rag/flow/pipeline.py
Normal file
@ -0,0 +1,121 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import trio
|
||||
from agent.canvas import Graph
|
||||
from api.db.services.document_service import DocumentService
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
class Pipeline(Graph):
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
self._doc_id = doc_id
|
||||
self._flow_id = flow_id
|
||||
self._kb_id = None
|
||||
if doc_id:
|
||||
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
|
||||
|
||||
def callback(self, component_name: str, progress: float|int|None=None, message: str = "") -> None:
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
bin = REDIS_CONN.get(log_key)
|
||||
obj = json.loads(bin.encode("utf-8"))
|
||||
if obj:
|
||||
if obj[-1]["component_name"] == component_name:
|
||||
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
|
||||
else:
|
||||
obj.append({
|
||||
"component_name": component_name,
|
||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
|
||||
})
|
||||
else:
|
||||
obj = [{
|
||||
"component_name": component_name,
|
||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
|
||||
}]
|
||||
REDIS_CONN.set_obj(log_key, obj, 60*10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
def fetch_logs(self):
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
bin = REDIS_CONN.get(log_key)
|
||||
if bin:
|
||||
return json.loads(bin.encode("utf-8"))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||
try:
|
||||
REDIS_CONN.set_obj(log_key, [], 60*10)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
async def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
if not self.path:
|
||||
self.path.append("begin")
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(self._doc_id, {
|
||||
"progress": random.randint(0,5)/100.,
|
||||
"progress_msg": "Start the pipeline...",
|
||||
"process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
})
|
||||
|
||||
self.error = ""
|
||||
idx = len(self.path) - 1
|
||||
if idx == 0:
|
||||
cpn_obj = self.get_component_obj(self.path[0])
|
||||
await cpn_obj.invoke(**kwargs)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
else:
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
while idx < len(self.path) and not self.error:
|
||||
last_cpn = self.get_component_obj(self.path[idx-1])
|
||||
cpn_obj = self.get_component_obj(self.path[idx])
|
||||
async def invoke():
|
||||
nonlocal last_cpn, cpn_obj
|
||||
await cpn_obj.invoke(**last_cpn.output())
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(invoke)
|
||||
if cpn_obj.error():
|
||||
self.error = "[ERROR]" + cpn_obj.error()
|
||||
break
|
||||
idx += 1
|
||||
self.path.extend(cpn_obj.get_downstream())
|
||||
|
||||
if self._doc_id:
|
||||
DocumentService.update_by_id(self._doc_id, {
|
||||
"progress": 1 if not self.error else -1,
|
||||
"progress_msg": "Pipeline finished...\n" + self.error,
|
||||
"process_duration": time.perf_counter() - st
|
||||
})
|
||||
|
||||
57
rag/flow/tests/client.py
Normal file
57
rag/flow/tests/client.py
Normal file
@ -0,0 +1,57 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import trio
|
||||
from api import settings
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
|
||||
def print_logs(pipeline):
|
||||
last_logs = "[]"
|
||||
while True:
|
||||
time.sleep(5)
|
||||
logs = pipeline.fetch_logs()
|
||||
logs_str = json.dumps(logs)
|
||||
if logs_str != last_logs:
|
||||
print(logs_str)
|
||||
last_logs = logs_str
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
dsl_default_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"dsl_examples",
|
||||
"general_pdf_all.json",
|
||||
)
|
||||
parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True)
|
||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
settings.init_settings()
|
||||
pipeline = Pipeline(open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx")
|
||||
pipeline.reset()
|
||||
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
thr = exe.submit(print_logs, pipeline)
|
||||
|
||||
trio.run(pipeline.run)
|
||||
thr.result()
|
||||
54
rag/flow/tests/dsl_examples/general_pdf_all.json
Normal file
54
rag/flow/tests/dsl_examples/general_pdf_all.json
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj":{
|
||||
"component_name": "File",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"downstream": ["parser:0"],
|
||||
"upstream": []
|
||||
},
|
||||
"parser:0": {
|
||||
"obj": {
|
||||
"component_name": "Parser",
|
||||
"params": {
|
||||
"setups": {
|
||||
"pdf": {
|
||||
"parse_method": "deepdoc",
|
||||
"vlm_name": "",
|
||||
"lang": "Chinese",
|
||||
"suffix": [
|
||||
"pdf"
|
||||
],
|
||||
"output_format": "json"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"downstream": ["chunker:0"],
|
||||
"upstream": ["begin"]
|
||||
},
|
||||
"chunker:0": {
|
||||
"obj": {
|
||||
"component_name": "Chunker",
|
||||
"params": {
|
||||
"method": "general",
|
||||
"auto_keywords": 5
|
||||
}
|
||||
},
|
||||
"downstream": ["tokenizer:0"],
|
||||
"upstream": ["chunker:0"]
|
||||
},
|
||||
"tokenizer:0": {
|
||||
"obj": {
|
||||
"component_name": "Tokenizer",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"downstream": [],
|
||||
"upstream": ["chunker:0"]
|
||||
}
|
||||
},
|
||||
"path": []
|
||||
}
|
||||
134
rag/flow/tokenizer.py
Normal file
134
rag/flow/tokenizer.py
Normal file
@ -0,0 +1,134 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import trio
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.settings import EMBEDDING_BATCH_SIZE
|
||||
from rag.svr.task_executor import embed_limiter
|
||||
from rag.utils import truncate
|
||||
|
||||
|
||||
class TokenizerParam(ProcessParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.search_method = ["full_text", "embedding"]
|
||||
self.filename_embd_weight = 0.1
|
||||
|
||||
def check(self):
|
||||
for v in self.search_method:
|
||||
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
|
||||
|
||||
|
||||
class Tokenizer(ProcessBase):
|
||||
component_name = "Tokenizer"
|
||||
|
||||
async def _embedding(self, name, chunks):
|
||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||
token_count = 0
|
||||
if self._canvas._kb_id:
|
||||
e, kb = KnowledgebaseService.get_by_id(self._canvas._kb_id)
|
||||
embedding_id = kb.embd_id
|
||||
else:
|
||||
e, ten = TenantService.get_by_id(self._canvas._tenant_id)
|
||||
embedding_id = ten.embd_id
|
||||
embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
texts = []
|
||||
for c in chunks:
|
||||
if c.get("questions"):
|
||||
texts.append("\n".join(c["questions"]))
|
||||
else:
|
||||
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c["text"]))
|
||||
vts, c = embedding_model.encode([name])
|
||||
token_count += c
|
||||
tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0)
|
||||
|
||||
@timeout(60)
|
||||
def batch_encode(txts):
|
||||
nonlocal embedding_model
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length-10) for c in txts])
|
||||
|
||||
cnts_ = np.array([])
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i: i + EMBEDDING_BATCH_SIZE]))
|
||||
if len(cnts_) == 0:
|
||||
cnts_ = vts
|
||||
else:
|
||||
cnts_ = np.concatenate((cnts_, vts), axis=0)
|
||||
token_count += c
|
||||
if i % 33 == 32:
|
||||
self.callback(i*1./len(texts)/parts/EMBEDDING_BATCH_SIZE + 0.5*(parts-1))
|
||||
|
||||
cnts = cnts_
|
||||
title_w = float(self._param.filename_embd_weight)
|
||||
vects = (title_w * tts + (1 - title_w) * cnts) if len(tts) == len(cnts) else cnts
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
v = vects[i].tolist()
|
||||
ck["q_%d_vec" % len(v)] = v
|
||||
return chunks, token_count
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||
if "full_text" in self._param.search_method:
|
||||
self.callback(random.randint(1,5)/100., "Start to tokenize.")
|
||||
if kwargs.get("chunks"):
|
||||
chunks = kwargs["chunks"]
|
||||
for i, ck in enumerate(chunks):
|
||||
if ck.get("questions"):
|
||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||
if ck.get("keywords"):
|
||||
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"]))
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i*1./len(chunks)/parts)
|
||||
elif kwargs.get("output_format") in ["markdown", "text"]:
|
||||
ck = {
|
||||
"text": kwargs.get(kwargs["output_format"], "")
|
||||
}
|
||||
if "full_text" in self._param.search_method:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
chunks = [ck]
|
||||
else:
|
||||
chunks = kwargs["json"]
|
||||
for i, ck in enumerate(chunks):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
if i % 100 == 99:
|
||||
self.callback(i*1./len(chunks)/parts)
|
||||
|
||||
self.callback(1./parts, "Finish tokenizing.")
|
||||
|
||||
if "embedding" in self._param.search_method:
|
||||
self.callback(random.randint(1,5)/100. + 0.5*(parts-1), "Start embedding inference.")
|
||||
chunks, token_count = await self._embedding(kwargs.get("name", ""), chunks)
|
||||
self.set_output("embedding_token_consumption", token_count)
|
||||
|
||||
self.callback(1., "Finish embedding.")
|
||||
|
||||
self.set_output("chunks", chunks)
|
||||
@ -563,7 +563,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?", overl
|
||||
return cks
|
||||
|
||||
|
||||
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
if not texts or len(texts) != len(images):
|
||||
return [], []
|
||||
cks = [""]
|
||||
@ -578,7 +579,10 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
|
||||
@ -93,7 +93,8 @@ class MCPToolCallSession(ToolCallSession):
|
||||
msg = f"Timeout initializing client_session for server {self._mcp_server.id}"
|
||||
logging.error(msg)
|
||||
await self._process_mcp_tasks(None, msg)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
msg = "Connection failed (possibly due to auth error). Please check authentication settings first"
|
||||
await self._process_mcp_tasks(None, msg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user