Feat: refine dataflow and initialize dataflow app (#9952)

### What problem does this PR solve?

Refine dataflow and initialize dataflow app.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-09-05 18:50:46 +08:00
committed by GitHub
parent 9aa8cfb73a
commit 45f52e85d7
21 changed files with 959 additions and 256 deletions

View File

@ -418,12 +418,10 @@ def setting():
return get_data_error_result(message="canvas not found.")
flow = flow.to_dict()
flow["title"] = req["title"]
if req["description"]:
flow["description"] = req["description"]
if req["permission"]:
flow["permission"] = req["permission"]
if req["avatar"]:
flow["avatar"] = req["avatar"]
for key in ["description", "permission", "avatar"]:
if value := req.get(key):
flow[key] = value
num= UserCanvasService.update_by_id(req["id"], flow)
return get_json_result(data=num)

353
api/apps/dataflow_app.py Normal file
View File

@ -0,0 +1,353 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
import sys
import time
from functools import partial
import trio
from flask import request
from flask_login import current_user, login_required
from agent.canvas import Canvas
from agent.component import LLM
from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from api.db.services.task_service import queue_dataflow
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline
@manager.route("/templates", methods=["GET"]) # noqa: F821
@login_required
def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)])
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1))
@manager.route("/rm", methods=["POST"]) # noqa: F821
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)
@manager.route("/set", methods=["POST"]) # noqa: F821
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
req["canvas_category"] = CanvasCategory.DataFlow
if "id" not in req:
req["user_id"] = current_user.id
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow):
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
req["id"] = get_uuid()
if not UserCanvasService.save(**req):
return get_data_error_result(message="Fail to save canvas.")
else:
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req)
# save version
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
UserCanvasVersionService.delete_all_versions(req["id"])
return get_json_result(data=req)
@manager.route("/get/<canvas_id>", methods=["GET"]) # noqa: F821
@login_required
def get(canvas_id):
if not UserCanvasService.accessible(canvas_id, current_user.id):
return get_data_error_result(message="canvas not found.")
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
return get_json_result(data=c)
@manager.route("/run", methods=["POST"]) # noqa: F821
@validate_request("id")
@login_required
def run():
req = request.json
flow_id = req.get("id", "")
doc_id = req.get("doc_id", "")
if not all([flow_id, doc_id]):
return get_data_error_result(message="id and doc_id are required.")
if not DocumentService.get_by_id(doc_id):
return get_data_error_result(message=f"Document for {doc_id} not found.")
user_id = req.get("user_id", current_user.id)
if not UserCanvasService.accessible(flow_id, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(flow_id)
if not e:
return get_data_error_result(message="canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
task_id = get_uuid()
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0)
if not ok:
return server_error_response(error_message)
return get_json_result(data={"task_id": task_id, "flow_id": flow_id})
@manager.route("/reset", methods=["POST"]) # noqa: F821
@validate_request("id")
@login_required
def reset():
req = request.json
flow_id = req.get("id", "")
if not flow_id:
return get_data_error_result(message="id is required.")
if not UserCanvasService.accessible(flow_id, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
task_id = req.get("task_id", "")
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id)
dataflow.reset()
req["dsl"] = json.loads(str(dataflow))
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id):
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
user_id = cvs["user_id"]
def structured(filename, filetype, blob, content_type):
nonlocal user_id
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
location = get_uuid()
FileService.put_blob(user_id, location, blob)
return {
"id": location,
"name": filename,
"size": sys.getsizeof(blob),
"extension": filename.split(".")[-1].lower(),
"mime_type": content_type,
"created_by": user_id,
"created_at": time.time(),
"preview_url": None,
}
if request.args.get("url"):
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter
try:
url = request.args.get("url")
filename = re.sub(r"\?.*", "", url.split("/")[-1])
async def adownload():
browser_config = BrowserConfig(
headless=True,
verbose=False,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False)
result: CrawlResult = await crawler.arun(url=url, config=crawler_config)
return result
page = trio.run(adownload())
if page.pdf:
if filename.split(".")[-1].lower() != "pdf":
filename += ".pdf"
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
except Exception as e:
return server_error_response(e)
file = request.files["file"]
try:
DocumentService.check_doc_health(user_id, file.filename)
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
except Exception as e:
return server_error_response(e)
@manager.route("/input_form", methods=["GET"]) # noqa: F821
@login_required
def input_form():
flow_id = request.args.get("id")
cpn_id = request.args.get("component_id")
try:
e, user_canvas = UserCanvasService.get_by_id(flow_id)
if not e:
return get_data_error_result(message="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=flow_id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="")
return get_json_result(data=dataflow.get_component_input_form(cpn_id))
except Exception as e:
return server_error_response(e)
@manager.route("/debug", methods=["POST"]) # noqa: F821
@validate_request("id", "component_id", "params")
@login_required
def debug():
req = request.json
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
canvas.message_id = get_uuid()
component = canvas.get_component(req["component_id"])["obj"]
component.reset()
if isinstance(component, LLM):
component.set_debug_inputs(req["params"])
component.invoke(**{k: o["value"] for k, o in req["params"].items()})
outputs = component.output()
for k in outputs.keys():
if isinstance(outputs[k], partial):
txt = ""
for c in outputs[k]():
txt += c
outputs[k] = txt
return get_json_result(data=outputs)
except Exception as e:
return server_error_response(e)
# api get list version dsl of canvas
@manager.route("/getlistversion/<canvas_id>", methods=["GET"]) # noqa: F821
@login_required
def getlistversion(canvas_id):
try:
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1)
return get_json_result(data=list)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")
# api get version dsl of canvas
@manager.route("/getversion/<version_id>", methods=["GET"]) # noqa: F821
@login_required
def getversion(version_id):
try:
e, version = UserCanvasVersionService.get_by_id(version_id)
if version:
return get_json_result(data=version.to_dict())
except Exception as e:
return get_json_result(data=f"Error getting history file: {e}")
@manager.route("/listteam", methods=["GET"]) # noqa: F821
@login_required
def list_canvas():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 150))
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
canvas, total = UserCanvasService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow
)
return get_json_result(data={"canvas": canvas, "total": total})
except Exception as e:
return server_error_response(e)
@manager.route("/setting", methods=["POST"]) # noqa: F821
@validate_request("id", "title", "permission")
@login_required
def setting():
req = request.json
req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
e, flow = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
flow = flow.to_dict()
flow["title"] = req["title"]
for key in ("description", "permission", "avatar"):
if value := req.get(key):
flow[key] = value
num = UserCanvasService.update_by_id(req["id"], flow)
return get_json_result(data=num)
@manager.route("/trace", methods=["GET"]) # noqa: F821
def trace():
dataflow_id = request.args.get("dataflow_id")
task_id = request.args.get("task_id")
if not all([dataflow_id, task_id]):
return get_data_error_result(message="dataflow_id and task_id are required.")
e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id)
if not e:
return get_data_error_result(message="dataflow not found.")
dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False)
dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id)
log = dataflow.fetch_logs()
return get_json_result(data=log)

View File

@ -54,15 +54,15 @@ def trim_header_by_lines(text: str, max_length) -> str:
class TaskService(CommonService):
"""Service class for managing document processing tasks.
This class extends CommonService to provide specialized functionality for document
processing task management, including task creation, progress tracking, and chunk
management. It handles various document types (PDF, Excel, etc.) and manages their
processing lifecycle.
The class implements a robust task queue system with retry mechanisms and progress
tracking, supporting both synchronous and asynchronous task execution.
Attributes:
model: The Task model class for database operations.
"""
@ -72,14 +72,14 @@ class TaskService(CommonService):
@DB.connection_context()
def get_task(cls, task_id):
"""Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document,
knowledge base, and tenant information. It also handles task retry logic and
progress updates.
Args:
task_id (str): The unique identifier of the task to retrieve.
Returns:
dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit.
@ -139,13 +139,13 @@ class TaskService(CommonService):
@DB.connection_context()
def get_tasks(cls, doc_id: str):
"""Retrieve all tasks associated with a document.
This method fetches all processing tasks for a given document, ordered by page
number and creation time. It includes task progress and chunk information.
Args:
doc_id (str): The unique identifier of the document.
Returns:
list[dict]: List of task dictionaries containing task details.
Returns None if no tasks are found.
@ -170,10 +170,10 @@ class TaskService(CommonService):
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
"""Update the chunk IDs associated with a task.
This method updates the chunk_ids field of a task, which stores the IDs of
processed document chunks in a space-separated string format.
Args:
id (str): The unique identifier of the task.
chunk_ids (str): Space-separated string of chunk identifiers.
@ -184,11 +184,11 @@ class TaskService(CommonService):
@DB.connection_context()
def get_ongoing_doc_name(cls):
"""Get names of documents that are currently being processed.
This method retrieves information about documents that are in the processing state,
including their locations and associated IDs. It uses database locking to ensure
thread safety when accessing the task information.
Returns:
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
for documents currently being processed. Returns empty list if
@ -238,14 +238,14 @@ class TaskService(CommonService):
@DB.connection_context()
def do_cancel(cls, id):
"""Check if a task should be cancelled based on its document status.
This method determines whether a task should be cancelled by checking the
associated document's run status and progress. A task should be cancelled
if its document is marked for cancellation or has negative progress.
Args:
id (str): The unique identifier of the task to check.
Returns:
bool: True if the task should be cancelled, False otherwise.
"""
@ -311,18 +311,18 @@ class TaskService(CommonService):
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
"""Create and queue document processing tasks.
This function creates processing tasks for a document based on its type and configuration.
It handles different document types (PDF, Excel, etc.) differently and manages task
chunking and configuration. It also implements task reuse optimization by checking
for previously completed tasks.
Args:
doc (dict): Document dictionary containing metadata and configuration.
bucket (str): Storage bucket name where the document is stored.
name (str): File name of the document.
priority (int, optional): Priority level for task queueing (default is 0).
Note:
- For PDF documents, tasks are created per page range based on configuration
- For Excel documents, tasks are created per row range
@ -410,19 +410,19 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
"""Attempt to reuse chunks from previous tasks for optimization.
This function checks if chunks from previously completed tasks can be reused for
the current task, which can significantly improve processing efficiency. It matches
tasks based on page ranges and configuration digests.
Args:
task (dict): Current task dictionary to potentially reuse chunks for.
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
chunking_config (dict): Configuration dictionary for chunk processing.
Returns:
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
Note:
Chunks can only be reused if:
- A previous task exists with matching page range and configuration digest
@ -470,3 +470,39 @@ def has_canceled(task_id):
except Exception as e:
logging.exception(e)
return False
def queue_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, priority: int, callback=None) -> tuple[bool, str]:
"""
Returns a tuple (success: bool, error_message: str).
"""
_ = callback
task = dict(
id=get_uuid() if not task_id else task_id,
doc_id=doc_id,
from_page=0,
to_page=100000000,
task_type="dataflow",
priority=priority,
)
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
kb_id = DocumentService.get_knowledgebase_id(doc_id)
if not kb_id:
return False, f"Can't find KB of this document: {doc_id}"
task["kb_id"] = kb_id
task["tenant_id"] = tenant_id
task["task_type"] = "dataflow"
task["dsl"] = dsl
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
if not REDIS_CONN.queue_product(
get_svr_queue_name(priority), message=task
):
return False, "Can't access Redis. Please check the Redis' status."
return True, ""

View File

@ -2690,21 +2690,21 @@
"status": "1",
"llm": [
{
"llm_name": "Qwen3-Embedding-8B",
"llm_name": "Qwen/Qwen3-Embedding-8B",
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
"max_tokens": 32000,
"model_type": "embedding",
"is_tools": false
},
{
"llm_name": "Qwen3-Embedding-4B",
"llm_name": "Qwen/Qwen3-Embedding-4B",
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
"max_tokens": 32000,
"model_type": "embedding",
"is_tools": false
},
{
"llm_name": "Qwen3-Embedding-0.6B",
"llm_name": "Qwen/Qwen3-Embedding-0.6B",
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
"max_tokens": 32000,
"model_type": "embedding",

View File

@ -14,36 +14,45 @@
# limitations under the License.
#
import os
import importlib
import inspect
import pkgutil
from pathlib import Path
from types import ModuleType
from typing import Dict, Type
_package_path = os.path.dirname(__file__)
__all_classes: Dict[str, Type] = {}
def _import_submodules() -> None:
for filename in os.listdir(_package_path): # noqa: F821
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
continue
module_name = filename[:-3]
_pkg_dir = Path(__file__).resolve().parent
_pkg_name = __name__
def _should_skip_module(mod_name: str) -> bool:
leaf = mod_name.rsplit(".", 1)[-1]
return leaf in {"__init__"} or leaf.startswith("__") or leaf.startswith("_") or leaf.startswith("base")
def _import_submodules() -> None:
for modinfo in pkgutil.walk_packages([str(_pkg_dir)], prefix=_pkg_name + "."): # noqa: F821
mod_name = modinfo.name
if _should_skip_module(mod_name): # noqa: F821
continue
try:
module = importlib.import_module(f".{module_name}", package=__name__)
module = importlib.import_module(mod_name)
_extract_classes_from_module(module) # noqa: F821
except ImportError as e:
print(f"Warning: Failed to import module {module_name}: {str(e)}")
print(f"Warning: Failed to import module {mod_name}: {e}")
def _extract_classes_from_module(module: ModuleType) -> None:
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
obj.__module__ == module.__name__ and not name.startswith("_")):
if inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_"):
__all_classes[name] = obj
globals()[name] = obj
_import_submodules()
__all__ = list(__all_classes.keys()) + ["__all_classes"]
del _package_path, _import_submodules, _extract_classes_from_module
del _pkg_dir, _pkg_name, _import_submodules, _extract_classes_from_module

View File

@ -1,5 +1,5 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import os
import logging
import os
import time
from functools import partial
from typing import Any
import trio
from agent.component.base import ComponentParamBase, ComponentBase
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
@ -31,14 +33,16 @@ class ProcessParamBase(ComponentParamBase):
class ProcessBase(ComponentBase):
def __init__(self, pipeline, id, param: ProcessParamBase):
super().__init__(pipeline, id, param)
self.callback = partial(self._canvas.callback, self.component_name)
if hasattr(self._canvas, "callback"):
self.callback = partial(self._canvas.callback, self.component_name)
else:
self.callback = partial(lambda *args, **kwargs: None, self.component_name)
async def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter())
for k,v in kwargs.items():
for k, v in kwargs.items():
self.set_output(k, v)
try:
with trio.fail_after(self._param.timeout):
@ -54,6 +58,6 @@ class ProcessBase(ComponentBase):
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))
async def _invoke(self, **kwargs):
raise NotImplementedError()

View File

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -1,5 +1,5 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import trio
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from graphrag.utils import get_llm_cache, chat_limiter, set_llm_cache
from graphrag.utils import chat_limiter, get_llm_cache, set_llm_cache
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.chunker.schema import ChunkerFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
from rag.prompts.prompts import keyword_extraction, question_proposal
@ -26,7 +29,23 @@ from rag.prompts.prompts import keyword_extraction, question_proposal
class ChunkerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.method_options = ["general", "q&a", "resume", "manual", "table", "paper", "book", "laws", "presentation", "one"]
self.method_options = [
# General
"general",
"onetable",
# Customer Service
"q&a",
"manual",
# Recruitment
"resume",
# Education & Research
"book",
"paper",
"laws",
"presentation",
# Other
# "Tag" # TODO: Other method
]
self.method = "general"
self.chunk_token_size = 512
self.delimiter = "\n"
@ -35,10 +54,7 @@ class ChunkerParam(ProcessParamBase):
self.auto_keywords = 0
self.auto_questions = 0
self.tag_sets = []
self.llm_setting = {
"llm_name": "",
"lang": "Chinese"
}
self.llm_setting = {"llm_name": "", "lang": "Chinese"}
def check(self):
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
@ -48,53 +64,79 @@ class ChunkerParam(ProcessParamBase):
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
def get_input_form(self) -> dict[str, dict]:
return {}
class Chunker(ProcessBase):
component_name = "Chunker"
def _general(self, **kwargs):
self.callback(random.randint(1,5)/100., "Start to chunk via `General`.")
if kwargs.get("output_format") in ["markdown", "text"]:
cks = naive_merge(kwargs.get(kwargs["output_format"]), self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
def _general(self, from_upstream: ChunkerFromUpstream):
self.callback(random.randint(1, 5) / 100.0, "Start to chunk via `General`.")
if from_upstream.output_format in ["markdown", "text"]:
if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result
else: # == "text"
payload = from_upstream.text_result
if not payload:
payload = ""
cks = naive_merge(
payload,
self._param.chunk_token_size,
self._param.delimiter,
self._param.overlapped_percent,
)
return [{"text": c} for c in cks]
sections, section_images = [], []
for o in kwargs["json"]:
sections.append((o["text"], o.get("position_tag","")))
for o in from_upstream.json_result or []:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(o.get("image"))
chunks, images = naive_merge_with_images(sections, section_images,self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
return [{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c)
} for c,img in zip(chunks,images)]
chunks, images = naive_merge_with_images(
sections,
section_images,
self._param.chunk_token_size,
self._param.delimiter,
self._param.overlapped_percent,
)
def _q_and_a(self, **kwargs):
return [
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c),
}
for c, img in zip(chunks, images)
]
def _q_and_a(self, from_upstream: ChunkerFromUpstream):
pass
def _resume(self, **kwargs):
def _resume(self, from_upstream: ChunkerFromUpstream):
pass
def _manual(self, **kwargs):
def _manual(self, from_upstream: ChunkerFromUpstream):
pass
def _table(self, **kwargs):
def _table(self, from_upstream: ChunkerFromUpstream):
pass
def _paper(self, **kwargs):
def _paper(self, from_upstream: ChunkerFromUpstream):
pass
def _book(self, **kwargs):
def _book(self, from_upstream: ChunkerFromUpstream):
pass
def _laws(self, **kwargs):
def _laws(self, from_upstream: ChunkerFromUpstream):
pass
def _presentation(self, **kwargs):
def _presentation(self, from_upstream: ChunkerFromUpstream):
pass
def _one(self, **kwargs):
def _one(self, from_upstream: ChunkerFromUpstream):
pass
async def _invoke(self, **kwargs):
@ -110,7 +152,14 @@ class Chunker(ProcessBase):
"presentation": self._presentation,
"one": self._one,
}
chunks = function_map[self._param.method](**kwargs)
try:
from_upstream = ChunkerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
chunks = function_map[self._param.method](from_upstream)
llm_setting = self._param.llm_setting
async def auto_keywords():

View File

@ -0,0 +1,37 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ChunkerFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
blob: bytes
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@ -1,5 +1,5 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -27,6 +27,9 @@ class FileParam(ProcessParamBase):
def check(self):
pass
def get_input_form(self) -> dict[str, dict]:
return {}
class File(ProcessBase):
component_name = "File"

View File

@ -1,107 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import trio
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import RAGFlowPdfParser, PlainParser, VisionParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.llm.cv_model import Base as VLM
from deepdoc.parser import ExcelParser
class ParserParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.setups = {
"pdf": {
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
"vlm_name": "",
"lang": "Chinese",
"suffix": ["pdf"],
"output_format": "json"
},
"excel": {
"output_format": "html"
},
"ppt": {},
"image": {
"parse_method": "ocr"
},
"email": {},
"text": {},
"audio": {},
"video": {},
}
def check(self):
if self.setups["pdf"].get("parse_method") not in ["deepdoc", "plain_text"]:
assert self.setups["pdf"].get("vlm_name"), "No VLM specified."
assert self.setups["pdf"].get("lang"), "No language specified."
class Parser(ProcessBase):
component_name = "Parser"
def _pdf(self, blob):
self.callback(random.randint(1,5)/100., "Start to work on a PDF.")
conf = self._param.setups["pdf"]
self.set_output("output_format", conf["output_format"])
if conf.get("parse_method") == "deepdoc":
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
elif conf.get("parse_method") == "plain_text":
lines,_ = PlainParser()(blob)
bboxes = [{"text": t} for t,_ in lines]
else:
assert conf.get("vlm_name")
vision_model = LLMBundle(self._canvas.tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(bin, callback=self.callback)
bboxes = []
for t, poss in lines:
pn, x0, x1, top, bott = poss.split(" ")
bboxes.append({"page_number": int(pn), "x0": int(x0), "x1": int(x1), "top": int(top), "bottom": int(bott), "text": t})
self.set_output("json", bboxes)
mkdn = ""
for b in bboxes:
if b.get("layout_type", "") == "title":
mkdn += "\n## "
if b.get("layout_type", "") == "figure":
mkdn += "\n![Image]({})".format(VLM.image2base64(b["image"]))
continue
mkdn += b.get("text", "") + "\n"
self.set_output("markdown", mkdn)
def _excel(self, blob):
self.callback(random.randint(1,5)/100., "Start to work on a Excel.")
conf = self._param.setups["excel"]
excel_parser = ExcelParser()
if conf.get("output_format") == "html":
html = excel_parser.html(blob,1000000000)
self.set_output("html", html)
elif conf.get("output_format") == "json":
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
elif conf.get("output_format") == "markdown":
self.set_output("markdown", excel_parser.markdown(blob))
async def _invoke(self, **kwargs):
function_map = {
"pdf": self._pdf,
}
for p_type, conf in self._param.setups.items():
if kwargs.get("name", "").split(".")[-1].lower() not in conf.get("suffix", []):
continue
await trio.to_thread.run_sync(function_map[p_type], kwargs["blob"])
break

View File

@ -0,0 +1,14 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

154
rag/flow/parser/parser.py Normal file
View File

@ -0,0 +1,154 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import trio
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser import ExcelParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.parser.schema import ParserFromUpstream
from rag.llm.cv_model import Base as VLM
class ParserParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.allowed_output_format = {
"pdf": ["json", "markdown"],
"excel": ["json", "markdown", "html"],
"ppt": [],
"image": [],
"email": [],
"text": [],
"audio": [],
"video": [],
}
self.setups = {
"pdf": {
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
"vlm_name": "",
"lang": "Chinese",
"suffix": ["pdf"],
"output_format": "json",
},
"excel": {
"output_format": "html",
"suffix": ["xls", "xlsx", "csv"],
},
"ppt": {},
"image": {
"parse_method": "ocr",
},
"email": {},
"text": {},
"audio": {},
"video": {},
}
def check(self):
pdf_config = self.setups.get("pdf", {})
if pdf_config:
pdf_parse_method = pdf_config.get("parse_method", "")
self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"])
if pdf_parse_method not in ["deepdoc", "plain_text"]:
self.check_empty(pdf_config.get("vlm_name"), "VLM")
pdf_language = pdf_config.get("lang", "")
self.check_empty(pdf_language, "Language")
pdf_output_format = pdf_config.get("output_format", "")
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
excel_config = self.setups.get("excel", "")
if excel_config:
excel_output_format = excel_config.get("output_format", "")
self.check_valid_value(excel_output_format, "Excel output format abnormal.", self.allowed_output_format["excel"])
image_config = self.setups.get("image", "")
if image_config:
image_parse_method = image_config.get("parse_method", "")
self.check_valid_value(image_parse_method.lower(), "Parse method abnormal.", ["ocr"])
def get_input_form(self) -> dict[str, dict]:
return {}
class Parser(ProcessBase):
component_name = "Parser"
def _pdf(self, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.")
conf = self._param.setups["pdf"]
self.set_output("output_format", conf["output_format"])
if conf.get("parse_method") == "deepdoc":
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
elif conf.get("parse_method") == "plain_text":
lines, _ = PlainParser()(blob)
bboxes = [{"text": t} for t, _ in lines]
else:
assert conf.get("vlm_name")
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self._param.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
bboxes = []
for t, poss in lines:
pn, x0, x1, top, bott = poss.split(" ")
bboxes.append({"page_number": int(pn), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
if conf.get("output_format") == "json":
self.set_output("json", bboxes)
if conf.get("output_format") == "markdown":
mkdn = ""
for b in bboxes:
if b.get("layout_type", "") == "title":
mkdn += "\n## "
if b.get("layout_type", "") == "figure":
mkdn += "\n![Image]({})".format(VLM.image2base64(b["image"]))
continue
mkdn += b.get("text", "") + "\n"
self.set_output("markdown", mkdn)
def _excel(self, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Excel.")
conf = self._param.setups["excel"]
self.set_output("output_format", conf["output_format"])
excel_parser = ExcelParser()
if conf.get("output_format") == "html":
html = excel_parser.html(blob, 1000000000)
self.set_output("html", html)
elif conf.get("output_format") == "json":
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
elif conf.get("output_format") == "markdown":
self.set_output("markdown", excel_parser.markdown(blob))
async def _invoke(self, **kwargs):
function_map = {
"pdf": self._pdf,
"excel": self._excel,
}
try:
from_upstream = ParserFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue
await trio.to_thread.run_sync(function_map[p_type], from_upstream.blob)
break

25
rag/flow/parser/schema.py Normal file
View File

@ -0,0 +1,25 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel, ConfigDict, Field
class ParserFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
blob: bytes
model_config = ConfigDict(populate_by_name=True, extra="forbid")

View File

@ -1,5 +1,5 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,14 +18,15 @@ import json
import logging
import random
import time
import trio
from agent.canvas import Graph
from api.db.services.document_service import DocumentService
from rag.utils.redis_conn import REDIS_CONN
class Pipeline(Graph):
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
super().__init__(dsl, tenant_id, task_id)
self._doc_id = doc_id
@ -35,7 +36,7 @@ class Pipeline(Graph):
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
def callback(self, component_name: str, progress: float|int|None=None, message: str = "") -> None:
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
bin = REDIS_CONN.get(log_key)
@ -44,16 +45,10 @@ class Pipeline(Graph):
if obj[-1]["component_name"] == component_name:
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
else:
obj.append({
"component_name": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
})
obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]})
else:
obj = [{
"component_name": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
}]
REDIS_CONN.set_obj(log_key, obj, 60*10)
obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}]
REDIS_CONN.set_obj(log_key, obj, 60 * 10)
except Exception as e:
logging.exception(e)
@ -71,21 +66,19 @@ class Pipeline(Graph):
super().reset()
log_key = f"{self._flow_id}-{self.task_id}-logs"
try:
REDIS_CONN.set_obj(log_key, [], 60*10)
REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e:
logging.exception(e)
async def run(self, **kwargs):
st = time.perf_counter()
if not self.path:
self.path.append("begin")
self.path.append("File")
if self._doc_id:
DocumentService.update_by_id(self._doc_id, {
"progress": random.randint(0,5)/100.,
"progress_msg": "Start the pipeline...",
"process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
DocumentService.update_by_id(
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
)
self.error = ""
idx = len(self.path) - 1
@ -99,23 +92,21 @@ class Pipeline(Graph):
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx-1])
last_cpn = self.get_component_obj(self.path[idx - 1])
cpn_obj = self.get_component_obj(self.path[idx])
async def invoke():
nonlocal last_cpn, cpn_obj
await cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery:
nursery.start_soon(invoke)
if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error)
break
idx += 1
self.path.extend(cpn_obj.get_downstream())
if self._doc_id:
DocumentService.update_by_id(self._doc_id, {
"progress": 1 if not self.error else -1,
"progress_msg": "Pipeline finished...\n" + self.error,
"process_duration": time.perf_counter() - st
})
DocumentService.update_by_id(self._doc_id, {"progress": 1 if not self.error else -1, "progress_msg": "Pipeline finished...\n" + self.error, "process_duration": time.perf_counter() - st})

View File

@ -1,5 +1,5 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,12 +18,14 @@ import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
import trio
from api import settings
from rag.flow.pipeline import Pipeline
def print_logs(pipeline):
def print_logs(pipeline: Pipeline):
last_logs = "[]"
while True:
time.sleep(5)
@ -34,16 +36,16 @@ def print_logs(pipeline):
last_logs = logs_str
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
dsl_default_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"dsl_examples",
"general_pdf_all.json",
)
parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument("-s", "--dsl", default=dsl_default_path, help="input dsl", action="store", required=False)
parser.add_argument("-d", "--doc_id", default=False, help="Document ID", action="store", required=True)
parser.add_argument("-t", "--tenant_id", default=False, help="Tenant ID", action="store", required=True)
args = parser.parse_args()
settings.init_settings()
@ -53,5 +55,7 @@ if __name__ == '__main__':
exe = ThreadPoolExecutor(max_workers=5)
thr = exe.submit(print_logs, pipeline)
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
trio.run(pipeline.run)
thr.result()
thr.result()

View File

@ -1,15 +1,15 @@
{
"components": {
"begin": {
"File": {
"obj":{
"component_name": "File",
"params": {
}
},
"downstream": ["parser:0"],
"downstream": ["Parser:0"],
"upstream": []
},
"parser:0": {
"Parser:0": {
"obj": {
"component_name": "Parser",
"params": {
@ -22,14 +22,22 @@
"pdf"
],
"output_format": "json"
},
"excel": {
"output_format": "html",
"suffix": [
"xls",
"xlsx",
"csv"
]
}
}
}
},
"downstream": ["chunker:0"],
"upstream": ["begin"]
"downstream": ["Chunker:0"],
"upstream": ["Begin"]
},
"chunker:0": {
"Chunker:0": {
"obj": {
"component_name": "Chunker",
"params": {
@ -37,18 +45,19 @@
"auto_keywords": 5
}
},
"downstream": ["tokenizer:0"],
"upstream": ["chunker:0"]
"downstream": ["Tokenizer:0"],
"upstream": ["Parser:0"]
},
"tokenizer:0": {
"Tokenizer:0": {
"obj": {
"component_name": "Tokenizer",
"params": {
}
},
"downstream": [],
"upstream": ["chunker:0"]
"upstream": ["Chunker:0"]
}
},
"path": []
}
}

View File

@ -0,0 +1,14 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,51 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator
class TokenizerFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str = ""
blob: bytes
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
@model_validator(mode="after")
def _check_payloads(self) -> "TokenizerFromUpstream":
if self.chunks:
return self
if self.output_format in {"markdown", "text"}:
if self.output_format == "markdown" and not self.markdown_result:
raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').")
if self.output_format == "text" and not self.text_result:
raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').")
else:
if not self.json_result:
raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').")
return self

View File

@ -1,5 +1,5 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import re
@ -24,6 +25,7 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.utils.api_utils import timeout
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.tokenizer.schema import TokenizerFromUpstream
from rag.nlp import rag_tokenizer
from rag.settings import EMBEDDING_BATCH_SIZE
from rag.svr.task_executor import embed_limiter
@ -40,6 +42,9 @@ class TokenizerParam(ProcessParamBase):
for v in self.search_method:
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
def get_input_form(self) -> dict[str, dict]:
return {}
class Tokenizer(ProcessBase):
component_name = "Tokenizer"
@ -67,19 +72,19 @@ class Tokenizer(ProcessBase):
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length-10) for c in txts])
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
cnts_ = np.array([])
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i: i + EMBEDDING_BATCH_SIZE]))
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
if len(cnts_) == 0:
cnts_ = vts
else:
cnts_ = np.concatenate((cnts_, vts), axis=0)
token_count += c
if i % 33 == 32:
self.callback(i*1./len(texts)/parts/EMBEDDING_BATCH_SIZE + 0.5*(parts-1))
self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
cnts = cnts_
title_w = float(self._param.filename_embd_weight)
@ -92,11 +97,17 @@ class Tokenizer(ProcessBase):
return chunks, token_count
async def _invoke(self, **kwargs):
try:
from_upstream = TokenizerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
if "full_text" in self._param.search_method:
self.callback(random.randint(1,5)/100., "Start to tokenize.")
if kwargs.get("chunks"):
chunks = kwargs["chunks"]
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
if from_upstream.chunks:
chunks = from_upstream.chunks
for i, ck in enumerate(chunks):
if ck.get("questions"):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
@ -105,30 +116,40 @@ class Tokenizer(ProcessBase):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i*1./len(chunks)/parts)
elif kwargs.get("output_format") in ["markdown", "text"]:
ck = {
"text": kwargs.get(kwargs["output_format"], "")
}
if "full_text" in self._param.search_method:
self.callback(i * 1.0 / len(chunks) / parts)
elif from_upstream.output_format in ["markdown", "text"]:
if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result
else: # == "text"
payload = from_upstream.text_result
if not payload:
return ""
ck = {"text": payload}
if "full_text" in self._param.search_method:
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
chunks = [ck]
else:
chunks = kwargs["json"]
chunks = from_upstream.json_result
for i, ck in enumerate(chunks):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99:
self.callback(i*1./len(chunks)/parts)
self.callback(i * 1.0 / len(chunks) / parts)
self.callback(1./parts, "Finish tokenizing.")
self.callback(1.0 / parts, "Finish tokenizing.")
if "embedding" in self._param.search_method:
self.callback(random.randint(1,5)/100. + 0.5*(parts-1), "Start embedding inference.")
chunks, token_count = await self._embedding(kwargs.get("name", ""), chunks)
self.callback(random.randint(1, 5) / 100.0 + 0.5 * (parts - 1), "Start embedding inference.")
if from_upstream.name.strip() == "":
logging.warning("Tokenizer: empty name provided from upstream, embedding may be not accurate.")
chunks, token_count = await self._embedding(from_upstream.name, chunks)
self.set_output("embedding_token_consumption", token_count)
self.callback(1., "Finish embedding.")
self.callback(1.0, "Finish embedding.")
self.set_output("chunks", chunks)

View File

@ -21,10 +21,12 @@ import sys
import threading
import time
from api.utils import get_uuid
from api.utils.api_utils import timeout
from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline
from rag.prompts import keyword_extraction, question_proposal, content_tagging
import logging
@ -223,7 +225,14 @@ async def collect():
logging.warning(f"collect task {msg['id']} {state}")
redis_msg.ack()
return None, None
task["task_type"] = msg.get("task_type", "")
task_type = msg.get("task_type", "")
task["task_type"] = task_type
if task_type == "dataflow":
task["tenant_id"]=msg.get("tenant_id", "")
task["dsl"] = msg.get("dsl", "")
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
task["kb_id"] = msg.get("kb_id", "")
return redis_msg, task
@ -473,6 +482,15 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None):
_ = callback
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id)
pipeline.reset()
await pipeline.run()
@timeout(3600)
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = []
@ -558,15 +576,20 @@ async def do_handle_task(task):
init_kb(task, vector_size)
# Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor":
task_type = task.get("task_type", "")
if task_type == "dataflow":
task_dataflow_dsl = task["dsl"]
task_dataflow_id = task["dataflow_id"]
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
return
elif task_type == "raptor":
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR
async with kg_limiter:
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
# Either using graphrag or Standard chunking methods
elif task.get("task_type", "") == "graphrag":
elif task_type == "graphrag":
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
progress_callback(prog=-1.0, msg="Internal configuration error.")
return