mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: support dataflow run. (#10182)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -144,11 +144,10 @@ def run():
|
|||||||
|
|
||||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
task_id = get_uuid()
|
task_id = get_uuid()
|
||||||
flow_id = get_uuid()
|
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||||
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, file=files[0], task_id=task_id, flow_id=flow_id, priority=0)
|
|
||||||
if not ok:
|
if not ok:
|
||||||
return server_error_response(error_message)
|
return get_data_error_result(message=error_message)
|
||||||
return get_json_result(data={"task_id": task_id, "message_id": flow_id})
|
return get_json_result(data={"message_id": task_id})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
|
||||||
|
|||||||
@ -472,14 +472,10 @@ def has_canceled(task_id):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def queue_dataflow(dsl:str, tenant_id:str, task_id:str, flow_id:str=None, doc_id:str=None, file:dict=None, priority: int=0, callback=None) -> tuple[bool, str]:
|
def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str="x", file:dict=None, priority: int=0) -> tuple[bool, str]:
|
||||||
"""
|
|
||||||
Returns a tuple (success: bool, error_message: str).
|
|
||||||
"""
|
|
||||||
_ = callback
|
|
||||||
|
|
||||||
task = dict(
|
task = dict(
|
||||||
id=get_uuid() if not task_id else task_id,
|
id=task_id,
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=100000000,
|
to_page=100000000,
|
||||||
@ -490,15 +486,10 @@ def queue_dataflow(dsl:str, tenant_id:str, task_id:str, flow_id:str=None, doc_id
|
|||||||
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
|
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
|
||||||
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
||||||
|
|
||||||
kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
|
||||||
if not kb_id:
|
|
||||||
return False, f"Can't find KB of this document: {doc_id}"
|
|
||||||
|
|
||||||
task["kb_id"] = kb_id
|
|
||||||
task["tenant_id"] = tenant_id
|
task["tenant_id"] = tenant_id
|
||||||
task["task_type"] = "dataflow"
|
task["task_type"] = "dataflow"
|
||||||
task["dsl"] = dsl
|
task["dataflow_id"] = flow_id
|
||||||
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
|
|
||||||
task["file"] = file
|
task["file"] = file
|
||||||
|
|
||||||
if not REDIS_CONN.queue_product(
|
if not REDIS_CONN.queue_product(
|
||||||
|
|||||||
@ -35,9 +35,9 @@ class ProcessBase(ComponentBase):
|
|||||||
def __init__(self, pipeline, id, param: ProcessParamBase):
|
def __init__(self, pipeline, id, param: ProcessParamBase):
|
||||||
super().__init__(pipeline, id, param)
|
super().__init__(pipeline, id, param)
|
||||||
if hasattr(self._canvas, "callback"):
|
if hasattr(self._canvas, "callback"):
|
||||||
self.callback = partial(self._canvas.callback, self.component_name)
|
self.callback = partial(self._canvas.callback, id)
|
||||||
else:
|
else:
|
||||||
self.callback = partial(lambda *args, **kwargs: None, self.component_name)
|
self.callback = partial(lambda *args, **kwargs: None, id)
|
||||||
|
|
||||||
async def invoke(self, **kwargs) -> dict[str, Any]:
|
async def invoke(self, **kwargs) -> dict[str, Any]:
|
||||||
self.set_output("_created_time", time.perf_counter())
|
self.set_output("_created_time", time.perf_counter())
|
||||||
|
|||||||
@ -76,7 +76,6 @@ class ParserParam(ProcessParamBase):
|
|||||||
self.setups = {
|
self.setups = {
|
||||||
"pdf": {
|
"pdf": {
|
||||||
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
||||||
"llm_id": "",
|
|
||||||
"lang": "Chinese",
|
"lang": "Chinese",
|
||||||
"suffix": [
|
"suffix": [
|
||||||
"pdf",
|
"pdf",
|
||||||
@ -98,8 +97,8 @@ class ParserParam(ProcessParamBase):
|
|||||||
],
|
],
|
||||||
"output_format": "json",
|
"output_format": "json",
|
||||||
},
|
},
|
||||||
"markdown": {
|
"text&markdown": {
|
||||||
"suffix": ["md", "markdown", "mdx"],
|
"suffix": ["md", "markdown", "mdx", "txt"],
|
||||||
"output_format": "json",
|
"output_format": "json",
|
||||||
},
|
},
|
||||||
"slides": {
|
"slides": {
|
||||||
@ -156,13 +155,10 @@ class ParserParam(ProcessParamBase):
|
|||||||
pdf_config = self.setups.get("pdf", {})
|
pdf_config = self.setups.get("pdf", {})
|
||||||
if pdf_config:
|
if pdf_config:
|
||||||
pdf_parse_method = pdf_config.get("parse_method", "")
|
pdf_parse_method = pdf_config.get("parse_method", "")
|
||||||
self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"])
|
self.check_empty(pdf_parse_method, "Parse method abnormal.")
|
||||||
|
|
||||||
if pdf_parse_method not in ["deepdoc", "plain_text"]:
|
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
|
||||||
self.check_empty(pdf_config.get("llm_id"), "VLM")
|
self.check_empty(pdf_config.get("lang", ""), "Language")
|
||||||
|
|
||||||
pdf_language = pdf_config.get("lang", "")
|
|
||||||
self.check_empty(pdf_language, "Language")
|
|
||||||
|
|
||||||
pdf_output_format = pdf_config.get("output_format", "")
|
pdf_output_format = pdf_config.get("output_format", "")
|
||||||
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
|
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
|
||||||
@ -226,8 +222,7 @@ class Parser(ProcessBase):
|
|||||||
lines, _ = PlainParser()(blob)
|
lines, _ = PlainParser()(blob)
|
||||||
bboxes = [{"text": t} for t, _ in lines]
|
bboxes = [{"text": t} for t, _ in lines]
|
||||||
else:
|
else:
|
||||||
assert conf.get("llm_id")
|
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
|
||||||
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("llm_id"), lang=self._param.setups["pdf"].get("lang"))
|
|
||||||
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
|
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
|
||||||
bboxes = []
|
bboxes = []
|
||||||
for t, poss in lines:
|
for t, poss in lines:
|
||||||
@ -236,6 +231,7 @@ class Parser(ProcessBase):
|
|||||||
|
|
||||||
if conf.get("output_format") == "json":
|
if conf.get("output_format") == "json":
|
||||||
self.set_output("json", bboxes)
|
self.set_output("json", bboxes)
|
||||||
|
|
||||||
if conf.get("output_format") == "markdown":
|
if conf.get("output_format") == "markdown":
|
||||||
mkdn = ""
|
mkdn = ""
|
||||||
for b in bboxes:
|
for b in bboxes:
|
||||||
@ -299,7 +295,6 @@ class Parser(ProcessBase):
|
|||||||
|
|
||||||
def _markdown(self, name, blob):
|
def _markdown(self, name, blob):
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
from rag.app.naive import Markdown as naive_markdown_parser
|
from rag.app.naive import Markdown as naive_markdown_parser
|
||||||
from rag.nlp import concat_img
|
from rag.nlp import concat_img
|
||||||
|
|
||||||
@ -330,22 +325,6 @@ class Parser(ProcessBase):
|
|||||||
else:
|
else:
|
||||||
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
|
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
|
||||||
|
|
||||||
def _text(self, name, blob):
|
|
||||||
from deepdoc.parser.utils import get_text
|
|
||||||
|
|
||||||
self.callback(random.randint(1, 5) / 100.0, "Start to work on a text.")
|
|
||||||
conf = self._param.setups["text"]
|
|
||||||
self.set_output("output_format", conf["output_format"])
|
|
||||||
|
|
||||||
# parse binary to text
|
|
||||||
text_content = get_text(name, binary=blob)
|
|
||||||
|
|
||||||
if conf.get("output_format") == "json":
|
|
||||||
result = [{"text": text_content}]
|
|
||||||
self.set_output("json", result)
|
|
||||||
else:
|
|
||||||
result = text_content
|
|
||||||
self.set_output("text", result)
|
|
||||||
|
|
||||||
def _image(self, from_upstream: ParserFromUpstream):
|
def _image(self, from_upstream: ParserFromUpstream):
|
||||||
from deepdoc.vision import OCR
|
from deepdoc.vision import OCR
|
||||||
@ -519,11 +498,10 @@ class Parser(ProcessBase):
|
|||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
function_map = {
|
function_map = {
|
||||||
"pdf": self._pdf,
|
"pdf": self._pdf,
|
||||||
"markdown": self._markdown,
|
"text&markdown": self._markdown,
|
||||||
"spreadsheet": self._spreadsheet,
|
"spreadsheet": self._spreadsheet,
|
||||||
"slides": self._slides,
|
"slides": self._slides,
|
||||||
"word": self._word,
|
"word": self._word,
|
||||||
"text": self._text,
|
|
||||||
"image": self._image,
|
"image": self._image,
|
||||||
"audio": self._audio,
|
"audio": self._audio,
|
||||||
"email": self._email,
|
"email": self._email,
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from timeit import default_timer as timer
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from agent.canvas import Graph
|
from agent.canvas import Graph
|
||||||
@ -38,25 +38,26 @@ class Pipeline(Graph):
|
|||||||
|
|
||||||
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
|
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
|
||||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||||
|
timestamp = timer()
|
||||||
try:
|
try:
|
||||||
bin = REDIS_CONN.get(log_key)
|
bin = REDIS_CONN.get(log_key)
|
||||||
obj = json.loads(bin.encode("utf-8"))
|
obj = json.loads(bin.encode("utf-8"))
|
||||||
if obj:
|
if obj:
|
||||||
if obj[-1]["component_name"] == component_name:
|
if obj[-1]["component_id"] == component_name:
|
||||||
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
|
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": timestamp-obj[-1]["trace"][-1]["timestamp"]})
|
||||||
else:
|
else:
|
||||||
obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]})
|
obj.append({"component_id": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}]})
|
||||||
else:
|
else:
|
||||||
obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}]
|
obj = [{"component_id": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}]}]
|
||||||
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
||||||
if self._doc_id:
|
if self._doc_id:
|
||||||
percentage = 1./len(self.components.items())
|
percentage = 1./len(self.components.items())
|
||||||
msg = ""
|
msg = ""
|
||||||
finished = 0.
|
finished = 0.
|
||||||
for o in obj:
|
for o in obj:
|
||||||
if o['component_name'] == "END":
|
if o['component_id'] == "END":
|
||||||
continue
|
continue
|
||||||
msg += f"\n[{o['component_name']}]:\n"
|
msg += f"\n[{o['component_id']}]:\n"
|
||||||
for t in o["trace"]:
|
for t in o["trace"]:
|
||||||
msg += "%s: %s\n"%(t["datetime"], t["message"])
|
msg += "%s: %s\n"%(t["datetime"], t["message"])
|
||||||
if t["progress"] < 0:
|
if t["progress"] < 0:
|
||||||
|
|||||||
@ -30,7 +30,7 @@ def print_logs(pipeline: Pipeline):
|
|||||||
while True:
|
while True:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
logs = pipeline.fetch_logs()
|
logs = pipeline.fetch_logs()
|
||||||
logs_str = json.dumps(logs)
|
logs_str = json.dumps(logs, ensure_ascii=False)
|
||||||
if logs_str != last_logs:
|
if logs_str != last_logs:
|
||||||
print(logs_str)
|
print(logs_str)
|
||||||
last_logs = logs_str
|
last_logs = logs_str
|
||||||
|
|||||||
@ -20,8 +20,7 @@ import random
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
from api.utils import get_uuid
|
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from api.utils.base64_image import image2id
|
from api.utils.base64_image import image2id
|
||||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||||
@ -29,7 +28,6 @@ from graphrag.general.index import run_graphrag
|
|||||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
from rag.flow.pipeline import Pipeline
|
from rag.flow.pipeline import Pipeline
|
||||||
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -45,10 +43,8 @@ import signal
|
|||||||
import trio
|
import trio
|
||||||
import exceptiongroup
|
import exceptiongroup
|
||||||
import faulthandler
|
import faulthandler
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from peewee import DoesNotExist
|
from peewee import DoesNotExist
|
||||||
|
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
@ -216,7 +212,11 @@ async def collect():
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
canceled = False
|
canceled = False
|
||||||
|
if msg.get("doc_id", "") == "x":
|
||||||
|
task = msg
|
||||||
|
else:
|
||||||
task = TaskService.get_task(msg["id"])
|
task = TaskService.get_task(msg["id"])
|
||||||
|
|
||||||
if task:
|
if task:
|
||||||
canceled = has_canceled(task["id"])
|
canceled = has_canceled(task["id"])
|
||||||
if not task or canceled:
|
if not task or canceled:
|
||||||
@ -229,9 +229,8 @@ async def collect():
|
|||||||
task_type = msg.get("task_type", "")
|
task_type = msg.get("task_type", "")
|
||||||
task["task_type"] = task_type
|
task["task_type"] = task_type
|
||||||
if task_type == "dataflow":
|
if task_type == "dataflow":
|
||||||
task["tenant_id"]=msg.get("tenant_id", "")
|
task["tenant_id"] = msg["tenant_id"]
|
||||||
task["dsl"] = msg.get("dsl", "")
|
task["dataflow_id"] = msg["dataflow_id"]
|
||||||
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
|
|
||||||
task["kb_id"] = msg.get("kb_id", "")
|
task["kb_id"] = msg.get("kb_id", "")
|
||||||
return redis_msg, task
|
return redis_msg, task
|
||||||
|
|
||||||
@ -460,13 +459,12 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
return tk_count, vector_size
|
return tk_count, vector_size
|
||||||
|
|
||||||
|
|
||||||
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None):
|
async def run_dataflow(task: dict):
|
||||||
_ = callback
|
dataflow_id = task["dataflow_id"]
|
||||||
|
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||||
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id)
|
pipeline = Pipeline(cvs.dsl, tenant_id=task["tenant_id"], doc_id=task["doc_id"], task_id=task["id"], flow_id=dataflow_id)
|
||||||
pipeline.reset()
|
pipeline.reset()
|
||||||
|
await pipeline.run(file=task.get("file"))
|
||||||
await pipeline.run()
|
|
||||||
|
|
||||||
|
|
||||||
@timeout(3600)
|
@timeout(3600)
|
||||||
@ -513,6 +511,12 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
|||||||
|
|
||||||
@timeout(60*60*2, 1)
|
@timeout(60*60*2, 1)
|
||||||
async def do_handle_task(task):
|
async def do_handle_task(task):
|
||||||
|
task_type = task.get("task_type", "")
|
||||||
|
|
||||||
|
if task_type == "dataflow" and task.get("doc_id", "") == "x":
|
||||||
|
await run_dataflow(task)
|
||||||
|
return
|
||||||
|
|
||||||
task_id = task["id"]
|
task_id = task["id"]
|
||||||
task_from_page = task["from_page"]
|
task_from_page = task["from_page"]
|
||||||
task_to_page = task["to_page"]
|
task_to_page = task["to_page"]
|
||||||
@ -526,6 +530,7 @@ async def do_handle_task(task):
|
|||||||
task_parser_config = task["parser_config"]
|
task_parser_config = task["parser_config"]
|
||||||
task_start_ts = timer()
|
task_start_ts = timer()
|
||||||
|
|
||||||
|
|
||||||
# prepare the progress callback function
|
# prepare the progress callback function
|
||||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||||
|
|
||||||
@ -554,13 +559,11 @@ async def do_handle_task(task):
|
|||||||
|
|
||||||
init_kb(task, vector_size)
|
init_kb(task, vector_size)
|
||||||
|
|
||||||
task_type = task.get("task_type", "")
|
|
||||||
if task_type == "dataflow":
|
if task_type == "dataflow":
|
||||||
task_dataflow_dsl = task["dsl"]
|
await run_dataflow(task)
|
||||||
task_dataflow_id = task["dataflow_id"]
|
|
||||||
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
|
|
||||||
return
|
return
|
||||||
elif task_type == "raptor":
|
|
||||||
|
if task_type == "raptor":
|
||||||
# bind LLM for raptor
|
# bind LLM for raptor
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
# run RAPTOR
|
# run RAPTOR
|
||||||
|
|||||||
Reference in New Issue
Block a user