mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: refine dataflow and initialize dataflow app (#9952)
### What problem does this PR solve? Refine dataflow and initialize dataflow app. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -418,12 +418,10 @@ def setting():
|
|||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
flow = flow.to_dict()
|
flow = flow.to_dict()
|
||||||
flow["title"] = req["title"]
|
flow["title"] = req["title"]
|
||||||
if req["description"]:
|
|
||||||
flow["description"] = req["description"]
|
for key in ["description", "permission", "avatar"]:
|
||||||
if req["permission"]:
|
if value := req.get(key):
|
||||||
flow["permission"] = req["permission"]
|
flow[key] = value
|
||||||
if req["avatar"]:
|
|
||||||
flow["avatar"] = req["avatar"]
|
|
||||||
|
|
||||||
num= UserCanvasService.update_by_id(req["id"], flow)
|
num= UserCanvasService.update_by_id(req["id"], flow)
|
||||||
return get_json_result(data=num)
|
return get_json_result(data=num)
|
||||||
|
|||||||
353
api/apps/dataflow_app.py
Normal file
353
api/apps/dataflow_app.py
Normal file
@ -0,0 +1,353 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import trio
|
||||||
|
from flask import request
|
||||||
|
from flask_login import current_user, login_required
|
||||||
|
|
||||||
|
from agent.canvas import Canvas
|
||||||
|
from agent.component import LLM
|
||||||
|
from api.db import CanvasCategory, FileType
|
||||||
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.task_service import queue_dataflow
|
||||||
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from api.settings import RetCode
|
||||||
|
from api.utils import get_uuid
|
||||||
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
|
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||||
|
from rag.flow.pipeline import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/templates", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def templates():
|
||||||
|
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)])
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def canvas_list():
|
||||||
|
return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1))
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("canvas_ids")
|
||||||
|
@login_required
|
||||||
|
def rm():
|
||||||
|
for i in request.json["canvas_ids"]:
|
||||||
|
if not UserCanvasService.accessible(i, current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
UserCanvasService.delete_by_id(i)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("dsl", "title")
|
||||||
|
@login_required
|
||||||
|
def save():
|
||||||
|
req = request.json
|
||||||
|
if not isinstance(req["dsl"], str):
|
||||||
|
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||||
|
req["dsl"] = json.loads(req["dsl"])
|
||||||
|
req["canvas_category"] = CanvasCategory.DataFlow
|
||||||
|
if "id" not in req:
|
||||||
|
req["user_id"] = current_user.id
|
||||||
|
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow):
|
||||||
|
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||||
|
req["id"] = get_uuid()
|
||||||
|
|
||||||
|
if not UserCanvasService.save(**req):
|
||||||
|
return get_data_error_result(message="Fail to save canvas.")
|
||||||
|
else:
|
||||||
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
UserCanvasService.update_by_id(req["id"], req)
|
||||||
|
# save version
|
||||||
|
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||||
|
UserCanvasVersionService.delete_all_versions(req["id"])
|
||||||
|
return get_json_result(data=req)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/get/<canvas_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def get(canvas_id):
|
||||||
|
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||||
|
return get_json_result(data=c)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/run", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("id")
|
||||||
|
@login_required
|
||||||
|
def run():
|
||||||
|
req = request.json
|
||||||
|
flow_id = req.get("id", "")
|
||||||
|
doc_id = req.get("doc_id", "")
|
||||||
|
if not all([flow_id, doc_id]):
|
||||||
|
return get_data_error_result(message="id and doc_id are required.")
|
||||||
|
|
||||||
|
if not DocumentService.get_by_id(doc_id):
|
||||||
|
return get_data_error_result(message=f"Document for {doc_id} not found.")
|
||||||
|
|
||||||
|
user_id = req.get("user_id", current_user.id)
|
||||||
|
if not UserCanvasService.accessible(flow_id, current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
e, cvs = UserCanvasService.get_by_id(flow_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
|
if not isinstance(cvs.dsl, str):
|
||||||
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
|
||||||
|
task_id = get_uuid()
|
||||||
|
|
||||||
|
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0)
|
||||||
|
if not ok:
|
||||||
|
return server_error_response(error_message)
|
||||||
|
|
||||||
|
return get_json_result(data={"task_id": task_id, "flow_id": flow_id})
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/reset", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("id")
|
||||||
|
@login_required
|
||||||
|
def reset():
|
||||||
|
req = request.json
|
||||||
|
flow_id = req.get("id", "")
|
||||||
|
if not flow_id:
|
||||||
|
return get_data_error_result(message="id is required.")
|
||||||
|
|
||||||
|
if not UserCanvasService.accessible(flow_id, current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
task_id = req.get("task_id", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
|
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id)
|
||||||
|
dataflow.reset()
|
||||||
|
req["dsl"] = json.loads(str(dataflow))
|
||||||
|
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
||||||
|
return get_json_result(data=req["dsl"])
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||||
|
def upload(canvas_id):
|
||||||
|
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
|
user_id = cvs["user_id"]
|
||||||
|
|
||||||
|
def structured(filename, filetype, blob, content_type):
|
||||||
|
nonlocal user_id
|
||||||
|
if filetype == FileType.PDF.value:
|
||||||
|
blob = read_potential_broken_pdf(blob)
|
||||||
|
|
||||||
|
location = get_uuid()
|
||||||
|
FileService.put_blob(user_id, location, blob)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": location,
|
||||||
|
"name": filename,
|
||||||
|
"size": sys.getsizeof(blob),
|
||||||
|
"extension": filename.split(".")[-1].lower(),
|
||||||
|
"mime_type": content_type,
|
||||||
|
"created_by": user_id,
|
||||||
|
"created_at": time.time(),
|
||||||
|
"preview_url": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.args.get("url"):
|
||||||
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = request.args.get("url")
|
||||||
|
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||||
|
|
||||||
|
async def adownload():
|
||||||
|
browser_config = BrowserConfig(
|
||||||
|
headless=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
|
crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False)
|
||||||
|
result: CrawlResult = await crawler.arun(url=url, config=crawler_config)
|
||||||
|
return result
|
||||||
|
|
||||||
|
page = trio.run(adownload())
|
||||||
|
if page.pdf:
|
||||||
|
if filename.split(".")[-1].lower() != "pdf":
|
||||||
|
filename += ".pdf"
|
||||||
|
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||||
|
|
||||||
|
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
file = request.files["file"]
|
||||||
|
try:
|
||||||
|
DocumentService.check_doc_health(user_id, file.filename)
|
||||||
|
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/input_form", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def input_form():
|
||||||
|
flow_id = request.args.get("id")
|
||||||
|
cpn_id = request.args.get("component_id")
|
||||||
|
try:
|
||||||
|
e, user_canvas = UserCanvasService.get_by_id(flow_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
if not UserCanvasService.query(user_id=current_user.id, id=flow_id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="")
|
||||||
|
|
||||||
|
return get_json_result(data=dataflow.get_component_input_form(cpn_id))
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/debug", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("id", "component_id", "params")
|
||||||
|
@login_required
|
||||||
|
def debug():
|
||||||
|
req = request.json
|
||||||
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
try:
|
||||||
|
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||||
|
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||||
|
canvas.reset()
|
||||||
|
canvas.message_id = get_uuid()
|
||||||
|
component = canvas.get_component(req["component_id"])["obj"]
|
||||||
|
component.reset()
|
||||||
|
|
||||||
|
if isinstance(component, LLM):
|
||||||
|
component.set_debug_inputs(req["params"])
|
||||||
|
component.invoke(**{k: o["value"] for k, o in req["params"].items()})
|
||||||
|
outputs = component.output()
|
||||||
|
for k in outputs.keys():
|
||||||
|
if isinstance(outputs[k], partial):
|
||||||
|
txt = ""
|
||||||
|
for c in outputs[k]():
|
||||||
|
txt += c
|
||||||
|
outputs[k] = txt
|
||||||
|
return get_json_result(data=outputs)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# api get list version dsl of canvas
|
||||||
|
@manager.route("/getlistversion/<canvas_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def getlistversion(canvas_id):
|
||||||
|
try:
|
||||||
|
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1)
|
||||||
|
return get_json_result(data=list)
|
||||||
|
except Exception as e:
|
||||||
|
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# api get version dsl of canvas
|
||||||
|
@manager.route("/getversion/<version_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def getversion(version_id):
|
||||||
|
try:
|
||||||
|
e, version = UserCanvasVersionService.get_by_id(version_id)
|
||||||
|
if version:
|
||||||
|
return get_json_result(data=version.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
return get_json_result(data=f"Error getting history file: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/listteam", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def list_canvas():
|
||||||
|
keywords = request.args.get("keywords", "")
|
||||||
|
page_number = int(request.args.get("page", 1))
|
||||||
|
items_per_page = int(request.args.get("page_size", 150))
|
||||||
|
orderby = request.args.get("orderby", "create_time")
|
||||||
|
desc = request.args.get("desc", True)
|
||||||
|
try:
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||||
|
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||||
|
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow
|
||||||
|
)
|
||||||
|
return get_json_result(data={"canvas": canvas, "total": total})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("id", "title", "permission")
|
||||||
|
@login_required
|
||||||
|
def setting():
|
||||||
|
req = request.json
|
||||||
|
req["user_id"] = current_user.id
|
||||||
|
|
||||||
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
e, flow = UserCanvasService.get_by_id(req["id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
flow = flow.to_dict()
|
||||||
|
flow["title"] = req["title"]
|
||||||
|
for key in ("description", "permission", "avatar"):
|
||||||
|
if value := req.get(key):
|
||||||
|
flow[key] = value
|
||||||
|
|
||||||
|
num = UserCanvasService.update_by_id(req["id"], flow)
|
||||||
|
return get_json_result(data=num)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/trace", methods=["GET"]) # noqa: F821
|
||||||
|
def trace():
|
||||||
|
dataflow_id = request.args.get("dataflow_id")
|
||||||
|
task_id = request.args.get("task_id")
|
||||||
|
if not all([dataflow_id, task_id]):
|
||||||
|
return get_data_error_result(message="dataflow_id and task_id are required.")
|
||||||
|
|
||||||
|
e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="dataflow not found.")
|
||||||
|
|
||||||
|
dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False)
|
||||||
|
dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id)
|
||||||
|
log = dataflow.fetch_logs()
|
||||||
|
|
||||||
|
return get_json_result(data=log)
|
||||||
@ -54,15 +54,15 @@ def trim_header_by_lines(text: str, max_length) -> str:
|
|||||||
|
|
||||||
class TaskService(CommonService):
|
class TaskService(CommonService):
|
||||||
"""Service class for managing document processing tasks.
|
"""Service class for managing document processing tasks.
|
||||||
|
|
||||||
This class extends CommonService to provide specialized functionality for document
|
This class extends CommonService to provide specialized functionality for document
|
||||||
processing task management, including task creation, progress tracking, and chunk
|
processing task management, including task creation, progress tracking, and chunk
|
||||||
management. It handles various document types (PDF, Excel, etc.) and manages their
|
management. It handles various document types (PDF, Excel, etc.) and manages their
|
||||||
processing lifecycle.
|
processing lifecycle.
|
||||||
|
|
||||||
The class implements a robust task queue system with retry mechanisms and progress
|
The class implements a robust task queue system with retry mechanisms and progress
|
||||||
tracking, supporting both synchronous and asynchronous task execution.
|
tracking, supporting both synchronous and asynchronous task execution.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
model: The Task model class for database operations.
|
model: The Task model class for database operations.
|
||||||
"""
|
"""
|
||||||
@ -72,14 +72,14 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_task(cls, task_id):
|
def get_task(cls, task_id):
|
||||||
"""Retrieve detailed task information by task ID.
|
"""Retrieve detailed task information by task ID.
|
||||||
|
|
||||||
This method fetches comprehensive task details including associated document,
|
This method fetches comprehensive task details including associated document,
|
||||||
knowledge base, and tenant information. It also handles task retry logic and
|
knowledge base, and tenant information. It also handles task retry logic and
|
||||||
progress updates.
|
progress updates.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_id (str): The unique identifier of the task to retrieve.
|
task_id (str): The unique identifier of the task to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Task details dictionary containing all task information and related metadata.
|
dict: Task details dictionary containing all task information and related metadata.
|
||||||
Returns None if task is not found or has exceeded retry limit.
|
Returns None if task is not found or has exceeded retry limit.
|
||||||
@ -139,13 +139,13 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_tasks(cls, doc_id: str):
|
def get_tasks(cls, doc_id: str):
|
||||||
"""Retrieve all tasks associated with a document.
|
"""Retrieve all tasks associated with a document.
|
||||||
|
|
||||||
This method fetches all processing tasks for a given document, ordered by page
|
This method fetches all processing tasks for a given document, ordered by page
|
||||||
number and creation time. It includes task progress and chunk information.
|
number and creation time. It includes task progress and chunk information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
doc_id (str): The unique identifier of the document.
|
doc_id (str): The unique identifier of the document.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[dict]: List of task dictionaries containing task details.
|
list[dict]: List of task dictionaries containing task details.
|
||||||
Returns None if no tasks are found.
|
Returns None if no tasks are found.
|
||||||
@ -170,10 +170,10 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||||
"""Update the chunk IDs associated with a task.
|
"""Update the chunk IDs associated with a task.
|
||||||
|
|
||||||
This method updates the chunk_ids field of a task, which stores the IDs of
|
This method updates the chunk_ids field of a task, which stores the IDs of
|
||||||
processed document chunks in a space-separated string format.
|
processed document chunks in a space-separated string format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id (str): The unique identifier of the task.
|
id (str): The unique identifier of the task.
|
||||||
chunk_ids (str): Space-separated string of chunk identifiers.
|
chunk_ids (str): Space-separated string of chunk identifiers.
|
||||||
@ -184,11 +184,11 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_ongoing_doc_name(cls):
|
def get_ongoing_doc_name(cls):
|
||||||
"""Get names of documents that are currently being processed.
|
"""Get names of documents that are currently being processed.
|
||||||
|
|
||||||
This method retrieves information about documents that are in the processing state,
|
This method retrieves information about documents that are in the processing state,
|
||||||
including their locations and associated IDs. It uses database locking to ensure
|
including their locations and associated IDs. It uses database locking to ensure
|
||||||
thread safety when accessing the task information.
|
thread safety when accessing the task information.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
|
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
|
||||||
for documents currently being processed. Returns empty list if
|
for documents currently being processed. Returns empty list if
|
||||||
@ -238,14 +238,14 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def do_cancel(cls, id):
|
def do_cancel(cls, id):
|
||||||
"""Check if a task should be cancelled based on its document status.
|
"""Check if a task should be cancelled based on its document status.
|
||||||
|
|
||||||
This method determines whether a task should be cancelled by checking the
|
This method determines whether a task should be cancelled by checking the
|
||||||
associated document's run status and progress. A task should be cancelled
|
associated document's run status and progress. A task should be cancelled
|
||||||
if its document is marked for cancellation or has negative progress.
|
if its document is marked for cancellation or has negative progress.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id (str): The unique identifier of the task to check.
|
id (str): The unique identifier of the task to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the task should be cancelled, False otherwise.
|
bool: True if the task should be cancelled, False otherwise.
|
||||||
"""
|
"""
|
||||||
@ -311,18 +311,18 @@ class TaskService(CommonService):
|
|||||||
|
|
||||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||||
"""Create and queue document processing tasks.
|
"""Create and queue document processing tasks.
|
||||||
|
|
||||||
This function creates processing tasks for a document based on its type and configuration.
|
This function creates processing tasks for a document based on its type and configuration.
|
||||||
It handles different document types (PDF, Excel, etc.) differently and manages task
|
It handles different document types (PDF, Excel, etc.) differently and manages task
|
||||||
chunking and configuration. It also implements task reuse optimization by checking
|
chunking and configuration. It also implements task reuse optimization by checking
|
||||||
for previously completed tasks.
|
for previously completed tasks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
doc (dict): Document dictionary containing metadata and configuration.
|
doc (dict): Document dictionary containing metadata and configuration.
|
||||||
bucket (str): Storage bucket name where the document is stored.
|
bucket (str): Storage bucket name where the document is stored.
|
||||||
name (str): File name of the document.
|
name (str): File name of the document.
|
||||||
priority (int, optional): Priority level for task queueing (default is 0).
|
priority (int, optional): Priority level for task queueing (default is 0).
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
- For PDF documents, tasks are created per page range based on configuration
|
- For PDF documents, tasks are created per page range based on configuration
|
||||||
- For Excel documents, tasks are created per row range
|
- For Excel documents, tasks are created per row range
|
||||||
@ -410,19 +410,19 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
|||||||
|
|
||||||
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
||||||
"""Attempt to reuse chunks from previous tasks for optimization.
|
"""Attempt to reuse chunks from previous tasks for optimization.
|
||||||
|
|
||||||
This function checks if chunks from previously completed tasks can be reused for
|
This function checks if chunks from previously completed tasks can be reused for
|
||||||
the current task, which can significantly improve processing efficiency. It matches
|
the current task, which can significantly improve processing efficiency. It matches
|
||||||
tasks based on page ranges and configuration digests.
|
tasks based on page ranges and configuration digests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (dict): Current task dictionary to potentially reuse chunks for.
|
task (dict): Current task dictionary to potentially reuse chunks for.
|
||||||
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
|
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
|
||||||
chunking_config (dict): Configuration dictionary for chunk processing.
|
chunking_config (dict): Configuration dictionary for chunk processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
|
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Chunks can only be reused if:
|
Chunks can only be reused if:
|
||||||
- A previous task exists with matching page range and configuration digest
|
- A previous task exists with matching page range and configuration digest
|
||||||
@ -470,3 +470,39 @@ def has_canceled(task_id):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def queue_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, priority: int, callback=None) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Returns a tuple (success: bool, error_message: str).
|
||||||
|
"""
|
||||||
|
_ = callback
|
||||||
|
|
||||||
|
task = dict(
|
||||||
|
id=get_uuid() if not task_id else task_id,
|
||||||
|
doc_id=doc_id,
|
||||||
|
from_page=0,
|
||||||
|
to_page=100000000,
|
||||||
|
task_type="dataflow",
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
|
||||||
|
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
||||||
|
|
||||||
|
kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||||
|
if not kb_id:
|
||||||
|
return False, f"Can't find KB of this document: {doc_id}"
|
||||||
|
|
||||||
|
task["kb_id"] = kb_id
|
||||||
|
task["tenant_id"] = tenant_id
|
||||||
|
task["task_type"] = "dataflow"
|
||||||
|
task["dsl"] = dsl
|
||||||
|
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
|
||||||
|
|
||||||
|
if not REDIS_CONN.queue_product(
|
||||||
|
get_svr_queue_name(priority), message=task
|
||||||
|
):
|
||||||
|
return False, "Can't access Redis. Please check the Redis' status."
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|||||||
@ -2690,21 +2690,21 @@
|
|||||||
"status": "1",
|
"status": "1",
|
||||||
"llm": [
|
"llm": [
|
||||||
{
|
{
|
||||||
"llm_name": "Qwen3-Embedding-8B",
|
"llm_name": "Qwen/Qwen3-Embedding-8B",
|
||||||
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
"model_type": "embedding",
|
"model_type": "embedding",
|
||||||
"is_tools": false
|
"is_tools": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "Qwen3-Embedding-4B",
|
"llm_name": "Qwen/Qwen3-Embedding-4B",
|
||||||
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
"model_type": "embedding",
|
"model_type": "embedding",
|
||||||
"is_tools": false
|
"is_tools": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "Qwen3-Embedding-0.6B",
|
"llm_name": "Qwen/Qwen3-Embedding-0.6B",
|
||||||
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
"model_type": "embedding",
|
"model_type": "embedding",
|
||||||
|
|||||||
@ -14,36 +14,45 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import os
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import pkgutil
|
||||||
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Dict, Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
_package_path = os.path.dirname(__file__)
|
|
||||||
__all_classes: Dict[str, Type] = {}
|
__all_classes: Dict[str, Type] = {}
|
||||||
|
|
||||||
def _import_submodules() -> None:
|
_pkg_dir = Path(__file__).resolve().parent
|
||||||
for filename in os.listdir(_package_path): # noqa: F821
|
_pkg_name = __name__
|
||||||
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
|
|
||||||
continue
|
|
||||||
module_name = filename[:-3]
|
|
||||||
|
|
||||||
|
|
||||||
|
def _should_skip_module(mod_name: str) -> bool:
|
||||||
|
leaf = mod_name.rsplit(".", 1)[-1]
|
||||||
|
return leaf in {"__init__"} or leaf.startswith("__") or leaf.startswith("_") or leaf.startswith("base")
|
||||||
|
|
||||||
|
|
||||||
|
def _import_submodules() -> None:
|
||||||
|
for modinfo in pkgutil.walk_packages([str(_pkg_dir)], prefix=_pkg_name + "."): # noqa: F821
|
||||||
|
mod_name = modinfo.name
|
||||||
|
if _should_skip_module(mod_name): # noqa: F821
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(f".{module_name}", package=__name__)
|
module = importlib.import_module(mod_name)
|
||||||
_extract_classes_from_module(module) # noqa: F821
|
_extract_classes_from_module(module) # noqa: F821
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"Warning: Failed to import module {module_name}: {str(e)}")
|
print(f"Warning: Failed to import module {mod_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _extract_classes_from_module(module: ModuleType) -> None:
|
def _extract_classes_from_module(module: ModuleType) -> None:
|
||||||
for name, obj in inspect.getmembers(module):
|
for name, obj in inspect.getmembers(module):
|
||||||
if (inspect.isclass(obj) and
|
if inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_"):
|
||||||
obj.__module__ == module.__name__ and not name.startswith("_")):
|
|
||||||
__all_classes[name] = obj
|
__all_classes[name] = obj
|
||||||
globals()[name] = obj
|
globals()[name] = obj
|
||||||
|
|
||||||
|
|
||||||
_import_submodules()
|
_import_submodules()
|
||||||
|
|
||||||
__all__ = list(__all_classes.keys()) + ["__all_classes"]
|
__all__ = list(__all_classes.keys()) + ["__all_classes"]
|
||||||
|
|
||||||
del _package_path, _import_submodules, _extract_classes_from_module
|
del _pkg_dir, _pkg_name, _import_submodules, _extract_classes_from_module
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -13,13 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import time
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
|
||||||
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
|
|
||||||
|
|
||||||
@ -31,14 +33,16 @@ class ProcessParamBase(ComponentParamBase):
|
|||||||
|
|
||||||
|
|
||||||
class ProcessBase(ComponentBase):
|
class ProcessBase(ComponentBase):
|
||||||
|
|
||||||
def __init__(self, pipeline, id, param: ProcessParamBase):
|
def __init__(self, pipeline, id, param: ProcessParamBase):
|
||||||
super().__init__(pipeline, id, param)
|
super().__init__(pipeline, id, param)
|
||||||
self.callback = partial(self._canvas.callback, self.component_name)
|
if hasattr(self._canvas, "callback"):
|
||||||
|
self.callback = partial(self._canvas.callback, self.component_name)
|
||||||
|
else:
|
||||||
|
self.callback = partial(lambda *args, **kwargs: None, self.component_name)
|
||||||
|
|
||||||
async def invoke(self, **kwargs) -> dict[str, Any]:
|
async def invoke(self, **kwargs) -> dict[str, Any]:
|
||||||
self.set_output("_created_time", time.perf_counter())
|
self.set_output("_created_time", time.perf_counter())
|
||||||
for k,v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
self.set_output(k, v)
|
self.set_output(k, v)
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(self._param.timeout):
|
with trio.fail_after(self._param.timeout):
|
||||||
@ -54,6 +58,6 @@ class ProcessBase(ComponentBase):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return self.output()
|
return self.output()
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
15
rag/flow/chunker/__init__.py
Normal file
15
rag/flow/chunker/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -13,12 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||||
from graphrag.utils import get_llm_cache, chat_limiter, set_llm_cache
|
from graphrag.utils import chat_limiter, get_llm_cache, set_llm_cache
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.flow.chunker.schema import ChunkerFromUpstream
|
||||||
from rag.nlp import naive_merge, naive_merge_with_images
|
from rag.nlp import naive_merge, naive_merge_with_images
|
||||||
from rag.prompts.prompts import keyword_extraction, question_proposal
|
from rag.prompts.prompts import keyword_extraction, question_proposal
|
||||||
|
|
||||||
@ -26,7 +29,23 @@ from rag.prompts.prompts import keyword_extraction, question_proposal
|
|||||||
class ChunkerParam(ProcessParamBase):
|
class ChunkerParam(ProcessParamBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.method_options = ["general", "q&a", "resume", "manual", "table", "paper", "book", "laws", "presentation", "one"]
|
self.method_options = [
|
||||||
|
# General
|
||||||
|
"general",
|
||||||
|
"onetable",
|
||||||
|
# Customer Service
|
||||||
|
"q&a",
|
||||||
|
"manual",
|
||||||
|
# Recruitment
|
||||||
|
"resume",
|
||||||
|
# Education & Research
|
||||||
|
"book",
|
||||||
|
"paper",
|
||||||
|
"laws",
|
||||||
|
"presentation",
|
||||||
|
# Other
|
||||||
|
# "Tag" # TODO: Other method
|
||||||
|
]
|
||||||
self.method = "general"
|
self.method = "general"
|
||||||
self.chunk_token_size = 512
|
self.chunk_token_size = 512
|
||||||
self.delimiter = "\n"
|
self.delimiter = "\n"
|
||||||
@ -35,10 +54,7 @@ class ChunkerParam(ProcessParamBase):
|
|||||||
self.auto_keywords = 0
|
self.auto_keywords = 0
|
||||||
self.auto_questions = 0
|
self.auto_questions = 0
|
||||||
self.tag_sets = []
|
self.tag_sets = []
|
||||||
self.llm_setting = {
|
self.llm_setting = {"llm_name": "", "lang": "Chinese"}
|
||||||
"llm_name": "",
|
|
||||||
"lang": "Chinese"
|
|
||||||
}
|
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
|
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
|
||||||
@ -48,53 +64,79 @@ class ChunkerParam(ProcessParamBase):
|
|||||||
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
|
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
|
||||||
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class Chunker(ProcessBase):
|
class Chunker(ProcessBase):
|
||||||
component_name = "Chunker"
|
component_name = "Chunker"
|
||||||
|
|
||||||
def _general(self, **kwargs):
|
def _general(self, from_upstream: ChunkerFromUpstream):
|
||||||
self.callback(random.randint(1,5)/100., "Start to chunk via `General`.")
|
self.callback(random.randint(1, 5) / 100.0, "Start to chunk via `General`.")
|
||||||
if kwargs.get("output_format") in ["markdown", "text"]:
|
if from_upstream.output_format in ["markdown", "text"]:
|
||||||
cks = naive_merge(kwargs.get(kwargs["output_format"]), self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
|
if from_upstream.output_format == "markdown":
|
||||||
|
payload = from_upstream.markdown_result
|
||||||
|
else: # == "text"
|
||||||
|
payload = from_upstream.text_result
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
payload = ""
|
||||||
|
|
||||||
|
cks = naive_merge(
|
||||||
|
payload,
|
||||||
|
self._param.chunk_token_size,
|
||||||
|
self._param.delimiter,
|
||||||
|
self._param.overlapped_percent,
|
||||||
|
)
|
||||||
return [{"text": c} for c in cks]
|
return [{"text": c} for c in cks]
|
||||||
|
|
||||||
sections, section_images = [], []
|
sections, section_images = [], []
|
||||||
for o in kwargs["json"]:
|
for o in from_upstream.json_result or []:
|
||||||
sections.append((o["text"], o.get("position_tag","")))
|
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||||
section_images.append(o.get("image"))
|
section_images.append(o.get("image"))
|
||||||
|
|
||||||
chunks, images = naive_merge_with_images(sections, section_images,self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
|
chunks, images = naive_merge_with_images(
|
||||||
return [{
|
sections,
|
||||||
"text": RAGFlowPdfParser.remove_tag(c),
|
section_images,
|
||||||
"image": img,
|
self._param.chunk_token_size,
|
||||||
"positions": RAGFlowPdfParser.extract_positions(c)
|
self._param.delimiter,
|
||||||
} for c,img in zip(chunks,images)]
|
self._param.overlapped_percent,
|
||||||
|
)
|
||||||
|
|
||||||
def _q_and_a(self, **kwargs):
|
return [
|
||||||
|
{
|
||||||
|
"text": RAGFlowPdfParser.remove_tag(c),
|
||||||
|
"image": img,
|
||||||
|
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||||
|
}
|
||||||
|
for c, img in zip(chunks, images)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _q_and_a(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _resume(self, **kwargs):
|
def _resume(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _manual(self, **kwargs):
|
def _manual(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _table(self, **kwargs):
|
def _table(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _paper(self, **kwargs):
|
def _paper(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _book(self, **kwargs):
|
def _book(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _laws(self, **kwargs):
|
def _laws(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _presentation(self, **kwargs):
|
def _presentation(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _one(self, **kwargs):
|
def _one(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
@ -110,7 +152,14 @@ class Chunker(ProcessBase):
|
|||||||
"presentation": self._presentation,
|
"presentation": self._presentation,
|
||||||
"one": self._one,
|
"one": self._one,
|
||||||
}
|
}
|
||||||
chunks = function_map[self._param.method](**kwargs)
|
|
||||||
|
try:
|
||||||
|
from_upstream = ChunkerFromUpstream.model_validate(kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
chunks = function_map[self._param.method](from_upstream)
|
||||||
llm_setting = self._param.llm_setting
|
llm_setting = self._param.llm_setting
|
||||||
|
|
||||||
async def auto_keywords():
|
async def auto_keywords():
|
||||||
37
rag/flow/chunker/schema.py
Normal file
37
rag/flow/chunker/schema.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkerFromUpstream(BaseModel):
|
||||||
|
created_time: float | None = Field(default=None, alias="_created_time")
|
||||||
|
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||||
|
|
||||||
|
name: str
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||||
|
|
||||||
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
|
html_result: str | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||||
|
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -27,6 +27,9 @@ class FileParam(ProcessParamBase):
|
|||||||
def check(self):
|
def check(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class File(ProcessBase):
|
class File(ProcessBase):
|
||||||
component_name = "File"
|
component_name = "File"
|
||||||
@ -1,107 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import random
|
|
||||||
import trio
|
|
||||||
from api.db import LLMType
|
|
||||||
from api.db.services.llm_service import LLMBundle
|
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser, PlainParser, VisionParser
|
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
|
||||||
from rag.llm.cv_model import Base as VLM
|
|
||||||
from deepdoc.parser import ExcelParser
|
|
||||||
|
|
||||||
|
|
||||||
class ParserParam(ProcessParamBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.setups = {
|
|
||||||
"pdf": {
|
|
||||||
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
|
||||||
"vlm_name": "",
|
|
||||||
"lang": "Chinese",
|
|
||||||
"suffix": ["pdf"],
|
|
||||||
"output_format": "json"
|
|
||||||
},
|
|
||||||
"excel": {
|
|
||||||
"output_format": "html"
|
|
||||||
},
|
|
||||||
"ppt": {},
|
|
||||||
"image": {
|
|
||||||
"parse_method": "ocr"
|
|
||||||
},
|
|
||||||
"email": {},
|
|
||||||
"text": {},
|
|
||||||
"audio": {},
|
|
||||||
"video": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
def check(self):
|
|
||||||
if self.setups["pdf"].get("parse_method") not in ["deepdoc", "plain_text"]:
|
|
||||||
assert self.setups["pdf"].get("vlm_name"), "No VLM specified."
|
|
||||||
assert self.setups["pdf"].get("lang"), "No language specified."
|
|
||||||
|
|
||||||
|
|
||||||
class Parser(ProcessBase):
|
|
||||||
component_name = "Parser"
|
|
||||||
|
|
||||||
def _pdf(self, blob):
|
|
||||||
self.callback(random.randint(1,5)/100., "Start to work on a PDF.")
|
|
||||||
conf = self._param.setups["pdf"]
|
|
||||||
self.set_output("output_format", conf["output_format"])
|
|
||||||
if conf.get("parse_method") == "deepdoc":
|
|
||||||
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
|
||||||
elif conf.get("parse_method") == "plain_text":
|
|
||||||
lines,_ = PlainParser()(blob)
|
|
||||||
bboxes = [{"text": t} for t,_ in lines]
|
|
||||||
else:
|
|
||||||
assert conf.get("vlm_name")
|
|
||||||
vision_model = LLMBundle(self._canvas.tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self.setups["pdf"].get("lang"))
|
|
||||||
lines, _ = VisionParser(vision_model=vision_model)(bin, callback=self.callback)
|
|
||||||
bboxes = []
|
|
||||||
for t, poss in lines:
|
|
||||||
pn, x0, x1, top, bott = poss.split(" ")
|
|
||||||
bboxes.append({"page_number": int(pn), "x0": int(x0), "x1": int(x1), "top": int(top), "bottom": int(bott), "text": t})
|
|
||||||
|
|
||||||
self.set_output("json", bboxes)
|
|
||||||
mkdn = ""
|
|
||||||
for b in bboxes:
|
|
||||||
if b.get("layout_type", "") == "title":
|
|
||||||
mkdn += "\n## "
|
|
||||||
if b.get("layout_type", "") == "figure":
|
|
||||||
mkdn += "\n".format(VLM.image2base64(b["image"]))
|
|
||||||
continue
|
|
||||||
mkdn += b.get("text", "") + "\n"
|
|
||||||
self.set_output("markdown", mkdn)
|
|
||||||
|
|
||||||
def _excel(self, blob):
|
|
||||||
self.callback(random.randint(1,5)/100., "Start to work on a Excel.")
|
|
||||||
conf = self._param.setups["excel"]
|
|
||||||
excel_parser = ExcelParser()
|
|
||||||
if conf.get("output_format") == "html":
|
|
||||||
html = excel_parser.html(blob,1000000000)
|
|
||||||
self.set_output("html", html)
|
|
||||||
elif conf.get("output_format") == "json":
|
|
||||||
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
|
|
||||||
elif conf.get("output_format") == "markdown":
|
|
||||||
self.set_output("markdown", excel_parser.markdown(blob))
|
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
|
||||||
function_map = {
|
|
||||||
"pdf": self._pdf,
|
|
||||||
}
|
|
||||||
for p_type, conf in self._param.setups.items():
|
|
||||||
if kwargs.get("name", "").split(".")[-1].lower() not in conf.get("suffix", []):
|
|
||||||
continue
|
|
||||||
await trio.to_thread.run_sync(function_map[p_type], kwargs["blob"])
|
|
||||||
break
|
|
||||||
14
rag/flow/parser/__init__.py
Normal file
14
rag/flow/parser/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
154
rag/flow/parser/parser.py
Normal file
154
rag/flow/parser/parser.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import random
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from api.db import LLMType
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from deepdoc.parser import ExcelParser
|
||||||
|
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
||||||
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.flow.parser.schema import ParserFromUpstream
|
||||||
|
from rag.llm.cv_model import Base as VLM
|
||||||
|
|
||||||
|
|
||||||
|
class ParserParam(ProcessParamBase):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.allowed_output_format = {
|
||||||
|
"pdf": ["json", "markdown"],
|
||||||
|
"excel": ["json", "markdown", "html"],
|
||||||
|
"ppt": [],
|
||||||
|
"image": [],
|
||||||
|
"email": [],
|
||||||
|
"text": [],
|
||||||
|
"audio": [],
|
||||||
|
"video": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.setups = {
|
||||||
|
"pdf": {
|
||||||
|
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
||||||
|
"vlm_name": "",
|
||||||
|
"lang": "Chinese",
|
||||||
|
"suffix": ["pdf"],
|
||||||
|
"output_format": "json",
|
||||||
|
},
|
||||||
|
"excel": {
|
||||||
|
"output_format": "html",
|
||||||
|
"suffix": ["xls", "xlsx", "csv"],
|
||||||
|
},
|
||||||
|
"ppt": {},
|
||||||
|
"image": {
|
||||||
|
"parse_method": "ocr",
|
||||||
|
},
|
||||||
|
"email": {},
|
||||||
|
"text": {},
|
||||||
|
"audio": {},
|
||||||
|
"video": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
pdf_config = self.setups.get("pdf", {})
|
||||||
|
if pdf_config:
|
||||||
|
pdf_parse_method = pdf_config.get("parse_method", "")
|
||||||
|
self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"])
|
||||||
|
|
||||||
|
if pdf_parse_method not in ["deepdoc", "plain_text"]:
|
||||||
|
self.check_empty(pdf_config.get("vlm_name"), "VLM")
|
||||||
|
|
||||||
|
pdf_language = pdf_config.get("lang", "")
|
||||||
|
self.check_empty(pdf_language, "Language")
|
||||||
|
|
||||||
|
pdf_output_format = pdf_config.get("output_format", "")
|
||||||
|
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
|
||||||
|
|
||||||
|
excel_config = self.setups.get("excel", "")
|
||||||
|
if excel_config:
|
||||||
|
excel_output_format = excel_config.get("output_format", "")
|
||||||
|
self.check_valid_value(excel_output_format, "Excel output format abnormal.", self.allowed_output_format["excel"])
|
||||||
|
|
||||||
|
image_config = self.setups.get("image", "")
|
||||||
|
if image_config:
|
||||||
|
image_parse_method = image_config.get("parse_method", "")
|
||||||
|
self.check_valid_value(image_parse_method.lower(), "Parse method abnormal.", ["ocr"])
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class Parser(ProcessBase):
|
||||||
|
component_name = "Parser"
|
||||||
|
|
||||||
|
def _pdf(self, blob):
|
||||||
|
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.")
|
||||||
|
conf = self._param.setups["pdf"]
|
||||||
|
self.set_output("output_format", conf["output_format"])
|
||||||
|
if conf.get("parse_method") == "deepdoc":
|
||||||
|
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
||||||
|
elif conf.get("parse_method") == "plain_text":
|
||||||
|
lines, _ = PlainParser()(blob)
|
||||||
|
bboxes = [{"text": t} for t, _ in lines]
|
||||||
|
else:
|
||||||
|
assert conf.get("vlm_name")
|
||||||
|
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self._param.setups["pdf"].get("lang"))
|
||||||
|
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
|
||||||
|
bboxes = []
|
||||||
|
for t, poss in lines:
|
||||||
|
pn, x0, x1, top, bott = poss.split(" ")
|
||||||
|
bboxes.append({"page_number": int(pn), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
|
||||||
|
if conf.get("output_format") == "json":
|
||||||
|
self.set_output("json", bboxes)
|
||||||
|
if conf.get("output_format") == "markdown":
|
||||||
|
mkdn = ""
|
||||||
|
for b in bboxes:
|
||||||
|
if b.get("layout_type", "") == "title":
|
||||||
|
mkdn += "\n## "
|
||||||
|
if b.get("layout_type", "") == "figure":
|
||||||
|
mkdn += "\n".format(VLM.image2base64(b["image"]))
|
||||||
|
continue
|
||||||
|
mkdn += b.get("text", "") + "\n"
|
||||||
|
self.set_output("markdown", mkdn)
|
||||||
|
|
||||||
|
def _excel(self, blob):
|
||||||
|
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Excel.")
|
||||||
|
conf = self._param.setups["excel"]
|
||||||
|
self.set_output("output_format", conf["output_format"])
|
||||||
|
excel_parser = ExcelParser()
|
||||||
|
if conf.get("output_format") == "html":
|
||||||
|
html = excel_parser.html(blob, 1000000000)
|
||||||
|
self.set_output("html", html)
|
||||||
|
elif conf.get("output_format") == "json":
|
||||||
|
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
|
||||||
|
elif conf.get("output_format") == "markdown":
|
||||||
|
self.set_output("markdown", excel_parser.markdown(blob))
|
||||||
|
|
||||||
|
async def _invoke(self, **kwargs):
|
||||||
|
function_map = {
|
||||||
|
"pdf": self._pdf,
|
||||||
|
"excel": self._excel,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
from_upstream = ParserFromUpstream.model_validate(kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for p_type, conf in self._param.setups.items():
|
||||||
|
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
|
||||||
|
continue
|
||||||
|
await trio.to_thread.run_sync(function_map[p_type], from_upstream.blob)
|
||||||
|
break
|
||||||
25
rag/flow/parser/schema.py
Normal file
25
rag/flow/parser/schema.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ParserFromUpstream(BaseModel):
|
||||||
|
created_time: float | None = Field(default=None, alias="_created_time")
|
||||||
|
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||||
|
|
||||||
|
name: str
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -18,14 +18,15 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from agent.canvas import Graph
|
from agent.canvas import Graph
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
class Pipeline(Graph):
|
class Pipeline(Graph):
|
||||||
|
|
||||||
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
||||||
super().__init__(dsl, tenant_id, task_id)
|
super().__init__(dsl, tenant_id, task_id)
|
||||||
self._doc_id = doc_id
|
self._doc_id = doc_id
|
||||||
@ -35,7 +36,7 @@ class Pipeline(Graph):
|
|||||||
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||||
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
|
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
|
||||||
|
|
||||||
def callback(self, component_name: str, progress: float|int|None=None, message: str = "") -> None:
|
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
|
||||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||||
try:
|
try:
|
||||||
bin = REDIS_CONN.get(log_key)
|
bin = REDIS_CONN.get(log_key)
|
||||||
@ -44,16 +45,10 @@ class Pipeline(Graph):
|
|||||||
if obj[-1]["component_name"] == component_name:
|
if obj[-1]["component_name"] == component_name:
|
||||||
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
|
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
|
||||||
else:
|
else:
|
||||||
obj.append({
|
obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]})
|
||||||
"component_name": component_name,
|
|
||||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
obj = [{
|
obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}]
|
||||||
"component_name": component_name,
|
REDIS_CONN.set_obj(log_key, obj, 60 * 10)
|
||||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
|
|
||||||
}]
|
|
||||||
REDIS_CONN.set_obj(log_key, obj, 60*10)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
|
||||||
@ -71,21 +66,19 @@ class Pipeline(Graph):
|
|||||||
super().reset()
|
super().reset()
|
||||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||||
try:
|
try:
|
||||||
REDIS_CONN.set_obj(log_key, [], 60*10)
|
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
|
||||||
async def run(self, **kwargs):
|
async def run(self, **kwargs):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
if not self.path:
|
if not self.path:
|
||||||
self.path.append("begin")
|
self.path.append("File")
|
||||||
|
|
||||||
if self._doc_id:
|
if self._doc_id:
|
||||||
DocumentService.update_by_id(self._doc_id, {
|
DocumentService.update_by_id(
|
||||||
"progress": random.randint(0,5)/100.,
|
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||||
"progress_msg": "Start the pipeline...",
|
)
|
||||||
"process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
})
|
|
||||||
|
|
||||||
self.error = ""
|
self.error = ""
|
||||||
idx = len(self.path) - 1
|
idx = len(self.path) - 1
|
||||||
@ -99,23 +92,21 @@ class Pipeline(Graph):
|
|||||||
self.path.extend(cpn_obj.get_downstream())
|
self.path.extend(cpn_obj.get_downstream())
|
||||||
|
|
||||||
while idx < len(self.path) and not self.error:
|
while idx < len(self.path) and not self.error:
|
||||||
last_cpn = self.get_component_obj(self.path[idx-1])
|
last_cpn = self.get_component_obj(self.path[idx - 1])
|
||||||
cpn_obj = self.get_component_obj(self.path[idx])
|
cpn_obj = self.get_component_obj(self.path[idx])
|
||||||
|
|
||||||
async def invoke():
|
async def invoke():
|
||||||
nonlocal last_cpn, cpn_obj
|
nonlocal last_cpn, cpn_obj
|
||||||
await cpn_obj.invoke(**last_cpn.output())
|
await cpn_obj.invoke(**last_cpn.output())
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
nursery.start_soon(invoke)
|
nursery.start_soon(invoke)
|
||||||
if cpn_obj.error():
|
if cpn_obj.error():
|
||||||
self.error = "[ERROR]" + cpn_obj.error()
|
self.error = "[ERROR]" + cpn_obj.error()
|
||||||
|
self.callback(cpn_obj.component_name, -1, self.error)
|
||||||
break
|
break
|
||||||
idx += 1
|
idx += 1
|
||||||
self.path.extend(cpn_obj.get_downstream())
|
self.path.extend(cpn_obj.get_downstream())
|
||||||
|
|
||||||
if self._doc_id:
|
if self._doc_id:
|
||||||
DocumentService.update_by_id(self._doc_id, {
|
DocumentService.update_by_id(self._doc_id, {"progress": 1 if not self.error else -1, "progress_msg": "Pipeline finished...\n" + self.error, "process_duration": time.perf_counter() - st})
|
||||||
"progress": 1 if not self.error else -1,
|
|
||||||
"progress_msg": "Pipeline finished...\n" + self.error,
|
|
||||||
"process_duration": time.perf_counter() - st
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -18,12 +18,14 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from rag.flow.pipeline import Pipeline
|
from rag.flow.pipeline import Pipeline
|
||||||
|
|
||||||
|
|
||||||
def print_logs(pipeline):
|
def print_logs(pipeline: Pipeline):
|
||||||
last_logs = "[]"
|
last_logs = "[]"
|
||||||
while True:
|
while True:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
@ -34,16 +36,16 @@ def print_logs(pipeline):
|
|||||||
last_logs = logs_str
|
last_logs = logs_str
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
dsl_default_path = os.path.join(
|
dsl_default_path = os.path.join(
|
||||||
os.path.dirname(os.path.realpath(__file__)),
|
os.path.dirname(os.path.realpath(__file__)),
|
||||||
"dsl_examples",
|
"dsl_examples",
|
||||||
"general_pdf_all.json",
|
"general_pdf_all.json",
|
||||||
)
|
)
|
||||||
parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True)
|
parser.add_argument("-s", "--dsl", default=dsl_default_path, help="input dsl", action="store", required=False)
|
||||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
parser.add_argument("-d", "--doc_id", default=False, help="Document ID", action="store", required=True)
|
||||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
parser.add_argument("-t", "--tenant_id", default=False, help="Tenant ID", action="store", required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
@ -53,5 +55,7 @@ if __name__ == '__main__':
|
|||||||
exe = ThreadPoolExecutor(max_workers=5)
|
exe = ThreadPoolExecutor(max_workers=5)
|
||||||
thr = exe.submit(print_logs, pipeline)
|
thr = exe.submit(print_logs, pipeline)
|
||||||
|
|
||||||
|
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
|
||||||
|
|
||||||
trio.run(pipeline.run)
|
trio.run(pipeline.run)
|
||||||
thr.result()
|
thr.result()
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
{
|
{
|
||||||
"components": {
|
"components": {
|
||||||
"begin": {
|
"File": {
|
||||||
"obj":{
|
"obj":{
|
||||||
"component_name": "File",
|
"component_name": "File",
|
||||||
"params": {
|
"params": {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"downstream": ["parser:0"],
|
"downstream": ["Parser:0"],
|
||||||
"upstream": []
|
"upstream": []
|
||||||
},
|
},
|
||||||
"parser:0": {
|
"Parser:0": {
|
||||||
"obj": {
|
"obj": {
|
||||||
"component_name": "Parser",
|
"component_name": "Parser",
|
||||||
"params": {
|
"params": {
|
||||||
@ -22,14 +22,22 @@
|
|||||||
"pdf"
|
"pdf"
|
||||||
],
|
],
|
||||||
"output_format": "json"
|
"output_format": "json"
|
||||||
|
},
|
||||||
|
"excel": {
|
||||||
|
"output_format": "html",
|
||||||
|
"suffix": [
|
||||||
|
"xls",
|
||||||
|
"xlsx",
|
||||||
|
"csv"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"downstream": ["chunker:0"],
|
"downstream": ["Chunker:0"],
|
||||||
"upstream": ["begin"]
|
"upstream": ["Begin"]
|
||||||
},
|
},
|
||||||
"chunker:0": {
|
"Chunker:0": {
|
||||||
"obj": {
|
"obj": {
|
||||||
"component_name": "Chunker",
|
"component_name": "Chunker",
|
||||||
"params": {
|
"params": {
|
||||||
@ -37,18 +45,19 @@
|
|||||||
"auto_keywords": 5
|
"auto_keywords": 5
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"downstream": ["tokenizer:0"],
|
"downstream": ["Tokenizer:0"],
|
||||||
"upstream": ["chunker:0"]
|
"upstream": ["Parser:0"]
|
||||||
},
|
},
|
||||||
"tokenizer:0": {
|
"Tokenizer:0": {
|
||||||
"obj": {
|
"obj": {
|
||||||
"component_name": "Tokenizer",
|
"component_name": "Tokenizer",
|
||||||
"params": {
|
"params": {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"downstream": [],
|
"downstream": [],
|
||||||
"upstream": ["chunker:0"]
|
"upstream": ["Chunker:0"]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"path": []
|
"path": []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
14
rag/flow/tokenizer/__init__.py
Normal file
14
rag/flow/tokenizer/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
51
rag/flow/tokenizer/schema.py
Normal file
51
rag/flow/tokenizer/schema.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerFromUpstream(BaseModel):
|
||||||
|
created_time: float | None = Field(default=None, alias="_created_time")
|
||||||
|
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||||
|
|
||||||
|
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||||
|
|
||||||
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
|
html_result: str | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _check_payloads(self) -> "TokenizerFromUpstream":
|
||||||
|
if self.chunks:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self.output_format in {"markdown", "text"}:
|
||||||
|
if self.output_format == "markdown" and not self.markdown_result:
|
||||||
|
raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').")
|
||||||
|
if self.output_format == "text" and not self.text_result:
|
||||||
|
raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').")
|
||||||
|
else:
|
||||||
|
if not self.json_result:
|
||||||
|
raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').")
|
||||||
|
return self
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -24,6 +25,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.flow.tokenizer.schema import TokenizerFromUpstream
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
from rag.settings import EMBEDDING_BATCH_SIZE
|
from rag.settings import EMBEDDING_BATCH_SIZE
|
||||||
from rag.svr.task_executor import embed_limiter
|
from rag.svr.task_executor import embed_limiter
|
||||||
@ -40,6 +42,9 @@ class TokenizerParam(ProcessParamBase):
|
|||||||
for v in self.search_method:
|
for v in self.search_method:
|
||||||
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
|
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer(ProcessBase):
|
class Tokenizer(ProcessBase):
|
||||||
component_name = "Tokenizer"
|
component_name = "Tokenizer"
|
||||||
@ -67,19 +72,19 @@ class Tokenizer(ProcessBase):
|
|||||||
@timeout(60)
|
@timeout(60)
|
||||||
def batch_encode(txts):
|
def batch_encode(txts):
|
||||||
nonlocal embedding_model
|
nonlocal embedding_model
|
||||||
return embedding_model.encode([truncate(c, embedding_model.max_length-10) for c in txts])
|
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||||
|
|
||||||
cnts_ = np.array([])
|
cnts_ = np.array([])
|
||||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||||
async with embed_limiter:
|
async with embed_limiter:
|
||||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i: i + EMBEDDING_BATCH_SIZE]))
|
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||||
if len(cnts_) == 0:
|
if len(cnts_) == 0:
|
||||||
cnts_ = vts
|
cnts_ = vts
|
||||||
else:
|
else:
|
||||||
cnts_ = np.concatenate((cnts_, vts), axis=0)
|
cnts_ = np.concatenate((cnts_, vts), axis=0)
|
||||||
token_count += c
|
token_count += c
|
||||||
if i % 33 == 32:
|
if i % 33 == 32:
|
||||||
self.callback(i*1./len(texts)/parts/EMBEDDING_BATCH_SIZE + 0.5*(parts-1))
|
self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
|
||||||
|
|
||||||
cnts = cnts_
|
cnts = cnts_
|
||||||
title_w = float(self._param.filename_embd_weight)
|
title_w = float(self._param.filename_embd_weight)
|
||||||
@ -92,11 +97,17 @@ class Tokenizer(ProcessBase):
|
|||||||
return chunks, token_count
|
return chunks, token_count
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
|
try:
|
||||||
|
from_upstream = TokenizerFromUpstream.model_validate(kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||||
if "full_text" in self._param.search_method:
|
if "full_text" in self._param.search_method:
|
||||||
self.callback(random.randint(1,5)/100., "Start to tokenize.")
|
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
|
||||||
if kwargs.get("chunks"):
|
if from_upstream.chunks:
|
||||||
chunks = kwargs["chunks"]
|
chunks = from_upstream.chunks
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
if ck.get("questions"):
|
if ck.get("questions"):
|
||||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||||
@ -105,30 +116,40 @@ class Tokenizer(ProcessBase):
|
|||||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
if i % 100 == 99:
|
if i % 100 == 99:
|
||||||
self.callback(i*1./len(chunks)/parts)
|
self.callback(i * 1.0 / len(chunks) / parts)
|
||||||
elif kwargs.get("output_format") in ["markdown", "text"]:
|
elif from_upstream.output_format in ["markdown", "text"]:
|
||||||
ck = {
|
if from_upstream.output_format == "markdown":
|
||||||
"text": kwargs.get(kwargs["output_format"], "")
|
payload = from_upstream.markdown_result
|
||||||
}
|
else: # == "text"
|
||||||
if "full_text" in self._param.search_method:
|
payload = from_upstream.text_result
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
ck = {"text": payload}
|
||||||
|
if "full_text" in self._param.search_method:
|
||||||
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
|
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
|
||||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
chunks = [ck]
|
chunks = [ck]
|
||||||
else:
|
else:
|
||||||
chunks = kwargs["json"]
|
chunks = from_upstream.json_result
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
if i % 100 == 99:
|
if i % 100 == 99:
|
||||||
self.callback(i*1./len(chunks)/parts)
|
self.callback(i * 1.0 / len(chunks) / parts)
|
||||||
|
|
||||||
self.callback(1./parts, "Finish tokenizing.")
|
self.callback(1.0 / parts, "Finish tokenizing.")
|
||||||
|
|
||||||
if "embedding" in self._param.search_method:
|
if "embedding" in self._param.search_method:
|
||||||
self.callback(random.randint(1,5)/100. + 0.5*(parts-1), "Start embedding inference.")
|
self.callback(random.randint(1, 5) / 100.0 + 0.5 * (parts - 1), "Start embedding inference.")
|
||||||
chunks, token_count = await self._embedding(kwargs.get("name", ""), chunks)
|
|
||||||
|
if from_upstream.name.strip() == "":
|
||||||
|
logging.warning("Tokenizer: empty name provided from upstream, embedding may be not accurate.")
|
||||||
|
|
||||||
|
chunks, token_count = await self._embedding(from_upstream.name, chunks)
|
||||||
self.set_output("embedding_token_consumption", token_count)
|
self.set_output("embedding_token_consumption", token_count)
|
||||||
|
|
||||||
self.callback(1., "Finish embedding.")
|
self.callback(1.0, "Finish embedding.")
|
||||||
|
|
||||||
self.set_output("chunks", chunks)
|
self.set_output("chunks", chunks)
|
||||||
@ -21,10 +21,12 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||||
from graphrag.general.index import run_graphrag
|
from graphrag.general.index import run_graphrag
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
|
from rag.flow.pipeline import Pipeline
|
||||||
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -223,7 +225,14 @@ async def collect():
|
|||||||
logging.warning(f"collect task {msg['id']} {state}")
|
logging.warning(f"collect task {msg['id']} {state}")
|
||||||
redis_msg.ack()
|
redis_msg.ack()
|
||||||
return None, None
|
return None, None
|
||||||
task["task_type"] = msg.get("task_type", "")
|
|
||||||
|
task_type = msg.get("task_type", "")
|
||||||
|
task["task_type"] = task_type
|
||||||
|
if task_type == "dataflow":
|
||||||
|
task["tenant_id"]=msg.get("tenant_id", "")
|
||||||
|
task["dsl"] = msg.get("dsl", "")
|
||||||
|
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
|
||||||
|
task["kb_id"] = msg.get("kb_id", "")
|
||||||
return redis_msg, task
|
return redis_msg, task
|
||||||
|
|
||||||
|
|
||||||
@ -473,6 +482,15 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
return tk_count, vector_size
|
return tk_count, vector_size
|
||||||
|
|
||||||
|
|
||||||
|
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None):
|
||||||
|
_ = callback
|
||||||
|
|
||||||
|
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id)
|
||||||
|
pipeline.reset()
|
||||||
|
|
||||||
|
await pipeline.run()
|
||||||
|
|
||||||
|
|
||||||
@timeout(3600)
|
@timeout(3600)
|
||||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
@ -558,15 +576,20 @@ async def do_handle_task(task):
|
|||||||
|
|
||||||
init_kb(task, vector_size)
|
init_kb(task, vector_size)
|
||||||
|
|
||||||
# Either using RAPTOR or Standard chunking methods
|
task_type = task.get("task_type", "")
|
||||||
if task.get("task_type", "") == "raptor":
|
if task_type == "dataflow":
|
||||||
|
task_dataflow_dsl = task["dsl"]
|
||||||
|
task_dataflow_id = task["dataflow_id"]
|
||||||
|
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
|
||||||
|
return
|
||||||
|
elif task_type == "raptor":
|
||||||
# bind LLM for raptor
|
# bind LLM for raptor
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
# run RAPTOR
|
# run RAPTOR
|
||||||
async with kg_limiter:
|
async with kg_limiter:
|
||||||
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||||
# Either using graphrag or Standard chunking methods
|
# Either using graphrag or Standard chunking methods
|
||||||
elif task.get("task_type", "") == "graphrag":
|
elif task_type == "graphrag":
|
||||||
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
||||||
progress_callback(prog=-1.0, msg="Internal configuration error.")
|
progress_callback(prog=-1.0, msg="Internal configuration error.")
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user