mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: support dataflow run. (#10182)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -20,8 +20,7 @@ import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from api.utils import get_uuid
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.utils.api_utils import timeout
|
||||
from api.utils.base64_image import image2id
|
||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||
@ -29,7 +28,6 @@ from graphrag.general.index import run_graphrag
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
@ -45,10 +43,8 @@ import signal
|
||||
import trio
|
||||
import exceptiongroup
|
||||
import faulthandler
|
||||
|
||||
import numpy as np
|
||||
from peewee import DoesNotExist
|
||||
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
@ -216,7 +212,11 @@ async def collect():
|
||||
return None, None
|
||||
|
||||
canceled = False
|
||||
task = TaskService.get_task(msg["id"])
|
||||
if msg.get("doc_id", "") == "x":
|
||||
task = msg
|
||||
else:
|
||||
task = TaskService.get_task(msg["id"])
|
||||
|
||||
if task:
|
||||
canceled = has_canceled(task["id"])
|
||||
if not task or canceled:
|
||||
@ -229,9 +229,8 @@ async def collect():
|
||||
task_type = msg.get("task_type", "")
|
||||
task["task_type"] = task_type
|
||||
if task_type == "dataflow":
|
||||
task["tenant_id"]=msg.get("tenant_id", "")
|
||||
task["dsl"] = msg.get("dsl", "")
|
||||
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
|
||||
task["tenant_id"] = msg["tenant_id"]
|
||||
task["dataflow_id"] = msg["dataflow_id"]
|
||||
task["kb_id"] = msg.get("kb_id", "")
|
||||
return redis_msg, task
|
||||
|
||||
@ -460,13 +459,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
return tk_count, vector_size
|
||||
|
||||
|
||||
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None):
|
||||
_ = callback
|
||||
|
||||
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id)
|
||||
async def run_dataflow(task: dict):
|
||||
dataflow_id = task["dataflow_id"]
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
|
||||
pipeline.reset()
|
||||
|
||||
await pipeline.run()
|
||||
await pipeline.run(file=task.get("file"))
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
@ -513,6 +511,12 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
|
||||
@timeout(60*60*2, 1)
|
||||
async def do_handle_task(task):
|
||||
task_type = task.get("task_type", "")
|
||||
|
||||
if task_type == "dataflow" and task.get("doc_id", "") == "x":
|
||||
await run_dataflow(task)
|
||||
return
|
||||
|
||||
task_id = task["id"]
|
||||
task_from_page = task["from_page"]
|
||||
task_to_page = task["to_page"]
|
||||
@ -526,6 +530,7 @@ async def do_handle_task(task):
|
||||
task_parser_config = task["parser_config"]
|
||||
task_start_ts = timer()
|
||||
|
||||
|
||||
# prepare the progress callback function
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
|
||||
@ -554,13 +559,11 @@ async def do_handle_task(task):
|
||||
|
||||
init_kb(task, vector_size)
|
||||
|
||||
task_type = task.get("task_type", "")
|
||||
if task_type == "dataflow":
|
||||
task_dataflow_dsl = task["dsl"]
|
||||
task_dataflow_id = task["dataflow_id"]
|
||||
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
|
||||
await run_dataflow(task)
|
||||
return
|
||||
elif task_type == "raptor":
|
||||
|
||||
if task_type == "raptor":
|
||||
# bind LLM for raptor
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
|
||||
Reference in New Issue
Block a user