From 45f52e85d720b539ac17c1153056a766b2923373 Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Fri, 5 Sep 2025 18:50:46 +0800 Subject: [PATCH] Feat: refine dataflow and initialize dataflow app (#9952) ### What problem does this PR solve? Refine dataflow and initialize dataflow app. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/canvas_app.py | 10 +- api/apps/dataflow_app.py | 353 ++++++++++++++++++ api/db/services/task_service.py | 82 ++-- conf/llm_factories.json | 6 +- rag/flow/__init__.py | 33 +- rag/flow/base.py | 20 +- rag/flow/chunker/__init__.py | 15 + rag/flow/{ => chunker}/chunker.py | 107 ++++-- rag/flow/chunker/schema.py | 37 ++ rag/flow/{begin.py => file.py} | 5 +- rag/flow/parser.py | 107 ------ rag/flow/parser/__init__.py | 14 + rag/flow/parser/parser.py | 154 ++++++++ rag/flow/parser/schema.py | 25 ++ rag/flow/pipeline.py | 43 +-- rag/flow/tests/client.py | 18 +- .../tests/dsl_examples/general_pdf_all.json | 31 +- rag/flow/tokenizer/__init__.py | 14 + rag/flow/tokenizer/schema.py | 51 +++ rag/flow/{ => tokenizer}/tokenizer.py | 59 ++- rag/svr/task_executor.py | 31 +- 21 files changed, 959 insertions(+), 256 deletions(-) create mode 100644 api/apps/dataflow_app.py create mode 100644 rag/flow/chunker/__init__.py rename rag/flow/{ => chunker}/chunker.py (63%) create mode 100644 rag/flow/chunker/schema.py rename rag/flow/{begin.py => file.py} (92%) delete mode 100644 rag/flow/parser.py create mode 100644 rag/flow/parser/__init__.py create mode 100644 rag/flow/parser/parser.py create mode 100644 rag/flow/parser/schema.py create mode 100644 rag/flow/tokenizer/__init__.py create mode 100644 rag/flow/tokenizer/schema.py rename rag/flow/{ => tokenizer}/tokenizer.py (71%) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 22d99e3e8..a220af096 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -418,12 +418,10 @@ def setting(): return get_data_error_result(message="canvas not found.") flow = flow.to_dict() flow["title"] = req["title"] - if req["description"]: - flow["description"] = req["description"] - if req["permission"]: - flow["permission"] = req["permission"] - if req["avatar"]: - flow["avatar"] = req["avatar"] + + for key in ["description", "permission", "avatar"]: + if value := req.get(key): + flow[key] = value num= UserCanvasService.update_by_id(req["id"], flow) return get_json_result(data=num) diff --git a/api/apps/dataflow_app.py b/api/apps/dataflow_app.py new file mode 100644 index 000000000..49bc8687b --- /dev/null +++ b/api/apps/dataflow_app.py @@ -0,0 +1,353 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import re +import sys +import time +from functools import partial + +import trio +from flask import request +from flask_login import current_user, login_required + +from agent.canvas import Canvas +from agent.component import LLM +from api.db import CanvasCategory, FileType +from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService +from api.db.services.document_service import DocumentService +from api.db.services.file_service import FileService +from api.db.services.task_service import queue_dataflow +from api.db.services.user_canvas_version import UserCanvasVersionService +from api.db.services.user_service import TenantService +from api.settings import RetCode +from api.utils import get_uuid +from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request +from api.utils.file_utils import filename_type, read_potential_broken_pdf +from rag.flow.pipeline import Pipeline + + +@manager.route("/templates", methods=["GET"]) # noqa: F821 +@login_required +def templates(): + return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)]) + + +@manager.route("/list", methods=["GET"]) # noqa: F821 +@login_required +def canvas_list(): + return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1)) + + +@manager.route("/rm", methods=["POST"]) # noqa: F821 +@validate_request("canvas_ids") +@login_required +def rm(): + for i in request.json["canvas_ids"]: + if not UserCanvasService.accessible(i, current_user.id): + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) + UserCanvasService.delete_by_id(i) + return get_json_result(data=True) + + +@manager.route("/set", methods=["POST"]) # noqa: F821 +@validate_request("dsl", "title") +@login_required +def save(): + req = request.json + if not isinstance(req["dsl"], str): + req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) + req["dsl"] = json.loads(req["dsl"]) + req["canvas_category"] = CanvasCategory.DataFlow + if "id" not in req: + req["user_id"] = current_user.id + if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow): + return get_data_error_result(message=f"{req['title'].strip()} already exists.") + req["id"] = get_uuid() + + if not UserCanvasService.save(**req): + return get_data_error_result(message="Fail to save canvas.") + else: + if not UserCanvasService.accessible(req["id"], current_user.id): + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) + UserCanvasService.update_by_id(req["id"], req) + # save version + UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S"))) + UserCanvasVersionService.delete_all_versions(req["id"]) + return get_json_result(data=req) + + +@manager.route("/get/", methods=["GET"]) # noqa: F821 +@login_required +def get(canvas_id): + if not UserCanvasService.accessible(canvas_id, current_user.id): + return get_data_error_result(message="canvas not found.") + e, c = UserCanvasService.get_by_tenant_id(canvas_id) + return get_json_result(data=c) + + +@manager.route("/run", methods=["POST"]) # noqa: F821 +@validate_request("id") +@login_required +def run(): + req = request.json + flow_id = req.get("id", "") + doc_id = req.get("doc_id", "") + if not all([flow_id, doc_id]): + return get_data_error_result(message="id and doc_id are required.") + + if not DocumentService.get_by_id(doc_id): + return get_data_error_result(message=f"Document for {doc_id} not found.") + + user_id = req.get("user_id", current_user.id) + if not UserCanvasService.accessible(flow_id, current_user.id): + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) + + e, cvs = UserCanvasService.get_by_id(flow_id) + if not e: + return get_data_error_result(message="canvas not found.") + + if not isinstance(cvs.dsl, str): + cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) + + task_id = get_uuid() + + ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0) + if not ok: + return server_error_response(error_message) + + return get_json_result(data={"task_id": task_id, "flow_id": flow_id}) + + +@manager.route("/reset", methods=["POST"]) # noqa: F821 +@validate_request("id") +@login_required +def reset(): + req = request.json + flow_id = req.get("id", "") + if not flow_id: + return get_data_error_result(message="id is required.") + + if not UserCanvasService.accessible(flow_id, current_user.id): + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) + + task_id = req.get("task_id", "") + + try: + e, user_canvas = UserCanvasService.get_by_id(req["id"]) + if not e: + return get_data_error_result(message="canvas not found.") + + dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id) + dataflow.reset() + req["dsl"] = json.loads(str(dataflow)) + UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]}) + return get_json_result(data=req["dsl"]) + except Exception as e: + return server_error_response(e) + + +@manager.route("/upload/", methods=["POST"]) # noqa: F821 +def upload(canvas_id): + e, cvs = UserCanvasService.get_by_tenant_id(canvas_id) + if not e: + return get_data_error_result(message="canvas not found.") + + user_id = cvs["user_id"] + + def structured(filename, filetype, blob, content_type): + nonlocal user_id + if filetype == FileType.PDF.value: + blob = read_potential_broken_pdf(blob) + + location = get_uuid() + FileService.put_blob(user_id, location, blob) + + return { + "id": location, + "name": filename, + "size": sys.getsizeof(blob), + "extension": filename.split(".")[-1].lower(), + "mime_type": content_type, + "created_by": user_id, + "created_at": time.time(), + "preview_url": None, + } + + if request.args.get("url"): + from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter + + try: + url = request.args.get("url") + filename = re.sub(r"\?.*", "", url.split("/")[-1]) + + async def adownload(): + browser_config = BrowserConfig( + headless=True, + verbose=False, + ) + async with AsyncWebCrawler(config=browser_config) as crawler: + crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False) + result: CrawlResult = await crawler.arun(url=url, config=crawler_config) + return result + + page = trio.run(adownload()) + if page.pdf: + if filename.split(".")[-1].lower() != "pdf": + filename += ".pdf" + return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"])) + + return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)) + + except Exception as e: + return server_error_response(e) + + file = request.files["file"] + try: + DocumentService.check_doc_health(user_id, file.filename) + return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type)) + except Exception as e: + return server_error_response(e) + + +@manager.route("/input_form", methods=["GET"]) # noqa: F821 +@login_required +def input_form(): + flow_id = request.args.get("id") + cpn_id = request.args.get("component_id") + try: + e, user_canvas = UserCanvasService.get_by_id(flow_id) + if not e: + return get_data_error_result(message="canvas not found.") + if not UserCanvasService.query(user_id=current_user.id, id=flow_id): + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) + + dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="") + + return get_json_result(data=dataflow.get_component_input_form(cpn_id)) + except Exception as e: + return server_error_response(e) + + +@manager.route("/debug", methods=["POST"]) # noqa: F821 +@validate_request("id", "component_id", "params") +@login_required +def debug(): + req = request.json + if not UserCanvasService.accessible(req["id"], current_user.id): + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) + try: + e, user_canvas = UserCanvasService.get_by_id(req["id"]) + canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) + canvas.reset() + canvas.message_id = get_uuid() + component = canvas.get_component(req["component_id"])["obj"] + component.reset() + + if isinstance(component, LLM): + component.set_debug_inputs(req["params"]) + component.invoke(**{k: o["value"] for k, o in req["params"].items()}) + outputs = component.output() + for k in outputs.keys(): + if isinstance(outputs[k], partial): + txt = "" + for c in outputs[k](): + txt += c + outputs[k] = txt + return get_json_result(data=outputs) + except Exception as e: + return server_error_response(e) + + +# api get list version dsl of canvas +@manager.route("/getlistversion/", methods=["GET"]) # noqa: F821 +@login_required +def getlistversion(canvas_id): + try: + list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1) + return get_json_result(data=list) + except Exception as e: + return get_data_error_result(message=f"Error getting history files: {e}") + + +# api get version dsl of canvas +@manager.route("/getversion/", methods=["GET"]) # noqa: F821 +@login_required +def getversion(version_id): + try: + e, version = UserCanvasVersionService.get_by_id(version_id) + if version: + return get_json_result(data=version.to_dict()) + except Exception as e: + return get_json_result(data=f"Error getting history file: {e}") + + +@manager.route("/listteam", methods=["GET"]) # noqa: F821 +@login_required +def list_canvas(): + keywords = request.args.get("keywords", "") + page_number = int(request.args.get("page", 1)) + items_per_page = int(request.args.get("page_size", 150)) + orderby = request.args.get("orderby", "create_time") + desc = request.args.get("desc", True) + try: + tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) + canvas, total = UserCanvasService.get_by_tenant_ids( + [m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow + ) + return get_json_result(data={"canvas": canvas, "total": total}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/setting", methods=["POST"]) # noqa: F821 +@validate_request("id", "title", "permission") +@login_required +def setting(): + req = request.json + req["user_id"] = current_user.id + + if not UserCanvasService.accessible(req["id"], current_user.id): + return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR) + + e, flow = UserCanvasService.get_by_id(req["id"]) + if not e: + return get_data_error_result(message="canvas not found.") + flow = flow.to_dict() + flow["title"] = req["title"] + for key in ("description", "permission", "avatar"): + if value := req.get(key): + flow[key] = value + + num = UserCanvasService.update_by_id(req["id"], flow) + return get_json_result(data=num) + + +@manager.route("/trace", methods=["GET"]) # noqa: F821 +def trace(): + dataflow_id = request.args.get("dataflow_id") + task_id = request.args.get("task_id") + if not all([dataflow_id, task_id]): + return get_data_error_result(message="dataflow_id and task_id are required.") + + e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id) + if not e: + return get_data_error_result(message="dataflow not found.") + + dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False) + dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id) + log = dataflow.fetch_logs() + + return get_json_result(data=log) diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 207b6355d..46087f8ba 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -54,15 +54,15 @@ def trim_header_by_lines(text: str, max_length) -> str: class TaskService(CommonService): """Service class for managing document processing tasks. - + This class extends CommonService to provide specialized functionality for document processing task management, including task creation, progress tracking, and chunk management. It handles various document types (PDF, Excel, etc.) and manages their processing lifecycle. - + The class implements a robust task queue system with retry mechanisms and progress tracking, supporting both synchronous and asynchronous task execution. - + Attributes: model: The Task model class for database operations. """ @@ -72,14 +72,14 @@ class TaskService(CommonService): @DB.connection_context() def get_task(cls, task_id): """Retrieve detailed task information by task ID. - + This method fetches comprehensive task details including associated document, knowledge base, and tenant information. It also handles task retry logic and progress updates. - + Args: task_id (str): The unique identifier of the task to retrieve. - + Returns: dict: Task details dictionary containing all task information and related metadata. Returns None if task is not found or has exceeded retry limit. @@ -139,13 +139,13 @@ class TaskService(CommonService): @DB.connection_context() def get_tasks(cls, doc_id: str): """Retrieve all tasks associated with a document. - + This method fetches all processing tasks for a given document, ordered by page number and creation time. It includes task progress and chunk information. - + Args: doc_id (str): The unique identifier of the document. - + Returns: list[dict]: List of task dictionaries containing task details. Returns None if no tasks are found. @@ -170,10 +170,10 @@ class TaskService(CommonService): @DB.connection_context() def update_chunk_ids(cls, id: str, chunk_ids: str): """Update the chunk IDs associated with a task. - + This method updates the chunk_ids field of a task, which stores the IDs of processed document chunks in a space-separated string format. - + Args: id (str): The unique identifier of the task. chunk_ids (str): Space-separated string of chunk identifiers. @@ -184,11 +184,11 @@ class TaskService(CommonService): @DB.connection_context() def get_ongoing_doc_name(cls): """Get names of documents that are currently being processed. - + This method retrieves information about documents that are in the processing state, including their locations and associated IDs. It uses database locking to ensure thread safety when accessing the task information. - + Returns: list[tuple]: A list of tuples, each containing (parent_id/kb_id, location) for documents currently being processed. Returns empty list if @@ -238,14 +238,14 @@ class TaskService(CommonService): @DB.connection_context() def do_cancel(cls, id): """Check if a task should be cancelled based on its document status. - + This method determines whether a task should be cancelled by checking the associated document's run status and progress. A task should be cancelled if its document is marked for cancellation or has negative progress. - + Args: id (str): The unique identifier of the task to check. - + Returns: bool: True if the task should be cancelled, False otherwise. """ @@ -311,18 +311,18 @@ class TaskService(CommonService): def queue_tasks(doc: dict, bucket: str, name: str, priority: int): """Create and queue document processing tasks. - + This function creates processing tasks for a document based on its type and configuration. It handles different document types (PDF, Excel, etc.) differently and manages task chunking and configuration. It also implements task reuse optimization by checking for previously completed tasks. - + Args: doc (dict): Document dictionary containing metadata and configuration. bucket (str): Storage bucket name where the document is stored. name (str): File name of the document. priority (int, optional): Priority level for task queueing (default is 0). - + Note: - For PDF documents, tasks are created per page range based on configuration - For Excel documents, tasks are created per row range @@ -410,19 +410,19 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int): def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict): """Attempt to reuse chunks from previous tasks for optimization. - + This function checks if chunks from previously completed tasks can be reused for the current task, which can significantly improve processing efficiency. It matches tasks based on page ranges and configuration digests. - + Args: task (dict): Current task dictionary to potentially reuse chunks for. prev_tasks (list[dict]): List of previous task dictionaries to check for reuse. chunking_config (dict): Configuration dictionary for chunk processing. - + Returns: int: Number of chunks successfully reused. Returns 0 if no chunks could be reused. - + Note: Chunks can only be reused if: - A previous task exists with matching page range and configuration digest @@ -470,3 +470,39 @@ def has_canceled(task_id): except Exception as e: logging.exception(e) return False + + +def queue_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, priority: int, callback=None) -> tuple[bool, str]: + """ + Returns a tuple (success: bool, error_message: str). + """ + _ = callback + + task = dict( + id=get_uuid() if not task_id else task_id, + doc_id=doc_id, + from_page=0, + to_page=100000000, + task_type="dataflow", + priority=priority, + ) + + TaskService.model.delete().where(TaskService.model.id == task["id"]).execute() + bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True) + + kb_id = DocumentService.get_knowledgebase_id(doc_id) + if not kb_id: + return False, f"Can't find KB of this document: {doc_id}" + + task["kb_id"] = kb_id + task["tenant_id"] = tenant_id + task["task_type"] = "dataflow" + task["dsl"] = dsl + task["dataflow_id"] = get_uuid() if not flow_id else flow_id + + if not REDIS_CONN.queue_product( + get_svr_queue_name(priority), message=task + ): + return False, "Can't access Redis. Please check the Redis' status." + + return True, "" diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 1c0ea19b6..71f48609a 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -2690,21 +2690,21 @@ "status": "1", "llm": [ { - "llm_name": "Qwen3-Embedding-8B", + "llm_name": "Qwen/Qwen3-Embedding-8B", "tags": "TEXT EMBEDDING,TEXT RE-RANK,32k", "max_tokens": 32000, "model_type": "embedding", "is_tools": false }, { - "llm_name": "Qwen3-Embedding-4B", + "llm_name": "Qwen/Qwen3-Embedding-4B", "tags": "TEXT EMBEDDING,TEXT RE-RANK,32k", "max_tokens": 32000, "model_type": "embedding", "is_tools": false }, { - "llm_name": "Qwen3-Embedding-0.6B", + "llm_name": "Qwen/Qwen3-Embedding-0.6B", "tags": "TEXT EMBEDDING,TEXT RE-RANK,32k", "max_tokens": 32000, "model_type": "embedding", diff --git a/rag/flow/__init__.py b/rag/flow/__init__.py index 318507ac4..ca6202e15 100644 --- a/rag/flow/__init__.py +++ b/rag/flow/__init__.py @@ -14,36 +14,45 @@ # limitations under the License. # -import os import importlib import inspect +import pkgutil +from pathlib import Path from types import ModuleType from typing import Dict, Type -_package_path = os.path.dirname(__file__) __all_classes: Dict[str, Type] = {} -def _import_submodules() -> None: - for filename in os.listdir(_package_path): # noqa: F821 - if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"): - continue - module_name = filename[:-3] +_pkg_dir = Path(__file__).resolve().parent +_pkg_name = __name__ + +def _should_skip_module(mod_name: str) -> bool: + leaf = mod_name.rsplit(".", 1)[-1] + return leaf in {"__init__"} or leaf.startswith("__") or leaf.startswith("_") or leaf.startswith("base") + + +def _import_submodules() -> None: + for modinfo in pkgutil.walk_packages([str(_pkg_dir)], prefix=_pkg_name + "."): # noqa: F821 + mod_name = modinfo.name + if _should_skip_module(mod_name): # noqa: F821 + continue try: - module = importlib.import_module(f".{module_name}", package=__name__) + module = importlib.import_module(mod_name) _extract_classes_from_module(module) # noqa: F821 except ImportError as e: - print(f"Warning: Failed to import module {module_name}: {str(e)}") + print(f"Warning: Failed to import module {mod_name}: {e}") + def _extract_classes_from_module(module: ModuleType) -> None: for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and - obj.__module__ == module.__name__ and not name.startswith("_")): + if inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_"): __all_classes[name] = obj globals()[name] = obj + _import_submodules() __all__ = list(__all_classes.keys()) + ["__all_classes"] -del _package_path, _import_submodules, _extract_classes_from_module \ No newline at end of file +del _pkg_dir, _pkg_name, _import_submodules, _extract_classes_from_module diff --git a/rag/flow/base.py b/rag/flow/base.py index d0c486f3b..89b37b501 100644 --- a/rag/flow/base.py +++ b/rag/flow/base.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import time -import os import logging +import os +import time from functools import partial from typing import Any + import trio -from agent.component.base import ComponentParamBase, ComponentBase + +from agent.component.base import ComponentBase, ComponentParamBase from api.utils.api_utils import timeout @@ -31,14 +33,16 @@ class ProcessParamBase(ComponentParamBase): class ProcessBase(ComponentBase): - def __init__(self, pipeline, id, param: ProcessParamBase): super().__init__(pipeline, id, param) - self.callback = partial(self._canvas.callback, self.component_name) + if hasattr(self._canvas, "callback"): + self.callback = partial(self._canvas.callback, self.component_name) + else: + self.callback = partial(lambda *args, **kwargs: None, self.component_name) async def invoke(self, **kwargs) -> dict[str, Any]: self.set_output("_created_time", time.perf_counter()) - for k,v in kwargs.items(): + for k, v in kwargs.items(): self.set_output(k, v) try: with trio.fail_after(self._param.timeout): @@ -54,6 +58,6 @@ class ProcessBase(ComponentBase): self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return self.output() - @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)) + @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60)) async def _invoke(self, **kwargs): raise NotImplementedError() diff --git a/rag/flow/chunker/__init__.py b/rag/flow/chunker/__init__.py new file mode 100644 index 000000000..b4663378e --- /dev/null +++ b/rag/flow/chunker/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/rag/flow/chunker.py b/rag/flow/chunker/chunker.py similarity index 63% rename from rag/flow/chunker.py rename to rag/flow/chunker/chunker.py index d869c94ec..f853fc9e7 100644 --- a/rag/flow/chunker.py +++ b/rag/flow/chunker/chunker.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import random + import trio + from api.db import LLMType from api.db.services.llm_service import LLMBundle from deepdoc.parser.pdf_parser import RAGFlowPdfParser -from graphrag.utils import get_llm_cache, chat_limiter, set_llm_cache +from graphrag.utils import chat_limiter, get_llm_cache, set_llm_cache from rag.flow.base import ProcessBase, ProcessParamBase +from rag.flow.chunker.schema import ChunkerFromUpstream from rag.nlp import naive_merge, naive_merge_with_images from rag.prompts.prompts import keyword_extraction, question_proposal @@ -26,7 +29,23 @@ from rag.prompts.prompts import keyword_extraction, question_proposal class ChunkerParam(ProcessParamBase): def __init__(self): super().__init__() - self.method_options = ["general", "q&a", "resume", "manual", "table", "paper", "book", "laws", "presentation", "one"] + self.method_options = [ + # General + "general", + "onetable", + # Customer Service + "q&a", + "manual", + # Recruitment + "resume", + # Education & Research + "book", + "paper", + "laws", + "presentation", + # Other + # "Tag" # TODO: Other method + ] self.method = "general" self.chunk_token_size = 512 self.delimiter = "\n" @@ -35,10 +54,7 @@ class ChunkerParam(ProcessParamBase): self.auto_keywords = 0 self.auto_questions = 0 self.tag_sets = [] - self.llm_setting = { - "llm_name": "", - "lang": "Chinese" - } + self.llm_setting = {"llm_name": "", "lang": "Chinese"} def check(self): self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options) @@ -48,53 +64,79 @@ class ChunkerParam(ProcessParamBase): self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]") self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)") + def get_input_form(self) -> dict[str, dict]: + return {} + class Chunker(ProcessBase): component_name = "Chunker" - def _general(self, **kwargs): - self.callback(random.randint(1,5)/100., "Start to chunk via `General`.") - if kwargs.get("output_format") in ["markdown", "text"]: - cks = naive_merge(kwargs.get(kwargs["output_format"]), self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent) + def _general(self, from_upstream: ChunkerFromUpstream): + self.callback(random.randint(1, 5) / 100.0, "Start to chunk via `General`.") + if from_upstream.output_format in ["markdown", "text"]: + if from_upstream.output_format == "markdown": + payload = from_upstream.markdown_result + else: # == "text" + payload = from_upstream.text_result + + if not payload: + payload = "" + + cks = naive_merge( + payload, + self._param.chunk_token_size, + self._param.delimiter, + self._param.overlapped_percent, + ) return [{"text": c} for c in cks] sections, section_images = [], [] - for o in kwargs["json"]: - sections.append((o["text"], o.get("position_tag",""))) + for o in from_upstream.json_result or []: + sections.append((o.get("text", ""), o.get("position_tag", ""))) section_images.append(o.get("image")) - chunks, images = naive_merge_with_images(sections, section_images,self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent) - return [{ - "text": RAGFlowPdfParser.remove_tag(c), - "image": img, - "positions": RAGFlowPdfParser.extract_positions(c) - } for c,img in zip(chunks,images)] + chunks, images = naive_merge_with_images( + sections, + section_images, + self._param.chunk_token_size, + self._param.delimiter, + self._param.overlapped_percent, + ) - def _q_and_a(self, **kwargs): + return [ + { + "text": RAGFlowPdfParser.remove_tag(c), + "image": img, + "positions": RAGFlowPdfParser.extract_positions(c), + } + for c, img in zip(chunks, images) + ] + + def _q_and_a(self, from_upstream: ChunkerFromUpstream): pass - def _resume(self, **kwargs): + def _resume(self, from_upstream: ChunkerFromUpstream): pass - def _manual(self, **kwargs): + def _manual(self, from_upstream: ChunkerFromUpstream): pass - def _table(self, **kwargs): + def _table(self, from_upstream: ChunkerFromUpstream): pass - def _paper(self, **kwargs): + def _paper(self, from_upstream: ChunkerFromUpstream): pass - def _book(self, **kwargs): + def _book(self, from_upstream: ChunkerFromUpstream): pass - def _laws(self, **kwargs): + def _laws(self, from_upstream: ChunkerFromUpstream): pass - def _presentation(self, **kwargs): + def _presentation(self, from_upstream: ChunkerFromUpstream): pass - def _one(self, **kwargs): + def _one(self, from_upstream: ChunkerFromUpstream): pass async def _invoke(self, **kwargs): @@ -110,7 +152,14 @@ class Chunker(ProcessBase): "presentation": self._presentation, "one": self._one, } - chunks = function_map[self._param.method](**kwargs) + + try: + from_upstream = ChunkerFromUpstream.model_validate(kwargs) + except Exception as e: + self.set_output("_ERROR", f"Input error: {str(e)}") + return + + chunks = function_map[self._param.method](from_upstream) llm_setting = self._param.llm_setting async def auto_keywords(): diff --git a/rag/flow/chunker/schema.py b/rag/flow/chunker/schema.py new file mode 100644 index 000000000..0f0e3042c --- /dev/null +++ b/rag/flow/chunker/schema.py @@ -0,0 +1,37 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field + + +class ChunkerFromUpstream(BaseModel): + created_time: float | None = Field(default=None, alias="_created_time") + elapsed_time: float | None = Field(default=None, alias="_elapsed_time") + + name: str + blob: bytes + + output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None) + + json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") + markdown_result: str | None = Field(default=None, alias="markdown") + text_result: str | None = Field(default=None, alias="text") + html_result: str | None = Field(default=None, alias="html") + + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + # def to_dict(self, *, exclude_none: bool = True) -> dict: + # return self.model_dump(by_alias=True, exclude_none=exclude_none) diff --git a/rag/flow/begin.py b/rag/flow/file.py similarity index 92% rename from rag/flow/begin.py rename to rag/flow/file.py index 3b3622fb9..584b0ff9c 100644 --- a/rag/flow/begin.py +++ b/rag/flow/file.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,6 +27,9 @@ class FileParam(ProcessParamBase): def check(self): pass + def get_input_form(self) -> dict[str, dict]: + return {} + class File(ProcessBase): component_name = "File" diff --git a/rag/flow/parser.py b/rag/flow/parser.py deleted file mode 100644 index a991a969d..000000000 --- a/rag/flow/parser.py +++ /dev/null @@ -1,107 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import random -import trio -from api.db import LLMType -from api.db.services.llm_service import LLMBundle -from deepdoc.parser.pdf_parser import RAGFlowPdfParser, PlainParser, VisionParser -from rag.flow.base import ProcessBase, ProcessParamBase -from rag.llm.cv_model import Base as VLM -from deepdoc.parser import ExcelParser - - -class ParserParam(ProcessParamBase): - def __init__(self): - super().__init__() - self.setups = { - "pdf": { - "parse_method": "deepdoc", # deepdoc/plain_text/vlm - "vlm_name": "", - "lang": "Chinese", - "suffix": ["pdf"], - "output_format": "json" - }, - "excel": { - "output_format": "html" - }, - "ppt": {}, - "image": { - "parse_method": "ocr" - }, - "email": {}, - "text": {}, - "audio": {}, - "video": {}, - } - - def check(self): - if self.setups["pdf"].get("parse_method") not in ["deepdoc", "plain_text"]: - assert self.setups["pdf"].get("vlm_name"), "No VLM specified." - assert self.setups["pdf"].get("lang"), "No language specified." - - -class Parser(ProcessBase): - component_name = "Parser" - - def _pdf(self, blob): - self.callback(random.randint(1,5)/100., "Start to work on a PDF.") - conf = self._param.setups["pdf"] - self.set_output("output_format", conf["output_format"]) - if conf.get("parse_method") == "deepdoc": - bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback) - elif conf.get("parse_method") == "plain_text": - lines,_ = PlainParser()(blob) - bboxes = [{"text": t} for t,_ in lines] - else: - assert conf.get("vlm_name") - vision_model = LLMBundle(self._canvas.tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self.setups["pdf"].get("lang")) - lines, _ = VisionParser(vision_model=vision_model)(bin, callback=self.callback) - bboxes = [] - for t, poss in lines: - pn, x0, x1, top, bott = poss.split(" ") - bboxes.append({"page_number": int(pn), "x0": int(x0), "x1": int(x1), "top": int(top), "bottom": int(bott), "text": t}) - - self.set_output("json", bboxes) - mkdn = "" - for b in bboxes: - if b.get("layout_type", "") == "title": - mkdn += "\n## " - if b.get("layout_type", "") == "figure": - mkdn += "\n![Image]({})".format(VLM.image2base64(b["image"])) - continue - mkdn += b.get("text", "") + "\n" - self.set_output("markdown", mkdn) - - def _excel(self, blob): - self.callback(random.randint(1,5)/100., "Start to work on a Excel.") - conf = self._param.setups["excel"] - excel_parser = ExcelParser() - if conf.get("output_format") == "html": - html = excel_parser.html(blob,1000000000) - self.set_output("html", html) - elif conf.get("output_format") == "json": - self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt]) - elif conf.get("output_format") == "markdown": - self.set_output("markdown", excel_parser.markdown(blob)) - - async def _invoke(self, **kwargs): - function_map = { - "pdf": self._pdf, - } - for p_type, conf in self._param.setups.items(): - if kwargs.get("name", "").split(".")[-1].lower() not in conf.get("suffix", []): - continue - await trio.to_thread.run_sync(function_map[p_type], kwargs["blob"]) - break \ No newline at end of file diff --git a/rag/flow/parser/__init__.py b/rag/flow/parser/__init__.py new file mode 100644 index 000000000..e6ddad793 --- /dev/null +++ b/rag/flow/parser/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/rag/flow/parser/parser.py b/rag/flow/parser/parser.py new file mode 100644 index 000000000..fd65665fa --- /dev/null +++ b/rag/flow/parser/parser.py @@ -0,0 +1,154 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +import trio + +from api.db import LLMType +from api.db.services.llm_service import LLMBundle +from deepdoc.parser import ExcelParser +from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser +from rag.flow.base import ProcessBase, ProcessParamBase +from rag.flow.parser.schema import ParserFromUpstream +from rag.llm.cv_model import Base as VLM + + +class ParserParam(ProcessParamBase): + def __init__(self): + super().__init__() + self.allowed_output_format = { + "pdf": ["json", "markdown"], + "excel": ["json", "markdown", "html"], + "ppt": [], + "image": [], + "email": [], + "text": [], + "audio": [], + "video": [], + } + + self.setups = { + "pdf": { + "parse_method": "deepdoc", # deepdoc/plain_text/vlm + "vlm_name": "", + "lang": "Chinese", + "suffix": ["pdf"], + "output_format": "json", + }, + "excel": { + "output_format": "html", + "suffix": ["xls", "xlsx", "csv"], + }, + "ppt": {}, + "image": { + "parse_method": "ocr", + }, + "email": {}, + "text": {}, + "audio": {}, + "video": {}, + } + + def check(self): + pdf_config = self.setups.get("pdf", {}) + if pdf_config: + pdf_parse_method = pdf_config.get("parse_method", "") + self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"]) + + if pdf_parse_method not in ["deepdoc", "plain_text"]: + self.check_empty(pdf_config.get("vlm_name"), "VLM") + + pdf_language = pdf_config.get("lang", "") + self.check_empty(pdf_language, "Language") + + pdf_output_format = pdf_config.get("output_format", "") + self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"]) + + excel_config = self.setups.get("excel", "") + if excel_config: + excel_output_format = excel_config.get("output_format", "") + self.check_valid_value(excel_output_format, "Excel output format abnormal.", self.allowed_output_format["excel"]) + + image_config = self.setups.get("image", "") + if image_config: + image_parse_method = image_config.get("parse_method", "") + self.check_valid_value(image_parse_method.lower(), "Parse method abnormal.", ["ocr"]) + + def get_input_form(self) -> dict[str, dict]: + return {} + + +class Parser(ProcessBase): + component_name = "Parser" + + def _pdf(self, blob): + self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.") + conf = self._param.setups["pdf"] + self.set_output("output_format", conf["output_format"]) + if conf.get("parse_method") == "deepdoc": + bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback) + elif conf.get("parse_method") == "plain_text": + lines, _ = PlainParser()(blob) + bboxes = [{"text": t} for t, _ in lines] + else: + assert conf.get("vlm_name") + vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self._param.setups["pdf"].get("lang")) + lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback) + bboxes = [] + for t, poss in lines: + pn, x0, x1, top, bott = poss.split(" ") + bboxes.append({"page_number": int(pn), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t}) + if conf.get("output_format") == "json": + self.set_output("json", bboxes) + if conf.get("output_format") == "markdown": + mkdn = "" + for b in bboxes: + if b.get("layout_type", "") == "title": + mkdn += "\n## " + if b.get("layout_type", "") == "figure": + mkdn += "\n![Image]({})".format(VLM.image2base64(b["image"])) + continue + mkdn += b.get("text", "") + "\n" + self.set_output("markdown", mkdn) + + def _excel(self, blob): + self.callback(random.randint(1, 5) / 100.0, "Start to work on a Excel.") + conf = self._param.setups["excel"] + self.set_output("output_format", conf["output_format"]) + excel_parser = ExcelParser() + if conf.get("output_format") == "html": + html = excel_parser.html(blob, 1000000000) + self.set_output("html", html) + elif conf.get("output_format") == "json": + self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt]) + elif conf.get("output_format") == "markdown": + self.set_output("markdown", excel_parser.markdown(blob)) + + async def _invoke(self, **kwargs): + function_map = { + "pdf": self._pdf, + "excel": self._excel, + } + try: + from_upstream = ParserFromUpstream.model_validate(kwargs) + except Exception as e: + self.set_output("_ERROR", f"Input error: {str(e)}") + return + + for p_type, conf in self._param.setups.items(): + if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []): + continue + await trio.to_thread.run_sync(function_map[p_type], from_upstream.blob) + break diff --git a/rag/flow/parser/schema.py b/rag/flow/parser/schema.py new file mode 100644 index 000000000..37292e058 --- /dev/null +++ b/rag/flow/parser/schema.py @@ -0,0 +1,25 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pydantic import BaseModel, ConfigDict, Field + + +class ParserFromUpstream(BaseModel): + created_time: float | None = Field(default=None, alias="_created_time") + elapsed_time: float | None = Field(default=None, alias="_elapsed_time") + + name: str + blob: bytes + + model_config = ConfigDict(populate_by_name=True, extra="forbid") diff --git a/rag/flow/pipeline.py b/rag/flow/pipeline.py index c57d04e2a..9f88d29ea 100644 --- a/rag/flow/pipeline.py +++ b/rag/flow/pipeline.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,14 +18,15 @@ import json import logging import random import time + import trio + from agent.canvas import Graph from api.db.services.document_service import DocumentService from rag.utils.redis_conn import REDIS_CONN class Pipeline(Graph): - def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None): super().__init__(dsl, tenant_id, task_id) self._doc_id = doc_id @@ -35,7 +36,7 @@ class Pipeline(Graph): self._kb_id = DocumentService.get_knowledgebase_id(doc_id) assert self._kb_id, f"Can't find KB of this document: {doc_id}" - def callback(self, component_name: str, progress: float|int|None=None, message: str = "") -> None: + def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None: log_key = f"{self._flow_id}-{self.task_id}-logs" try: bin = REDIS_CONN.get(log_key) @@ -44,16 +45,10 @@ class Pipeline(Graph): if obj[-1]["component_name"] == component_name: obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}) else: - obj.append({ - "component_name": component_name, - "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}] - }) + obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}) else: - obj = [{ - "component_name": component_name, - "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}] - }] - REDIS_CONN.set_obj(log_key, obj, 60*10) + obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}] + REDIS_CONN.set_obj(log_key, obj, 60 * 10) except Exception as e: logging.exception(e) @@ -71,21 +66,19 @@ class Pipeline(Graph): super().reset() log_key = f"{self._flow_id}-{self.task_id}-logs" try: - REDIS_CONN.set_obj(log_key, [], 60*10) + REDIS_CONN.set_obj(log_key, [], 60 * 10) except Exception as e: logging.exception(e) async def run(self, **kwargs): st = time.perf_counter() if not self.path: - self.path.append("begin") + self.path.append("File") if self._doc_id: - DocumentService.update_by_id(self._doc_id, { - "progress": random.randint(0,5)/100., - "progress_msg": "Start the pipeline...", - "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - }) + DocumentService.update_by_id( + self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} + ) self.error = "" idx = len(self.path) - 1 @@ -99,23 +92,21 @@ class Pipeline(Graph): self.path.extend(cpn_obj.get_downstream()) while idx < len(self.path) and not self.error: - last_cpn = self.get_component_obj(self.path[idx-1]) + last_cpn = self.get_component_obj(self.path[idx - 1]) cpn_obj = self.get_component_obj(self.path[idx]) + async def invoke(): nonlocal last_cpn, cpn_obj await cpn_obj.invoke(**last_cpn.output()) + async with trio.open_nursery() as nursery: nursery.start_soon(invoke) if cpn_obj.error(): self.error = "[ERROR]" + cpn_obj.error() + self.callback(cpn_obj.component_name, -1, self.error) break idx += 1 self.path.extend(cpn_obj.get_downstream()) if self._doc_id: - DocumentService.update_by_id(self._doc_id, { - "progress": 1 if not self.error else -1, - "progress_msg": "Pipeline finished...\n" + self.error, - "process_duration": time.perf_counter() - st - }) - + DocumentService.update_by_id(self._doc_id, {"progress": 1 if not self.error else -1, "progress_msg": "Pipeline finished...\n" + self.error, "process_duration": time.perf_counter() - st}) diff --git a/rag/flow/tests/client.py b/rag/flow/tests/client.py index eedaf7efc..cf4a4db37 100644 --- a/rag/flow/tests/client.py +++ b/rag/flow/tests/client.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,12 +18,14 @@ import json import os import time from concurrent.futures import ThreadPoolExecutor + import trio + from api import settings from rag.flow.pipeline import Pipeline -def print_logs(pipeline): +def print_logs(pipeline: Pipeline): last_logs = "[]" while True: time.sleep(5) @@ -34,16 +36,16 @@ def print_logs(pipeline): last_logs = logs_str -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() dsl_default_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "dsl_examples", "general_pdf_all.json", ) - parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True) - parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) - parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) + parser.add_argument("-s", "--dsl", default=dsl_default_path, help="input dsl", action="store", required=False) + parser.add_argument("-d", "--doc_id", default=False, help="Document ID", action="store", required=True) + parser.add_argument("-t", "--tenant_id", default=False, help="Tenant ID", action="store", required=True) args = parser.parse_args() settings.init_settings() @@ -53,5 +55,7 @@ if __name__ == '__main__': exe = ThreadPoolExecutor(max_workers=5) thr = exe.submit(print_logs, pipeline) + # queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0) + trio.run(pipeline.run) - thr.result() \ No newline at end of file + thr.result() diff --git a/rag/flow/tests/dsl_examples/general_pdf_all.json b/rag/flow/tests/dsl_examples/general_pdf_all.json index 5c29433ed..7142e5547 100644 --- a/rag/flow/tests/dsl_examples/general_pdf_all.json +++ b/rag/flow/tests/dsl_examples/general_pdf_all.json @@ -1,15 +1,15 @@ { "components": { - "begin": { + "File": { "obj":{ "component_name": "File", "params": { } }, - "downstream": ["parser:0"], + "downstream": ["Parser:0"], "upstream": [] }, - "parser:0": { + "Parser:0": { "obj": { "component_name": "Parser", "params": { @@ -22,14 +22,22 @@ "pdf" ], "output_format": "json" + }, + "excel": { + "output_format": "html", + "suffix": [ + "xls", + "xlsx", + "csv" + ] } } } }, - "downstream": ["chunker:0"], - "upstream": ["begin"] + "downstream": ["Chunker:0"], + "upstream": ["Begin"] }, - "chunker:0": { + "Chunker:0": { "obj": { "component_name": "Chunker", "params": { @@ -37,18 +45,19 @@ "auto_keywords": 5 } }, - "downstream": ["tokenizer:0"], - "upstream": ["chunker:0"] + "downstream": ["Tokenizer:0"], + "upstream": ["Parser:0"] }, - "tokenizer:0": { + "Tokenizer:0": { "obj": { "component_name": "Tokenizer", "params": { } }, "downstream": [], - "upstream": ["chunker:0"] + "upstream": ["Chunker:0"] } }, "path": [] -} \ No newline at end of file +} + diff --git a/rag/flow/tokenizer/__init__.py b/rag/flow/tokenizer/__init__.py new file mode 100644 index 000000000..e6ddad793 --- /dev/null +++ b/rag/flow/tokenizer/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/rag/flow/tokenizer/schema.py b/rag/flow/tokenizer/schema.py new file mode 100644 index 000000000..508fa002c --- /dev/null +++ b/rag/flow/tokenizer/schema.py @@ -0,0 +1,51 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class TokenizerFromUpstream(BaseModel): + created_time: float | None = Field(default=None, alias="_created_time") + elapsed_time: float | None = Field(default=None, alias="_elapsed_time") + + name: str = "" + blob: bytes + + output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None) + + chunks: list[dict[str, Any]] | None = Field(default=None) + + json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") + markdown_result: str | None = Field(default=None, alias="markdown") + text_result: str | None = Field(default=None, alias="text") + html_result: str | None = Field(default=None, alias="html") + + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + @model_validator(mode="after") + def _check_payloads(self) -> "TokenizerFromUpstream": + if self.chunks: + return self + + if self.output_format in {"markdown", "text"}: + if self.output_format == "markdown" and not self.markdown_result: + raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').") + if self.output_format == "text" and not self.text_result: + raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').") + else: + if not self.json_result: + raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').") + return self diff --git a/rag/flow/tokenizer.py b/rag/flow/tokenizer/tokenizer.py similarity index 71% rename from rag/flow/tokenizer.py rename to rag/flow/tokenizer/tokenizer.py index 0f11e8ece..5ad209776 100644 --- a/rag/flow/tokenizer.py +++ b/rag/flow/tokenizer/tokenizer.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import random import re @@ -24,6 +25,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.user_service import TenantService from api.utils.api_utils import timeout from rag.flow.base import ProcessBase, ProcessParamBase +from rag.flow.tokenizer.schema import TokenizerFromUpstream from rag.nlp import rag_tokenizer from rag.settings import EMBEDDING_BATCH_SIZE from rag.svr.task_executor import embed_limiter @@ -40,6 +42,9 @@ class TokenizerParam(ProcessParamBase): for v in self.search_method: self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"]) + def get_input_form(self) -> dict[str, dict]: + return {} + class Tokenizer(ProcessBase): component_name = "Tokenizer" @@ -67,19 +72,19 @@ class Tokenizer(ProcessBase): @timeout(60) def batch_encode(txts): nonlocal embedding_model - return embedding_model.encode([truncate(c, embedding_model.max_length-10) for c in txts]) + return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts]) cnts_ = np.array([]) for i in range(0, len(texts), EMBEDDING_BATCH_SIZE): async with embed_limiter: - vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i: i + EMBEDDING_BATCH_SIZE])) + vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE])) if len(cnts_) == 0: cnts_ = vts else: cnts_ = np.concatenate((cnts_, vts), axis=0) token_count += c if i % 33 == 32: - self.callback(i*1./len(texts)/parts/EMBEDDING_BATCH_SIZE + 0.5*(parts-1)) + self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1)) cnts = cnts_ title_w = float(self._param.filename_embd_weight) @@ -92,11 +97,17 @@ class Tokenizer(ProcessBase): return chunks, token_count async def _invoke(self, **kwargs): + try: + from_upstream = TokenizerFromUpstream.model_validate(kwargs) + except Exception as e: + self.set_output("_ERROR", f"Input error: {str(e)}") + return + parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method]) if "full_text" in self._param.search_method: - self.callback(random.randint(1,5)/100., "Start to tokenize.") - if kwargs.get("chunks"): - chunks = kwargs["chunks"] + self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.") + if from_upstream.chunks: + chunks = from_upstream.chunks for i, ck in enumerate(chunks): if ck.get("questions"): ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"])) @@ -105,30 +116,40 @@ class Tokenizer(ProcessBase): ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) if i % 100 == 99: - self.callback(i*1./len(chunks)/parts) - elif kwargs.get("output_format") in ["markdown", "text"]: - ck = { - "text": kwargs.get(kwargs["output_format"], "") - } - if "full_text" in self._param.search_method: + self.callback(i * 1.0 / len(chunks) / parts) + elif from_upstream.output_format in ["markdown", "text"]: + if from_upstream.output_format == "markdown": + payload = from_upstream.markdown_result + else: # == "text" + payload = from_upstream.text_result + + if not payload: + return "" + + ck = {"text": payload} + if "full_text" in self._param.search_method: ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], "")) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) chunks = [ck] else: - chunks = kwargs["json"] + chunks = from_upstream.json_result for i, ck in enumerate(chunks): ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) if i % 100 == 99: - self.callback(i*1./len(chunks)/parts) + self.callback(i * 1.0 / len(chunks) / parts) - self.callback(1./parts, "Finish tokenizing.") + self.callback(1.0 / parts, "Finish tokenizing.") if "embedding" in self._param.search_method: - self.callback(random.randint(1,5)/100. + 0.5*(parts-1), "Start embedding inference.") - chunks, token_count = await self._embedding(kwargs.get("name", ""), chunks) + self.callback(random.randint(1, 5) / 100.0 + 0.5 * (parts - 1), "Start embedding inference.") + + if from_upstream.name.strip() == "": + logging.warning("Tokenizer: empty name provided from upstream, embedding may be not accurate.") + + chunks, token_count = await self._embedding(from_upstream.name, chunks) self.set_output("embedding_token_consumption", token_count) - self.callback(1., "Finish embedding.") + self.callback(1.0, "Finish embedding.") self.set_output("chunks", chunks) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index e9151b70d..84c73d2b6 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -21,10 +21,12 @@ import sys import threading import time +from api.utils import get_uuid from api.utils.api_utils import timeout from api.utils.log_utils import init_root_logger, get_project_base_directory from graphrag.general.index import run_graphrag from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache +from rag.flow.pipeline import Pipeline from rag.prompts import keyword_extraction, question_proposal, content_tagging import logging @@ -223,7 +225,14 @@ async def collect(): logging.warning(f"collect task {msg['id']} {state}") redis_msg.ack() return None, None - task["task_type"] = msg.get("task_type", "") + + task_type = msg.get("task_type", "") + task["task_type"] = task_type + if task_type == "dataflow": + task["tenant_id"]=msg.get("tenant_id", "") + task["dsl"] = msg.get("dsl", "") + task["dataflow_id"] = msg.get("dataflow_id", get_uuid()) + task["kb_id"] = msg.get("kb_id", "") return redis_msg, task @@ -473,6 +482,15 @@ async def embedding(docs, mdl, parser_config=None, callback=None): return tk_count, vector_size +async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None): + _ = callback + + pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id) + pipeline.reset() + + await pipeline.run() + + @timeout(3600) async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): chunks = [] @@ -558,15 +576,20 @@ async def do_handle_task(task): init_kb(task, vector_size) - # Either using RAPTOR or Standard chunking methods - if task.get("task_type", "") == "raptor": + task_type = task.get("task_type", "") + if task_type == "dataflow": + task_dataflow_dsl = task["dsl"] + task_dataflow_id = task["dataflow_id"] + await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None) + return + elif task_type == "raptor": # bind LLM for raptor chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) # run RAPTOR async with kg_limiter: chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) # Either using graphrag or Standard chunking methods - elif task.get("task_type", "") == "graphrag": + elif task_type == "graphrag": if not task_parser_config.get("graphrag", {}).get("use_graphrag", False): progress_callback(prog=-1.0, msg="Internal configuration error.") return