Feat: support dataflow run. (#10182)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-22 09:36:21 +08:00
committed by GitHub
parent 028c2d83e9
commit d050ef568d
7 changed files with 50 additions and 78 deletions

View File

@ -144,11 +144,10 @@ def run():
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
flow_id = get_uuid()
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, file=files[0], task_id=task_id, flow_id=flow_id, priority=0)
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
if not ok:
return server_error_response(error_message)
return get_json_result(data={"task_id": task_id, "message_id": flow_id})
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
try:
canvas = Canvas(cvs.dsl, current_user.id, req["id"])

View File

@ -472,14 +472,10 @@ def has_canceled(task_id):
return False
def queue_dataflow(dsl:str, tenant_id:str, task_id:str, flow_id:str=None, doc_id:str=None, file:dict=None, priority: int=0, callback=None) -> tuple[bool, str]:
"""
Returns a tuple (success: bool, error_message: str).
"""
_ = callback
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]:
task = dict(
id=get_uuid() if not task_id else task_id,
id=task_id,
doc_id=doc_id,
from_page=0,
to_page=100000000,
@ -490,15 +486,10 @@ def queue_dataflow(dsl:str, tenant_id:str, task_id:str, flow_id:str=None, doc_id
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
kb_id = DocumentService.get_knowledgebase_id(doc_id)
if not kb_id:
return False, f"Can't find KB of this document: {doc_id}"
task["kb_id"] = kb_id
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
task["tenant_id"] = tenant_id
task["task_type"] = "dataflow"
task["dsl"] = dsl
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
task["dataflow_id"] = flow_id
task["file"] = file
if not REDIS_CONN.queue_product(

View File

@ -35,9 +35,9 @@ class ProcessBase(ComponentBase):
def __init__(self, pipeline, id, param: ProcessParamBase):
super().__init__(pipeline, id, param)
if hasattr(self._canvas, "callback"):
self.callback = partial(self._canvas.callback, self.component_name)
self.callback = partial(self._canvas.callback, id)
else:
self.callback = partial(lambda *args, **kwargs: None, self.component_name)
self.callback = partial(lambda *args, **kwargs: None, id)
async def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter())

View File

@ -76,7 +76,6 @@ class ParserParam(ProcessParamBase):
self.setups = {
"pdf": {
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
"llm_id": "",
"lang": "Chinese",
"suffix": [
"pdf",
@ -98,8 +97,8 @@ class ParserParam(ProcessParamBase):
],
"output_format": "json",
},
"markdown": {
"suffix": ["md", "markdown", "mdx"],
"text&markdown": {
"suffix": ["md", "markdown", "mdx", "txt"],
"output_format": "json",
},
"slides": {
@ -156,13 +155,10 @@ class ParserParam(ProcessParamBase):
pdf_config = self.setups.get("pdf", {})
if pdf_config:
pdf_parse_method = pdf_config.get("parse_method", "")
self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"])
self.check_empty(pdf_parse_method, "Parse method abnormal.")
if pdf_parse_method not in ["deepdoc", "plain_text"]:
self.check_empty(pdf_config.get("llm_id"), "VLM")
pdf_language = pdf_config.get("lang", "")
self.check_empty(pdf_language, "Language")
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
self.check_empty(pdf_config.get("lang", ""), "Language")
pdf_output_format = pdf_config.get("output_format", "")
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
@ -226,8 +222,7 @@ class Parser(ProcessBase):
lines, _ = PlainParser()(blob)
bboxes = [{"text": t} for t, _ in lines]
else:
assert conf.get("llm_id")
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("llm_id"), lang=self._param.setups["pdf"].get("lang"))
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
bboxes = []
for t, poss in lines:
@ -236,6 +231,7 @@ class Parser(ProcessBase):
if conf.get("output_format") == "json":
self.set_output("json", bboxes)
if conf.get("output_format") == "markdown":
mkdn = ""
for b in bboxes:
@ -299,7 +295,6 @@ class Parser(ProcessBase):
def _markdown(self, name, blob):
from functools import reduce
from rag.app.naive import Markdown as naive_markdown_parser
from rag.nlp import concat_img
@ -330,22 +325,6 @@ class Parser(ProcessBase):
else:
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
def _text(self, name, blob):
from deepdoc.parser.utils import get_text
self.callback(random.randint(1, 5) / 100.0, "Start to work on a text.")
conf = self._param.setups["text"]
self.set_output("output_format", conf["output_format"])
# parse binary to text
text_content = get_text(name, binary=blob)
if conf.get("output_format") == "json":
result = [{"text": text_content}]
self.set_output("json", result)
else:
result = text_content
self.set_output("text", result)
def _image(self, from_upstream: ParserFromUpstream):
from deepdoc.vision import OCR
@ -519,11 +498,10 @@ class Parser(ProcessBase):
async def _invoke(self, **kwargs):
function_map = {
"pdf": self._pdf,
"markdown": self._markdown,
"text&markdown": self._markdown,
"spreadsheet": self._spreadsheet,
"slides": self._slides,
"word": self._word,
"text": self._text,
"image": self._image,
"audio": self._audio,
"email": self._email,

View File

@ -18,7 +18,7 @@ import json
import logging
import random
import time
from timeit import default_timer as timer
import trio
from agent.canvas import Graph
@ -38,25 +38,26 @@ class Pipeline(Graph):
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
log_key = f"{self._flow_id}-{self.task_id}-logs"
timestamp = timer()
try:
bin = REDIS_CONN.get(log_key)
obj = json.loads(bin.encode("utf-8"))
if obj:
if obj[-1]["component_name"] == component_name:
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
if obj[-1]["component_id"] == component_name:
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": timestamp-obj[-1]["trace"][-1]["timestamp"]})
else:
obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]})
obj.append({"component_id": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}]})
else:
obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}]
obj = [{"component_id": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}]}]
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id:
percentage = 1./len(self.components.items())
msg = ""
finished = 0.
for o in obj:
if o['component_name'] == "END":
if o['component_id'] == "END":
continue
msg += f"\n[{o['component_name']}]:\n"
msg += f"\n[{o['component_id']}]:\n"
for t in o["trace"]:
msg += "%s: %s\n"%(t["datetime"], t["message"])
if t["progress"] < 0:

View File

@ -30,7 +30,7 @@ def print_logs(pipeline: Pipeline):
while True:
time.sleep(5)
logs = pipeline.fetch_logs()
logs_str = json.dumps(logs)
logs_str = json.dumps(logs, ensure_ascii=False)
if logs_str != last_logs:
print(logs_str)
last_logs = logs_str

View File

@ -20,8 +20,7 @@ import random
import sys
import threading
import time
from api.utils import get_uuid
from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import timeout
from api.utils.base64_image import image2id
from api.utils.log_utils import init_root_logger, get_project_base_directory
@ -29,7 +28,6 @@ from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline
from rag.prompts import keyword_extraction, question_proposal, content_tagging
import logging
import os
from datetime import datetime
@ -45,10 +43,8 @@ import signal
import trio
import exceptiongroup
import faulthandler
import numpy as np
from peewee import DoesNotExist
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
@ -216,7 +212,11 @@ async def collect():
return None, None
canceled = False
if msg.get("doc_id", "") == "x":
task = msg
else:
task = TaskService.get_task(msg["id"])
if task:
canceled = has_canceled(task["id"])
if not task or canceled:
@ -229,9 +229,8 @@ async def collect():
task_type = msg.get("task_type", "")
task["task_type"] = task_type
if task_type == "dataflow":
task["tenant_id"]=msg.get("tenant_id", "")
task["dsl"] = msg.get("dsl", "")
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
task["tenant_id"] = msg["tenant_id"]
task["dataflow_id"] = msg["dataflow_id"]
task["kb_id"] = msg.get("kb_id", "")
return redis_msg, task
@ -460,13 +459,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None):
_ = callback
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id)
async def run_dataflow(task: dict):
dataflow_id = task["dataflow_id"]
e, cvs = UserCanvasService.get_by_id(dataflow_id)
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
pipeline.reset()
await pipeline.run()
await pipeline.run(file=task.get("file"))
@timeout(3600)
@ -513,6 +511,12 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
@timeout(60*60*2, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
if task_type == "dataflow" and task.get("doc_id", "") == "x":
await run_dataflow(task)
return
task_id = task["id"]
task_from_page = task["from_page"]
task_to_page = task["to_page"]
@ -526,6 +530,7 @@ async def do_handle_task(task):
task_parser_config = task["parser_config"]
task_start_ts = timer()
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@ -554,13 +559,11 @@ async def do_handle_task(task):
init_kb(task, vector_size)
task_type = task.get("task_type", "")
if task_type == "dataflow":
task_dataflow_dsl = task["dsl"]
task_dataflow_id = task["dataflow_id"]
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
await run_dataflow(task)
return
elif task_type == "raptor":
if task_type == "raptor":
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR