Feat: support dataflow run. (#10182)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-22 09:36:21 +08:00
committed by GitHub
parent 028c2d83e9
commit d050ef568d
7 changed files with 50 additions and 78 deletions

View File

@ -144,11 +144,10 @@ def run():
if cvs.canvas_category == CanvasCategory.DataFlow: if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid() task_id = get_uuid()
flow_id = get_uuid() ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, file=files[0], task_id=task_id, flow_id=flow_id, priority=0)
if not ok: if not ok:
return server_error_response(error_message) return get_data_error_result(message=error_message)
return get_json_result(data={"task_id": task_id, "message_id": flow_id}) return get_json_result(data={"message_id": task_id})
try: try:
canvas = Canvas(cvs.dsl, current_user.id, req["id"]) canvas = Canvas(cvs.dsl, current_user.id, req["id"])

View File

@ -472,14 +472,10 @@ def has_canceled(task_id):
return False return False
def queue_dataflow(dsl:str, tenant_id:str, task_id:str, flow_id:str=None, doc_id:str=None, file:dict=None, priority: int=0, callback=None) -> tuple[bool, str]: def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]:
"""
Returns a tuple (success: bool, error_message: str).
"""
_ = callback
task = dict( task = dict(
id=get_uuid() if not task_id else task_id, id=task_id,
doc_id=doc_id, doc_id=doc_id,
from_page=0, from_page=0,
to_page=100000000, to_page=100000000,
@ -490,15 +486,10 @@ def queue_dataflow(dsl:str, tenant_id:str, task_id:str, flow_id:str=None, doc_id
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute() TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True) bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
kb_id = DocumentService.get_knowledgebase_id(doc_id) task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
if not kb_id:
return False, f"Can't find KB of this document: {doc_id}"
task["kb_id"] = kb_id
task["tenant_id"] = tenant_id task["tenant_id"] = tenant_id
task["task_type"] = "dataflow" task["task_type"] = "dataflow"
task["dsl"] = dsl task["dataflow_id"] = flow_id
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
task["file"] = file task["file"] = file
if not REDIS_CONN.queue_product( if not REDIS_CONN.queue_product(

View File

@ -35,9 +35,9 @@ class ProcessBase(ComponentBase):
def __init__(self, pipeline, id, param: ProcessParamBase): def __init__(self, pipeline, id, param: ProcessParamBase):
super().__init__(pipeline, id, param) super().__init__(pipeline, id, param)
if hasattr(self._canvas, "callback"): if hasattr(self._canvas, "callback"):
self.callback = partial(self._canvas.callback, self.component_name) self.callback = partial(self._canvas.callback, id)
else: else:
self.callback = partial(lambda *args, **kwargs: None, self.component_name) self.callback = partial(lambda *args, **kwargs: None, id)
async def invoke(self, **kwargs) -> dict[str, Any]: async def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter()) self.set_output("_created_time", time.perf_counter())

View File

@ -76,7 +76,6 @@ class ParserParam(ProcessParamBase):
self.setups = { self.setups = {
"pdf": { "pdf": {
"parse_method": "deepdoc", # deepdoc/plain_text/vlm "parse_method": "deepdoc", # deepdoc/plain_text/vlm
"llm_id": "",
"lang": "Chinese", "lang": "Chinese",
"suffix": [ "suffix": [
"pdf", "pdf",
@ -98,8 +97,8 @@ class ParserParam(ProcessParamBase):
], ],
"output_format": "json", "output_format": "json",
}, },
"markdown": { "text&markdown": {
"suffix": ["md", "markdown", "mdx"], "suffix": ["md", "markdown", "mdx", "txt"],
"output_format": "json", "output_format": "json",
}, },
"slides": { "slides": {
@ -156,13 +155,10 @@ class ParserParam(ProcessParamBase):
pdf_config = self.setups.get("pdf", {}) pdf_config = self.setups.get("pdf", {})
if pdf_config: if pdf_config:
pdf_parse_method = pdf_config.get("parse_method", "") pdf_parse_method = pdf_config.get("parse_method", "")
self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"]) self.check_empty(pdf_parse_method, "Parse method abnormal.")
if pdf_parse_method not in ["deepdoc", "plain_text"]: if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
self.check_empty(pdf_config.get("llm_id"), "VLM") self.check_empty(pdf_config.get("lang", ""), "Language")
pdf_language = pdf_config.get("lang", "")
self.check_empty(pdf_language, "Language")
pdf_output_format = pdf_config.get("output_format", "") pdf_output_format = pdf_config.get("output_format", "")
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"]) self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
@ -226,8 +222,7 @@ class Parser(ProcessBase):
lines, _ = PlainParser()(blob) lines, _ = PlainParser()(blob)
bboxes = [{"text": t} for t, _ in lines] bboxes = [{"text": t} for t, _ in lines]
else: else:
assert conf.get("llm_id") vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("llm_id"), lang=self._param.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback) lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
bboxes = [] bboxes = []
for t, poss in lines: for t, poss in lines:
@ -236,6 +231,7 @@ class Parser(ProcessBase):
if conf.get("output_format") == "json": if conf.get("output_format") == "json":
self.set_output("json", bboxes) self.set_output("json", bboxes)
if conf.get("output_format") == "markdown": if conf.get("output_format") == "markdown":
mkdn = "" mkdn = ""
for b in bboxes: for b in bboxes:
@ -299,7 +295,6 @@ class Parser(ProcessBase):
def _markdown(self, name, blob): def _markdown(self, name, blob):
from functools import reduce from functools import reduce
from rag.app.naive import Markdown as naive_markdown_parser from rag.app.naive import Markdown as naive_markdown_parser
from rag.nlp import concat_img from rag.nlp import concat_img
@ -330,22 +325,6 @@ class Parser(ProcessBase):
else: else:
self.set_output("text", "\n".join([section_text for section_text, _ in sections])) self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
def _text(self, name, blob):
from deepdoc.parser.utils import get_text
self.callback(random.randint(1, 5) / 100.0, "Start to work on a text.")
conf = self._param.setups["text"]
self.set_output("output_format", conf["output_format"])
# parse binary to text
text_content = get_text(name, binary=blob)
if conf.get("output_format") == "json":
result = [{"text": text_content}]
self.set_output("json", result)
else:
result = text_content
self.set_output("text", result)
def _image(self, from_upstream: ParserFromUpstream): def _image(self, from_upstream: ParserFromUpstream):
from deepdoc.vision import OCR from deepdoc.vision import OCR
@ -367,7 +346,7 @@ class Parser(ProcessBase):
else: else:
# use VLM to describe the picture # use VLM to describe the picture
cv_model = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"], lang=lang) cv_model = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"],lang=lang)
img_binary = io.BytesIO() img_binary = io.BytesIO()
img.save(img_binary, format="JPEG") img.save(img_binary, format="JPEG")
img_binary.seek(0) img_binary.seek(0)
@ -519,11 +498,10 @@ class Parser(ProcessBase):
async def _invoke(self, **kwargs): async def _invoke(self, **kwargs):
function_map = { function_map = {
"pdf": self._pdf, "pdf": self._pdf,
"markdown": self._markdown, "text&markdown": self._markdown,
"spreadsheet": self._spreadsheet, "spreadsheet": self._spreadsheet,
"slides": self._slides, "slides": self._slides,
"word": self._word, "word": self._word,
"text": self._text,
"image": self._image, "image": self._image,
"audio": self._audio, "audio": self._audio,
"email": self._email, "email": self._email,

View File

@ -18,7 +18,7 @@ import json
import logging import logging
import random import random
import time import time
from timeit import default_timer as timer
import trio import trio
from agent.canvas import Graph from agent.canvas import Graph
@ -38,25 +38,26 @@ class Pipeline(Graph):
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None: def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
log_key = f"{self._flow_id}-{self.task_id}-logs" log_key = f"{self._flow_id}-{self.task_id}-logs"
timestamp = timer()
try: try:
bin = REDIS_CONN.get(log_key) bin = REDIS_CONN.get(log_key)
obj = json.loads(bin.encode("utf-8")) obj = json.loads(bin.encode("utf-8"))
if obj: if obj:
if obj[-1]["component_name"] == component_name: if obj[-1]["component_id"] == component_name:
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}) obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": timestamp-obj[-1]["trace"][-1]["timestamp"]})
else: else:
obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}) obj.append({"component_id": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}]})
else: else:
obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}] obj = [{"component_id": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}]}]
REDIS_CONN.set_obj(log_key, obj, 60 * 30) REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id: if self._doc_id:
percentage = 1./len(self.components.items()) percentage = 1./len(self.components.items())
msg = "" msg = ""
finished = 0. finished = 0.
for o in obj: for o in obj:
if o['component_name'] == "END": if o['component_id'] == "END":
continue continue
msg += f"\n[{o['component_name']}]:\n" msg += f"\n[{o['component_id']}]:\n"
for t in o["trace"]: for t in o["trace"]:
msg += "%s: %s\n"%(t["datetime"], t["message"]) msg += "%s: %s\n"%(t["datetime"], t["message"])
if t["progress"] < 0: if t["progress"] < 0:

View File

@ -30,7 +30,7 @@ def print_logs(pipeline: Pipeline):
while True: while True:
time.sleep(5) time.sleep(5)
logs = pipeline.fetch_logs() logs = pipeline.fetch_logs()
logs_str = json.dumps(logs) logs_str = json.dumps(logs, ensure_ascii=False)
if logs_str != last_logs: if logs_str != last_logs:
print(logs_str) print(logs_str)
last_logs = logs_str last_logs = logs_str

View File

@ -20,8 +20,7 @@ import random
import sys import sys
import threading import threading
import time import time
from api.db.services.canvas_service import UserCanvasService
from api.utils import get_uuid
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from api.utils.base64_image import image2id from api.utils.base64_image import image2id
from api.utils.log_utils import init_root_logger, get_project_base_directory from api.utils.log_utils import init_root_logger, get_project_base_directory
@ -29,7 +28,6 @@ from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
from rag.prompts import keyword_extraction, question_proposal, content_tagging from rag.prompts import keyword_extraction, question_proposal, content_tagging
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
@ -45,10 +43,8 @@ import signal
import trio import trio
import exceptiongroup import exceptiongroup
import faulthandler import faulthandler
import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
@ -216,7 +212,11 @@ async def collect():
return None, None return None, None
canceled = False canceled = False
if msg.get("doc_id", "") == "x":
task = msg
else:
task = TaskService.get_task(msg["id"]) task = TaskService.get_task(msg["id"])
if task: if task:
canceled = has_canceled(task["id"]) canceled = has_canceled(task["id"])
if not task or canceled: if not task or canceled:
@ -229,9 +229,8 @@ async def collect():
task_type = msg.get("task_type", "") task_type = msg.get("task_type", "")
task["task_type"] = task_type task["task_type"] = task_type
if task_type == "dataflow": if task_type == "dataflow":
task["tenant_id"]=msg.get("tenant_id", "") task["tenant_id"] = msg["tenant_id"]
task["dsl"] = msg.get("dsl", "") task["dataflow_id"] = msg["dataflow_id"]
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
task["kb_id"] = msg.get("kb_id", "") task["kb_id"] = msg.get("kb_id", "")
return redis_msg, task return redis_msg, task
@ -460,13 +459,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size return tk_count, vector_size
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None): async def run_dataflow(task: dict):
_ = callback dataflow_id = task["dataflow_id"]
e, cvs = UserCanvasService.get_by_id(dataflow_id)
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id) pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
pipeline.reset() pipeline.reset()
await pipeline.run(file=task.get("file"))
await pipeline.run()
@timeout(3600) @timeout(3600)
@ -513,6 +511,12 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
@timeout(60*60*2, 1) @timeout(60*60*2, 1)
async def do_handle_task(task): async def do_handle_task(task):
task_type = task.get("task_type", "")
if task_type == "dataflow" and task.get("doc_id", "") == "x":
await run_dataflow(task)
return
task_id = task["id"] task_id = task["id"]
task_from_page = task["from_page"] task_from_page = task["from_page"]
task_to_page = task["to_page"] task_to_page = task["to_page"]
@ -526,6 +530,7 @@ async def do_handle_task(task):
task_parser_config = task["parser_config"] task_parser_config = task["parser_config"]
task_start_ts = timer() task_start_ts = timer()
# prepare the progress callback function # prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@ -554,13 +559,11 @@ async def do_handle_task(task):
init_kb(task, vector_size) init_kb(task, vector_size)
task_type = task.get("task_type", "")
if task_type == "dataflow": if task_type == "dataflow":
task_dataflow_dsl = task["dsl"] await run_dataflow(task)
task_dataflow_id = task["dataflow_id"]
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
return return
elif task_type == "raptor":
if task_type == "raptor":
# bind LLM for raptor # bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR # run RAPTOR