Feat: support dataflow run. (#10182)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-22 09:36:21 +08:00
committed by GitHub
parent 028c2d83e9
commit d050ef568d
7 changed files with 50 additions and 78 deletions

View File

@ -20,8 +20,7 @@ import random
import sys
import threading
import time
from api.utils import get_uuid
from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import timeout
from api.utils.base64_image import image2id
from api.utils.log_utils import init_root_logger, get_project_base_directory
@ -29,7 +28,6 @@ from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline
from rag.prompts import keyword_extraction, question_proposal, content_tagging
import logging
import os
from datetime import datetime
@ -45,10 +43,8 @@ import signal
import trio
import exceptiongroup
import faulthandler
import numpy as np
from peewee import DoesNotExist
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
@ -216,7 +212,11 @@ async def collect():
return None, None
canceled = False
task = TaskService.get_task(msg["id"])
if msg.get("doc_id", "") == "x":
task = msg
else:
task = TaskService.get_task(msg["id"])
if task:
canceled = has_canceled(task["id"])
if not task or canceled:
@ -229,9 +229,8 @@ async def collect():
task_type = msg.get("task_type", "")
task["task_type"] = task_type
if task_type == "dataflow":
task["tenant_id"]=msg.get("tenant_id", "")
task["dsl"] = msg.get("dsl", "")
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
task["tenant_id"] = msg["tenant_id"]
task["dataflow_id"] = msg["dataflow_id"]
task["kb_id"] = msg.get("kb_id", "")
return redis_msg, task
@ -460,13 +459,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None):
_ = callback
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id)
async def run_dataflow(task: dict):
dataflow_id = task["dataflow_id"]
e, cvs = UserCanvasService.get_by_id(dataflow_id)
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
pipeline.reset()
await pipeline.run()
await pipeline.run(file=task.get("file"))
@timeout(3600)
@ -513,6 +511,12 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
@timeout(60*60*2, 1)
async def do_handle_task(task):
task_type = task.get("task_type", "")
if task_type == "dataflow" and task.get("doc_id", "") == "x":
await run_dataflow(task)
return
task_id = task["id"]
task_from_page = task["from_page"]
task_to_page = task["to_page"]
@ -526,6 +530,7 @@ async def do_handle_task(task):
task_parser_config = task["parser_config"]
task_start_ts = timer()
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@ -554,13 +559,11 @@ async def do_handle_task(task):
init_kb(task, vector_size)
task_type = task.get("task_type", "")
if task_type == "dataflow":
task_dataflow_dsl = task["dsl"]
task_dataflow_id = task["dataflow_id"]
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
await run_dataflow(task)
return
elif task_type == "raptor":
if task_type == "raptor":
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR