Fix: debug pipeline... (#10311)

### What problem does this PR solve?

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-09-26 19:11:30 +08:00
committed by GitHub
parent 771a38434f
commit 76b1ee2a00
18 changed files with 116 additions and 474 deletions

View File

@ -29,7 +29,7 @@ from api.db.services.canvas_service import CanvasTemplateService, UserCanvasServ
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import queue_dataflow from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_canvas_version import UserCanvasVersionService
from api.settings import RetCode from api.settings import RetCode
@ -41,6 +41,7 @@ from api.db.db_models import APIToken
import time import time
from api.utils.file_utils import filename_type, read_potential_broken_pdf from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
@ -145,6 +146,7 @@ def run():
if cvs.canvas_category == CanvasCategory.DataFlow: if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid() task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0) ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
if not ok: if not ok:
return get_data_error_result(message=error_message) return get_data_error_result(message=error_message)

View File

@ -1,353 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
import sys
import time
from functools import partial
import trio
from flask import request
from flask_login import current_user, login_required
from agent.canvas import Canvas
from agent.component import LLM
from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from api.db.services.task_service import queue_dataflow
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline
@manager.route("/templates", methods=["GET"]) # noqa: F821
@login_required
def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)])
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1))
@manager.route("/rm", methods=["POST"]) # noqa: F821
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)
@manager.route("/set", methods=["POST"]) # noqa: F821
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
req["canvas_category"] = CanvasCategory.DataFlow
if "id" not in req:
req["user_id"] = current_user.id
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow):
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
req["id"] = get_uuid()
if not UserCanvasService.save(**req):
return get_data_error_result(message="Fail to save canvas.")
else:
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req)
# save version
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
UserCanvasVersionService.delete_all_versions(req["id"])
return get_json_result(data=req)
@manager.route("/get/<canvas_id>", methods=["GET"]) # noqa: F821
@login_required
def get(canvas_id):
if not UserCanvasService.accessible(canvas_id, current_user.id):
return get_data_error_result(message="canvas not found.")
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
return get_json_result(data=c)
@manager.route("/run", methods=["POST"]) # noqa: F821
@validate_request("id")
@login_required
def run():
req = request.json
flow_id = req.get("id", "")
doc_id = req.get("doc_id", "")
if not all([flow_id, doc_id]):
return get_data_error_result(message="id and doc_id are required.")
if not DocumentService.get_by_id(doc_id):
return get_data_error_result(message=f"Document for {doc_id} not found.")
user_id = req.get("user_id", current_user.id)
if not UserCanvasService.accessible(flow_id, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(flow_id)
if not e:
return get_data_error_result(message="canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
task_id = get_uuid()
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0)
if not ok:
return server_error_response(error_message)
return get_json_result(data={"task_id": task_id, "flow_id": flow_id})
@manager.route("/reset", methods=["POST"]) # noqa: F821
@validate_request("id")
@login_required
def reset():
req = request.json
flow_id = req.get("id", "")
if not flow_id:
return get_data_error_result(message="id is required.")
if not UserCanvasService.accessible(flow_id, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
task_id = req.get("task_id", "")
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id)
dataflow.reset()
req["dsl"] = json.loads(str(dataflow))
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id):
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
user_id = cvs["user_id"]
def structured(filename, filetype, blob, content_type):
nonlocal user_id
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
location = get_uuid()
FileService.put_blob(user_id, location, blob)
return {
"id": location,
"name": filename,
"size": sys.getsizeof(blob),
"extension": filename.split(".")[-1].lower(),
"mime_type": content_type,
"created_by": user_id,
"created_at": time.time(),
"preview_url": None,
}
if request.args.get("url"):
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter
try:
url = request.args.get("url")
filename = re.sub(r"\?.*", "", url.split("/")[-1])
async def adownload():
browser_config = BrowserConfig(
headless=True,
verbose=False,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False)
result: CrawlResult = await crawler.arun(url=url, config=crawler_config)
return result
page = trio.run(adownload())
if page.pdf:
if filename.split(".")[-1].lower() != "pdf":
filename += ".pdf"
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
except Exception as e:
return server_error_response(e)
file = request.files["file"]
try:
DocumentService.check_doc_health(user_id, file.filename)
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
except Exception as e:
return server_error_response(e)
@manager.route("/input_form", methods=["GET"]) # noqa: F821
@login_required
def input_form():
flow_id = request.args.get("id")
cpn_id = request.args.get("component_id")
try:
e, user_canvas = UserCanvasService.get_by_id(flow_id)
if not e:
return get_data_error_result(message="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=flow_id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="")
return get_json_result(data=dataflow.get_component_input_form(cpn_id))
except Exception as e:
return server_error_response(e)
@manager.route("/debug", methods=["POST"]) # noqa: F821
@validate_request("id", "component_id", "params")
@login_required
def debug():
req = request.json
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
canvas.message_id = get_uuid()
component = canvas.get_component(req["component_id"])["obj"]
component.reset()
if isinstance(component, LLM):
component.set_debug_inputs(req["params"])
component.invoke(**{k: o["value"] for k, o in req["params"].items()})
outputs = component.output()
for k in outputs.keys():
if isinstance(outputs[k], partial):
txt = ""
for c in outputs[k]():
txt += c
outputs[k] = txt
return get_json_result(data=outputs)
except Exception as e:
return server_error_response(e)
# api get list version dsl of canvas
@manager.route("/getlistversion/<canvas_id>", methods=["GET"]) # noqa: F821
@login_required
def getlistversion(canvas_id):
try:
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1)
return get_json_result(data=list)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")
# api get version dsl of canvas
@manager.route("/getversion/<version_id>", methods=["GET"]) # noqa: F821
@login_required
def getversion(version_id):
try:
e, version = UserCanvasVersionService.get_by_id(version_id)
if version:
return get_json_result(data=version.to_dict())
except Exception as e:
return get_json_result(data=f"Error getting history file: {e}")
@manager.route("/listteam", methods=["GET"]) # noqa: F821
@login_required
def list_canvas():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 150))
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
canvas, total = UserCanvasService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow
)
return get_json_result(data={"canvas": canvas, "total": total})
except Exception as e:
return server_error_response(e)
@manager.route("/setting", methods=["POST"]) # noqa: F821
@validate_request("id", "title", "permission")
@login_required
def setting():
req = request.json
req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
e, flow = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
flow = flow.to_dict()
flow["title"] = req["title"]
for key in ("description", "permission", "avatar"):
if value := req.get(key):
flow[key] = value
num = UserCanvasService.update_by_id(req["id"], flow)
return get_json_result(data=num)
@manager.route("/trace", methods=["GET"]) # noqa: F821
def trace():
dataflow_id = request.args.get("dataflow_id")
task_id = request.args.get("task_id")
if not all([dataflow_id, task_id]):
return get_data_error_result(message="dataflow_id and task_id are required.")
e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id)
if not e:
return get_data_error_result(message="dataflow not found.")
dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False)
dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id)
log = dataflow.fetch_logs()
return get_json_result(data=log)

View File

@ -32,7 +32,7 @@ from api.db.services.document_service import DocumentService, doc_upload_and_par
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import ( from api.utils.api_utils import (
@ -480,6 +480,9 @@ def run():
kb_table_num_map[kb_id] = count kb_table_num_map[kb_id] = count
if kb_table_num_map[kb_id] <= 0: if kb_table_num_map[kb_id] <= 0:
KnowledgebaseService.delete_field_map(kb_id) KnowledgebaseService.delete_field_map(kb_id)
if doc.get("pipeline_id", ""):
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
else:
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0) queue_tasks(doc, bucket, name, 0)

View File

@ -417,8 +417,10 @@ def list_pipeline_logs():
desc = False desc = False
else: else:
desc = True desc = True
create_time_from = int(request.args.get("create_time_from", 0)) create_date_from = request.args.get("create_date_from", "")
create_time_to = int(request.args.get("create_time_to", 0)) create_date_to = request.args.get("create_date_to", "")
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = request.get_json() req = request.get_json()
@ -437,17 +439,7 @@ def list_pipeline_logs():
suffix = req.get("suffix", []) suffix = req.get("suffix", [])
try: try:
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix) logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
if create_time_from or create_time_to:
filtered_docs = []
for doc in logs:
doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc)
logs = filtered_docs
return get_json_result(data={"total": tol, "logs": logs}) return get_json_result(data={"total": tol, "logs": logs})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -467,8 +459,10 @@ def list_pipeline_dataset_logs():
desc = False desc = False
else: else:
desc = True desc = True
create_time_from = int(request.args.get("create_time_from", 0)) create_date_from = request.args.get("create_date_from", "")
create_time_to = int(request.args.get("create_time_to", 0)) create_date_to = request.args.get("create_date_to", "")
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = request.get_json() req = request.get_json()
@ -479,17 +473,7 @@ def list_pipeline_dataset_logs():
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}") return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
try: try:
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status) logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
if create_time_from or create_time_to:
filtered_docs = []
for doc in logs:
doc_create_time = doc.get("create_time", 0)
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
filtered_docs.append(doc)
logs = filtered_docs
return get_json_result(data={"total": tol, "logs": logs}) return get_json_result(data={"total": tol, "logs": logs})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -15,10 +15,10 @@
# #
from datetime import datetime from datetime import datetime
from peewee import fn from peewee import fn, JOIN
from api.db import StatusEnum, TenantPermission from api.db import StatusEnum, TenantPermission
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant, UserCanvas
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format from api.utils import current_timestamp, datetime_format
@ -226,13 +226,17 @@ class KnowledgebaseService(CommonService):
cls.model.chunk_num, cls.model.chunk_num,
cls.model.parser_id, cls.model.parser_id,
cls.model.pipeline_id, cls.model.pipeline_id,
UserCanvas.title,
UserCanvas.avatar.alias("pipeline_avatar"),
cls.model.parser_config, cls.model.parser_config,
cls.model.pagerank, cls.model.pagerank,
cls.model.create_time, cls.model.create_time,
cls.model.update_time cls.model.update_time
] ]
kbs = cls.model.select(*fields).join(Tenant, on=( kbs = cls.model.select(*fields)\
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( .join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value)))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.id == kb_id), (cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value) (cls.model.status == StatusEnum.VALID.value)
) )

View File

@ -83,10 +83,7 @@ class PipelineOperationLogService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]): def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl:str="{}"):
from rag.flow.pipeline import Pipeline
dsl = ""
referred_document_id = document_id referred_document_id = document_id
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids: if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
@ -108,13 +105,9 @@ class PipelineOperationLogService(CommonService):
ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id) ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id)
if not ok: if not ok:
raise RuntimeError(f"Pipeline {pipeline_id} not found") raise RuntimeError(f"Pipeline {pipeline_id} not found")
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
tenant_id = user_pipeline.user_id tenant_id = user_pipeline.user_id
title = user_pipeline.title title = user_pipeline.title
avatar = user_pipeline.avatar avatar = user_pipeline.avatar
dsl = json.loads(str(pipeline))
else: else:
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id) ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
if not ok: if not ok:
@ -143,7 +136,7 @@ class PipelineOperationLogService(CommonService):
progress_msg=document.progress_msg, progress_msg=document.progress_msg,
process_begin_at=document.process_begin_at, process_begin_at=document.process_begin_at,
process_duration=document.process_duration, process_duration=document.process_duration,
dsl=dsl, dsl=json.loads(dsl),
task_type=task_type, task_type=task_type,
operation_status=operation_status, operation_status=operation_status,
avatar=avatar, avatar=avatar,
@ -162,7 +155,7 @@ class PipelineOperationLogService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix): def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from=None, create_date_to=None):
fields = cls.get_file_logs_fields() fields = cls.get_file_logs_fields()
if keywords: if keywords:
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower()))) logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
@ -177,6 +170,10 @@ class PipelineOperationLogService(CommonService):
logs = logs.where(cls.model.document_type.in_(types)) logs = logs.where(cls.model.document_type.in_(types))
if suffix: if suffix:
logs = logs.where(cls.model.document_suffix.in_(suffix)) logs = logs.where(cls.model.document_suffix.in_(suffix))
if create_date_from:
logs = logs.where(cls.model.create_date >= create_date_from)
if create_date_to:
logs = logs.where(cls.model.create_date <= create_date_to)
count = logs.count() count = logs.count()
if desc: if desc:
@ -205,12 +202,16 @@ class PipelineOperationLogService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status): def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
fields = cls.get_dataset_logs_fields() fields = cls.get_dataset_logs_fields()
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID)) logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
if operation_status: if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status)) logs = logs.where(cls.model.operation_status.in_(operation_status))
if create_date_from:
logs = logs.where(cls.model.create_date >= create_date_from)
if create_date_to:
logs = logs.where(cls.model.create_date <= create_date_to)
count = logs.count() count = logs.count()
if desc: if desc:

View File

@ -488,8 +488,9 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE
task_type="dataflow" if not rerun else "dataflow_rerun", task_type="dataflow" if not rerun else "dataflow_rerun",
priority=priority, priority=priority,
) )
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute() TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
DocumentService.begin2parse(doc_id)
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True) bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id) task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)

View File

@ -1127,7 +1127,7 @@ class RAGFlowPdfParser:
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt): for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t") pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
left, right, top, bottom = float(left), float(right), float(top), float(bottom) left, right, top, bottom = float(left), float(right), float(top), float(bottom)
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom)) poss.append(([int(p) - 1 for p in pn.split("-")], int(left), int(right), int(top), int(bottom)))
return poss return poss
def crop(self, text, ZM=3, need_position=False): def crop(self, text, ZM=3, need_position=False):

View File

@ -31,6 +31,7 @@ class Extractor(ProcessBase, LLM):
component_name = "Extractor" component_name = "Extractor"
async def _invoke(self, **kwargs): async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.") self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
inputs = self.get_input_elements() inputs = self.get_input_elements()
chunks = [] chunks = []
@ -50,6 +51,7 @@ class Extractor(ProcessBase, LLM):
msg.insert(0, {"role": "system", "content": sys_prompt}) msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg) ck[self._param.field_name] = self._generate(msg)
prog += 1./len(chunks) prog += 1./len(chunks)
if i % (len(chunks)//100+1) == 1:
self.callback(prog, f"{i+1} / {len(chunks)}") self.callback(prog, f"{i+1} / {len(chunks)}")
self.set_output("chunks", chunks) self.set_output("chunks", chunks)
else: else:

View File

@ -25,7 +25,7 @@ class ExtractorFromUpstream(BaseModel):
file: dict | None = Field(default=None) file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None) chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None) output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown") markdown_result: str | None = Field(default=None, alias="markdown")

View File

@ -53,6 +53,7 @@ class HierarchicalMerger(ProcessBase):
self.set_output("_ERROR", f"Input error: {str(e)}") self.set_output("_ERROR", f"Input error: {str(e)}")
return return
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to merge hierarchically.") self.callback(random.randint(1, 5) / 100.0, "Start to merge hierarchically.")
if from_upstream.output_format in ["markdown", "text", "html"]: if from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown": if from_upstream.output_format == "markdown":

View File

@ -25,7 +25,7 @@ class HierarchicalMergerFromUpstream(BaseModel):
file: dict | None = Field(default=None) file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None) chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None) output_format: Literal["json", "chunks"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown") markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text") text_result: str | None = Field(default=None, alias="text")

View File

@ -148,7 +148,7 @@ class ParserParam(ProcessParamBase):
self.check_empty(pdf_parse_method, "Parse method abnormal.") self.check_empty(pdf_parse_method, "Parse method abnormal.")
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]: if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
self.check_empty(pdf_config.get("lang", ""), "Language") self.check_empty(pdf_config.get("lang", ""), "PDF VLM language")
pdf_output_format = pdf_config.get("output_format", "") pdf_output_format = pdf_config.get("output_format", "")
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"]) self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
@ -172,7 +172,7 @@ class ParserParam(ProcessParamBase):
if image_config: if image_config:
image_parse_method = image_config.get("parse_method", "") image_parse_method = image_config.get("parse_method", "")
if image_parse_method not in ["ocr"]: if image_parse_method not in ["ocr"]:
self.check_empty(image_config.get("lang", ""), "Language") self.check_empty(image_config.get("lang", ""), "Image VLM language")
text_config = self.setups.get("text&markdown", "") text_config = self.setups.get("text&markdown", "")
if text_config: if text_config:
@ -181,7 +181,7 @@ class ParserParam(ProcessParamBase):
audio_config = self.setups.get("audio", "") audio_config = self.setups.get("audio", "")
if audio_config: if audio_config:
self.check_empty(audio_config.get("llm_id"), "VLM") self.check_empty(audio_config.get("llm_id"), "Audio VLM")
audio_language = audio_config.get("lang", "") audio_language = audio_config.get("lang", "")
self.check_empty(audio_language, "Language") self.check_empty(audio_language, "Language")

View File

@ -76,22 +76,23 @@ class Pipeline(Graph):
} }
] ]
REDIS_CONN.set_obj(log_key, obj, 60 * 30) REDIS_CONN.set_obj(log_key, obj, 60 * 30)
if self._doc_id and self.task_id: if component_name != "END" and self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items()) percentage = 1.0 / len(self.components.items())
msg = ""
finished = 0.0 finished = 0.0
for o in obj: for o in obj:
if o["component_id"] == "END":
continue
msg += f"\n[{o['component_id']}]:\n"
for t in o["trace"]: for t in o["trace"]:
msg += "%s: %s\n" % (t["datetime"], t["message"])
if t["progress"] < 0: if t["progress"] < 0:
finished = -1 finished = -1
break break
if finished < 0: if finished < 0:
break break
finished += o["trace"][-1]["progress"] * percentage finished += o["trace"][-1]["progress"] * percentage
msg = ""
if len(obj[-1]["trace"]) == 1:
msg += f"\n-------------------------------------\n[{self.get_component_name(o['component_id'])}]:\n"
t = obj[-1]["trace"][-1]
msg += "%s: %s\n" % (t["datetime"], t["message"])
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg}) TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)

View File

@ -59,6 +59,7 @@ class Splitter(ProcessBase):
else: else:
deli += d deli += d
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.") self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.")
if from_upstream.output_format in ["markdown", "text", "html"]: if from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown": if from_upstream.output_format == "markdown":
@ -99,7 +100,7 @@ class Splitter(ProcessBase):
{ {
"text": RAGFlowPdfParser.remove_tag(c), "text": RAGFlowPdfParser.remove_tag(c),
"image": img, "image": img,
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)], "positions": [[pos[0][-1], *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
} }
for c, img in zip(chunks, images) for c, img in zip(chunks, images)
] ]

View File

@ -24,7 +24,7 @@ class TokenizerFromUpstream(BaseModel):
name: str = "" name: str = ""
file: dict | None = Field(default=None) file: dict | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None) output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None) chunks: list[dict[str, Any]] | None = Field(default=None)

View File

@ -108,6 +108,7 @@ class Tokenizer(ProcessBase):
self.set_output("_ERROR", f"Input error: {str(e)}") self.set_output("_ERROR", f"Input error: {str(e)}")
return return
self.set_output("output_format", "chunks")
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method]) parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
if "full_text" in self._param.search_method: if "full_text" in self._param.search_method:
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.") self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
@ -117,11 +118,13 @@ class Tokenizer(ProcessBase):
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name)) ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"]) ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
if ck.get("questions"): if ck.get("questions"):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"])) ck["question_kwd"] = ck["questions"].split("\n")
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
if ck.get("keywords"): if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize(",".join(ck["keywords"])) ck["important_kwd"] = ck["keywords"].split(",")
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
if ck.get("summary"): if ck.get("summary"):
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"]) ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else: else:
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])

View File

@ -20,6 +20,9 @@ import random
import sys import sys
import threading import threading
import time import time
import json_repair
from api.db.services.canvas_service import UserCanvasService from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
@ -57,7 +60,7 @@ from api.versions import get_ragflow_version
from api.db.db_models import close_connection from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
email, tag email, tag
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer, add_positions
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
@ -477,6 +480,8 @@ async def run_dataflow(task: dict):
dataflow_id = task["dataflow_id"] dataflow_id = task["dataflow_id"]
doc_id = task["doc_id"] doc_id = task["doc_id"]
task_id = task["id"] task_id = task["id"]
task_dataset_id = task["kb_id"]
if task["task_type"] == "dataflow": if task["task_type"] == "dataflow":
e, cvs = UserCanvasService.get_by_id(dataflow_id) e, cvs = UserCanvasService.get_by_id(dataflow_id)
assert e, "User pipeline not found." assert e, "User pipeline not found."
@ -486,12 +491,12 @@ async def run_dataflow(task: dict):
assert e, "Pipeline log not found." assert e, "Pipeline log not found."
dsl = pipeline_log.dsl dsl = pipeline_log.dsl
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id) pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run() chunks = await pipeline.run(file=task["file"]) if task.get("file") else await pipeline.run()
if doc_id == CANVAS_DEBUG_DOC_ID: if doc_id == CANVAS_DEBUG_DOC_ID:
return return
if not chunks: if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0) embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
@ -508,7 +513,7 @@ async def run_dataflow(task: dict):
keys = [k for o in chunks for k in list(o.keys())] keys = [k for o in chunks for k in list(o.keys())]
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]): if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
set_progress(task_id, prog=0.82, msg="Start to embedding...") set_progress(task_id, prog=0.82, msg="\n-------------------------------------\nStart to embedding...")
e, kb = KnowledgebaseService.get_by_id(task["kb_id"]) e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id) embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@ -518,7 +523,7 @@ async def run_dataflow(task: dict):
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts]) return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([]) vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks] texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE) delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE+1)
prog = 0.8 prog = 0.8
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE): for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter: async with embed_limiter:
@ -529,6 +534,7 @@ async def run_dataflow(task: dict):
vects = np.concatenate((vects, vts), axis=0) vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c embedding_token_consumption += c
prog += delta prog += delta
if i % (len(texts)//EMBEDDING_BATCH_SIZE/100+1) == 1:
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}") set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks) assert len(vects) == len(chunks)
@ -539,9 +545,23 @@ async def run_dataflow(task: dict):
metadata = {} metadata = {}
def dict_update(meta): def dict_update(meta):
nonlocal metadata nonlocal metadata
if not meta or not isinstance(meta, dict): if not meta:
return
if isinstance(meta, str):
try:
meta = json_repair.loads(meta)
except Exception:
logging.error("Meta data format error.")
return
if not isinstance(meta, dict):
return return
for k, v in meta.items(): for k, v in meta.items():
if isinstance(v, list):
v = [vv for vv in v if isinstance(vv, str)]
if not v:
continue
if not isinstance(v, list) and not isinstance(v, str):
continue
if k not in metadata: if k not in metadata:
metadata[k] = v metadata[k] = v
continue continue
@ -561,15 +581,29 @@ async def run_dataflow(task: dict):
ck["create_timestamp_flt"] = datetime.now().timestamp() ck["create_timestamp_flt"] = datetime.now().timestamp()
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest() ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck: if "questions" in ck:
if "question_tks" not in ck:
ck["question_kwd"] = ck["questions"].split("\n")
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
del ck["questions"] del ck["questions"]
if "keywords" in ck: if "keywords" in ck:
if "important_tks" not in ck:
ck["important_kwd"] = ck["keywords"].split(",")
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
del ck["keywords"] del ck["keywords"]
if "summary" in ck: if "summary" in ck:
if "content_ltks" not in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
del ck["summary"] del ck["summary"]
if "metadata" in ck: if "metadata" in ck:
dict_update(ck["metadata"]) dict_update(ck["metadata"])
del ck["metadata"] del ck["metadata"]
if "content_with_weight" not in ck:
ck["content_with_weight"] = ck["text"]
del ck["text"] del ck["text"]
if "positions" in ck:
add_positions(ck, ck["positions"])
del ck["positions"]
if metadata: if metadata:
e, doc = DocumentService.get_by_id(doc_id) e, doc = DocumentService.get_by_id(doc_id)
@ -580,59 +614,18 @@ async def run_dataflow(task: dict):
DocumentService.update_by_id(doc_id, {"meta_fields": metadata}) DocumentService.update_by_id(doc_id, {"meta_fields": metadata})
start_ts = timer() start_ts = timer()
set_progress(task_id, prog=0.82, msg="Start to index...") set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000)) e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e: if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return return
time_cost = timer() - start_ts time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost)) set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost)) logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE) PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
@timeout(3600)
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = []
vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor(
row["parser_config"]["raptor"].get("max_cluster", 64),
chat_mdl,
embd_mdl,
row["parser_config"]["raptor"]["prompt"],
row["parser_config"]["raptor"]["max_token"],
row["parser_config"]["raptor"]["threshold"]
)
original_length = len(chunks)
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
doc = {
"doc_id": row["doc_id"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"])
}
if row["pagerank"]:
doc[PAGERANK_FLD] = int(row["pagerank"])
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:
d = copy.deepcopy(doc)
d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
d[vctr_nm] = vctr.tolist()
d["content_with_weight"] = content
d["content_ltks"] = rag_tokenizer.tokenize(content)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
res.append(d)
tk_count += num_tokens_from_string(content)
return res, tk_count
@timeout(3600) @timeout(3600)
@ -787,7 +780,6 @@ async def do_handle_task(task):
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR # run RAPTOR
async with kg_limiter: async with kg_limiter:
# chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
chunks, token_count = await run_raptor_for_kb( chunks, token_count = await run_raptor_for_kb(
row=task, row=task,
kb_parser_config=kb_parser_config, kb_parser_config=kb_parser_config,
@ -908,8 +900,8 @@ async def handle_task():
task_document_ids = [] task_document_ids = []
if task_type in ["graphrag", "raptor"]: if task_type in ["graphrag", "raptor"]:
task_document_ids = task["doc_ids"] task_document_ids = task["doc_ids"]
if task["doc_id"] != CANVAS_DEBUG_DOC_ID: if not task.get("dataflow_id", ""):
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids) PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
redis_msg.ack() redis_msg.ack()