Feat: Use data pipeline to visualize the parsing configuration of the knowledge base (#10423)

### What problem does this PR solve?

#9869

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: jinhai <haijin.chn@gmail.com>
Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: chanx <1243304602@qq.com>
Co-authored-by: balibabu <cike8899@users.noreply.github.com>
Co-authored-by: Lynn <lynn_inf@hotmail.com>
Co-authored-by: 纷繁下的无奈 <zhileihuang@126.com>
Co-authored-by: huangzl <huangzl@shinemo.com>
Co-authored-by: writinwaters <93570324+writinwaters@users.noreply.github.com>
Co-authored-by: Wilmer <33392318@qq.com>
Co-authored-by: Adrian Weidig <adrianweidig@gmx.net>
Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yongteng Lei <yongtengrey@outlook.com>
Co-authored-by: Liu An <asiro@qq.com>
Co-authored-by: buua436 <66937541+buua436@users.noreply.github.com>
Co-authored-by: BadwomanCraZY <511528396@qq.com>
Co-authored-by: cucusenok <31804608+cucusenok@users.noreply.github.com>
Co-authored-by: Russell Valentine <russ@coldstonelabs.org>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Billy Bao <newyorkupperbay@gmail.com>
Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
Co-authored-by: TensorNull <129579691+TensorNull@users.noreply.github.com>
Co-authored-by: TensorNull <tensor.null@gmail.com>
Co-authored-by: TeslaZY <TeslaZY@outlook.com>
Co-authored-by: Ajay <160579663+aybanda@users.noreply.github.com>
Co-authored-by: AB <aj@Ajays-MacBook-Air.local>
Co-authored-by: 天海蒼灆 <huangaoqin@tecpie.com>
Co-authored-by: He Wang <wanghechn@qq.com>
Co-authored-by: Atsushi Hatakeyama <atu729@icloud.com>
Co-authored-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Mohamed Mathari <155896313+melmathari@users.noreply.github.com>
Co-authored-by: Mohamed Mathari <nocodeventure@Mac-mini-van-Mohamed.fritz.box>
Co-authored-by: Stephen Hu <stephenhu@seismic.com>
Co-authored-by: Shaun Zhang <zhangwfjh@users.noreply.github.com>
Co-authored-by: zhimeng123 <60221886+zhimeng123@users.noreply.github.com>
Co-authored-by: mxc <mxc@example.com>
Co-authored-by: Dominik Novotný <50611433+SgtMarmite@users.noreply.github.com>
Co-authored-by: EVGENY M <168018528+rjohny55@users.noreply.github.com>
Co-authored-by: mcoder6425 <mcoder64@gmail.com>
Co-authored-by: lemsn <lemsn@msn.com>
Co-authored-by: lemsn <lemsn@126.com>
Co-authored-by: Adrian Gora <47756404+adagora@users.noreply.github.com>
Co-authored-by: Womsxd <45663319+Womsxd@users.noreply.github.com>
Co-authored-by: FatMii <39074672+FatMii@users.noreply.github.com>
This commit is contained in:
Kevin Hu
2025-10-09 12:36:19 +08:00
committed by GitHub
parent ef0aecea3b
commit cbf04ee470
490 changed files with 10630 additions and 30688 deletions

17
admin/exceptions.py Normal file
View File

@ -0,0 +1,17 @@
class AdminException(Exception):
def __init__(self, message, code=400):
super().__init__(message)
self.code = code
self.message = message
class UserNotFoundError(AdminException):
def __init__(self, username):
super().__init__(f"User '{username}' not found", 404)
class UserAlreadyExistsError(AdminException):
def __init__(self, username):
super().__init__(f"User '{username}' already exists", 409)
class CannotDeleteAdminError(AdminException):
def __init__(self):
super().__init__("Cannot delete admin account", 403)

View File

@ -153,6 +153,16 @@ class Graph:
def get_tenant_id(self): def get_tenant_id(self):
return self._tenant_id return self._tenant_id
def get_variable_value(self, exp: str) -> Any:
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
if exp.find("@") < 0:
return self.globals[exp]
cpn_id, var_nm = exp.split("@")
cpn = self.get_component(cpn_id)
if not cpn:
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
return cpn["obj"].output(var_nm)
class Canvas(Graph): class Canvas(Graph):
@ -406,16 +416,6 @@ class Canvas(Graph):
return False return False
return True return True
def get_variable_value(self, exp: str) -> Any:
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
if exp.find("@") < 0:
return self.globals[exp]
cpn_id, var_nm = exp.split("@")
cpn = self.get_component(cpn_id)
if not cpn:
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
return cpn["obj"].output(var_nm)
def get_history(self, window_size): def get_history(self, window_size):
convs = [] convs = []
if window_size <= 0: if window_size <= 0:

View File

@ -101,6 +101,8 @@ class LLM(ComponentBase):
def get_input_elements(self) -> dict[str, Any]: def get_input_elements(self) -> dict[str, Any]:
res = self.get_input_elements_from_text(self._param.sys_prompt) res = self.get_input_elements_from_text(self._param.sys_prompt)
if isinstance(self._param.prompts, str):
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
for prompt in self._param.prompts: for prompt in self._param.prompts:
d = self.get_input_elements_from_text(prompt["content"]) d = self.get_input_elements_from_text(prompt["content"])
res.update(d) res.update(d)
@ -112,6 +114,17 @@ class LLM(ComponentBase):
def add2system_prompt(self, txt): def add2system_prompt(self, txt):
self._param.sys_prompt += txt self._param.sys_prompt += txt
def _sys_prompt_and_msg(self, msg, args):
if isinstance(self._param.prompts, str):
self._param.prompts = [{"role": "user", "content": self._param.prompts}]
for p in self._param.prompts:
if msg and msg[-1]["role"] == p["role"]:
continue
p = deepcopy(p)
p["content"] = self.string_format(p["content"], args)
msg.append(p)
return msg, self.string_format(self._param.sys_prompt, args)
def _prepare_prompt_variables(self): def _prepare_prompt_variables(self):
if self._param.visual_files_var: if self._param.visual_files_var:
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var) self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
@ -127,7 +140,6 @@ class LLM(ComponentBase):
args = {} args = {}
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
sys_prompt = self._param.sys_prompt
for k, o in vars.items(): for k, o in vars.items():
args[k] = o["value"] args[k] = o["value"]
if not isinstance(args[k], str): if not isinstance(args[k], str):
@ -137,16 +149,8 @@ class LLM(ComponentBase):
args[k] = str(args[k]) args[k] = str(args[k])
self.set_input_value(k, args[k]) self.set_input_value(k, args[k])
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1] msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
for p in self._param.prompts:
if msg and msg[-1]["role"] == p["role"]:
continue
msg.append(deepcopy(p))
sys_prompt = self.string_format(sys_prompt, args)
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt) user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
for m in msg:
m["content"] = self.string_format(m["content"], args)
if self._param.cite and self._canvas.get_reference()["chunks"]: if self._param.cite and self._canvas.get_reference()["chunks"]:
sys_prompt += citation_prompt(user_defined_prompt) sys_prompt += citation_prompt(user_defined_prompt)

View File

@ -19,15 +19,19 @@ import re
import sys import sys
from functools import partial from functools import partial
import flask
import trio import trio
from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from agent.component.llm import LLM from agent.component import LLM
from api import settings
from api.db import CanvasCategory, FileType from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, TaskService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_canvas_version import UserCanvasVersionService
from api.settings import RetCode from api.settings import RetCode
@ -35,10 +39,12 @@ from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken from api.db.db_models import APIToken, Task
import time import time
from api.utils.file_utils import filename_type, read_potential_broken_pdf from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
@ -48,14 +54,6 @@ def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)]) return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
@manager.route('/list', methods=['GET']) # noqa: F821
@login_required
def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in \
UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.Agent)], key=lambda x: x["update_time"]*-1)
)
@manager.route('/rm', methods=['POST']) # noqa: F821 @manager.route('/rm', methods=['POST']) # noqa: F821
@validate_request("canvas_ids") @validate_request("canvas_ids")
@login_required @login_required
@ -77,9 +75,10 @@ def save():
if not isinstance(req["dsl"], str): if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"]) req["dsl"] = json.loads(req["dsl"])
cate = req.get("canvas_category", CanvasCategory.Agent)
if "id" not in req: if "id" not in req:
req["user_id"] = current_user.id req["user_id"] = current_user.id
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.Agent): if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
return get_data_error_result(message=f"{req['title'].strip()} already exists.") return get_data_error_result(message=f"{req['title'].strip()} already exists.")
req["id"] = get_uuid() req["id"] = get_uuid()
if not UserCanvasService.save(**req): if not UserCanvasService.save(**req):
@ -148,6 +147,14 @@ def run():
if not isinstance(cvs.dsl, str): if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
if cvs.canvas_category == CanvasCategory.DataFlow:
task_id = get_uuid()
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
if not ok:
return get_data_error_result(message=error_message)
return get_json_result(data={"message_id": task_id})
try: try:
canvas = Canvas(cvs.dsl, current_user.id, req["id"]) canvas = Canvas(cvs.dsl, current_user.id, req["id"])
except Exception as e: except Exception as e:
@ -173,6 +180,44 @@ def run():
return resp return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
def rerun():
req = request.json
doc = PipelineOperationLogService.get_documents_info(req["id"])
if not doc:
return get_data_error_result(message="Document not found.")
doc = doc[0]
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = ""
doc["chunk_num"] = 0
doc["token_num"] = 0
DocumentService.clear_chunk_num_when_rerun(doc["id"])
DocumentService.update_by_id(id, doc)
TaskService.filter_delete([Task.doc_id == id])
dsl = req["dsl"]
dsl["path"] = [req["component_id"]]
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
return get_json_result(data=True)
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
@login_required
def cancel(task_id):
try:
REDIS_CONN.set(f"{task_id}-cancel", "x")
except Exception as e:
logging.exception(e)
return get_json_result(data=True)
@manager.route('/reset', methods=['POST']) # noqa: F821 @manager.route('/reset', methods=['POST']) # noqa: F821
@validate_request("id") @validate_request("id")
@login_required @login_required
@ -399,22 +444,32 @@ def getversion( version_id):
return get_json_result(data=f"Error getting history file: {e}") return get_json_result(data=f"Error getting history file: {e}")
@manager.route('/listteam', methods=['GET']) # noqa: F821 @manager.route('/list', methods=['GET']) # noqa: F821
@login_required @login_required
def list_canvas(): def list_canvas():
keywords = request.args.get("keywords", "") keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1)) page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 150)) items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time") orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True) canvas_category = request.args.get("canvas_category")
try: if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
if not owner_ids:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
tenants = [m["tenant_id"] for m in tenants]
tenants.append(current_user.id)
canvas, total = UserCanvasService.get_by_tenant_ids( canvas, total = UserCanvasService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, tenants, current_user.id, page_number,
items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.Agent) items_per_page, orderby, desc, keywords, canvas_category)
return get_json_result(data={"canvas": canvas, "total": total}) else:
except Exception as e: tenants = owner_ids
return server_error_response(e) canvas, total = UserCanvasService.get_by_tenant_ids(
tenants, current_user.id, 0,
0, orderby, desc, keywords, canvas_category)
return get_json_result(data={"canvas": canvas, "total": total})
@manager.route('/setting', methods=['POST']) # noqa: F821 @manager.route('/setting', methods=['POST']) # noqa: F821
@ -499,3 +554,11 @@ def prompts():
#"context_ranking": RANK_MEMORY, #"context_ranking": RANK_MEMORY,
"citation_guidelines": CITATION_PROMPT_TEMPLATE "citation_guidelines": CITATION_PROMPT_TEMPLATE
}) })
@manager.route('/download', methods=['GET']) # noqa: F821
def download():
id = request.args.get("id")
created_by = request.args.get("created_by")
blob = FileService.get_blob(created_by, id)
return flask.make_response(blob)

View File

@ -1,353 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
import sys
import time
from functools import partial
import trio
from flask import request
from flask_login import current_user, login_required
from agent.canvas import Canvas
from agent.component.llm import LLM
from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from api.db.services.task_service import queue_dataflow
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.flow.pipeline import Pipeline
@manager.route("/templates", methods=["GET"]) # noqa: F821
@login_required
def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)])
@manager.route("/list", methods=["GET"]) # noqa: F821
@login_required
def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1))
@manager.route("/rm", methods=["POST"]) # noqa: F821
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
if not UserCanvasService.accessible(i, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)
@manager.route("/set", methods=["POST"]) # noqa: F821
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
req["dsl"] = json.loads(req["dsl"])
req["canvas_category"] = CanvasCategory.DataFlow
if "id" not in req:
req["user_id"] = current_user.id
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow):
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
req["id"] = get_uuid()
if not UserCanvasService.save(**req):
return get_data_error_result(message="Fail to save canvas.")
else:
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req)
# save version
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
UserCanvasVersionService.delete_all_versions(req["id"])
return get_json_result(data=req)
@manager.route("/get/<canvas_id>", methods=["GET"]) # noqa: F821
@login_required
def get(canvas_id):
if not UserCanvasService.accessible(canvas_id, current_user.id):
return get_data_error_result(message="canvas not found.")
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
return get_json_result(data=c)
@manager.route("/run", methods=["POST"]) # noqa: F821
@validate_request("id")
@login_required
def run():
req = request.json
flow_id = req.get("id", "")
doc_id = req.get("doc_id", "")
if not all([flow_id, doc_id]):
return get_data_error_result(message="id and doc_id are required.")
if not DocumentService.get_by_id(doc_id):
return get_data_error_result(message=f"Document for {doc_id} not found.")
user_id = req.get("user_id", current_user.id)
if not UserCanvasService.accessible(flow_id, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
e, cvs = UserCanvasService.get_by_id(flow_id)
if not e:
return get_data_error_result(message="canvas not found.")
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
task_id = get_uuid()
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0)
if not ok:
return server_error_response(error_message)
return get_json_result(data={"task_id": task_id, "flow_id": flow_id})
@manager.route("/reset", methods=["POST"]) # noqa: F821
@validate_request("id")
@login_required
def reset():
req = request.json
flow_id = req.get("id", "")
if not flow_id:
return get_data_error_result(message="id is required.")
if not UserCanvasService.accessible(flow_id, current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
task_id = req.get("task_id", "")
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id)
dataflow.reset()
req["dsl"] = json.loads(str(dataflow))
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id):
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
user_id = cvs["user_id"]
def structured(filename, filetype, blob, content_type):
nonlocal user_id
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
location = get_uuid()
FileService.put_blob(user_id, location, blob)
return {
"id": location,
"name": filename,
"size": sys.getsizeof(blob),
"extension": filename.split(".")[-1].lower(),
"mime_type": content_type,
"created_by": user_id,
"created_at": time.time(),
"preview_url": None,
}
if request.args.get("url"):
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter
try:
url = request.args.get("url")
filename = re.sub(r"\?.*", "", url.split("/")[-1])
async def adownload():
browser_config = BrowserConfig(
headless=True,
verbose=False,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False)
result: CrawlResult = await crawler.arun(url=url, config=crawler_config)
return result
page = trio.run(adownload())
if page.pdf:
if filename.split(".")[-1].lower() != "pdf":
filename += ".pdf"
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
except Exception as e:
return server_error_response(e)
file = request.files["file"]
try:
DocumentService.check_doc_health(user_id, file.filename)
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
except Exception as e:
return server_error_response(e)
@manager.route("/input_form", methods=["GET"]) # noqa: F821
@login_required
def input_form():
flow_id = request.args.get("id")
cpn_id = request.args.get("component_id")
try:
e, user_canvas = UserCanvasService.get_by_id(flow_id)
if not e:
return get_data_error_result(message="canvas not found.")
if not UserCanvasService.query(user_id=current_user.id, id=flow_id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="")
return get_json_result(data=dataflow.get_component_input_form(cpn_id))
except Exception as e:
return server_error_response(e)
@manager.route("/debug", methods=["POST"]) # noqa: F821
@validate_request("id", "component_id", "params")
@login_required
def debug():
req = request.json
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
canvas.message_id = get_uuid()
component = canvas.get_component(req["component_id"])["obj"]
component.reset()
if isinstance(component, LLM):
component.set_debug_inputs(req["params"])
component.invoke(**{k: o["value"] for k, o in req["params"].items()})
outputs = component.output()
for k in outputs.keys():
if isinstance(outputs[k], partial):
txt = ""
for c in outputs[k]():
txt += c
outputs[k] = txt
return get_json_result(data=outputs)
except Exception as e:
return server_error_response(e)
# api get list version dsl of canvas
@manager.route("/getlistversion/<canvas_id>", methods=["GET"]) # noqa: F821
@login_required
def getlistversion(canvas_id):
try:
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1)
return get_json_result(data=list)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")
# api get version dsl of canvas
@manager.route("/getversion/<version_id>", methods=["GET"]) # noqa: F821
@login_required
def getversion(version_id):
try:
e, version = UserCanvasVersionService.get_by_id(version_id)
if version:
return get_json_result(data=version.to_dict())
except Exception as e:
return get_json_result(data=f"Error getting history file: {e}")
@manager.route("/listteam", methods=["GET"]) # noqa: F821
@login_required
def list_canvas():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 150))
orderby = request.args.get("orderby", "create_time")
desc = request.args.get("desc", True)
try:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
canvas, total = UserCanvasService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow
)
return get_json_result(data={"canvas": canvas, "total": total})
except Exception as e:
return server_error_response(e)
@manager.route("/setting", methods=["POST"]) # noqa: F821
@validate_request("id", "title", "permission")
@login_required
def setting():
req = request.json
req["user_id"] = current_user.id
if not UserCanvasService.accessible(req["id"], current_user.id):
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
e, flow = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
flow = flow.to_dict()
flow["title"] = req["title"]
for key in ("description", "permission", "avatar"):
if value := req.get(key):
flow[key] = value
num = UserCanvasService.update_by_id(req["id"], flow)
return get_json_result(data=num)
@manager.route("/trace", methods=["GET"]) # noqa: F821
def trace():
dataflow_id = request.args.get("dataflow_id")
task_id = request.args.get("task_id")
if not all([dataflow_id, task_id]):
return get_data_error_result(message="dataflow_id and task_id are required.")
e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id)
if not e:
return get_data_error_result(message="dataflow not found.")
dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False)
dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id)
log = dataflow.fetch_logs()
return get_json_result(data=log)

View File

@ -33,7 +33,7 @@ from api.db.services.document_service import DocumentService, doc_upload_and_par
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import ( from api.utils.api_utils import (
@ -187,6 +187,7 @@ def create():
"id": get_uuid(), "id": get_uuid(),
"kb_id": kb.id, "kb_id": kb.id,
"parser_id": kb.parser_id, "parser_id": kb.parser_id,
"pipeline_id": kb.pipeline_id,
"parser_config": kb.parser_config, "parser_config": kb.parser_config,
"created_by": current_user.id, "created_by": current_user.id,
"type": FileType.VIRTUAL, "type": FileType.VIRTUAL,
@ -484,8 +485,11 @@ def run():
kb_table_num_map[kb_id] = count kb_table_num_map[kb_id] = count
if kb_table_num_map[kb_id] <= 0: if kb_table_num_map[kb_id] <= 0:
KnowledgebaseService.delete_field_map(kb_id) KnowledgebaseService.delete_field_map(kb_id)
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) if doc.get("pipeline_id", ""):
queue_tasks(doc, bucket, name, 0) queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
else:
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
queue_tasks(doc, bucket, name, 0)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
@ -551,31 +555,22 @@ def get(doc_id):
@manager.route("/change_parser", methods=["POST"]) # noqa: F821 @manager.route("/change_parser", methods=["POST"]) # noqa: F821
@login_required @login_required
@validate_request("doc_id", "parser_id") @validate_request("doc_id")
def change_parser(): def change_parser():
req = request.json req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if doc.parser_id.lower() == req["parser_id"].lower():
if "parser_config" in req:
if req["parser_config"] == doc.parser_config:
return get_json_result(data=True)
else:
return get_json_result(data=True)
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"): e, doc = DocumentService.get_by_id(req["doc_id"])
return get_data_error_result(message="Not supported yet!") if not e:
return get_data_error_result(message="Document not found!")
def reset_doc():
nonlocal doc
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value}) e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress": 0, "progress_msg": "", "run": TaskStatus.UNSTART.value})
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if "parser_config" in req:
DocumentService.update_parser_config(doc.id, req["parser_config"])
if doc.token_num > 0: if doc.token_num > 0:
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1) e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duration * -1)
if not e: if not e:
@ -586,6 +581,26 @@ def change_parser():
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
try:
if "pipeline_id" in req:
if doc.pipeline_id == req["pipeline_id"]:
return get_json_result(data=True)
DocumentService.update_by_id(doc.id, {"pipeline_id": req["pipeline_id"]})
reset_doc()
return get_json_result(data=True)
if doc.parser_id.lower() == req["parser_id"].lower():
if "parser_config" in req:
if req["parser_config"] == doc.parser_config:
return get_json_result(data=True)
else:
return get_json_result(data=True)
if (doc.type == FileType.VISUAL and req["parser_id"] != "picture") or (re.search(r"\.(ppt|pptx|pages)$", doc.name) and req["parser_id"] != "presentation"):
return get_data_error_result(message="Not supported yet!")
if "parser_config" in req:
DocumentService.update_parser_config(doc.id, req["parser_config"])
reset_doc()
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -179,9 +179,6 @@ def list_files():
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
files, total = FileService.get_by_pf_id( files, total = FileService.get_by_pf_id(
current_user.id, pf_id, page_number, items_per_page, orderby, desc, keywords) current_user.id, pf_id, page_number, items_per_page, orderby, desc, keywords)
@ -213,9 +210,6 @@ def get_parent_folder():
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
parent_folder = FileService.get_parent_folder(file_id) parent_folder = FileService.get_parent_folder(file_id)
return get_json_result(data={"parent_folder": parent_folder.to_json()}) return get_json_result(data={"parent_folder": parent_folder.to_json()})
except Exception as e: except Exception as e:
@ -231,9 +225,6 @@ def get_all_parent_folders():
if not e: if not e:
return get_data_error_result(message="Folder not found!") return get_data_error_result(message="Folder not found!")
if not check_file_team_permission(file, current_user.id):
return get_json_result(data=False, message='No authorization.', code=settings.RetCode.AUTHENTICATION_ERROR)
parent_folders = FileService.get_all_parent_folders(file_id) parent_folders = FileService.get_all_parent_folders(file_id)
parent_folders_res = [] parent_folders_res = []
for parent_folder in parent_folders: for parent_folder in parent_folders:

View File

@ -14,18 +14,21 @@
# limitations under the License. # limitations under the License.
# #
import json import json
import logging
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request, not_allowed_parameters, active_required from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters
from api.utils import get_uuid from api.utils import get_uuid
from api.db import StatusEnum, FileSource from api.db import PipelineTaskType, StatusEnum, FileSource, VALID_FILE_TYPES, VALID_TASK_STATUS
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@ -38,7 +41,6 @@ from rag.utils.storage_factory import STORAGE_IMPL
@manager.route('/create', methods=['post']) # noqa: F821 @manager.route('/create', methods=['post']) # noqa: F821
@login_required @login_required
@active_required
@validate_request("name") @validate_request("name")
def create(): def create():
req = request.json req = request.json
@ -62,10 +64,39 @@ def create():
req["name"] = dataset_name req["name"] = dataset_name
req["tenant_id"] = current_user.id req["tenant_id"] = current_user.id
req["created_by"] = current_user.id req["created_by"] = current_user.id
if not req.get("parser_id"):
req["parser_id"] = "naive"
e, t = TenantService.get_by_id(current_user.id) e, t = TenantService.get_by_id(current_user.id)
if not e: if not e:
return get_data_error_result(message="Tenant not found.") return get_data_error_result(message="Tenant not found.")
req["embd_id"] = t.embd_id req["parser_config"] = {
"layout_recognize": "DeepDOC",
"chunk_token_num": 512,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"topn_tags": 3,
"raptor": {
"use_raptor": True,
"prompt": "Please summarize the following paragraphs. Be careful with the numbers, do not make things up. Paragraphs as following:\n {cluster_content}\nThe above is the content you need to summarize.",
"max_token": 256,
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0
},
"graphrag": {
"use_graphrag": True,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category"
],
"method": "light"
}
}
if not KnowledgebaseService.save(**req): if not KnowledgebaseService.save(**req):
return get_data_error_result() return get_data_error_result()
return get_json_result(data={"kb_id": req["id"]}) return get_json_result(data={"kb_id": req["id"]})
@ -396,3 +427,352 @@ def get_basic_info():
basic_info = DocumentService.knowledgebase_basic_info(kb_id) basic_info = DocumentService.knowledgebase_basic_info(kb_id)
return get_json_result(data=basic_info) return get_json_result(data=basic_info)
@manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
create_date_from = request.args.get("create_date_from", "")
create_date_to = request.args.get("create_date_to", "")
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = request.get_json()
operation_status = req.get("operation_status", [])
if operation_status:
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
types = req.get("types", [])
if types:
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
if invalid_types:
return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}")
suffix = req.get("suffix", [])
try:
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821
@login_required
def list_pipeline_dataset_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True
create_date_from = request.args.get("create_date_from", "")
create_date_to = request.args.get("create_date_to", "")
if create_date_to > create_date_from:
return get_data_error_result(message="Create data filter is abnormal.")
req = request.get_json()
operation_status = req.get("operation_status", [])
if operation_status:
invalid_status = {s for s in operation_status if s not in VALID_TASK_STATUS}
if invalid_status:
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
try:
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
return get_json_result(data={"total": tol, "logs": logs})
except Exception as e:
return server_error_response(e)
@manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821
@login_required
def delete_pipeline_logs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
req = request.get_json()
log_ids = req.get("log_ids", [])
PipelineOperationLogService.delete_by_ids(log_ids)
return get_json_result(data=True)
@manager.route("/pipeline_log_detail", methods=["GET"]) # noqa: F821
@login_required
def pipeline_log_detail():
log_id = request.args.get("log_id")
if not log_id:
return get_json_result(data=False, message='Lack of "Pipeline log ID"', code=settings.RetCode.ARGUMENT_ERROR)
ok, log = PipelineOperationLogService.get_by_id(log_id)
if not ok:
return get_data_error_result(message="Invalid pipeline log ID")
return get_json_result(data=log.to_dict())
@manager.route("/run_graphrag", methods=["POST"]) # noqa: F821
@login_required
def run_graphrag():
req = request.json
kb_id = req.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.graphrag_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=kb_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
logging.warning(f"Cannot save graphrag_task_id for kb {kb_id}")
return get_json_result(data={"graphrag_task_id": task_id})
@manager.route("/trace_graphrag", methods=["GET"]) # noqa: F821
@login_required
def trace_graphrag():
kb_id = request.args.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.graphrag_task_id
if not task_id:
return get_json_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_error_data_result(message="GraphRAG Task Not Found or Error Occurred")
return get_json_result(data=task.to_dict())
@manager.route("/run_raptor", methods=["POST"]) # noqa: F821
@login_required
def run_raptor():
req = request.json
kb_id = req.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.raptor_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=kb_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
logging.warning(f"Cannot save raptor_task_id for kb {kb_id}")
return get_json_result(data={"raptor_task_id": task_id})
@manager.route("/trace_raptor", methods=["GET"]) # noqa: F821
@login_required
def trace_raptor():
kb_id = request.args.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.raptor_task_id
if not task_id:
return get_json_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
return get_json_result(data=task.to_dict())
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
@login_required
def run_mindmap():
req = request.json
kb_id = req.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.mindmap_task_id
if task_id:
ok, task = TaskService.get_by_id(task_id)
if not ok:
logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}")
if task and task.progress not in [-1, 1]:
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.")
documents, _ = DocumentService.get_by_kb_id(
kb_id=kb_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
if not documents:
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
sample_document = documents[0]
document_ids = [document["id"] for document in documents]
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
return get_json_result(data={"mindmap_task_id": task_id})
@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821
@login_required
def trace_mindmap():
kb_id = request.args.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_error_data_result(message="Invalid Knowledgebase ID")
task_id = kb.mindmap_task_id
if not task_id:
return get_json_result(data={})
ok, task = TaskService.get_by_id(task_id)
if not ok:
return get_error_data_result(message="Mindmap Task Not Found or Error Occurred")
return get_json_result(data=task.to_dict())
@manager.route("/unbind_task", methods=["DELETE"]) # noqa: F821
@login_required
def delete_kb_task():
kb_id = request.args.get("kb_id", "")
if not kb_id:
return get_error_data_result(message='Lack of "KB ID"')
ok, kb = KnowledgebaseService.get_by_id(kb_id)
if not ok:
return get_json_result(data=True)
pipeline_task_type = request.args.get("pipeline_task_type", "")
if not pipeline_task_type or pipeline_task_type not in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
return get_error_data_result(message="Invalid task type")
match pipeline_task_type:
case PipelineTaskType.GRAPH_RAG:
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id)
kb_task_id = "graphrag_task_id"
kb_task_finish_at = "graphrag_task_finish_at"
case PipelineTaskType.RAPTOR:
kb_task_id = "raptor_task_id"
kb_task_finish_at = "raptor_task_finish_at"
case PipelineTaskType.MINDMAP:
kb_task_id = "mindmap_task_id"
kb_task_finish_at = "mindmap_task_finish_at"
case _:
return get_error_data_result(message="Internal Error: Invalid task type")
ok = KnowledgebaseService.update_by_id(kb_id, {kb_task_id: "", kb_task_finish_at: None})
if not ok:
return server_error_response(f"Internal error: cannot delete task {pipeline_task_type}")
return get_json_result(data=True)

View File

@ -127,4 +127,15 @@ class MCPServerType(StrEnum):
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP} VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class PipelineTaskType(StrEnum):
PARSE = "Parse"
DOWNLOAD = "Download"
RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG"
MINDMAP = "Mindmap"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"

View File

@ -684,8 +684,17 @@ class Knowledgebase(DataBaseModel):
vector_similarity_weight = FloatField(default=0.3, index=True) vector_similarity_weight = FloatField(default=0.3, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True) parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
pagerank = IntegerField(default=0, index=False) pagerank = IntegerField(default=0, index=False)
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
graphrag_task_finish_at = DateTimeField(null=True)
raptor_task_id = CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)
raptor_task_finish_at = DateTimeField(null=True)
mindmap_task_id = CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)
mindmap_task_finish_at = DateTimeField(null=True)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
def __str__(self): def __str__(self):
@ -700,6 +709,7 @@ class Document(DataBaseModel):
thumbnail = TextField(null=True, help_text="thumbnail base64 string") thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True) kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True) parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True) source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
type = CharField(max_length=32, null=False, help_text="file extension", index=True) type = CharField(max_length=32, null=False, help_text="file extension", index=True)
@ -942,6 +952,32 @@ class Search(DataBaseModel):
db_table = "search" db_table = "search"
class PipelineOperationLog(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
document_id = CharField(max_length=32, index=True)
tenant_id = CharField(max_length=32, null=False, index=True)
kb_id = CharField(max_length=32, null=False, index=True)
pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)
pipeline_title = CharField(max_length=32, null=True, help_text="Pipeline title", index=True)
parser_id = CharField(max_length=32, null=False, help_text="Parser ID", index=True)
document_name = CharField(max_length=255, null=False, help_text="File name")
document_suffix = CharField(max_length=255, null=False, help_text="File suffix")
document_type = CharField(max_length=255, null=False, help_text="Document type")
source_from = CharField(max_length=255, null=False, help_text="Source")
progress = FloatField(default=0, index=True)
progress_msg = TextField(null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True, index=True)
process_duration = FloatField(default=0)
dsl = JSONField(null=True, default=dict)
task_type = CharField(max_length=32, null=False, default="")
operation_status = CharField(max_length=32, null=False, help_text="Operation status")
avatar = TextField(null=True, help_text="avatar base64 string")
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
class Meta:
db_table = "pipeline_operation_log"
def migrate_db(): def migrate_db():
logging.disable(logging.ERROR) logging.disable(logging.ERROR)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
@ -1058,7 +1094,6 @@ def migrate_db():
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={}))) migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
except Exception: except Exception:
pass pass
try: try:
migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title"))) migrate(migrator.alter_column_type("canvas_template", "title", JSONField(null=True, default=dict, help_text="Canvas title")))
except Exception: except Exception:
@ -1075,4 +1110,36 @@ def migrate_db():
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True))) migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
except Exception: except Exception:
pass pass
try:
migrate(migrator.add_column("knowledgebase", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("document", "pipeline_id", CharField(max_length=32, null=True, help_text="Pipeline ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_id", CharField(max_length=32, null=True, help_text="Gragh RAG task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
except Exception:
pass
logging.disable(logging.NOTSET) logging.disable(logging.NOTSET)

View File

@ -126,7 +126,7 @@ class UserCanvasService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, page_number, items_per_page,
orderby, desc, keywords, canvas_category=CanvasCategory.Agent, orderby, desc, keywords, canvas_category=None
): ):
fields = [ fields = [
cls.model.id, cls.model.id,
@ -135,6 +135,7 @@ class UserCanvasService(CommonService):
cls.model.dsl, cls.model.dsl,
cls.model.description, cls.model.description,
cls.model.permission, cls.model.permission,
cls.model.user_id.alias("tenant_id"),
User.nickname, User.nickname,
User.avatar.alias('tenant_avatar'), User.avatar.alias('tenant_avatar'),
cls.model.update_time, cls.model.update_time,
@ -142,24 +143,26 @@ class UserCanvasService(CommonService):
] ]
if keywords: if keywords:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission == cls.model.user_id.in_(joined_tenant_ids),
TenantPermission.TEAM.value)) | ( fn.LOWER(cls.model.title).contains(keywords.lower())
cls.model.user_id == user_id)), #(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id)),
(fn.LOWER(cls.model.title).contains(keywords.lower())) #(fn.LOWER(cls.model.title).contains(keywords.lower()))
) )
else: else:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission == cls.model.user_id.in_(joined_tenant_ids)
TenantPermission.TEAM.value)) | ( #(((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id))
cls.model.user_id == user_id))
) )
agents = agents.where(cls.model.canvas_category == canvas_category) if canvas_category:
agents = agents.where(cls.model.canvas_category == canvas_category)
if desc: if desc:
agents = agents.order_by(cls.model.getter_by(orderby).desc()) agents = agents.order_by(cls.model.getter_by(orderby).desc())
else: else:
agents = agents.order_by(cls.model.getter_by(orderby).asc()) agents = agents.order_by(cls.model.getter_by(orderby).asc())
count = agents.count() count = agents.count()
agents = agents.paginate(page_number, items_per_page) if page_number and items_per_page:
agents = agents.paginate(page_number, items_per_page)
return list(agents.dicts()), count return list(agents.dicts()), count
@classmethod @classmethod

View File

@ -24,12 +24,13 @@ from io import BytesIO
import trio import trio
import xxhash import xxhash
from peewee import fn, Case from peewee import fn, Case, JOIN
from api import settings from api import settings
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole, CanvasCategory
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File, UserCanvas, \
User
from api.db.db_utils import bulk_insert_into_db from api.db.db_utils import bulk_insert_into_db
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -51,6 +52,7 @@ class DocumentService(CommonService):
cls.model.thumbnail, cls.model.thumbnail,
cls.model.kb_id, cls.model.kb_id,
cls.model.parser_id, cls.model.parser_id,
cls.model.pipeline_id,
cls.model.parser_config, cls.model.parser_config,
cls.model.source_type, cls.model.source_type,
cls.model.type, cls.model.type,
@ -79,7 +81,10 @@ class DocumentService(CommonService):
def get_list(cls, kb_id, page_number, items_per_page, def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id, name): orderby, desc, keywords, id, name):
fields = cls.get_cls_model_fields() fields = cls.get_cls_model_fields()
docs = cls.model.select(*fields).join(File2Document, on = (File2Document.document_id == cls.model.id)).join(File, on = (File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\
.join(File, on = (File.id == File2Document.file_id))\
.join(UserCanvas, on = ((cls.model.pipeline_id == UserCanvas.id) & (UserCanvas.canvas_category == CanvasCategory.DataFlow.value)), join_type=JOIN.LEFT_OUTER)\
.where(cls.model.kb_id == kb_id)
if id: if id:
docs = docs.where( docs = docs.where(
cls.model.id == id) cls.model.id == id)
@ -117,12 +122,22 @@ class DocumentService(CommonService):
orderby, desc, keywords, run_status, types, suffix): orderby, desc, keywords, run_status, types, suffix):
fields = cls.get_cls_model_fields() fields = cls.get_cls_model_fields()
if keywords: if keywords:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where( docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
(cls.model.kb_id == kb_id), .join(File2Document, on=(File2Document.document_id == cls.model.id))\
(fn.LOWER(cls.model.name).contains(keywords.lower())) .join(File, on=(File.id == File2Document.file_id))\
) .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.kb_id == kb_id),
(fn.LOWER(cls.model.name).contains(keywords.lower()))
)
else: else:
docs = cls.model.select(*fields).join(File2Document, on=(File2Document.document_id == cls.model.id)).join(File, on=(File.id == File2Document.file_id)).where(cls.model.kb_id == kb_id) docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\
.join(File2Document, on=(File2Document.document_id == cls.model.id))\
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.join(File, on=(File.id == File2Document.file_id))\
.join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\
.where(cls.model.kb_id == kb_id)
if run_status: if run_status:
docs = docs.where(cls.model.run.in_(run_status)) docs = docs.where(cls.model.run.in_(run_status))
@ -370,8 +385,7 @@ class DocumentService(CommonService):
process_duration=cls.model.process_duration + duration).where( process_duration=cls.model.process_duration + duration).where(
cls.model.id == doc_id).execute() cls.model.id == doc_id).execute()
if num == 0: if num == 0:
raise LookupError( logging.warning("Document not found which is supposed to be there")
"Document not found which is supposed to be there")
num = Knowledgebase.update( num = Knowledgebase.update(
token_num=Knowledgebase.token_num + token_num=Knowledgebase.token_num +
token_num, token_num,
@ -637,6 +651,22 @@ class DocumentService(CommonService):
@DB.connection_context() @DB.connection_context()
def update_progress(cls): def update_progress(cls):
docs = cls.get_unfinished_docs() docs = cls.get_unfinished_docs()
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def update_progress_immediately(cls, docs:list[dict]):
if not docs:
return
cls._sync_progress(docs)
@classmethod
@DB.connection_context()
def _sync_progress(cls, docs:list[dict]):
for d in docs: for d in docs:
try: try:
tsks = Task.query(doc_id=d["id"], order_by=Task.create_time) tsks = Task.query(doc_id=d["id"], order_by=Task.create_time)
@ -646,8 +676,6 @@ class DocumentService(CommonService):
prg = 0 prg = 0
finished = True finished = True
bad = 0 bad = 0
has_raptor = False
has_graphrag = False
e, doc = DocumentService.get_by_id(d["id"]) e, doc = DocumentService.get_by_id(d["id"])
status = doc.run # TaskStatus.RUNNING.value status = doc.run # TaskStatus.RUNNING.value
priority = 0 priority = 0
@ -659,24 +687,14 @@ class DocumentService(CommonService):
prg += t.progress if t.progress >= 0 else 0 prg += t.progress if t.progress >= 0 else 0
if t.progress_msg.strip(): if t.progress_msg.strip():
msg.append(t.progress_msg) msg.append(t.progress_msg)
if t.task_type == "raptor":
has_raptor = True
elif t.task_type == "graphrag":
has_graphrag = True
priority = max(priority, t.priority) priority = max(priority, t.priority)
prg /= len(tsks) prg /= len(tsks)
if finished and bad: if finished and bad:
prg = -1 prg = -1
status = TaskStatus.FAIL.value status = TaskStatus.FAIL.value
elif finished: elif finished:
if (d["parser_config"].get("raptor") or {}).get("use_raptor") and not has_raptor: prg = 1
queue_raptor_o_graphrag_tasks(d, "raptor", priority) status = TaskStatus.DONE.value
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif (d["parser_config"].get("graphrag") or {}).get("use_graphrag") and not has_graphrag:
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
prg = 0.98 * len(tsks) / (len(tsks) + 1)
else:
status = TaskStatus.DONE.value
msg = "\n".join(sorted(msg)) msg = "\n".join(sorted(msg))
info = { info = {
@ -688,7 +706,7 @@ class DocumentService(CommonService):
info["progress"] = prg info["progress"] = prg
if msg: if msg:
info["progress_msg"] = msg info["progress_msg"] = msg
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"): if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
else: else:
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
@ -769,7 +787,11 @@ class DocumentService(CommonService):
"cancelled": int(cancelled), "cancelled": int(cancelled),
} }
def queue_raptor_o_graphrag_tasks(doc, ty, priority): def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]):
"""
You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level.
Optionally, specify a list of doc_ids to determine which documents participate in the task.
"""
chunking_config = DocumentService.get_chunking_config(doc["id"]) chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64() hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()): for field in sorted(chunking_config.keys()):
@ -779,11 +801,12 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
nonlocal doc nonlocal doc
return { return {
"id": get_uuid(), "id": get_uuid(),
"doc_id": doc["id"], "doc_id": fake_doc_id if fake_doc_id else doc["id"],
"from_page": 100000000, "from_page": 100000000,
"to_page": 100000000, "to_page": 100000000,
"task_type": ty, "task_type": ty,
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
"begin_at": datetime.now(),
} }
task = new_task() task = new_task()
@ -792,7 +815,12 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority):
hasher.update(ty.encode("utf-8")) hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest() task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True) bulk_insert_into_db(Task, [task], True)
if ty in ["graphrag", "raptor", "mindmap"]:
task["doc_ids"] = doc_ids
DocumentService.begin2parse(doc["id"])
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
return task["id"]
def get_queue_length(priority): def get_queue_length(priority):

View File

@ -457,6 +457,7 @@ class FileService(CommonService):
"id": doc_id, "id": doc_id,
"kb_id": kb.id, "kb_id": kb.id,
"parser_id": self.get_parser(filetype, filename, kb.parser_id), "parser_id": self.get_parser(filetype, filename, kb.parser_id),
"pipeline_id": kb.pipeline_id,
"parser_config": kb.parser_config, "parser_config": kb.parser_config,
"created_by": user_id, "created_by": user_id,
"type": filetype, "type": filetype,
@ -512,7 +513,7 @@ class FileService(CommonService):
return ParserType.AUDIO.value return ParserType.AUDIO.value
if re.search(r"\.(ppt|pptx|pages)$", filename): if re.search(r"\.(ppt|pptx|pages)$", filename):
return ParserType.PRESENTATION.value return ParserType.PRESENTATION.value
if re.search(r"\.(eml)$", filename): if re.search(r"\.(msg|eml)$", filename):
return ParserType.EMAIL.value return ParserType.EMAIL.value
return default return default

View File

@ -15,10 +15,10 @@
# #
from datetime import datetime from datetime import datetime
from peewee import fn from peewee import fn, JOIN
from api.db import StatusEnum, TenantPermission from api.db import StatusEnum, TenantPermission
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format from api.utils import current_timestamp, datetime_format
@ -260,20 +260,29 @@ class KnowledgebaseService(CommonService):
cls.model.token_num, cls.model.token_num,
cls.model.chunk_num, cls.model.chunk_num,
cls.model.parser_id, cls.model.parser_id,
cls.model.pipeline_id,
UserCanvas.title.alias("pipeline_name"),
UserCanvas.avatar.alias("pipeline_avatar"),
cls.model.parser_config, cls.model.parser_config,
cls.model.pagerank, cls.model.pagerank,
cls.model.graphrag_task_id,
cls.model.graphrag_task_finish_at,
cls.model.raptor_task_id,
cls.model.raptor_task_finish_at,
cls.model.mindmap_task_id,
cls.model.mindmap_task_finish_at,
cls.model.create_time, cls.model.create_time,
cls.model.update_time cls.model.update_time
] ]
kbs = cls.model.select(*fields).join(Tenant, on=( kbs = cls.model.select(*fields)\
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
.where(
(cls.model.id == kb_id), (cls.model.id == kb_id),
(cls.model.status == StatusEnum.VALID.value) (cls.model.status == StatusEnum.VALID.value)
) ).dicts()
if not kbs: if not kbs:
return return
d = kbs[0].to_dict() return kbs[0]
return d
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()

View File

@ -0,0 +1,263 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
from datetime import datetime, timedelta
from peewee import fn
from api.db import VALID_PIPELINE_TASK_TYPES, PipelineTaskType
from api.db.db_models import DB, Document, PipelineOperationLog
from api.db.services.canvas_service import UserCanvasService
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID
from api.utils import current_timestamp, datetime_format, get_uuid
class PipelineOperationLogService(CommonService):
model = PipelineOperationLog
@classmethod
def get_file_logs_fields(cls):
return [
cls.model.id,
cls.model.document_id,
cls.model.tenant_id,
cls.model.kb_id,
cls.model.pipeline_id,
cls.model.pipeline_title,
cls.model.parser_id,
cls.model.document_name,
cls.model.document_suffix,
cls.model.document_type,
cls.model.source_from,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.dsl,
cls.model.task_type,
cls.model.operation_status,
cls.model.avatar,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod
def get_dataset_logs_fields(cls):
return [
cls.model.id,
cls.model.tenant_id,
cls.model.kb_id,
cls.model.progress,
cls.model.progress_msg,
cls.model.process_begin_at,
cls.model.process_duration,
cls.model.task_type,
cls.model.operation_status,
cls.model.avatar,
cls.model.status,
cls.model.create_time,
cls.model.create_date,
cls.model.update_time,
cls.model.update_date,
]
@classmethod
def save(cls, **kwargs):
"""
wrap this function in a transaction
"""
sample_obj = cls.model(**kwargs).save(force_insert=True)
return sample_obj
@classmethod
@DB.connection_context()
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: str = "{}"):
referred_document_id = document_id
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
referred_document_id = fake_document_ids[0]
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
DocumentService.update_progress_immediately([document.to_dict()])
ok, document = DocumentService.get_by_id(referred_document_id)
if not ok:
logging.warning(f"Document for referred_document_id {referred_document_id} not found")
return
if document.progress not in [1, -1]:
return
operation_status = document.run
if pipeline_id:
ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id)
if not ok:
raise RuntimeError(f"Pipeline {pipeline_id} not found")
tenant_id = user_pipeline.user_id
title = user_pipeline.title
avatar = user_pipeline.avatar
else:
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
if not ok:
raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}")
tenant_id = kb_info.tenant_id
title = document.parser_id
avatar = document.thumbnail
if task_type not in VALID_PIPELINE_TASK_TYPES:
raise ValueError(f"Invalid task type: {task_type}")
if task_type in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
finish_at = document.process_begin_at + timedelta(seconds=document.process_duration)
if task_type == PipelineTaskType.GRAPH_RAG:
KnowledgebaseService.update_by_id(
document.kb_id,
{"graphrag_task_finish_at": finish_at},
)
elif task_type == PipelineTaskType.RAPTOR:
KnowledgebaseService.update_by_id(
document.kb_id,
{"raptor_task_finish_at": finish_at},
)
elif task_type == PipelineTaskType.MINDMAP:
KnowledgebaseService.update_by_id(
document.kb_id,
{"mindmap_task_finish_at": finish_at},
)
log = dict(
id=get_uuid(),
document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
tenant_id=tenant_id,
kb_id=document.kb_id,
pipeline_id=pipeline_id,
pipeline_title=title,
parser_id=document.parser_id,
document_name=document.name,
document_suffix=document.suffix,
document_type=document.type,
source_from="", # TODO: add in the future
progress=document.progress,
progress_msg=document.progress_msg,
process_begin_at=document.process_begin_at,
process_duration=document.process_duration,
dsl=json.loads(dsl),
task_type=task_type,
operation_status=operation_status,
avatar=avatar,
)
log["create_time"] = current_timestamp()
log["create_date"] = datetime_format(datetime.now())
log["update_time"] = current_timestamp()
log["update_date"] = datetime_format(datetime.now())
with DB.atomic():
obj = cls.save(**log)
limit = int(os.getenv("PIPELINE_OPERATION_LOG_LIMIT", 1000))
total = cls.model.select().where(cls.model.kb_id == document.kb_id).count()
if total > limit:
keep_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == document.kb_id).order_by(cls.model.create_time.desc()).limit(limit)]
deleted = cls.model.delete().where(cls.model.kb_id == document.kb_id, cls.model.id.not_in(keep_ids)).execute()
logging.info(f"[PipelineOperationLogService] Cleaned {deleted} old logs, kept latest {limit} for {document.kb_id}")
return obj
@classmethod
@DB.connection_context()
def record_pipeline_operation(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
return cls.create(document_id=document_id, pipeline_id=pipeline_id, task_type=task_type, fake_document_ids=fake_document_ids)
@classmethod
@DB.connection_context()
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from=None, create_date_to=None):
fields = cls.get_file_logs_fields()
if keywords:
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
else:
logs = cls.model.select(*fields).where(cls.model.kb_id == kb_id)
logs = logs.where(cls.model.document_id != GRAPH_RAPTOR_FAKE_DOC_ID)
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
if types:
logs = logs.where(cls.model.document_type.in_(types))
if suffix:
logs = logs.where(cls.model.document_suffix.in_(suffix))
if create_date_from:
logs = logs.where(cls.model.create_date >= create_date_from)
if create_date_to:
logs = logs.where(cls.model.create_date <= create_date_to)
count = logs.count()
if desc:
logs = logs.order_by(cls.model.getter_by(orderby).desc())
else:
logs = logs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count
@classmethod
@DB.connection_context()
def get_documents_info(cls, id):
fields = [Document.id, Document.name, Document.progress, Document.kb_id]
return (
cls.model.select(*fields)
.join(Document, on=(cls.model.document_id == Document.id))
.where(
cls.model.id == id
)
.dicts()
)
@classmethod
@DB.connection_context()
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
fields = cls.get_dataset_logs_fields()
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
if operation_status:
logs = logs.where(cls.model.operation_status.in_(operation_status))
if create_date_from:
logs = logs.where(cls.model.create_date >= create_date_from)
if create_date_to:
logs = logs.where(cls.model.create_date <= create_date_to)
count = logs.count()
if desc:
logs = logs.order_by(cls.model.getter_by(orderby).desc())
else:
logs = logs.order_by(cls.model.getter_by(orderby).asc())
if page_number and items_per_page:
logs = logs.paginate(page_number, items_per_page)
return list(logs.dicts()), count

View File

@ -35,6 +35,8 @@ from rag.utils.redis_conn import REDIS_CONN
from api import settings from api import settings
from rag.nlp import search from rag.nlp import search
CANVAS_DEBUG_DOC_ID = "dataflow_x"
GRAPH_RAPTOR_FAKE_DOC_ID = "graph_raptor_x"
def trim_header_by_lines(text: str, max_length) -> str: def trim_header_by_lines(text: str, max_length) -> str:
# Trim header text to maximum length while preserving line breaks # Trim header text to maximum length while preserving line breaks
@ -70,7 +72,7 @@ class TaskService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_task(cls, task_id): def get_task(cls, task_id, doc_ids=[]):
"""Retrieve detailed task information by task ID. """Retrieve detailed task information by task ID.
This method fetches comprehensive task details including associated document, This method fetches comprehensive task details including associated document,
@ -84,6 +86,10 @@ class TaskService(CommonService):
dict: Task details dictionary containing all task information and related metadata. dict: Task details dictionary containing all task information and related metadata.
Returns None if task is not found or has exceeded retry limit. Returns None if task is not found or has exceeded retry limit.
""" """
doc_id = cls.model.doc_id
if doc_id == CANVAS_DEBUG_DOC_ID and doc_ids:
doc_id = doc_ids[0]
fields = [ fields = [
cls.model.id, cls.model.id,
cls.model.doc_id, cls.model.doc_id,
@ -109,7 +115,7 @@ class TaskService(CommonService):
] ]
docs = ( docs = (
cls.model.select(*fields) cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id)) .join(Document, on=(doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id) .where(cls.model.id == task_id)
@ -292,21 +298,23 @@ class TaskService(CommonService):
((prog == -1) | (prog > cls.model.progress)) ((prog == -1) | (prog > cls.model.progress))
) )
).execute() ).execute()
return else:
with DB.lock("update_progress", -1):
if info["progress_msg"]:
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
prog = info["progress"]
cls.model.update(progress=prog).where(
(cls.model.id == id) &
(
(cls.model.progress != -1) &
((prog == -1) | (prog > cls.model.progress))
)
).execute()
with DB.lock("update_progress", -1): process_duration = (datetime.now() - task.begin_at).total_seconds()
if info["progress_msg"]: cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
prog = info["progress"]
cls.model.update(progress=prog).where(
(cls.model.id == id) &
(
(cls.model.progress != -1) &
((prog == -1) | (prog > cls.model.progress))
)
).execute()
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
@ -336,7 +344,14 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
- Previous task chunks may be reused if available - Previous task chunks may be reused if available
""" """
def new_task(): def new_task():
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000} return {
"id": get_uuid(),
"doc_id": doc["id"],
"progress": 0.0,
"from_page": 0,
"to_page": 100000000,
"begin_at": datetime.now(),
}
parse_task_array = [] parse_task_array = []
@ -349,7 +364,7 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
page_size = doc["parser_config"].get("task_page_size") or 12 page_size = doc["parser_config"].get("task_page_size") or 12
if doc["parser_id"] == "paper": if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size") or 22 page_size = doc["parser_config"].get("task_page_size") or 22
if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC": if doc["parser_id"] in ["one", "knowledge_graph"] or do_layout != "DeepDOC" or doc["parser_config"].get("toc", True):
page_size = 10 ** 9 page_size = 10 ** 9
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)] page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
for s, e in page_ranges: for s, e in page_ranges:
@ -478,33 +493,26 @@ def has_canceled(task_id):
return False return False
def queue_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, priority: int, callback=None) -> tuple[bool, str]: def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DEBUG_DOC_ID, file:dict=None, priority: int=0, rerun:bool=False) -> tuple[bool, str]:
"""
Returns a tuple (success: bool, error_message: str).
"""
_ = callback
task = dict( task = dict(
id=get_uuid() if not task_id else task_id, id=task_id,
doc_id=doc_id, doc_id=doc_id,
from_page=0, from_page=0,
to_page=100000000, to_page=100000000,
task_type="dataflow", task_type="dataflow" if not rerun else "dataflow_rerun",
priority=priority, priority=priority,
begin_at=datetime.now(),
) )
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute() TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
DocumentService.begin2parse(doc_id)
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True) bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
kb_id = DocumentService.get_knowledgebase_id(doc_id) task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
if not kb_id:
return False, f"Can't find KB of this document: {doc_id}"
task["kb_id"] = kb_id
task["tenant_id"] = tenant_id task["tenant_id"] = tenant_id
task["task_type"] = "dataflow" task["dataflow_id"] = flow_id
task["dsl"] = dsl task["file"] = file
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
if not REDIS_CONN.queue_product( if not REDIS_CONN.queue_product(
get_svr_queue_name(priority), message=task get_svr_queue_name(priority), message=task

View File

@ -705,7 +705,9 @@ TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None): def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
if isinstance(seconds, str):
seconds = float(seconds)
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):

View File

@ -1,3 +1,56 @@
import base64 import base64
import logging
from functools import partial
from io import BytesIO
from PIL import Image
test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg==" test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAAA6ElEQVR4nO3QwQ3AIBDAsIP9d25XIC+EZE8QZc18w5l9O+AlZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBWYFZgVmBT+IYAHHLHkdEgAAAABJRU5ErkJggg=="
test_image = base64.b64decode(test_image_base64) test_image = base64.b64decode(test_image_base64)
async def image2id(d: dict, storage_put_func: partial, objname:str, bucket:str="imagetemps"):
import logging
from io import BytesIO
import trio
from rag.svr.task_executor import minio_limiter
if not d.get("image"):
return
with BytesIO() as output_buffer:
if isinstance(d["image"], bytes):
output_buffer.write(d["image"])
output_buffer.seek(0)
else:
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
if d["image"].mode in ("RGBA", "P"):
converted_image = d["image"].convert("RGB")
d["image"] = converted_image
try:
d["image"].save(output_buffer, format='JPEG')
except OSError as e:
logging.warning(
"Saving image exception, ignore: {}".format(str(e)))
async with minio_limiter:
await trio.to_thread.run_sync(lambda: storage_put_func(bucket=bucket, fnm=objname, binary=output_buffer.getvalue()))
d["img_id"] = f"{bucket}-{objname}"
if not isinstance(d["image"], bytes):
d["image"].close()
del d["image"] # Remove image reference
def id2image(image_id:str|None, storage_get_func: partial):
if not image_id:
return
arr = image_id.split("-")
if len(arr) != 2:
return
bkt, nm = image_id.split("-")
try:
blob = storage_get_func(bucket=bkt, filename=nm)
if not blob:
return
return Image.open(BytesIO(blob))
except Exception as e:
logging.exception(e)

View File

@ -155,7 +155,7 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename): if re.match(r".*\.pdf$", filename):
return FileType.PDF.value return FileType.PDF.value
if re.match(r".*\.(eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename):
return FileType.DOC.value return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename): if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename):

104
api/utils/health.py Normal file
View File

@ -0,0 +1,104 @@
from timeit import default_timer as timer
from api import settings
from api.db.db_models import DB
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.storage_factory import STORAGE_IMPL
def _ok_nok(ok: bool) -> str:
return "ok" if ok else "nok"
def check_db() -> tuple[bool, dict]:
st = timer()
try:
# lightweight probe; works for MySQL/Postgres
DB.execute_sql("SELECT 1")
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_redis() -> tuple[bool, dict]:
st = timer()
try:
ok = bool(REDIS_CONN.health())
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_doc_engine() -> tuple[bool, dict]:
st = timer()
try:
meta = settings.docStoreConn.health()
# treat any successful call as ok
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", **(meta or {})}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_storage() -> tuple[bool, dict]:
st = timer()
try:
STORAGE_IMPL.health()
return True, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def check_chat() -> tuple[bool, dict]:
st = timer()
try:
cfg = getattr(settings, "CHAT_CFG", None)
ok = bool(cfg and cfg.get("factory"))
return ok, {"elapsed": f"{(timer() - st) * 1000.0:.1f}"}
except Exception as e:
return False, {"elapsed": f"{(timer() - st) * 1000.0:.1f}", "error": str(e)}
def run_health_checks() -> tuple[dict, bool]:
result: dict[str, str | dict] = {}
db_ok, db_meta = check_db()
chat_ok, chat_meta = check_chat()
result["db"] = _ok_nok(db_ok)
if not db_ok:
result.setdefault("_meta", {})["db"] = db_meta
result["chat"] = _ok_nok(chat_ok)
if not chat_ok:
result.setdefault("_meta", {})["chat"] = chat_meta
# Optional probes (do not change minimal contract but exposed for observability)
try:
redis_ok, redis_meta = check_redis()
result["redis"] = _ok_nok(redis_ok)
if not redis_ok:
result.setdefault("_meta", {})["redis"] = redis_meta
except Exception:
result["redis"] = "nok"
try:
doc_ok, doc_meta = check_doc_engine()
result["doc_engine"] = _ok_nok(doc_ok)
if not doc_ok:
result.setdefault("_meta", {})["doc_engine"] = doc_meta
except Exception:
result["doc_engine"] = "nok"
try:
sto_ok, sto_meta = check_storage()
result["storage"] = _ok_nok(sto_ok)
if not sto_ok:
result.setdefault("_meta", {})["storage"] = sto_meta
except Exception:
result["storage"] = "nok"
all_ok = (result.get("db") == "ok") and (result.get("chat") == "ok")
result["status"] = "ok" if all_ok else "nok"
return result, all_ok

View File

@ -5147,4 +5147,4 @@
] ]
} }
] ]
} }

View File

@ -1075,11 +1075,10 @@ class RAGFlowPdfParser:
def insert_table_figures(tbls_or_figs, layout_type): def insert_table_figures(tbls_or_figs, layout_type):
def min_rectangle_distance(rect1, rect2): def min_rectangle_distance(rect1, rect2):
import math import math
pn1, left1, right1, top1, bottom1 = rect1 pn1, left1, right1, top1, bottom1 = rect1
pn2, left2, right2, top2, bottom2 = rect2 pn2, left2, right2, top2, bottom2 = rect2
if right1 >= left2 and right2 >= left1 and bottom1 >= top2 and bottom2 >= top1: if right1 >= left2 and right2 >= left1 and bottom1 >= top2 and bottom2 >= top1:
return 0 + (pn1 - pn2) * 10000 return 0
if right1 < left2: if right1 < left2:
dx = left2 - right1 dx = left2 - right1
elif right2 < left1: elif right2 < left1:
@ -1092,20 +1091,27 @@ class RAGFlowPdfParser:
dy = top1 - bottom2 dy = top1 - bottom2
else: else:
dy = 0 dy = 0
return math.sqrt(dx * dx + dy * dy) + (pn1 - pn2) * 10000 return math.sqrt(dx*dx + dy*dy)# + (pn2-pn1)*10000
for (img, txt), poss in tbls_or_figs: for (img, txt), poss in tbls_or_figs:
bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)] bboxes = [(i, (b["page_number"], b["x0"], b["x1"], b["top"], b["bottom"])) for i, b in enumerate(self.boxes)]
dists = [(min_rectangle_distance((pn, left, right, top, bott), rect), i) for i, rect in bboxes for pn, left, right, top, bott in poss] dists = [(min_rectangle_distance((pn, left, right, top+self.page_cum_height[pn], bott+self.page_cum_height[pn]), rect),i) for i, rect in bboxes for pn, left, right, top, bott in poss]
min_i = np.argmin(dists, axis=0)[0] min_i = np.argmin(dists, axis=0)[0]
min_i, rect = bboxes[dists[min_i][-1]] min_i, rect = bboxes[dists[min_i][-1]]
if isinstance(txt, list): if isinstance(txt, list):
txt = "\n".join(txt) txt = "\n".join(txt)
self.boxes.insert(min_i, {"page_number": rect[0], "x0": rect[1], "x1": rect[2], "top": rect[3], "bottom": rect[4], "layout_type": layout_type, "text": txt, "image": img}) pn, left, right, top, bott = poss[0]
if self.boxes[min_i]["bottom"] < top+self.page_cum_height[pn]:
min_i += 1
self.boxes.insert(min_i, {
"page_number": pn+1, "x0": left, "x1": right, "top": top+self.page_cum_height[pn], "bottom": bott+self.page_cum_height[pn], "layout_type": layout_type, "text": txt, "image": img,
"positions": [[pn+1, int(left), int(right), int(top), int(bott)]]
})
for b in self.boxes: for b in self.boxes:
b["position_tag"] = self._line_tag(b, zoomin) b["position_tag"] = self._line_tag(b, zoomin)
b["image"] = self.crop(b["position_tag"], zoomin) b["image"] = self.crop(b["position_tag"], zoomin)
b["positions"] = [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(b["position_tag"])]
insert_table_figures(tbls, "table") insert_table_figures(tbls, "table")
insert_table_figures(figs, "figure") insert_table_figures(figs, "figure")

View File

@ -21,6 +21,7 @@ import networkx as nx
import trio import trio
from api import settings from api import settings
from api.db.services.document_service import DocumentService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from graphrag.entity_resolution import EntityResolution from graphrag.entity_resolution import EntityResolution
@ -54,7 +55,7 @@ async def run_graphrag(
start = trio.current_time() start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = [] chunks = []
for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]): for d in settings.retrievaler.chunk_list(doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"]) chunks.append(d["content_with_weight"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000): with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@ -125,6 +126,212 @@ async def run_graphrag(
return return
async def run_graphrag_for_kb(
row: dict,
doc_ids: list[str],
language: str,
kb_parser_config: dict,
chat_model,
embedding_model,
callback,
*,
with_resolution: bool = True,
with_community: bool = True,
max_parallel_docs: int = 4,
) -> dict:
tenant_id, kb_id = row["tenant_id"], row["kb_id"]
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = trio.current_time()
fields_for_chunks = ["content_with_weight", "doc_id"]
if not doc_ids:
logging.info(f"Fetching all docs for {kb_id}")
docs, _ = DocumentService.get_by_kb_id(
kb_id=kb_id,
page_number=0,
items_per_page=0,
orderby="create_time",
desc=False,
keywords="",
run_status=[],
types=[],
suffix=[],
)
doc_ids = [doc["id"] for doc in docs]
doc_ids = list(dict.fromkeys(doc_ids))
if not doc_ids:
callback(msg=f"[GraphRAG] kb:{kb_id} has no processable doc_id.")
return {"ok_docs": [], "failed_docs": [], "total_docs": 0, "total_chunks": 0, "seconds": 0.0}
def load_doc_chunks(doc_id: str) -> list[str]:
from rag.utils import num_tokens_from_string
chunks = []
current_chunk = ""
for d in settings.retrievaler.chunk_list(
doc_id,
tenant_id,
[kb_id],
fields=fields_for_chunks,
sort_by_position=True,
):
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = content
if current_chunk:
chunks.append(current_chunk)
return chunks
all_doc_chunks: dict[str, list[str]] = {}
total_chunks = 0
for doc_id in doc_ids:
chunks = load_doc_chunks(doc_id)
all_doc_chunks[doc_id] = chunks
total_chunks += len(chunks)
if total_chunks == 0:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = trio.Semaphore(max_parallel_docs)
subgraphs: dict[str, object] = {}
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
async with semaphore:
try:
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
with trio.fail_after(deadline):
sg = await generate_subgraph(
kg_extractor,
tenant_id,
kb_id,
doc_id,
chunks,
language,
kb_parser_config.get("graphrag", {}).get("entity_types", []),
chat_model,
embedding_model,
callback,
)
if sg:
subgraphs[doc_id] = sg
callback(msg=f"{msg} done")
else:
failed_docs.append((doc_id, "subgraph is empty"))
callback(msg=f"{msg} empty")
except Exception as e:
failed_docs.append((doc_id, repr(e)))
callback(msg=f"[GraphRAG] build_subgraph doc:{doc_id} FAILED: {e!r}")
async with trio.open_nursery() as nursery:
for doc_id in doc_ids:
nursery.start_soon(build_one, doc_id)
ok_docs = [d for d in doc_ids if d in subgraphs]
if not ok_docs:
callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.")
now = trio.current_time()
return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
kb_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value="batch_merge", timeout=1200)
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} merge lock acquired")
try:
union_nodes: set = set()
final_graph = None
for doc_id in ok_docs:
sg = subgraphs[doc_id]
union_nodes.update(set(sg.nodes()))
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
sg,
embedding_model,
callback,
)
if new_graph is not None:
final_graph = new_graph
if final_graph is None:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).")
else:
callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.")
finally:
kb_lock.release()
if not with_resolution and not with_community:
now = trio.current_time()
callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}")
return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start}
await kb_lock.spin_acquire()
callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community")
try:
subgraph_nodes = set()
for sg in subgraphs.values():
subgraph_nodes.update(set(sg.nodes()))
if with_resolution:
await resolve_entities(
final_graph,
subgraph_nodes,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
if with_community:
await extract_community(
final_graph,
tenant_id,
kb_id,
None,
chat_model,
embedding_model,
callback,
)
finally:
kb_lock.release()
now = trio.current_time()
callback(msg=f"[GraphRAG] GraphRAG for KB {kb_id} done in {now - start:.2f} seconds. ok={len(ok_docs)} failed={len(failed_docs)} total_docs={len(doc_ids)} total_chunks={total_chunks}")
return {
"ok_docs": ok_docs,
"failed_docs": failed_docs, # [(doc_id, error), ...]
"total_docs": len(doc_ids),
"total_chunks": total_chunks,
"seconds": now - start,
}
async def generate_subgraph( async def generate_subgraph(
extractor: Extractor, extractor: Extractor,
tenant_id: str, tenant_id: str,

View File

@ -34,6 +34,7 @@ dependencies = [
"elastic-transport==8.12.0", "elastic-transport==8.12.0",
"elasticsearch==8.12.1", "elasticsearch==8.12.1",
"elasticsearch-dsl==8.12.0", "elasticsearch-dsl==8.12.0",
"extract-msg>=0.39.0",
"filelock==3.15.4", "filelock==3.15.4",
"flask==3.0.3", "flask==3.0.3",
"flask-cors==5.0.0", "flask-cors==5.0.0",

View File

@ -78,7 +78,7 @@ def chunk(
_add_content(msg, msg.get_content_type()) _add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [ sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line (line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
] ]
st = timer() st = timer()

View File

@ -18,9 +18,7 @@ import os
import time import time
from functools import partial from functools import partial
from typing import Any from typing import Any
import trio import trio
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
@ -36,9 +34,9 @@ class ProcessBase(ComponentBase):
def __init__(self, pipeline, id, param: ProcessParamBase): def __init__(self, pipeline, id, param: ProcessParamBase):
super().__init__(pipeline, id, param) super().__init__(pipeline, id, param)
if hasattr(self._canvas, "callback"): if hasattr(self._canvas, "callback"):
self.callback = partial(self._canvas.callback, self.component_name) self.callback = partial(self._canvas.callback, id)
else: else:
self.callback = partial(lambda *args, **kwargs: None, self.component_name) self.callback = partial(lambda *args, **kwargs: None, id)
async def invoke(self, **kwargs) -> dict[str, Any]: async def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter()) self.set_output("_created_time", time.perf_counter())

View File

@ -1,212 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import trio
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from graphrag.utils import chat_limiter, get_llm_cache, set_llm_cache
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.chunker.schema import ChunkerFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
from rag.prompts.generator import keyword_extraction, question_proposal
class ChunkerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.method_options = [
# General
"general",
"onetable",
# Customer Service
"q&a",
"manual",
# Recruitment
"resume",
# Education & Research
"book",
"paper",
"laws",
"presentation",
# Other
# "Tag" # TODO: Other method
]
self.method = "general"
self.chunk_token_size = 512
self.delimiter = "\n"
self.overlapped_percent = 0
self.page_rank = 0
self.auto_keywords = 0
self.auto_questions = 0
self.tag_sets = []
self.llm_setting = {"llm_name": "", "lang": "Chinese"}
def check(self):
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
self.check_nonnegative_number(self.page_rank, "Page rank value: (0, 10]")
self.check_nonnegative_number(self.auto_keywords, "Auto-keyword value: (0, 10]")
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
def get_input_form(self) -> dict[str, dict]:
return {}
class Chunker(ProcessBase):
component_name = "Chunker"
def _general(self, from_upstream: ChunkerFromUpstream):
self.callback(random.randint(1, 5) / 100.0, "Start to chunk via `General`.")
if from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result
elif from_upstream.output_format == "text":
payload = from_upstream.text_result
else: # == "html"
payload = from_upstream.html_result
if not payload:
payload = ""
cks = naive_merge(
payload,
self._param.chunk_token_size,
self._param.delimiter,
self._param.overlapped_percent,
)
return [{"text": c} for c in cks]
# json
sections, section_images = [], []
for o in from_upstream.json_result or []:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(o.get("image"))
chunks, images = naive_merge_with_images(
sections,
section_images,
self._param.chunk_token_size,
self._param.delimiter,
self._param.overlapped_percent,
)
return [
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c),
}
for c, img in zip(chunks, images)
]
def _q_and_a(self, from_upstream: ChunkerFromUpstream):
pass
def _resume(self, from_upstream: ChunkerFromUpstream):
pass
def _manual(self, from_upstream: ChunkerFromUpstream):
pass
def _table(self, from_upstream: ChunkerFromUpstream):
pass
def _paper(self, from_upstream: ChunkerFromUpstream):
pass
def _book(self, from_upstream: ChunkerFromUpstream):
pass
def _laws(self, from_upstream: ChunkerFromUpstream):
pass
def _presentation(self, from_upstream: ChunkerFromUpstream):
pass
def _one(self, from_upstream: ChunkerFromUpstream):
pass
async def _invoke(self, **kwargs):
function_map = {
"general": self._general,
"q&a": self._q_and_a,
"resume": self._resume,
"manual": self._manual,
"table": self._table,
"paper": self._paper,
"book": self._book,
"laws": self._laws,
"presentation": self._presentation,
"one": self._one,
}
try:
from_upstream = ChunkerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
chunks = function_map[self._param.method](from_upstream)
llm_setting = self._param.llm_setting
async def auto_keywords():
nonlocal chunks, llm_setting
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
async def doc_keyword_extraction(chat_mdl, ck, topn):
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "keywords", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: keyword_extraction(chat_mdl, ck["text"], topn))
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "keywords", {"topn": topn})
if cached:
ck["keywords"] = cached.split(",")
async with trio.open_nursery() as nursery:
for ck in chunks:
nursery.start_soon(doc_keyword_extraction, chat_mdl, ck, self._param.auto_keywords)
async def auto_questions():
nonlocal chunks, llm_setting
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT, llm_name=llm_setting["llm_name"], lang=llm_setting["lang"])
async def doc_question_proposal(chat_mdl, d, topn):
cached = get_llm_cache(chat_mdl.llm_name, ck["text"], "question", {"topn": topn})
if not cached:
async with chat_limiter:
cached = await trio.to_thread.run_sync(lambda: question_proposal(chat_mdl, ck["text"], topn))
set_llm_cache(chat_mdl.llm_name, ck["text"], cached, "question", {"topn": topn})
if cached:
d["questions"] = cached.split("\n")
async with trio.open_nursery() as nursery:
for ck in chunks:
nursery.start_soon(doc_question_proposal, chat_mdl, ck, self._param.auto_questions)
async with trio.open_nursery() as nursery:
if self._param.auto_questions:
nursery.start_soon(auto_questions)
if self._param.auto_keywords:
nursery.start_soon(auto_keywords)
if self._param.page_rank:
for ck in chunks:
ck["page_rank"] = self._param.page_rank
self.set_output("chunks", chunks)

View File

@ -0,0 +1,63 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from copy import deepcopy
from agent.component.llm import LLMParam, LLM
from rag.flow.base import ProcessBase, ProcessParamBase
class ExtractorParam(ProcessParamBase, LLMParam):
def __init__(self):
super().__init__()
self.field_name = ""
def check(self):
super().check()
self.check_empty(self.field_name, "Result Destination")
class Extractor(ProcessBase, LLM):
component_name = "Extractor"
async def _invoke(self, **kwargs):
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
inputs = self.get_input_elements()
chunks = []
chunks_key = ""
args = {}
for k, v in inputs.items():
args[k] = v["value"]
if isinstance(args[k], list):
chunks = deepcopy(args[k])
chunks_key = k
if chunks:
prog = 0
for i, ck in enumerate(chunks):
args[chunks_key] = ck["text"]
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg)
prog += 1./len(chunks)
if i % (len(chunks)//100+1) == 1:
self.callback(prog, f"{i+1} / {len(chunks)}")
self.set_output("chunks", chunks)
else:
msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt})
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])

View File

@ -0,0 +1,38 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class ExtractorFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@ -14,10 +14,7 @@
# limitations under the License. # limitations under the License.
# #
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.utils.storage_factory import STORAGE_IMPL
class FileParam(ProcessParamBase): class FileParam(ProcessParamBase):
@ -41,10 +38,13 @@ class File(ProcessBase):
self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!") self.set_output("_ERROR", f"Document({self._canvas._doc_id}) not found!")
return return
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id) #b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
self.set_output("blob", STORAGE_IMPL.get(b, n)) #self.set_output("blob", STORAGE_IMPL.get(b, n))
self.set_output("name", doc.name) self.set_output("name", doc.name)
else: else:
file = kwargs.get("file") file = kwargs.get("file")
self.set_output("name", file["name"]) self.set_output("name", file["name"])
self.set_output("blob", FileService.get_blob(file["created_by"], file["id"])) self.set_output("file", file)
#self.set_output("blob", FileService.get_blob(file["created_by"], file["id"]))
self.callback(1, "File fetched.")

View File

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,186 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import re
from copy import deepcopy
from functools import partial
import trio
from api.utils import get_uuid
from api.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.hierarchical_merger.schema import HierarchicalMergerFromUpstream
from rag.nlp import concat_img
from rag.utils.storage_factory import STORAGE_IMPL
class HierarchicalMergerParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.levels = []
self.hierarchy = None
def check(self):
self.check_empty(self.levels, "Hierarchical setups.")
self.check_empty(self.hierarchy, "Hierarchy number.")
def get_input_form(self) -> dict[str, dict]:
return {}
class HierarchicalMerger(ProcessBase):
component_name = "HierarchicalMerger"
async def _invoke(self, **kwargs):
try:
from_upstream = HierarchicalMergerFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to merge hierarchically.")
if from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result
elif from_upstream.output_format == "text":
payload = from_upstream.text_result
else: # == "html"
payload = from_upstream.html_result
if not payload:
payload = ""
lines = [ln for ln in payload.split("\n") if ln]
else:
arr = from_upstream.chunks if from_upstream.output_format == "chunks" else from_upstream.json_result
lines = [o.get("text", "") for o in arr]
sections, section_images = [], []
for o in arr or []:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(o.get("img_id"))
matches = []
for txt in lines:
good = False
for lvl, regs in enumerate(self._param.levels):
for reg in regs:
if re.search(reg, txt):
matches.append(lvl)
good = True
break
if good:
break
if not good:
matches.append(len(self._param.levels))
assert len(matches) == len(lines), f"{len(matches)} vs. {len(lines)}"
root = {
"level": -1,
"index": -1,
"texts": [],
"children": []
}
for i, m in enumerate(matches):
if m == 0:
root["children"].append({
"level": m,
"index": i,
"texts": [],
"children": []
})
elif m == len(self._param.levels):
def dfs(b):
if not b["children"]:
b["texts"].append(i)
else:
dfs(b["children"][-1])
dfs(root)
else:
def dfs(b):
nonlocal m, i
if not b["children"] or m == b["level"] + 1:
b["children"].append({
"level": m,
"index": i,
"texts": [],
"children": []
})
return
dfs(b["children"][-1])
dfs(root)
all_pathes = []
def dfs(n, path, depth):
nonlocal all_pathes
if not n["children"] and path:
all_pathes.append(path)
for nn in n["children"]:
if depth < self._param.hierarchy:
_path = deepcopy(path)
else:
_path = path
_path.extend([nn["index"], *nn["texts"]])
dfs(nn, _path, depth+1)
if depth == self._param.hierarchy:
all_pathes.append(_path)
for i in range(len(lines)):
print(i, lines[i])
dfs(root, [], 0)
if root["texts"]:
all_pathes.insert(0, root["texts"])
if from_upstream.output_format in ["markdown", "text", "html"]:
cks = []
for path in all_pathes:
txt = ""
for i in path:
txt += lines[i] + "\n"
cks.append(txt)
self.set_output("chunks", [{"text": c} for c in cks if c])
else:
cks = []
images = []
for path in all_pathes:
txt = ""
img = None
for i in path:
txt += lines[i] + "\n"
concat_img(img, id2image(section_images[i], partial(STORAGE_IMPL.get)))
cks.append(txt)
images.append(img)
cks = [
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": RAGFlowPdfParser.extract_positions(c),
}
for c, img in zip(cks, images)
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@ -0,0 +1,37 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class HierarchicalMergerFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str
file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "chunks"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text")
html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
# def to_dict(self, *, exclude_none: bool = True) -> dict:
# return self.model_dump(by_alias=True, exclude_none=exclude_none)

View File

@ -13,20 +13,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io import io
import logging import json
import os
import random import random
from functools import partial
import trio import trio
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from api.db import LLMType from api.db import LLMType
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
from api.utils.base64_image import image2id
from deepdoc.parser import ExcelParser from deepdoc.parser import ExcelParser
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
from rag.app.naive import Docx
from rag.flow.base import ProcessBase, ProcessParamBase from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.parser.schema import ParserFromUpstream from rag.flow.parser.schema import ParserFromUpstream
from rag.llm.cv_model import Base as VLM from rag.llm.cv_model import Base as VLM
from rag.utils.storage_factory import STORAGE_IMPL
class ParserParam(ProcessParamBase): class ParserParam(ProcessParamBase):
@ -45,12 +53,14 @@ class ParserParam(ProcessParamBase):
"word": [ "word": [
"json", "json",
], ],
"ppt": [], "slides": [
"json",
],
"image": [ "image": [
"text" "text"
], ],
"email": [], "email": ["text", "json"],
"text": [ "text&markdown": [
"text", "text",
"json" "json"
], ],
@ -63,7 +73,6 @@ class ParserParam(ProcessParamBase):
self.setups = { self.setups = {
"pdf": { "pdf": {
"parse_method": "deepdoc", # deepdoc/plain_text/vlm "parse_method": "deepdoc", # deepdoc/plain_text/vlm
"llm_id": "",
"lang": "Chinese", "lang": "Chinese",
"suffix": [ "suffix": [
"pdf", "pdf",
@ -85,23 +94,29 @@ class ParserParam(ProcessParamBase):
], ],
"output_format": "json", "output_format": "json",
}, },
"markdown": { "text&markdown": {
"suffix": ["md", "markdown"], "suffix": ["md", "markdown", "mdx", "txt"],
"output_format": "json",
},
"slides": {
"suffix": [
"pptx",
],
"output_format": "json", "output_format": "json",
}, },
"ppt": {},
"image": { "image": {
"parse_method": ["ocr", "vlm"], "parse_method": "ocr",
"llm_id": "", "llm_id": "",
"lang": "Chinese", "lang": "Chinese",
"system_prompt": "",
"suffix": ["jpg", "jpeg", "png", "gif"], "suffix": ["jpg", "jpeg", "png", "gif"],
"output_format": "json", "output_format": "text",
}, },
"email": {}, "email": {
"text": {
"suffix": [ "suffix": [
"txt" "eml", "msg"
], ],
"fields": ["from", "to", "cc", "bcc", "date", "subject", "body", "attachments", "metadata"],
"output_format": "json", "output_format": "json",
}, },
"audio": { "audio": {
@ -131,13 +146,10 @@ class ParserParam(ProcessParamBase):
pdf_config = self.setups.get("pdf", {}) pdf_config = self.setups.get("pdf", {})
if pdf_config: if pdf_config:
pdf_parse_method = pdf_config.get("parse_method", "") pdf_parse_method = pdf_config.get("parse_method", "")
self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"]) self.check_empty(pdf_parse_method, "Parse method abnormal.")
if pdf_parse_method not in ["deepdoc", "plain_text"]: if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
self.check_empty(pdf_config.get("llm_id"), "VLM") self.check_empty(pdf_config.get("lang", ""), "PDF VLM language")
pdf_language = pdf_config.get("lang", "")
self.check_empty(pdf_language, "Language")
pdf_output_format = pdf_config.get("output_format", "") pdf_output_format = pdf_config.get("output_format", "")
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"]) self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
@ -147,32 +159,38 @@ class ParserParam(ProcessParamBase):
spreadsheet_output_format = spreadsheet_config.get("output_format", "") spreadsheet_output_format = spreadsheet_config.get("output_format", "")
self.check_valid_value(spreadsheet_output_format, "Spreadsheet output format abnormal.", self.allowed_output_format["spreadsheet"]) self.check_valid_value(spreadsheet_output_format, "Spreadsheet output format abnormal.", self.allowed_output_format["spreadsheet"])
doc_config = self.setups.get("doc", "") doc_config = self.setups.get("word", "")
if doc_config: if doc_config:
doc_output_format = doc_config.get("output_format", "") doc_output_format = doc_config.get("output_format", "")
self.check_valid_value(doc_output_format, "Word processer document output format abnormal.", self.allowed_output_format["doc"]) self.check_valid_value(doc_output_format, "Word processer document output format abnormal.", self.allowed_output_format["word"])
slides_config = self.setups.get("slides", "")
if slides_config:
slides_output_format = slides_config.get("output_format", "")
self.check_valid_value(slides_output_format, "Slides output format abnormal.", self.allowed_output_format["slides"])
image_config = self.setups.get("image", "") image_config = self.setups.get("image", "")
if image_config: if image_config:
image_parse_method = image_config.get("parse_method", "") image_parse_method = image_config.get("parse_method", "")
self.check_valid_value(image_parse_method.lower(), "Parse method abnormal.", ["ocr", "vlm"])
if image_parse_method not in ["ocr"]: if image_parse_method not in ["ocr"]:
self.check_empty(image_config.get("llm_id"), "VLM") self.check_empty(image_config.get("lang", ""), "Image VLM language")
image_language = image_config.get("lang", "") text_config = self.setups.get("text&markdown", "")
self.check_empty(image_language, "Language")
text_config = self.setups.get("text", "")
if text_config: if text_config:
text_output_format = text_config.get("output_format", "") text_output_format = text_config.get("output_format", "")
self.check_valid_value(text_output_format, "Text output format abnormal.", self.allowed_output_format["text"]) self.check_valid_value(text_output_format, "Text output format abnormal.", self.allowed_output_format["text&markdown"])
audio_config = self.setups.get("audio", "") audio_config = self.setups.get("audio", "")
if audio_config: if audio_config:
self.check_empty(audio_config.get("llm_id"), "VLM") self.check_empty(audio_config.get("llm_id"), "Audio VLM")
audio_language = audio_config.get("lang", "") audio_language = audio_config.get("lang", "")
self.check_empty(audio_language, "Language") self.check_empty(audio_language, "Language")
email_config = self.setups.get("email", "")
if email_config:
email_output_format = email_config.get("output_format", "")
self.check_valid_value(email_output_format, "Email output format abnormal.", self.allowed_output_format["email"])
def get_input_form(self) -> dict[str, dict]: def get_input_form(self) -> dict[str, dict]:
return {} return {}
@ -180,21 +198,18 @@ class ParserParam(ProcessParamBase):
class Parser(ProcessBase): class Parser(ProcessBase):
component_name = "Parser" component_name = "Parser"
def _pdf(self, from_upstream: ParserFromUpstream): def _pdf(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.") self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.")
blob = from_upstream.blob
conf = self._param.setups["pdf"] conf = self._param.setups["pdf"]
self.set_output("output_format", conf["output_format"]) self.set_output("output_format", conf["output_format"])
if conf.get("parse_method") == "deepdoc": if conf.get("parse_method").lower() == "deepdoc":
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback) bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
elif conf.get("parse_method") == "plain_text": elif conf.get("parse_method").lower() == "plain_text":
lines, _ = PlainParser()(blob) lines, _ = PlainParser()(blob)
bboxes = [{"text": t} for t, _ in lines] bboxes = [{"text": t} for t, _ in lines]
else: else:
assert conf.get("llm_id") vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("parse_method"), lang=self._param.setups["pdf"].get("lang"))
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("llm_id"), lang=self._param.setups["pdf"].get("lang"))
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback) lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
bboxes = [] bboxes = []
for t, poss in lines: for t, poss in lines:
@ -214,66 +229,63 @@ class Parser(ProcessBase):
mkdn += b.get("text", "") + "\n" mkdn += b.get("text", "") + "\n"
self.set_output("markdown", mkdn) self.set_output("markdown", mkdn)
def _spreadsheet(self, from_upstream: ParserFromUpstream): def _spreadsheet(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Spreadsheet.") self.callback(random.randint(1, 5) / 100.0, "Start to work on a Spreadsheet.")
blob = from_upstream.blob
conf = self._param.setups["spreadsheet"] conf = self._param.setups["spreadsheet"]
self.set_output("output_format", conf["output_format"]) self.set_output("output_format", conf["output_format"])
print("spreadsheet {conf=}", flush=True)
spreadsheet_parser = ExcelParser() spreadsheet_parser = ExcelParser()
if conf.get("output_format") == "html": if conf.get("output_format") == "html":
html = spreadsheet_parser.html(blob, 1000000000) htmls = spreadsheet_parser.html(blob, 1000000000)
self.set_output("html", html) self.set_output("html", htmls[0])
elif conf.get("output_format") == "json": elif conf.get("output_format") == "json":
self.set_output("json", [{"text": txt} for txt in spreadsheet_parser(blob) if txt]) self.set_output("json", [{"text": txt} for txt in spreadsheet_parser(blob) if txt])
elif conf.get("output_format") == "markdown": elif conf.get("output_format") == "markdown":
self.set_output("markdown", spreadsheet_parser.markdown(blob)) self.set_output("markdown", spreadsheet_parser.markdown(blob))
def _word(self, from_upstream: ParserFromUpstream): def _word(self, name, blob):
from tika import parser as word_parser
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Word Processor Document") self.callback(random.randint(1, 5) / 100.0, "Start to work on a Word Processor Document")
blob = from_upstream.blob
name = from_upstream.name
conf = self._param.setups["word"] conf = self._param.setups["word"]
self.set_output("output_format", conf["output_format"]) self.set_output("output_format", conf["output_format"])
docx_parser = Docx()
print("word {conf=}", flush=True) sections, tbls = docx_parser(name, binary=blob)
doc_parsed = word_parser.from_buffer(blob) sections = [{"text": section[0], "image": section[1]} for section in sections if section]
sections.extend([{"text": tb, "image": None} for ((_,tb), _) in tbls])
sections = []
if doc_parsed.get("content"):
sections = doc_parsed["content"].split("\n")
sections = [{"text": section} for section in sections if section]
else:
logging.warning(f"tika.parser got empty content from {name}.")
# json # json
assert conf.get("output_format") == "json", "have to be json for doc" assert conf.get("output_format") == "json", "have to be json for doc"
if conf.get("output_format") == "json": if conf.get("output_format") == "json":
self.set_output("json", sections) self.set_output("json", sections)
def _markdown(self, from_upstream: ParserFromUpstream): def _slides(self, name, blob):
from deepdoc.parser.ppt_parser import RAGFlowPptParser as ppt_parser
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PowerPoint Document")
conf = self._param.setups["slides"]
self.set_output("output_format", conf["output_format"])
ppt_parser = ppt_parser()
txts = ppt_parser(blob, 0, 100000, None)
sections = [{"text": section} for section in txts if section.strip()]
# json
assert conf.get("output_format") == "json", "have to be json for ppt"
if conf.get("output_format") == "json":
self.set_output("json", sections)
def _markdown(self, name, blob):
from functools import reduce from functools import reduce
from rag.app.naive import Markdown as naive_markdown_parser from rag.app.naive import Markdown as naive_markdown_parser
from rag.nlp import concat_img from rag.nlp import concat_img
self.callback(random.randint(1, 5) / 100.0, "Start to work on a markdown.") self.callback(random.randint(1, 5) / 100.0, "Start to work on a markdown.")
conf = self._param.setups["text&markdown"]
blob = from_upstream.blob
name = from_upstream.name
conf = self._param.setups["markdown"]
self.set_output("output_format", conf["output_format"]) self.set_output("output_format", conf["output_format"])
markdown_parser = naive_markdown_parser() markdown_parser = naive_markdown_parser()
sections, tables = markdown_parser(name, blob, separate_tables=False) sections, tables = markdown_parser(name, blob, separate_tables=False)
# json
assert conf.get("output_format") == "json", "have to be json for doc"
if conf.get("output_format") == "json": if conf.get("output_format") == "json":
json_results = [] json_results = []
@ -291,69 +303,51 @@ class Parser(ProcessBase):
json_results.append(json_result) json_results.append(json_result)
self.set_output("json", json_results) self.set_output("json", json_results)
def _text(self, from_upstream: ParserFromUpstream):
from deepdoc.parser.utils import get_text
self.callback(random.randint(1, 5) / 100.0, "Start to work on a text.")
blob = from_upstream.blob
name = from_upstream.name
conf = self._param.setups["text"]
self.set_output("output_format", conf["output_format"])
# parse binary to text
text_content = get_text(name, binary=blob)
if conf.get("output_format") == "json":
result = [{"text": text_content}]
self.set_output("json", result)
else: else:
result = text_content self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
self.set_output("text", result)
def _image(self, from_upstream: ParserFromUpstream):
def _image(self, name, blob):
from deepdoc.vision import OCR from deepdoc.vision import OCR
self.callback(random.randint(1, 5) / 100.0, "Start to work on an image.") self.callback(random.randint(1, 5) / 100.0, "Start to work on an image.")
blob = from_upstream.blob
conf = self._param.setups["image"] conf = self._param.setups["image"]
self.set_output("output_format", conf["output_format"]) self.set_output("output_format", conf["output_format"])
img = Image.open(io.BytesIO(blob)).convert("RGB") img = Image.open(io.BytesIO(blob)).convert("RGB")
lang = conf["lang"]
if conf["parse_method"] == "ocr": if conf["parse_method"] == "ocr":
# use ocr, recognize chars only # use ocr, recognize chars only
ocr = OCR() ocr = OCR()
bxs = ocr(np.array(img)) # return boxes and recognize result bxs = ocr(np.array(img)) # return boxes and recognize result
txt = "\n".join([t[0] for _, t in bxs if t[0]]) txt = "\n".join([t[0] for _, t in bxs if t[0]])
else: else:
lang = conf["lang"]
# use VLM to describe the picture # use VLM to describe the picture
cv_model = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"],lang=lang) cv_model = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["parse_method"], lang=lang)
img_binary = io.BytesIO() img_binary = io.BytesIO()
img.save(img_binary, format="JPEG") img.save(img_binary, format="JPEG")
img_binary.seek(0) img_binary.seek(0)
txt = cv_model.describe(img_binary.read())
system_prompt = conf.get("system_prompt")
if system_prompt:
txt = cv_model.describe_with_prompt(img_binary.read(), system_prompt)
else:
txt = cv_model.describe(img_binary.read())
self.set_output("text", txt) self.set_output("text", txt)
def _audio(self, from_upstream: ParserFromUpstream): def _audio(self, name, blob):
import os import os
import tempfile import tempfile
self.callback(random.randint(1, 5) / 100.0, "Start to work on an audio.") self.callback(random.randint(1, 5) / 100.0, "Start to work on an audio.")
blob = from_upstream.blob
name = from_upstream.name
conf = self._param.setups["audio"] conf = self._param.setups["audio"]
self.set_output("output_format", conf["output_format"]) self.set_output("output_format", conf["output_format"])
lang = conf["lang"] lang = conf["lang"]
_, ext = os.path.splitext(name) _, ext = os.path.splitext(name)
tmp_path = ""
with tempfile.NamedTemporaryFile(suffix=ext) as tmpf: with tempfile.NamedTemporaryFile(suffix=ext) as tmpf:
tmpf.write(blob) tmpf.write(blob)
tmpf.flush() tmpf.flush()
@ -364,15 +358,131 @@ class Parser(ProcessBase):
self.set_output("text", txt) self.set_output("text", txt)
def _email(self, name, blob):
self.callback(random.randint(1, 5) / 100.0, "Start to work on an email.")
email_content = {}
conf = self._param.setups["email"]
target_fields = conf["fields"]
_, ext = os.path.splitext(name)
if ext == ".eml":
# handle eml file
from email import policy
from email.parser import BytesParser
msg = BytesParser(policy=policy.default).parse(io.BytesIO(blob))
email_content['metadata'] = {}
# handle header info
for header, value in msg.items():
# get fields like from, to, cc, bcc, date, subject
if header.lower() in target_fields:
email_content[header.lower()] = value
# get metadata
elif header.lower() not in ["from", "to", "cc", "bcc", "date", "subject"]:
email_content["metadata"][header.lower()] = value
# get body
if "body" in target_fields:
body_text, body_html = [], []
def _add_content(m, content_type):
if content_type == "text/plain":
body_text.append(
m.get_payload(decode=True).decode(m.get_content_charset())
)
elif content_type == "text/html":
body_html.append(
m.get_payload(decode=True).decode(m.get_content_charset())
)
elif "multipart" in content_type:
if m.is_multipart():
for part in m.iter_parts():
_add_content(part, part.get_content_type())
_add_content(msg, msg.get_content_type())
email_content["text"] = body_text
email_content["text_html"] = body_html
# get attachment
if "attachments" in target_fields:
attachments = []
for part in msg.iter_attachments():
content_disposition = part.get("Content-Disposition")
if content_disposition:
dispositions = content_disposition.strip().split(";")
if dispositions[0].lower() == "attachment":
filename = part.get_filename()
payload = part.get_payload(decode=True)
attachments.append({
"filename": filename,
"payload": payload,
})
email_content["attachments"] = attachments
else:
# handle msg file
import extract_msg
print("handle a msg file.")
msg = extract_msg.Message(blob)
# handle header info
basic_content = {
"from": msg.sender,
"to": msg.to,
"cc": msg.cc,
"bcc": msg.bcc,
"date": msg.date,
"subject": msg.subject,
}
email_content.update({k: v for k, v in basic_content.items() if k in target_fields})
# get metadata
email_content['metadata'] = {
'message_id': msg.messageId,
'in_reply_to': msg.inReplyTo,
}
# get body
if "body" in target_fields:
email_content["text"] = msg.body # usually empty. try text_html instead
email_content["text_html"] = msg.htmlBody
# get attachments
if "attachments" in target_fields:
attachments = []
for t in msg.attachments:
attachments.append({
"filename": t.name,
"payload": t.data # binary
})
email_content["attachments"] = attachments
if conf["output_format"] == "json":
self.set_output("json", [email_content])
else:
content_txt = ''
for k, v in email_content.items():
if isinstance(v, str):
# basic info
content_txt += f'{k}:{v}' + "\n"
elif isinstance(v, dict):
# metadata
content_txt += f'{k}:{json.dumps(v)}' + "\n"
elif isinstance(v, list):
# attachments or others
for fb in v:
if isinstance(fb, dict):
# attachments
content_txt += f'{fb["filename"]}:{fb["payload"]}' + "\n"
else:
# str, usually plain text
content_txt += fb
self.set_output("text", content_txt)
async def _invoke(self, **kwargs): async def _invoke(self, **kwargs):
function_map = { function_map = {
"pdf": self._pdf, "pdf": self._pdf,
"markdown": self._markdown, "text&markdown": self._markdown,
"spreadsheet": self._spreadsheet, "spreadsheet": self._spreadsheet,
"slides": self._slides,
"word": self._word, "word": self._word,
"text": self._text,
"image": self._image, "image": self._image,
"audio": self._audio, "audio": self._audio,
"email": self._email,
} }
try: try:
from_upstream = ParserFromUpstream.model_validate(kwargs) from_upstream = ParserFromUpstream.model_validate(kwargs)
@ -380,8 +490,25 @@ class Parser(ProcessBase):
self.set_output("_ERROR", f"Input error: {str(e)}") self.set_output("_ERROR", f"Input error: {str(e)}")
return return
name = from_upstream.name
if self._canvas._doc_id:
b, n = File2DocumentService.get_storage_address(doc_id=self._canvas._doc_id)
blob = STORAGE_IMPL.get(b, n)
else:
blob = FileService.get_blob(from_upstream.file["created_by"], from_upstream.file["id"])
done = False
for p_type, conf in self._param.setups.items(): for p_type, conf in self._param.setups.items():
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []): if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
continue continue
await trio.to_thread.run_sync(function_map[p_type], from_upstream) await trio.to_thread.run_sync(function_map[p_type], name, blob)
done = True
break break
if not done:
raise Exception("No suitable for file extension: `.%s`" % from_upstream.name.split(".")[-1].lower())
outs = self.output()
async with trio.open_nursery() as nursery:
for d in outs.get("json", []):
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())

View File

@ -20,6 +20,5 @@ class ParserFromUpstream(BaseModel):
elapsed_time: float | None = Field(default=None, alias="_elapsed_time") elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str name: str
blob: bytes file: dict | None = Field(default=None)
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")

View File

@ -17,41 +17,92 @@ import datetime
import json import json
import logging import logging
import random import random
import time from timeit import default_timer as timer
import trio import trio
from agent.canvas import Graph from agent.canvas import Graph
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.task_service import has_canceled, TaskService, CANVAS_DEBUG_DOC_ID
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
class Pipeline(Graph): class Pipeline(Graph):
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None): def __init__(self, dsl: str|dict, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
if isinstance(dsl, dict):
dsl = json.dumps(dsl, ensure_ascii=False)
super().__init__(dsl, tenant_id, task_id) super().__init__(dsl, tenant_id, task_id)
if doc_id == CANVAS_DEBUG_DOC_ID:
doc_id = None
self._doc_id = doc_id self._doc_id = doc_id
self._flow_id = flow_id self._flow_id = flow_id
self._kb_id = None self._kb_id = None
if doc_id: if self._doc_id:
self._kb_id = DocumentService.get_knowledgebase_id(doc_id) self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
assert self._kb_id, f"Can't find KB of this document: {doc_id}" if not self._kb_id:
self._doc_id = None
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None: def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
from rag.svr.task_executor import TaskCanceledException
log_key = f"{self._flow_id}-{self.task_id}-logs" log_key = f"{self._flow_id}-{self.task_id}-logs"
timestamp = timer()
if has_canceled(self.task_id):
progress = -1
message += "[CANCEL]"
try: try:
bin = REDIS_CONN.get(log_key) bin = REDIS_CONN.get(log_key)
obj = json.loads(bin.encode("utf-8")) obj = json.loads(bin.encode("utf-8"))
if obj: if obj:
if obj[-1]["component_name"] == component_name: if obj[-1]["component_id"] == component_name:
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}) obj[-1]["trace"].append(
{
"progress": progress,
"message": message,
"datetime": datetime.datetime.now().strftime("%H:%M:%S"),
"timestamp": timestamp,
"elapsed_time": timestamp - obj[-1]["trace"][-1]["timestamp"],
}
)
else: else:
obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}) obj.append(
{
"component_id": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}],
}
)
else: else:
obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}] obj = [
REDIS_CONN.set_obj(log_key, obj, 60 * 10) {
"component_id": component_name,
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S"), "timestamp": timestamp, "elapsed_time": 0}],
}
]
if component_name != "END" and self._doc_id and self.task_id:
percentage = 1.0 / len(self.components.items())
finished = 0.0
for o in obj:
for t in o["trace"]:
if t["progress"] < 0:
finished = -1
break
if finished < 0:
break
finished += o["trace"][-1]["progress"] * percentage
msg = ""
if len(obj[-1]["trace"]) == 1:
msg += f"\n-------------------------------------\n[{self.get_component_name(o['component_id'])}]:\n"
t = obj[-1]["trace"][-1]
msg += "%s: %s\n" % (t["datetime"], t["message"])
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
elif component_name == "END" and not self._doc_id:
obj[-1]["trace"][-1]["dsl"] = json.loads(str(self))
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
if has_canceled(self.task_id):
raise TaskCanceledException(message)
def fetch_logs(self): def fetch_logs(self):
log_key = f"{self._flow_id}-{self.task_id}-logs" log_key = f"{self._flow_id}-{self.task_id}-logs"
try: try:
@ -62,34 +113,32 @@ class Pipeline(Graph):
logging.exception(e) logging.exception(e)
return [] return []
def reset(self):
super().reset() async def run(self, **kwargs):
log_key = f"{self._flow_id}-{self.task_id}-logs" log_key = f"{self._flow_id}-{self.task_id}-logs"
try: try:
REDIS_CONN.set_obj(log_key, [], 60 * 10) REDIS_CONN.set_obj(log_key, [], 60 * 10)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
self.error = ""
async def run(self, **kwargs):
st = time.perf_counter()
if not self.path: if not self.path:
self.path.append("File") self.path.append("File")
if self._doc_id:
DocumentService.update_by_id(
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
)
self.error = ""
idx = len(self.path) - 1
if idx == 0:
cpn_obj = self.get_component_obj(self.path[0]) cpn_obj = self.get_component_obj(self.path[0])
await cpn_obj.invoke(**kwargs) await cpn_obj.invoke(**kwargs)
if cpn_obj.error(): if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error() self.error = "[ERROR]" + cpn_obj.error()
else: self.callback(cpn_obj.component_name, -1, self.error)
idx += 1
self.path.extend(cpn_obj.get_downstream()) if self._doc_id:
TaskService.update_progress(self.task_id, {
"progress": random.randint(0, 5) / 100.0,
"progress_msg": "Start the pipeline...",
"begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")})
idx = len(self.path) - 1
cpn_obj = self.get_component_obj(self.path[idx])
idx += 1
self.path.extend(cpn_obj.get_downstream())
while idx < len(self.path) and not self.error: while idx < len(self.path) and not self.error:
last_cpn = self.get_component_obj(self.path[idx - 1]) last_cpn = self.get_component_obj(self.path[idx - 1])
@ -98,15 +147,28 @@ class Pipeline(Graph):
async def invoke(): async def invoke():
nonlocal last_cpn, cpn_obj nonlocal last_cpn, cpn_obj
await cpn_obj.invoke(**last_cpn.output()) await cpn_obj.invoke(**last_cpn.output())
#if inspect.iscoroutinefunction(cpn_obj.invoke):
# await cpn_obj.invoke(**last_cpn.output())
#else:
# cpn_obj.invoke(**last_cpn.output())
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
nursery.start_soon(invoke) nursery.start_soon(invoke)
if cpn_obj.error(): if cpn_obj.error():
self.error = "[ERROR]" + cpn_obj.error() self.error = "[ERROR]" + cpn_obj.error()
self.callback(cpn_obj.component_name, -1, self.error) self.callback(cpn_obj._id, -1, self.error)
break break
idx += 1 idx += 1
self.path.extend(cpn_obj.get_downstream()) self.path.extend(cpn_obj.get_downstream())
if self._doc_id: self.callback("END", 1 if not self.error else -1, json.dumps(self.get_component_obj(self.path[-1]).output(), ensure_ascii=False))
DocumentService.update_by_id(self._doc_id, {"progress": 1 if not self.error else -1, "progress_msg": "Pipeline finished...\n" + self.error, "process_duration": time.perf_counter() - st})
if not self.error:
return self.get_component_obj(self.path[-1]).output()
TaskService.update_progress(self.task_id, {
"progress": -1,
"progress_msg": f"[ERROR]: {self.error}"})
return {}

View File

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -17,19 +17,20 @@ from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
class ChunkerFromUpstream(BaseModel): class SplitterFromUpstream(BaseModel):
created_time: float | None = Field(default=None, alias="_created_time") created_time: float | None = Field(default=None, alias="_created_time")
elapsed_time: float | None = Field(default=None, alias="_elapsed_time") elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str name: str
blob: bytes file: dict | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None) output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown") markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text") text_result: str | None = Field(default=None, alias="text")
html_result: list[str] | None = Field(default=None, alias="html") html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")

View File

@ -0,0 +1,111 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from functools import partial
import trio
from api.utils import get_uuid
from api.utils.base64_image import id2image, image2id
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
from rag.flow.base import ProcessBase, ProcessParamBase
from rag.flow.splitter.schema import SplitterFromUpstream
from rag.nlp import naive_merge, naive_merge_with_images
from rag.utils.storage_factory import STORAGE_IMPL
class SplitterParam(ProcessParamBase):
def __init__(self):
super().__init__()
self.chunk_token_size = 512
self.delimiters = ["\n"]
self.overlapped_percent = 0
def check(self):
self.check_empty(self.delimiters, "Delimiters.")
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
def get_input_form(self) -> dict[str, dict]:
return {}
class Splitter(ProcessBase):
component_name = "Splitter"
async def _invoke(self, **kwargs):
try:
from_upstream = SplitterFromUpstream.model_validate(kwargs)
except Exception as e:
self.set_output("_ERROR", f"Input error: {str(e)}")
return
deli = ""
for d in self._param.delimiters:
if len(d) > 1:
deli += f"`{d}`"
else:
deli += d
self.set_output("output_format", "chunks")
self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.")
if from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result
elif from_upstream.output_format == "text":
payload = from_upstream.text_result
else: # == "html"
payload = from_upstream.html_result
if not payload:
payload = ""
cks = naive_merge(
payload,
self._param.chunk_token_size,
deli,
self._param.overlapped_percent,
)
self.set_output("chunks", [{"text": c.strip()} for c in cks if c.strip()])
self.callback(1, "Done.")
return
# json
sections, section_images = [], []
for o in from_upstream.json_result or []:
sections.append((o.get("text", ""), o.get("position_tag", "")))
section_images.append(id2image(o.get("img_id"), partial(STORAGE_IMPL.get)))
chunks, images = naive_merge_with_images(
sections,
section_images,
self._param.chunk_token_size,
deli,
self._param.overlapped_percent,
)
cks = [
{
"text": RAGFlowPdfParser.remove_tag(c),
"image": img,
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
}
for c, img in zip(chunks, images) if c.strip()
]
async with trio.open_nursery() as nursery:
for d in cks:
nursery.start_soon(image2id, d, partial(STORAGE_IMPL.put), get_uuid())
self.set_output("chunks", cks)
self.callback(1, "Done.")

View File

@ -30,7 +30,7 @@ def print_logs(pipeline: Pipeline):
while True: while True:
time.sleep(5) time.sleep(5)
logs = pipeline.fetch_logs() logs = pipeline.fetch_logs()
logs_str = json.dumps(logs) logs_str = json.dumps(logs, ensure_ascii=False)
if logs_str != last_logs: if logs_str != last_logs:
print(logs_str) print(logs_str)
last_logs = logs_str last_logs = logs_str

View File

@ -38,6 +38,13 @@
], ],
"output_format": "json" "output_format": "json"
}, },
"slides": {
"parse_method": "presentation",
"suffix": [
"pptx"
],
"output_format": "json"
},
"markdown": { "markdown": {
"suffix": [ "suffix": [
"md", "md",
@ -82,19 +89,36 @@
"lang": "Chinese", "lang": "Chinese",
"llm_id": "SenseVoiceSmall", "llm_id": "SenseVoiceSmall",
"output_format": "json" "output_format": "json"
},
"email": {
"suffix": [
"msg"
],
"fields": [
"from",
"to",
"cc",
"bcc",
"date",
"subject",
"body",
"attachments"
],
"output_format": "json"
} }
} }
} }
}, },
"downstream": ["Chunker:0"], "downstream": ["Splitter:0"],
"upstream": ["Begin"] "upstream": ["Begin"]
}, },
"Chunker:0": { "Splitter:0": {
"obj": { "obj": {
"component_name": "Chunker", "component_name": "Splitter",
"params": { "params": {
"method": "general", "chunk_token_size": 512,
"auto_keywords": 5 "delimiters": ["\n"],
"overlapped_percent": 0
} }
}, },
"downstream": ["Tokenizer:0"], "downstream": ["Tokenizer:0"],

View File

@ -0,0 +1,84 @@
{
"components": {
"File": {
"obj":{
"component_name": "File",
"params": {
}
},
"downstream": ["Parser:0"],
"upstream": []
},
"Parser:0": {
"obj": {
"component_name": "Parser",
"params": {
"setups": {
"pdf": {
"parse_method": "deepdoc",
"vlm_name": "",
"lang": "Chinese",
"suffix": [
"pdf"
],
"output_format": "json"
},
"spreadsheet": {
"suffix": [
"xls",
"xlsx",
"csv"
],
"output_format": "html"
},
"word": {
"suffix": [
"doc",
"docx"
],
"output_format": "json"
},
"markdown": {
"suffix": [
"md",
"markdown"
],
"output_format": "text"
},
"text": {
"suffix": ["txt"],
"output_format": "json"
}
}
}
},
"downstream": ["Splitter:0"],
"upstream": ["File"]
},
"Splitter:0": {
"obj": {
"component_name": "Splitter",
"params": {
"chunk_token_size": 512,
"delimiters": ["\r\n"],
"overlapped_percent": 0
}
},
"downstream": ["HierarchicalMerger:0"],
"upstream": ["Parser:0"]
},
"HierarchicalMerger:0": {
"obj": {
"component_name": "HierarchicalMerger",
"params": {
"levels": [["^#[^#]"], ["^##[^#]"], ["^###[^#]"], ["^####[^#]"]],
"hierarchy": 2
}
},
"downstream": [],
"upstream": ["Splitter:0"]
}
},
"path": []
}

View File

@ -22,16 +22,16 @@ class TokenizerFromUpstream(BaseModel):
elapsed_time: float | None = Field(default=None, alias="_elapsed_time") elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
name: str = "" name: str = ""
blob: bytes file: dict | None = Field(default=None)
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None) output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
chunks: list[dict[str, Any]] | None = Field(default=None) chunks: list[dict[str, Any]] | None = Field(default=None)
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json") json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
markdown_result: str | None = Field(default=None, alias="markdown") markdown_result: str | None = Field(default=None, alias="markdown")
text_result: str | None = Field(default=None, alias="text") text_result: str | None = Field(default=None, alias="text")
html_result: list[str] | None = Field(default=None, alias="html") html_result: str | None = Field(default=None, alias="html")
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@ -40,12 +40,14 @@ class TokenizerFromUpstream(BaseModel):
if self.chunks: if self.chunks:
return self return self
if self.output_format in {"markdown", "text"}: if self.output_format in {"markdown", "text", "html"}:
if self.output_format == "markdown" and not self.markdown_result: if self.output_format == "markdown" and not self.markdown_result:
raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').") raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').")
if self.output_format == "text" and not self.text_result: if self.output_format == "text" and not self.text_result:
raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').") raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').")
if self.output_format == "html" and not self.html_result:
raise ValueError("output_format=text requires a html payload (field: 'html' or 'html_result').")
else: else:
if not self.json_result: if not self.json_result and not self.chunks:
raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').") raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').")
return self return self

View File

@ -37,6 +37,7 @@ class TokenizerParam(ProcessParamBase):
super().__init__() super().__init__()
self.search_method = ["full_text", "embedding"] self.search_method = ["full_text", "embedding"]
self.filename_embd_weight = 0.1 self.filename_embd_weight = 0.1
self.fields = ["text"]
def check(self): def check(self):
for v in self.search_method: for v in self.search_method:
@ -61,10 +62,14 @@ class Tokenizer(ProcessBase):
embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id) embedding_model = LLMBundle(self._canvas._tenant_id, LLMType.EMBEDDING, llm_name=embedding_id)
texts = [] texts = []
for c in chunks: for c in chunks:
if c.get("questions"): txt = ""
texts.append("\n".join(c["questions"])) for f in self._param.fields:
else: f = c.get(f)
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", c["text"])) if isinstance(f, str):
txt += f
elif isinstance(f, list):
txt += "\n".join(f)
texts.append(re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", txt))
vts, c = embedding_model.encode([name]) vts, c = embedding_model.encode([name])
token_count += c token_count += c
tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0) tts = np.concatenate([vts[0] for _ in range(len(texts))], axis=0)
@ -103,26 +108,36 @@ class Tokenizer(ProcessBase):
self.set_output("_ERROR", f"Input error: {str(e)}") self.set_output("_ERROR", f"Input error: {str(e)}")
return return
self.set_output("output_format", "chunks")
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method]) parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
if "full_text" in self._param.search_method: if "full_text" in self._param.search_method:
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.") self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
if from_upstream.chunks: if from_upstream.chunks:
chunks = from_upstream.chunks chunks = from_upstream.chunks
for i, ck in enumerate(chunks): for i, ck in enumerate(chunks):
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
if ck.get("questions"): if ck.get("questions"):
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"])) ck["question_kwd"] = ck["questions"].split("\n")
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
if ck.get("keywords"): if ck.get("keywords"):
ck["important_tks"] = rag_tokenizer.tokenize("\n".join(ck["keywords"])) ck["important_kwd"] = ck["keywords"].split(",")
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) if ck.get("summary"):
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
else:
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99: if i % 100 == 99:
self.callback(i * 1.0 / len(chunks) / parts) self.callback(i * 1.0 / len(chunks) / parts)
elif from_upstream.output_format in ["markdown", "text", "html"]: elif from_upstream.output_format in ["markdown", "text", "html"]:
if from_upstream.output_format == "markdown": if from_upstream.output_format == "markdown":
payload = from_upstream.markdown_result payload = from_upstream.markdown_result
elif from_upstream.output_format == "text": elif from_upstream.output_format == "text":
payload = from_upstream.text_result payload = from_upstream.text_result
else: # == "html" else:
payload = from_upstream.html_result payload = from_upstream.html_result
if not payload: if not payload:
@ -130,12 +145,16 @@ class Tokenizer(ProcessBase):
ck = {"text": payload} ck = {"text": payload}
if "full_text" in self._param.search_method: if "full_text" in self._param.search_method:
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], "")) ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
ck["content_ltks"] = rag_tokenizer.tokenize(payload)
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
chunks = [ck] chunks = [ck]
else: else:
chunks = from_upstream.json_result chunks = from_upstream.json_result
for i, ck in enumerate(chunks): for i, ck in enumerate(chunks):
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"]) ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"]) ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
if i % 100 == 99: if i % 100 == 99:

View File

@ -33,7 +33,7 @@ from zhipuai import ZhipuAI
from api import settings from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate, total_token_count_from_response from rag.utils import num_tokens_from_string, truncate
class Base(ABC): class Base(ABC):
@ -52,7 +52,15 @@ class Base(ABC):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
def total_token_count(self, resp): def total_token_count(self, resp):
return total_token_count_from_response(resp) try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0
class DefaultEmbedding(Base): class DefaultEmbedding(Base):
@ -138,7 +146,7 @@ class OpenAIEmbed(Base):
ress = [] ress = []
total_tokens = 0 total_tokens = 0
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float") res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name, encoding_format="float", extra_body={"drop_params": True})
try: try:
ress.extend([d.embedding for d in res.data]) ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res) total_tokens += self.total_token_count(res)
@ -147,7 +155,7 @@ class OpenAIEmbed(Base):
return np.array(ress), total_tokens return np.array(ress), total_tokens
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float") res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name, encoding_format="float",extra_body={"drop_params": True})
return np.array(res.data[0].embedding), self.total_token_count(res) return np.array(res.data[0].embedding), self.total_token_count(res)
@ -489,7 +497,6 @@ class MistralEmbed(Base):
def encode_queries(self, text): def encode_queries(self, text):
import time import time
import random import random
retry_max = 5 retry_max = 5
while retry_max > 0: while retry_max > 0:
try: try:
@ -748,7 +755,7 @@ class SILICONFLOWEmbed(Base):
texts_batch = texts[i : i + batch_size] texts_batch = texts[i : i + batch_size]
if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]: if self.model_name in ["BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5"]:
# limit 512, 340 is almost safe # limit 512, 340 is almost safe
texts_batch = [" " if not text.strip() else truncate(text, 340) for text in texts_batch] texts_batch = [" " if not text.strip() else truncate(text, 256) for text in texts_batch]
else: else:
texts_batch = [" " if not text.strip() else text for text in texts_batch] texts_batch = [" " if not text.strip() else text for text in texts_batch]
@ -937,7 +944,6 @@ class GiteeEmbed(SILICONFLOWEmbed):
base_url = "https://ai.gitee.com/v1/embeddings" base_url = "https://ai.gitee.com/v1/embeddings"
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
class DeepInfraEmbed(OpenAIEmbed): class DeepInfraEmbed(OpenAIEmbed):
_FACTORY_NAME = "DeepInfra" _FACTORY_NAME = "DeepInfra"

View File

@ -292,6 +292,7 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res.append(d) res.append(d)
return res return res
def tokenize_chunks_with_images(chunks, doc, eng, images): def tokenize_chunks_with_images(chunks, doc, eng, images):
res = [] res = []
# wrap up as es documents # wrap up as es documents
@ -306,6 +307,7 @@ def tokenize_chunks_with_images(chunks, doc, eng, images):
res.append(d) res.append(d)
return res return res
def tokenize_table(tbls, doc, eng, batch_size=10): def tokenize_table(tbls, doc, eng, batch_size=10):
res = [] res = []
# add tables # add tables
@ -579,7 +581,9 @@ def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。
from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not sections: if not sections:
return [] return []
if isinstance(sections[0], type("")): if isinstance(sections, str):
sections = [sections]
if isinstance(sections[0], str):
sections = [(s, "") for s in sections] sections = [(s, "") for s in sections]
cks = [""] cks = [""]
tk_nums = [0] tk_nums = [0]

View File

@ -383,7 +383,7 @@ class Dealer:
vector_column = f"q_{dim}_vec" vector_column = f"q_{dim}_vec"
zero_vector = [0.0] * dim zero_vector = [0.0] * dim
sim_np = np.array(sim) sim_np = np.array(sim)
filtered_count = (sim_np >= similarity_threshold).sum() filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx: for i in idx:
if sim[i] < similarity_threshold: if sim[i] < similarity_threshold:
@ -444,12 +444,27 @@ class Dealer:
def chunk_list(self, doc_id: str, tenant_id: str, def chunk_list(self, doc_id: str, tenant_id: str,
kb_ids: list[str], max_count=1024, kb_ids: list[str], max_count=1024,
offset=0, offset=0,
fields=["docnm_kwd", "content_with_weight", "img_id"]): fields=["docnm_kwd", "content_with_weight", "img_id"],
sort_by_position: bool = False):
condition = {"doc_id": doc_id} condition = {"doc_id": doc_id}
fields_set = set(fields or [])
if sort_by_position:
for need in ("page_num_int", "position_int", "top_int"):
if need not in fields_set:
fields_set.add(need)
fields = list(fields_set)
orderBy = OrderByExpr()
if sort_by_position:
orderBy.asc("page_num_int")
orderBy.asc("position_int")
orderBy.asc("top_int")
res = [] res = []
bs = 128 bs = 128
for p in range(offset, max_count, bs): for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids) kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields) dict_chunks = self.dataStore.getFields(es_res, fields)
for id, doc in dict_chunks.items(): for id, doc in dict_chunks.items():

View File

@ -436,4 +436,217 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
return ans return ans
except Exception: except Exception:
logging.exception(f"Loading json failure: {ans}") logging.exception(f"Loading json failure: {ans}")
return [] return []
def gen_json(system_prompt:str, user_prompt:str, chat_mdl):
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
return json_repair.loads(ans)
except Exception:
logging.exception(f"Loading json failure: {ans}")
TOC_DETECTION = load_prompt("toc_detection")
def detect_table_of_contents(page_1024:list[str], chat_mdl):
toc_secs = []
for i, sec in enumerate(page_1024[:22]):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
if toc_secs and not ans["exists"]:
break
toc_secs.append(sec)
return toc_secs
TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages:
return []
return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
The provided pages contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X.
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
"title": <title of the section>,
"physical_index": "<physical_index_X>" (keep the format)
},
...
]
Only add the physical_index to the sections that are in the provided pages.
If the title of the section are not in the provided pages, do not add the physical_index to it.
Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
return gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index")
def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections:
return []
toc_map = {}
for i, it in enumerate(toc_arr):
k1 = (it["structure"]+it["title"]).replace(" ", "")
k2 = it["title"].strip()
if k1 not in toc_map:
toc_map[k1] = []
if k2 not in toc_map:
toc_map[k2] = []
toc_map[k1].append(i)
toc_map[k2].append(i)
for it in toc_arr:
it["indices"] = []
for i, sec in enumerate(sections):
sec = sec.strip()
if sec.replace(" ", "") in toc_map:
for j in toc_map[sec.replace(" ", "")]:
toc_arr[j]["indices"].append(i)
all_pathes = []
def dfs(start, path):
nonlocal all_pathes
if start >= len(toc_arr):
if path:
all_pathes.append(path)
return
if not toc_arr[start]["indices"]:
dfs(start+1, path)
return
added = False
for j in toc_arr[start]["indices"]:
if path and j < path[-1][0]:
continue
_path = deepcopy(path)
_path.append((j, start))
added = True
dfs(start+1, _path)
if not added and path:
all_pathes.append(path)
dfs(0, [])
path = max(all_pathes, key=lambda x:len(x))
for it in toc_arr:
it["indices"] = []
for j, i in path:
toc_arr[i]["indices"] = [j]
print(json.dumps(toc_arr, ensure_ascii=False, indent=2))
i = 0
while i < len(toc_arr):
it = toc_arr[i]
if it["indices"]:
i += 1
continue
if i>0 and toc_arr[i-1]["indices"]:
st_i = toc_arr[i-1]["indices"][-1]
else:
st_i = 0
e = i + 1
while e <len(toc_arr) and not toc_arr[e]["indices"]:
e += 1
if e >= len(toc_arr):
e = len(sections)
else:
e = toc_arr[e]["indices"][0]
for j in range(st_i, min(e+1, len(sections))):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"],
title=it["title"],
text=sections[j]), "Only JSON please.", chat_mdl)
if ans["exist"] == "yes":
it["indices"].append(j)
break
i += 1
return toc_arr
def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
prompt = """
You are given a raw table of contents and a table of contents.
Your job is to check if the table of contents is complete.
Reply format:
{{
"thinking": <why do you think the cleaned table of contents is complete or not>
"completed": "yes" or "no"
}}
Directly return the final JSON structure. Do not output anything else."""
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
response = gen_json(prompt, "Only JSON please.", chat_mdl)
return response['completed']
def toc_transformer(toc_pages, chat_mdl):
init_prompt = """
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The `title` is a short phrase or a several-words term.
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
"title": <title of the section>
},
...
],
You should transform the full table of contents in one go.
Directly return the final JSON structure, do not output anything else. """
toc_content = "\n".join(toc_pages)
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
def clean_toc(arr):
for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
clean_toc(last_complete)
if if_complete == "yes":
return last_complete
while not (if_complete == "yes"):
prompt = f"""
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
The response should be in the following JSON format:
The raw table of contents json structure is:
{toc_content}
The incomplete transformed table of contents json structure is:
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
Please continue the json structure, directly output the remaining part of the json structure."""
new_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
break
clean_toc(new_complete)
last_complete.extend(new_complete)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
return last_complete

View File

@ -0,0 +1,29 @@
You are an AI assistant designed to analyze text content and detect whether a table of contents (TOC) list exists on the given page. Follow these steps:
1. **Analyze the Input**: Carefully review the provided text content.
2. **Identify Key Features**: Look for common indicators of a TOC, such as:
- Section titles or headings paired with page numbers.
- Patterns like repeated formatting (e.g., bold/italicized text, dots/dashes between titles and numbers).
- Phrases like "Table of Contents," "Contents," or similar headings.
- Logical grouping of topics/subtopics with sequential page references.
3. **Discern Negative Features**:
- The text contains no numbers, or the numbers present are clearly not page references (e.g., dates, statistical figures, phone numbers, version numbers).
- The text consists of full, descriptive sentences and paragraphs that form a narrative, present arguments, or explain concepts, rather than succinctly listing topics.
- Contains citations with authors, publication years, journal titles, and page ranges (e.g., "Smith, J. (2020). Journal Title, 10(2), 45-67.").
- Lists keywords or terms followed by multiple page numbers, often in alphabetical order.
- Comprises terms followed by their definitions or explanations.
- Labeled with headers like "Appendix A," "Appendix B," etc.
- Contains expressive language thanking individuals or organizations for their support or contributions.
4. **Evaluate Evidence**: Weigh the presence/absence of these features to determine if the content resembles a TOC.
5. **Output Format**: Provide your response in the following JSON structure:
```json
{
"reasoning": "Step-by-step explanation of your analysis based on the features identified." ,
"exists": true/false
}
```
6. **DO NOT** output anything else except JSON structure.
**Input text Content ( Text-Only Extraction ):**
{{ page_txt }}

View File

@ -0,0 +1,53 @@
You are an expert parser and data formatter. Your task is to analyze the provided table of contents (TOC) text and convert it into a valid JSON array of objects.
**Instructions:**
1. Analyze each line of the input TOC.
2. For each line, extract the following three pieces of information:
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5", "A.1"). If a line has no visible numbering or structure indicator (like a main "Chapter" title), use `null`.
* `title`: The textual title of the section or chapter. This should be the main descriptive text, clean and without the page number.
3. Output **only** a valid JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json) in your response.
**JSON Format:**
The output must be a list of objects following this exact schema:
```json
[
{
"structure": <structure index, "x.x.x" or None> (string,
"title": <title of the section>
},
...
]
```
**Input Example:**
```
Contents
1 Introduction to the System ... 1
1.1 Overview .... 2
1.2 Key Features .... 5
2 Installation Guide ....8
2.1 Prerequisites ........ 9
2.2 Step-by-Step Process ........ 12
Appendix A: Specifications ..... 45
References ... 47
```
**Expected Output For The Example:**
```json
[
{"structure": null, "title": "Contents"},
{"structure": "1", "title": "Introduction to the System"},
{"structure": "1.1", "title": "Overview"},
{"structure": "1.2", "title": "Key Features"},
{"structure": "2", "title": "Installation Guide"},
{"structure": "2.1", "title": "Prerequisites"},
{"structure": "2.2", "title": "Step-by-Step Process"},
{"structure": "A", "title": "Specifications"},
{"structure": null, "title": "References"}
]
```
**Now, process the following TOC input:**
```
{{ toc_page }}
```

View File

@ -0,0 +1,60 @@
You are an expert parser and data formatter, currently in the process of building a JSON array from a multi-page table of contents (TOC). Your task is to analyze the new page of content and **append** the new entries to the existing JSON array.
**Instructions:**
1. You will be given two inputs:
* `current_page_text`: The text content from the new page of the TOC.
* `existing_json`: The valid JSON array you have generated from the previous pages.
2. Analyze each line of the `current_page_text` input.
3. For each new line, extract the following three pieces of information:
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5"). Use `null` if none exists.
* `title`: The clean textual title of the section or chapter.
* `page`: The page number on which the section starts. Extract only the number. Use `null` if not present.
4. **Append these new entries** to the `existing_json` array. Do not modify, reorder, or delete any of the existing entries.
5. Output **only** the complete, updated JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json).
**JSON Format:**
The output must be a valid JSON array following this schema:
```json
[
{
"structure": <string or null>,
"title": <string>,
"page": <number or null>
},
...
]
```
**Input Example:**
`current_page_text`:
```
3.2 Advanced Configuration ........... 25
3.3 Troubleshooting .................. 28
4 User Management .................... 30
```
`existing_json`:
```json
[
{"structure": "1", "title": "Introduction", "page": 1},
{"structure": "2", "title": "Installation", "page": 5},
{"structure": "3", "title": "Configuration", "page": 12},
{"structure": "3.1", "title": "Basic Setup", "page": 15}
]
```
**Expected Output For The Example:**
```json
[
{"structure": "3.2", "title": "Advanced Configuration", "page": 25},
{"structure": "3.3", "title": "Troubleshooting", "page": 28},
{"structure": "4", "title": "User Management", "page": 30}
]
```
**Now, process the following inputs:**
`current_page_text`:
{{ toc_page }}
`existing_json`:
{{ toc_json }}

20
rag/prompts/toc_index.md Normal file
View File

@ -0,0 +1,20 @@
You are an expert analyst tasked with matching text content to the title.
**Instructions:**
1. Analyze the given title with its numeric structure index and the provided text.
2. Determine whether the title is mentioned as a section tile in the given text.
3. Provide a concise, step-by-step reasoning for your decision.
4. Output **only** the complete JSON object. Do not include any other text, explanations, or markdown code block fences (like ```json).
**Output Format:**
Your output must be a valid JSON object with the following keys:
{
"reasoning": "Step-by-step explanation of your analysis.",
"exist": "<yes or no>",
}
** The title: **
{{ structure }} {{ title }}
** Given text: **
{{ text }}

View File

@ -21,14 +21,18 @@ import sys
import threading import threading
import time import time
from api.utils import get_uuid import json_repair
from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from api.utils.base64_image import image2id
from api.utils.log_utils import init_root_logger, get_project_base_directory from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline from rag.flow.pipeline import Pipeline
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
@ -37,7 +41,6 @@ import xxhash
import copy import copy
import re import re
from functools import partial from functools import partial
from io import BytesIO
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
from timeit import default_timer as timer from timeit import default_timer as timer
import tracemalloc import tracemalloc
@ -45,21 +48,19 @@ import signal
import trio import trio
import exceptiongroup import exceptiongroup
import faulthandler import faulthandler
import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
from api.db import LLMType, ParserType, PipelineTaskType
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, has_canceled from api.db.services.task_service import TaskService, has_canceled, CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api import settings from api import settings
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.db.db_models import close_connection from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
email, tag email, tag
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer, add_positions
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
@ -88,6 +89,13 @@ FACTORY = {
ParserType.TAG.value: tag ParserType.TAG.value: tag
} }
TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"dataflow" : PipelineTaskType.PARSE,
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
"mindmap": PipelineTaskType.MINDMAP,
}
UNACKED_ITERATOR = None UNACKED_ITERATOR = None
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
@ -143,6 +151,7 @@ def start_tracemalloc_and_snapshot(signum, frame):
max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB") logging.info(f"taken snapshot {snapshot_file}. max RSS={max_rss / 1000:.2f} MB, current memory usage: {current / 10**6:.2f} MB, Peak memory usage: {peak / 10**6:.2f} MB")
# SIGUSR2 handler: stop tracemalloc # SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame): def stop_tracemalloc(signum, frame):
if tracemalloc.is_tracing(): if tracemalloc.is_tracing():
@ -151,6 +160,7 @@ def stop_tracemalloc(signum, frame):
else: else:
logging.info("tracemalloc not running") logging.info("tracemalloc not running")
class TaskCanceledException(Exception): class TaskCanceledException(Exception):
def __init__(self, msg): def __init__(self, msg):
self.msg = msg self.msg = msg
@ -216,7 +226,14 @@ async def collect():
return None, None return None, None
canceled = False canceled = False
task = TaskService.get_task(msg["id"]) if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]:
task = msg
if task["task_type"] in ["graphrag", "raptor", "mindmap"] and msg.get("doc_ids", []):
task = TaskService.get_task(msg["id"], msg["doc_ids"])
task["doc_ids"] = msg["doc_ids"]
else:
task = TaskService.get_task(msg["id"])
if task: if task:
canceled = has_canceled(task["id"]) canceled = has_canceled(task["id"])
if not task or canceled: if not task or canceled:
@ -228,10 +245,9 @@ async def collect():
task_type = msg.get("task_type", "") task_type = msg.get("task_type", "")
task["task_type"] = task_type task["task_type"] = task_type
if task_type == "dataflow": if task_type[:8] == "dataflow":
task["tenant_id"]=msg.get("tenant_id", "") task["tenant_id"] = msg["tenant_id"]
task["dsl"] = msg.get("dsl", "") task["dataflow_id"] = msg["dataflow_id"]
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
task["kb_id"] = msg.get("kb_id", "") task["kb_id"] = msg.get("kb_id", "")
return redis_msg, task return redis_msg, task
@ -301,30 +317,8 @@ async def build_chunks(task, progress_callback):
d["img_id"] = "" d["img_id"] = ""
docs.append(d) docs.append(d)
return return
await image2id(d, partial(STORAGE_IMPL.put), d["id"], task["kb_id"])
with BytesIO() as output_buffer: docs.append(d)
if isinstance(d["image"], bytes):
output_buffer.write(d["image"])
output_buffer.seek(0)
else:
# If the image is in RGBA mode, convert it to RGB mode before saving it in JPEG format.
if d["image"].mode in ("RGBA", "P"):
converted_image = d["image"].convert("RGB")
#d["image"].close() # Close original image
d["image"] = converted_image
try:
d["image"].save(output_buffer, format='JPEG')
except OSError as e:
logging.warning(
"Saving image of chunk {}/{}/{} got exception, ignore: {}".format(task["location"], task["name"], d["id"], str(e)))
async with minio_limiter:
await trio.to_thread.run_sync(lambda: STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue()))
d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
if not isinstance(d["image"], bytes):
d["image"].close()
del d["image"] # Remove image reference
docs.append(d)
except Exception: except Exception:
logging.exception( logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"])) "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
@ -482,35 +476,192 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size return tk_count, vector_size
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None): async def run_dataflow(task: dict):
_ = callback task_start_ts = timer()
dataflow_id = task["dataflow_id"]
doc_id = task["doc_id"]
task_id = task["id"]
task_dataset_id = task["kb_id"]
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id) if task["task_type"] == "dataflow":
pipeline.reset() e, cvs = UserCanvasService.get_by_id(dataflow_id)
assert e, "User pipeline not found."
dsl = cvs.dsl
else:
e, pipeline_log = PipelineOperationLogService.get_by_id(dataflow_id)
assert e, "Pipeline log not found."
dsl = pipeline_log.dsl
dataflow_id = pipeline_log.pipeline_id
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
chunks = await pipeline.run(file=task["file"]) if task.get("file") else await pipeline.run()
if doc_id == CANVAS_DEBUG_DOC_ID:
return
await pipeline.run() if not chunks:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
if chunks.get("chunks"):
chunks = copy.deepcopy(chunks["chunks"])
elif chunks.get("json"):
chunks = copy.deepcopy(chunks["json"])
elif chunks.get("markdown"):
chunks = [{"text": [chunks["markdown"]]}]
elif chunks.get("text"):
chunks = [{"text": [chunks["text"]]}]
elif chunks.get("html"):
chunks = [{"text": [chunks["html"]]}]
keys = [k for o in chunks for k in list(o.keys())]
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
try:
set_progress(task_id, prog=0.82, msg="\n-------------------------------------\nStart to embedding...")
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
embedding_id = kb.embd_id
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
@timeout(60)
def batch_encode(txts):
nonlocal embedding_model
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
vects = np.array([])
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE+1)
prog = 0.8
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
async with embed_limiter:
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
if len(vects) == 0:
vects = vts
else:
vects = np.concatenate((vects, vts), axis=0)
embedding_token_consumption += c
prog += delta
if i % (len(texts)//EMBEDDING_BATCH_SIZE/100+1) == 1:
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
assert len(vects) == len(chunks)
for i, ck in enumerate(chunks):
v = vects[i].tolist()
ck["q_%d_vec" % len(v)] = v
except Exception as e:
set_progress(task_id, prog=-1, msg=f"[ERROR]: {e}")
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
metadata = {}
def dict_update(meta):
nonlocal metadata
if not meta:
return
if isinstance(meta, str):
try:
meta = json_repair.loads(meta)
except Exception:
logging.error("Meta data format error.")
return
if not isinstance(meta, dict):
return
for k, v in meta.items():
if isinstance(v, list):
v = [vv for vv in v if isinstance(vv, str)]
if not v:
continue
if not isinstance(v, list) and not isinstance(v, str):
continue
if k not in metadata:
metadata[k] = v
continue
if isinstance(metadata[k], list):
if isinstance(v, list):
metadata[k].extend(v)
else:
metadata[k].append(v)
else:
metadata[k] = v
for ck in chunks:
ck["doc_id"] = doc_id
ck["kb_id"] = [str(task["kb_id"])]
ck["docnm_kwd"] = task["name"]
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
ck["create_timestamp_flt"] = datetime.now().timestamp()
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
if "questions" in ck:
if "question_tks" not in ck:
ck["question_kwd"] = ck["questions"].split("\n")
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
del ck["questions"]
if "keywords" in ck:
if "important_tks" not in ck:
ck["important_kwd"] = ck["keywords"].split(",")
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
del ck["keywords"]
if "summary" in ck:
if "content_ltks" not in ck:
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
del ck["summary"]
if "metadata" in ck:
dict_update(ck["metadata"])
del ck["metadata"]
if "content_with_weight" not in ck:
ck["content_with_weight"] = ck["text"]
del ck["text"]
if "positions" in ck:
add_positions(ck, ck["positions"])
del ck["positions"]
if metadata:
e, doc = DocumentService.get_by_id(doc_id)
if e:
if isinstance(doc.meta_fields, str):
doc.meta_fields = json.loads(doc.meta_fields)
dict_update(doc.meta_fields)
DocumentService.update_by_id(doc_id, {"meta_fields": metadata})
start_ts = timer()
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
if not e:
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
return
time_cost = timer() - start_ts
task_time_cost = timer() - task_start_ts
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
@timeout(3600) @timeout(3600)
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {})
chunks = [] chunks = []
vctr_nm = "q_%d_vec"%vector_size vctr_nm = "q_%d_vec"%vector_size
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], for doc_id in doc_ids:
fields=["content_with_weight", vctr_nm]): for d in settings.retrievaler.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) fields=["content_with_weight", vctr_nm],
sort_by_position=True):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor( raptor = Raptor(
row["parser_config"]["raptor"].get("max_cluster", 64), raptor_config.get("max_cluster", 64),
chat_mdl, chat_mdl,
embd_mdl, embd_mdl,
row["parser_config"]["raptor"]["prompt"], raptor_config["prompt"],
row["parser_config"]["raptor"]["max_token"], raptor_config["max_token"],
row["parser_config"]["raptor"]["threshold"] raptor_config["threshold"],
) )
original_length = len(chunks) original_length = len(chunks)
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) chunks = await raptor(chunks, row["kb_parser_config"]["raptor"]["random_seed"], callback)
doc = { doc = {
"doc_id": row["doc_id"], "doc_id": fake_doc_id,
"kb_id": [str(row["kb_id"])], "kb_id": [str(row["kb_id"])],
"docnm_kwd": row["name"], "docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"]) "title_tks": rag_tokenizer.tokenize(row["name"])
@ -521,7 +672,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
tk_count = 0 tk_count = 0
for content, vctr in chunks[original_length:]: for content, vctr in chunks[original_length:]:
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest() d["id"] = xxhash.xxh64((content + str(fake_doc_id)).encode("utf-8")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp() d["create_timestamp_flt"] = datetime.now().timestamp()
d[vctr_nm] = vctr.tolist() d[vctr_nm] = vctr.tolist()
@ -533,8 +684,51 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count return res, tk_count
async def delete_image(kb_id, chunk_id):
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(f"Deleting image of chunk {chunk_id} got exception")
raise
async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback):
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task_id, chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task_id} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return
return True
@timeout(60*60*2, 1) @timeout(60*60*2, 1)
async def do_handle_task(task): async def do_handle_task(task):
task_type = task.get("task_type", "")
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
await run_dataflow(task)
return
task_id = task["id"] task_id = task["id"]
task_from_page = task["from_page"] task_from_page = task["from_page"]
task_to_page = task["to_page"] task_to_page = task["to_page"]
@ -576,32 +770,70 @@ async def do_handle_task(task):
init_kb(task, vector_size) init_kb(task, vector_size)
task_type = task.get("task_type", "") if task_type[:len("dataflow")] == "dataflow":
if task_type == "dataflow": await run_dataflow(task)
task_dataflow_dsl = task["dsl"]
task_dataflow_id = task["dataflow_id"]
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
return return
elif task_type == "raptor":
if task_type == "raptor":
ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
if not ok:
progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for RAPTOR task")
return
kb_parser_config = kb.parser_config
if not kb_parser_config.get("raptor", {}).get("use_raptor", False):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
# bind LLM for raptor # bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
# run RAPTOR # run RAPTOR
async with kg_limiter: async with kg_limiter:
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback) chunks, token_count = await run_raptor_for_kb(
row=task,
kb_parser_config=kb_parser_config,
chat_mdl=chat_model,
embd_mdl=embedding_model,
vector_size=vector_size,
callback=progress_callback,
doc_ids=task.get("doc_ids", []),
)
# Either using graphrag or Standard chunking methods # Either using graphrag or Standard chunking methods
elif task_type == "graphrag": elif task_type == "graphrag":
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False): ok, kb = KnowledgebaseService.get_by_id(task_dataset_id)
progress_callback(prog=-1.0, msg="Internal configuration error.") if not ok:
progress_callback(prog=-1.0, msg="Cannot found valid knowledgebase for GraphRAG task")
return return
graphrag_conf = task["kb_parser_config"].get("graphrag", {})
kb_parser_config = kb.parser_config
if not kb_parser_config.get("graphrag", {}).get("use_graphrag", False):
progress_callback(prog=-1.0, msg="Internal error: Invalid GraphRAG configuration")
return
graphrag_conf = kb_parser_config.get("graphrag", {})
start_ts = timer() start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
with_resolution = graphrag_conf.get("resolution", False) with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False) with_community = graphrag_conf.get("community", False)
async with kg_limiter: async with kg_limiter:
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) # await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
result = await run_graphrag_for_kb(
row=task,
doc_ids=task.get("doc_ids", []),
language=task_language,
kb_parser_config=kb_parser_config,
chat_model=chat_model,
embedding_model=embedding_model,
callback=progress_callback,
with_resolution=with_resolution,
with_community=with_community,
)
logging.info(f"GraphRAG task result for task {task}:\n{result}")
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return return
elif task_type == "mindmap":
progress_callback(1, "place holder")
pass
return
else: else:
# Standard chunking methods # Standard chunking methods
start_ts = timer() start_ts = timer()
@ -628,41 +860,9 @@ async def do_handle_task(task):
chunk_count = len(set([chunk["id"] for chunk in chunks])) chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer() start_ts = timer()
doc_store_result = "" e = await insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_callback)
if not e:
async def delete_image(kb_id, chunk_id): return
try:
async with minio_limiter:
STORAGE_IMPL.delete(kb_id, chunk_id)
except Exception:
logging.exception(
"Deleting image of chunk {}/{}/{} got exception".format(task["location"], task["name"], chunk_id))
raise
for b in range(0, len(chunks), DOC_BULK_SIZE):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + DOC_BULK_SIZE], search.index_name(task_tenant_id), task_dataset_id))
task_canceled = has_canceled(task_id)
if task_canceled:
progress_callback(-1, msg="Task has been canceled.")
return
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]]
chunk_ids_str = " ".join(chunk_ids)
try:
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
except DoesNotExist:
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id))
async with trio.open_nursery() as nursery:
for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task['id']} is unknown.")
return
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
task_to_page, len(chunks), task_to_page, len(chunks),
@ -685,6 +885,10 @@ async def handle_task():
if not task: if not task:
await trio.sleep(5) await trio.sleep(5)
return return
task_type = task["task_type"]
pipeline_task_type = TASK_TYPE_TO_PIPELINE_TASK_TYPE.get(task_type, PipelineTaskType.PARSE) or PipelineTaskType.PARSE
try: try:
logging.info(f"handle_task begin for task {json.dumps(task)}") logging.info(f"handle_task begin for task {json.dumps(task)}")
CURRENT_TASKS[task["id"]] = copy.deepcopy(task) CURRENT_TASKS[task["id"]] = copy.deepcopy(task)
@ -704,6 +908,13 @@ async def handle_task():
except Exception: except Exception:
pass pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}") logging.exception(f"handle_task got exception for task {json.dumps(task)}")
finally:
task_document_ids = []
if task_type in ["graphrag", "raptor", "mindmap"]:
task_document_ids = task["doc_ids"]
if not task.get("dataflow_id", ""):
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
redis_msg.ack() redis_msg.ack()

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,15 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M35.3194 10.6367H20.4258C19.4857 10.6367 18.7236 11.3988 18.7236 12.3388V34.892C18.7236 35.8321 19.4857 36.5942 20.4258 36.5942H35.3194C36.2594 36.5942 37.0215 35.8321 37.0215 34.892V12.3388C37.0215 11.3988 36.2594 10.6367 35.3194 10.6367Z" fill="url(#paint0_linear_488_37636)"/>
<path d="M31.0639 4.25391H5.10642C4.16637 4.25391 3.4043 5.01597 3.4043 5.95603V18.2965C3.4043 19.2365 4.16637 19.9986 5.10642 19.9986H31.0639C32.0039 19.9986 32.766 19.2365 32.766 18.2965V5.95603C32.766 5.01597 32.0039 4.25391 31.0639 4.25391Z" fill="#00BEB4" fill-opacity="0.1"/>
<path d="M31.0639 4.25391C32.0039 4.25391 32.766 5.01597 32.766 5.95603V18.2965C32.766 19.2365 32.0039 19.9986 31.0639 19.9986H5.10642C4.16637 19.9986 3.4043 19.2365 3.4043 18.2965V5.95603C3.4043 5.01597 4.16637 4.25391 5.10642 4.25391H31.0639ZM31.0639 4.67944H5.10642C4.40138 4.67944 3.82983 5.25099 3.82983 5.95603V18.2965C3.82983 19.0015 4.40138 19.5731 5.10642 19.5731H31.0639C31.7689 19.5731 32.3405 19.0015 32.3405 18.2965V5.95603C32.3405 5.25099 31.7689 4.67944 31.0639 4.67944Z" fill="#00BEB4"/>
<path d="M31.0639 22.5547H5.10642C4.16637 22.5547 3.4043 23.3168 3.4043 24.2568V34.8951C3.4043 35.8352 4.16637 36.5972 5.10642 36.5972H31.0639C32.0039 36.5972 32.766 35.8352 32.766 34.8951V24.2568C32.766 23.3168 32.0039 22.5547 31.0639 22.5547Z" fill="#00BEB4" fill-opacity="0.1"/>
<path d="M31.0639 22.5547C32.0039 22.5547 32.766 23.3168 32.766 24.2568V34.8951C32.766 35.8352 32.0039 36.5972 31.0639 36.5972H5.10642C4.16637 36.5972 3.4043 35.8352 3.4043 34.8951V24.2568C3.4043 23.3168 4.16637 22.5547 5.10642 22.5547H31.0639ZM31.0639 22.9802H5.10642C4.40138 22.9802 3.82983 23.5518 3.82983 24.2568V34.8951C3.82983 35.6002 4.40138 36.1717 5.10642 36.1717H31.0639C31.7689 36.1717 32.3405 35.6002 32.3405 34.8951V24.2568C32.3405 23.5518 31.7689 22.9802 31.0639 22.9802Z" fill="#00BEB4"/>
<path d="M10.6384 14.8949C12.2835 14.8949 13.6171 13.5613 13.6171 11.9162C13.6171 10.2711 12.2835 8.9375 10.6384 8.9375C8.99329 8.9375 7.65967 10.2711 7.65967 11.9162C7.65967 13.5613 8.99329 14.8949 10.6384 14.8949Z" fill="#00BEB4"/>
<path d="M10.6384 32.766C12.2835 32.766 13.6171 31.4324 13.6171 29.7873C13.6171 28.1422 12.2835 26.8086 10.6384 26.8086C8.99329 26.8086 7.65967 28.1422 7.65967 29.7873C7.65967 31.4324 8.99329 32.766 10.6384 32.766Z" fill="#00BEB4"/>
<defs>
<linearGradient id="paint0_linear_488_37636" x1="933.617" y1="10.6367" x2="933.617" y2="2606.38" gradientUnits="userSpaceOnUse">
<stop stop-color="#C9F1EF"/>
<stop offset="1" stop-color="#00BEB4"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

@ -0,0 +1,15 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M35.3194 10.6387H20.4258C19.4857 10.6387 18.7236 11.4007 18.7236 12.3408V34.894C18.7236 35.834 19.4857 36.5961 20.4258 36.5961H35.3194C36.2594 36.5961 37.0215 35.834 37.0215 34.894V12.3408C37.0215 11.4007 36.2594 10.6387 35.3194 10.6387Z" fill="url(#paint0_linear_491_41413)"/>
<path d="M31.0639 4.25586H5.10642C4.16637 4.25586 3.4043 5.01793 3.4043 5.95799V18.2984C3.4043 19.2385 4.16637 20.0005 5.10642 20.0005H31.0639C32.0039 20.0005 32.766 19.2385 32.766 18.2984V5.95799C32.766 5.01793 32.0039 4.25586 31.0639 4.25586Z" fill="#00BEB4" fill-opacity="0.2"/>
<path d="M31.0639 4.25586C32.0039 4.25586 32.766 5.01793 32.766 5.95799V18.2984C32.766 19.2385 32.0039 20.0005 31.0639 20.0005H5.10642C4.16637 20.0005 3.4043 19.2385 3.4043 18.2984V5.95799C3.4043 5.01793 4.16637 4.25586 5.10642 4.25586H31.0639ZM31.0639 4.68139H5.10642C4.40138 4.68139 3.82983 5.25294 3.82983 5.95799V18.2984C3.82983 19.0035 4.40138 19.575 5.10642 19.575H31.0639C31.7689 19.575 32.3405 19.0035 32.3405 18.2984V5.95799C32.3405 5.25294 31.7689 4.68139 31.0639 4.68139Z" fill="#226365"/>
<path d="M31.0639 22.5527H5.10642C4.16637 22.5527 3.4043 23.3148 3.4043 24.2549V34.8932C3.4043 35.8332 4.16637 36.5953 5.10642 36.5953H31.0639C32.0039 36.5953 32.766 35.8332 32.766 34.8932V24.2549C32.766 23.3148 32.0039 22.5527 31.0639 22.5527Z" fill="#3A9093" fill-opacity="0.2"/>
<path d="M31.0639 22.5527C32.0039 22.5527 32.766 23.3148 32.766 24.2549V34.8932C32.766 35.8332 32.0039 36.5953 31.0639 36.5953H5.10642C4.16637 36.5953 3.4043 35.8332 3.4043 34.8932V24.2549C3.4043 23.3148 4.16637 22.5527 5.10642 22.5527H31.0639ZM31.0639 22.9783H5.10642C4.40138 22.9783 3.82983 23.5498 3.82983 24.2549V34.8932C3.82983 35.5982 4.40138 36.1698 5.10642 36.1698H31.0639C31.7689 36.1698 32.3405 35.5982 32.3405 34.8932V24.2549C32.3405 23.5498 31.7689 22.9783 31.0639 22.9783Z" fill="#226365"/>
<path d="M10.6384 14.893C12.2835 14.893 13.6171 13.5594 13.6171 11.9143C13.6171 10.2692 12.2835 8.93555 10.6384 8.93555C8.99329 8.93555 7.65967 10.2692 7.65967 11.9143C7.65967 13.5594 8.99329 14.893 10.6384 14.893Z" fill="#3A9093"/>
<path d="M10.6384 32.766C12.2835 32.766 13.6171 31.4324 13.6171 29.7873C13.6171 28.1422 12.2835 26.8086 10.6384 26.8086C8.99329 26.8086 7.65967 28.1422 7.65967 29.7873C7.65967 31.4324 8.99329 32.766 10.6384 32.766Z" fill="#3A9093"/>
<defs>
<linearGradient id="paint0_linear_491_41413" x1="933.617" y1="10.6387" x2="933.617" y2="2606.38" gradientUnits="userSpaceOnUse">
<stop stop-color="#1B3C3D"/>
<stop offset="1" stop-color="#164142"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

@ -1 +0,0 @@
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1756884949583" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="11332" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M190.464 489.472h327.68v40.96h-327.68z" fill="#C7DCFE" p-id="11333"></path><path d="M482.34496 516.5056l111.26784-308.20352 38.54336 13.9264L520.86784 530.432z" fill="#C7DCFE" p-id="11334"></path><path d="M620.544 196.608m-122.88 0a122.88 122.88 0 1 0 245.76 0 122.88 122.88 0 1 0-245.76 0Z" fill="#8FB8FC" p-id="11335"></path><path d="M182.272 509.952m-122.88 0a122.88 122.88 0 1 0 245.76 0 122.88 122.88 0 1 0-245.76 0Z" fill="#C7DCFE" p-id="11336"></path><path d="M558.65344 520.9088l283.77088 163.84-20.48 35.47136-283.77088-163.84z" fill="#C7DCFE" p-id="11337"></path><path d="M841.728 686.08m-122.88 0a122.88 122.88 0 1 0 245.76 0 122.88 122.88 0 1 0-245.76 0Z" fill="#B3CEFE" p-id="11338"></path><path d="M448.67584 803.77856l49.60256-323.91168 40.48896 6.20544-49.60256 323.91168z" fill="#C7DCFE" p-id="11339"></path><path d="M512 530.432m-143.36 0a143.36 143.36 0 1 0 286.72 0 143.36 143.36 0 1 0-286.72 0Z" fill="#4185FF" p-id="11340"></path><path d="M462.848 843.776m-102.4 0a102.4 102.4 0 1 0 204.8 0 102.4 102.4 0 1 0-204.8 0Z" fill="#8FB8FC" p-id="11341"></path></svg>

Before

Width:  |  Height:  |  Size: 1.4 KiB

View File

@ -0,0 +1,6 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M21.8074 21.9283L30.4051 33.9033C30.9531 34.667 30.7785 35.7307 30.0148 36.2787C29.7258 36.4865 29.3785 36.5982 29.0223 36.5982H11.8273C10.8871 36.5982 10.125 35.8361 10.125 34.8963C10.125 34.54 10.2367 34.1928 10.4445 33.9033L19.0422 21.9283C19.5902 21.1646 20.6539 20.99 21.4176 21.5385C21.5676 21.6463 21.6996 21.7779 21.8074 21.9283Z" fill="#C6EFED"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M5.94336 3.39844H34.0285C35.9086 3.39844 37.4328 4.92266 37.4328 6.80273V27.2281C37.4328 29.1082 35.9086 30.6324 34.0285 30.6324H5.94336C4.06328 30.6324 2.53906 29.1082 2.53906 27.2281V6.80273C2.53906 4.92266 4.06328 3.39844 5.94336 3.39844Z" fill="#00BEB4" fill-opacity="0.2"/>
<path d="M34.0422 3.40625C35.9223 3.40625 37.4465 4.93047 37.4465 6.81055V27.2359C37.4465 29.116 35.9223 30.6402 34.0422 30.6402H5.95703C4.07695 30.6402 2.55273 29.116 2.55273 27.2359V6.81055C2.55273 4.93047 4.07695 3.40625 5.95703 3.40625H34.0422ZM34.0422 3.83164H5.95703C4.31211 3.83164 2.97852 5.16523 2.97852 6.81055V27.2359C2.97852 28.8812 4.31211 30.2148 5.95703 30.2148H34.0422C35.6871 30.2148 37.0207 28.8812 37.0207 27.2359V6.81055C37.0207 5.16523 35.6871 3.83164 34.0422 3.83164Z" fill="#00BEB4"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M19.9785 11.6797C20.6836 11.6797 21.2551 12.2512 21.2551 12.9562V21.0414C21.2551 21.7465 20.6836 22.318 19.9785 22.318C19.2734 22.318 18.702 21.7465 18.702 21.0414V12.9562C18.702 12.2512 19.2734 11.6797 19.9785 11.6797ZM11.0422 11.6797C11.7473 11.6797 12.3187 12.2512 12.3187 12.9562V21.0414C12.3187 21.7465 11.7473 22.318 11.0422 22.318C10.3371 22.318 9.76562 21.7465 9.76562 21.0414V12.9562C9.76562 12.2512 10.3371 11.6797 11.0422 11.6797ZM28.9145 11.6797C29.6195 11.6797 30.191 12.2512 30.191 12.9562V21.0414C30.191 21.7465 29.6195 22.318 28.9145 22.318C28.2094 22.318 27.6379 21.7465 27.6379 21.0414V12.9562C27.6379 12.2512 28.2094 11.6797 28.9145 11.6797Z" fill="#00BEB4"/>
</svg>

After

Width:  |  Height:  |  Size: 2.0 KiB

View File

@ -0,0 +1,6 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M21.8074 21.9264L30.4051 33.9014C30.9531 34.665 30.7785 35.7287 30.0148 36.2767C29.7258 36.4846 29.3785 36.5963 29.0223 36.5963H11.8273C10.8871 36.5963 10.125 35.8342 10.125 34.8943C10.125 34.5381 10.2367 34.1908 10.4445 33.9014L19.0422 21.9264C19.5902 21.1627 20.6539 20.9881 21.4176 21.5365C21.5676 21.6443 21.6996 21.776 21.8074 21.9264Z" fill="#1C3C3D"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M5.94336 3.39844H34.0285C35.9086 3.39844 37.4328 4.92266 37.4328 6.80273V27.2281C37.4328 29.1082 35.9086 30.6324 34.0285 30.6324H5.94336C4.06328 30.6324 2.53906 29.1082 2.53906 27.2281V6.80273C2.53906 4.92266 4.06328 3.39844 5.94336 3.39844Z" fill="#00BEB4" fill-opacity="0.2"/>
<path d="M34.0422 3.4043C35.9223 3.4043 37.4465 4.92852 37.4465 6.80859V27.234C37.4465 29.1141 35.9223 30.6383 34.0422 30.6383H5.95703C4.07695 30.6383 2.55273 29.1141 2.55273 27.234V6.80859C2.55273 4.92852 4.07695 3.4043 5.95703 3.4043H34.0422ZM34.0422 3.82969H5.95703C4.31211 3.82969 2.97852 5.16328 2.97852 6.80859V27.234C2.97852 28.8793 4.31211 30.2129 5.95703 30.2129H34.0422C35.6871 30.2129 37.0207 28.8793 37.0207 27.234V6.80859C37.0207 5.16328 35.6871 3.82969 34.0422 3.82969Z" fill="#1B3B3C"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M19.9785 11.6797C20.6836 11.6797 21.2551 12.2512 21.2551 12.9562V21.0414C21.2551 21.7465 20.6836 22.318 19.9785 22.318C19.2734 22.318 18.702 21.7465 18.702 21.0414V12.9562C18.702 12.2512 19.2734 11.6797 19.9785 11.6797ZM11.0422 11.6797C11.7473 11.6797 12.3187 12.2512 12.3187 12.9562V21.0414C12.3187 21.7465 11.7473 22.318 11.0422 22.318C10.3371 22.318 9.76562 21.7465 9.76562 21.0414V12.9562C9.76562 12.2512 10.3371 11.6797 11.0422 11.6797ZM28.9145 11.6797C29.6195 11.6797 30.191 12.2512 30.191 12.9562V21.0414C30.191 21.7465 29.6195 22.318 28.9145 22.318C28.2094 22.318 27.6379 21.7465 27.6379 21.0414V12.9562C27.6379 12.2512 28.2094 11.6797 28.9145 11.6797Z" fill="#00BEB4"/>
</svg>

After

Width:  |  Height:  |  Size: 2.0 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 36 KiB

View File

@ -0,0 +1,6 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.0291 4.67969C11.8025 4.67969 12.4787 5.20078 12.6752 5.94844L13.3494 8.50937H31.4275C33.1599 8.50937 34.6158 9.81055 34.8103 11.5316L37.0205 31.1062C37.231 32.9746 35.8877 34.6602 34.0193 34.8711C33.8927 34.8852 33.765 34.8926 33.6377 34.8926H6.30289C4.92476 34.8926 3.79547 33.7988 3.75094 32.4215L3.11734 12.7746H3.115L2.90719 6.4375C2.87633 5.49805 3.61304 4.71133 4.5525 4.68047C4.57086 4.68008 4.58961 4.67969 4.60836 4.67969H11.0291Z" fill="#00BEB4" fill-opacity="0.1"/>
<path d="M11.0291 4.67969C11.8025 4.67969 12.4787 5.20078 12.6752 5.94844L13.349 8.50937H31.4275C33.1599 8.50937 34.6158 9.81055 34.8103 11.5316L37.0205 31.1062C37.231 32.9746 35.8877 34.6602 34.0193 34.8711C33.8927 34.8852 33.765 34.8926 33.6377 34.8926H6.30289C4.92476 34.8926 3.79547 33.7988 3.75094 32.4215L3.11656 12.7742L2.90719 6.4375C2.87633 5.49805 3.61304 4.71133 4.5525 4.68047L4.58023 4.67969H11.0291ZM11.0291 5.10508H4.59078L4.56656 5.10586C3.86187 5.12891 3.30914 5.71914 3.33219 6.42344L3.54195 12.7605L4.17633 32.4078C4.21344 33.5555 5.15445 34.4668 6.30289 34.4668H33.6377C33.749 34.4668 33.8607 34.4605 33.9716 34.448C35.6064 34.2637 36.7822 32.7887 36.5974 31.1539L34.3873 11.5797C34.2173 10.0734 32.9431 8.93516 31.4275 8.93516H13.0209L12.9377 8.61758L12.2638 6.05703C12.1162 5.49609 11.6091 5.10508 11.0291 5.10508Z" fill="#00BEB4"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M9.72812 12.7656H36.6539C38.0637 12.7656 39.207 13.9086 39.207 15.3188C39.207 15.4328 39.1992 15.5465 39.184 15.6594L36.9922 31.943C36.7648 33.6324 35.323 34.8934 33.6184 34.8934H6.37969C4.96953 34.8934 3.82617 33.75 3.82617 32.3398C3.82617 32.2102 3.83633 32.0801 3.85586 31.952L6.36367 15.6523C6.61914 13.9914 8.04805 12.7656 9.72812 12.7656Z" fill="#CAF2F0"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M8.98438 14.6172H20.4848C20.899 14.6172 21.2348 14.9529 21.2348 15.3672C21.2348 15.7814 20.899 16.1172 20.4848 16.1172H8.98438C8.57013 16.1172 8.23438 15.7814 8.23438 15.3672C8.23438 14.9529 8.57013 14.6172 8.98438 14.6172Z" fill="#00BEB4"/>
</svg>

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@ -0,0 +1,6 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.0291 4.68164C11.8025 4.68164 12.4787 5.20273 12.6752 5.95039L13.3494 8.51133H31.4275C33.1599 8.51133 34.6158 9.8125 34.8103 11.5336L37.0205 31.1082C37.231 32.9766 35.8877 34.6621 34.0193 34.873C33.8927 34.8871 33.765 34.8945 33.6377 34.8945H6.30289C4.92476 34.8945 3.79547 33.8008 3.75094 32.4234L3.11734 12.7766H3.115L2.90719 6.43945C2.87633 5.5 3.61304 4.71328 4.5525 4.68242C4.57086 4.68203 4.58961 4.68164 4.60836 4.68164H11.0291Z" fill="#1F3232"/>
<path d="M11.0291 4.68164C11.8025 4.68164 12.4787 5.20273 12.6752 5.95039L13.349 8.51133H31.4275C33.1599 8.51133 34.6158 9.8125 34.8103 11.5336L37.0205 31.1082C37.231 32.9766 35.8877 34.6621 34.0193 34.873C33.8927 34.8871 33.765 34.8945 33.6377 34.8945H6.30289C4.92476 34.8945 3.79547 33.8008 3.75094 32.4234L3.11656 12.7762L2.90719 6.43945C2.87633 5.5 3.61304 4.71328 4.5525 4.68242L4.58023 4.68164H11.0291ZM11.0291 5.10703H4.59078L4.56656 5.10781C3.86187 5.13086 3.30914 5.72109 3.33219 6.42539L3.54195 12.7625L4.17633 32.4098C4.21344 33.5574 5.15445 34.4687 6.30289 34.4687H33.6377C33.749 34.4687 33.8607 34.4625 33.9716 34.45C35.6064 34.2656 36.7822 32.7906 36.5974 31.1559L34.3873 11.5816C34.2173 10.0754 32.9431 8.93711 31.4275 8.93711H13.0209L12.9377 8.61953L12.2638 6.05898C12.1162 5.49805 11.6091 5.10703 11.0291 5.10703Z" fill="#1B3B3C"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M9.72812 12.7656H36.6539C38.0637 12.7656 39.207 13.9086 39.207 15.3188C39.207 15.4328 39.1992 15.5465 39.184 15.6594L36.9922 31.943C36.7648 33.6324 35.323 34.8934 33.6184 34.8934H6.37969C4.96953 34.8934 3.82617 33.75 3.82617 32.3398C3.82617 32.2102 3.83633 32.0801 3.85586 31.952L6.36367 15.6523C6.61914 13.9914 8.04805 12.7656 9.72812 12.7656Z" fill="#1B3B3C"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M8.98438 14.6172H20.4848C20.899 14.6172 21.2348 14.9529 21.2348 15.3672C21.2348 15.7814 20.899 16.1172 20.4848 16.1172H8.98438C8.57013 16.1172 8.23438 15.7814 8.23438 15.3672C8.23438 14.9529 8.57013 14.6172 8.98438 14.6172Z" fill="#00BEB4"/>
</svg>

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@ -18,8 +18,11 @@ import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-reques
import { IModalProps } from '@/interfaces/common'; import { IModalProps } from '@/interfaces/common';
import { IParserConfig } from '@/interfaces/database/document'; import { IParserConfig } from '@/interfaces/database/document';
import { IChangeParserConfigRequestBody } from '@/interfaces/request/document'; import { IChangeParserConfigRequestBody } from '@/interfaces/request/document';
import {
ChunkMethodItem,
ParseTypeItem,
} from '@/pages/dataset/dataset-setting/configuration/common-item';
import { zodResolver } from '@hookform/resolvers/zod'; import { zodResolver } from '@hookform/resolvers/zod';
import get from 'lodash/get';
import omit from 'lodash/omit'; import omit from 'lodash/omit';
import {} from 'module'; import {} from 'module';
import { useEffect, useMemo } from 'react'; import { useEffect, useMemo } from 'react';
@ -30,24 +33,17 @@ import {
AutoKeywordsFormField, AutoKeywordsFormField,
AutoQuestionsFormField, AutoQuestionsFormField,
} from '../auto-keywords-form-field'; } from '../auto-keywords-form-field';
import { DataFlowSelect } from '../data-pipeline-select';
import { DelimiterFormField } from '../delimiter-form-field'; import { DelimiterFormField } from '../delimiter-form-field';
import { EntityTypesFormField } from '../entity-types-form-field'; import { EntityTypesFormField } from '../entity-types-form-field';
import { ExcelToHtmlFormField } from '../excel-to-html-form-field'; import { ExcelToHtmlFormField } from '../excel-to-html-form-field';
import { FormContainer } from '../form-container'; import { FormContainer } from '../form-container';
import { LayoutRecognizeFormField } from '../layout-recognize-form-field'; import { LayoutRecognizeFormField } from '../layout-recognize-form-field';
import { MaxTokenNumberFormField } from '../max-token-number-from-field'; import { MaxTokenNumberFormField } from '../max-token-number-from-field';
import {
UseGraphRagFormField,
showGraphRagItems,
} from '../parse-configuration/graph-rag-form-fields';
import RaptorFormFields, {
showRaptorParseConfiguration,
} from '../parse-configuration/raptor-form-fields';
import { ButtonLoading } from '../ui/button'; import { ButtonLoading } from '../ui/button';
import { Input } from '../ui/input'; import { Input } from '../ui/input';
import { RAGFlowSelect } from '../ui/select';
import { DynamicPageRange } from './dynamic-page-range'; import { DynamicPageRange } from './dynamic-page-range';
import { useFetchParserListOnMount, useShowAutoKeywords } from './hooks'; import { useShowAutoKeywords } from './hooks';
import { import {
useDefaultParserValues, useDefaultParserValues,
useFillDefaultValueOnMount, useFillDefaultValueOnMount,
@ -62,6 +58,7 @@ interface IProps
}> { }> {
loading: boolean; loading: boolean;
parserId: string; parserId: string;
pipelineId?: string;
parserConfig: IParserConfig; parserConfig: IParserConfig;
documentExtension: string; documentExtension: string;
documentId: string; documentId: string;
@ -80,6 +77,7 @@ export function ChunkMethodDialog({
hideModal, hideModal,
onOk, onOk,
parserId, parserId,
pipelineId,
documentExtension, documentExtension,
visible, visible,
parserConfig, parserConfig,
@ -87,8 +85,6 @@ export function ChunkMethodDialog({
}: IProps) { }: IProps) {
const { t } = useTranslation(); const { t } = useTranslation();
const { parserList } = useFetchParserListOnMount(documentExtension);
const { data: knowledgeDetails } = useFetchKnowledgeBaseConfiguration(); const { data: knowledgeDetails } = useFetchKnowledgeBaseConfiguration();
const useGraphRag = useMemo(() => { const useGraphRag = useMemo(() => {
@ -99,46 +95,59 @@ export function ChunkMethodDialog({
const fillDefaultParserValue = useFillDefaultValueOnMount(); const fillDefaultParserValue = useFillDefaultValueOnMount();
const FormSchema = z.object({ const FormSchema = z
parser_id: z .object({
.string() parseType: z.number(),
.min(1, { parser_id: z
message: t('common.pleaseSelect'), .string()
}) .min(1, {
.trim(), message: t('common.pleaseSelect'),
parser_config: z.object({
task_page_size: z.coerce.number().optional(),
layout_recognize: z.string().optional(),
chunk_token_num: z.coerce.number().optional(),
delimiter: z.string().optional(),
auto_keywords: z.coerce.number().optional(),
auto_questions: z.coerce.number().optional(),
html4excel: z.boolean().optional(),
raptor: z
.object({
use_raptor: z.boolean().optional(),
prompt: z.string().optional().optional(),
max_token: z.coerce.number().optional(),
threshold: z.coerce.number().optional(),
max_cluster: z.coerce.number().optional(),
random_seed: z.coerce.number().optional(),
}) })
.optional(), .trim(),
graphrag: z.object({ pipeline_id: z.string().optional(),
use_graphrag: z.boolean().optional(), parser_config: z.object({
task_page_size: z.coerce.number().optional(),
layout_recognize: z.string().optional(),
chunk_token_num: z.coerce.number().optional(),
delimiter: z.string().optional(),
auto_keywords: z.coerce.number().optional(),
auto_questions: z.coerce.number().optional(),
html4excel: z.boolean().optional(),
// raptor: z
// .object({
// use_raptor: z.boolean().optional(),
// prompt: z.string().optional().optional(),
// max_token: z.coerce.number().optional(),
// threshold: z.coerce.number().optional(),
// max_cluster: z.coerce.number().optional(),
// random_seed: z.coerce.number().optional(),
// })
// .optional(),
// graphrag: z.object({
// use_graphrag: z.boolean().optional(),
// }),
entity_types: z.array(z.string()).optional(),
pages: z
.array(z.object({ from: z.coerce.number(), to: z.coerce.number() }))
.optional(),
}), }),
entity_types: z.array(z.string()).optional(), })
pages: z .superRefine((data, ctx) => {
.array(z.object({ from: z.coerce.number(), to: z.coerce.number() })) if (data.parseType === 2 && !data.pipeline_id) {
.optional(), ctx.addIssue({
}), path: ['pipeline_id'],
}); message: t('common.pleaseSelect'),
code: 'custom',
});
}
});
const form = useForm<z.infer<typeof FormSchema>>({ const form = useForm<z.infer<typeof FormSchema>>({
resolver: zodResolver(FormSchema), resolver: zodResolver(FormSchema),
defaultValues: { defaultValues: {
parser_id: parserId, parser_id: parserId || '',
pipeline_id: pipelineId || '',
parseType: pipelineId ? 2 : 1,
parser_config: defaultParserValues, parser_config: defaultParserValues,
}, },
}); });
@ -200,17 +209,19 @@ export function ChunkMethodDialog({
const pages = const pages =
parserConfig?.pages?.map((x) => ({ from: x[0], to: x[1] })) ?? []; parserConfig?.pages?.map((x) => ({ from: x[0], to: x[1] })) ?? [];
form.reset({ form.reset({
parser_id: parserId, parser_id: parserId || '',
pipeline_id: pipelineId || '',
parseType: pipelineId ? 2 : 1,
parser_config: fillDefaultParserValue({ parser_config: fillDefaultParserValue({
pages: pages.length > 0 ? pages : [{ from: 1, to: 1024 }], pages: pages.length > 0 ? pages : [{ from: 1, to: 1024 }],
...omit(parserConfig, 'pages'), ...omit(parserConfig, 'pages'),
graphrag: { // graphrag: {
use_graphrag: get( // use_graphrag: get(
parserConfig, // parserConfig,
'graphrag.use_graphrag', // 'graphrag.use_graphrag',
useGraphRag, // useGraphRag,
), // ),
}, // },
}), }),
}); });
} }
@ -220,10 +231,20 @@ export function ChunkMethodDialog({
knowledgeDetails.parser_config, knowledgeDetails.parser_config,
parserConfig, parserConfig,
parserId, parserId,
pipelineId,
useGraphRag, useGraphRag,
visible, visible,
]); ]);
const parseType = useWatch({
control: form.control,
name: 'parseType',
defaultValue: pipelineId ? 2 : 1,
});
useEffect(() => {
if (parseType === 1) {
form.setValue('pipeline_id', '');
}
}, [parseType, form]);
return ( return (
<Dialog open onOpenChange={hideModal}> <Dialog open onOpenChange={hideModal}>
<DialogContent className="max-w-[50vw]"> <DialogContent className="max-w-[50vw]">
@ -237,7 +258,17 @@ export function ChunkMethodDialog({
id={FormId} id={FormId}
> >
<FormContainer> <FormContainer>
<FormField <ParseTypeItem />
{parseType === 1 && <ChunkMethodItem></ChunkMethodItem>}
{parseType === 2 && (
<DataFlowSelect
isMult={false}
// toDataPipeline={navigateToAgents}
formFieldName="pipeline_id"
/>
)}
{/* <FormField
control={form.control} control={form.control}
name="parser_id" name="parser_id"
render={({ field }) => ( render={({ field }) => (
@ -252,9 +283,11 @@ export function ChunkMethodDialog({
<FormMessage /> <FormMessage />
</FormItem> </FormItem>
)} )}
/> /> */}
{showPages && <DynamicPageRange></DynamicPageRange>} {showPages && parseType === 1 && (
{showPages && layoutRecognize && ( <DynamicPageRange></DynamicPageRange>
)}
{showPages && parseType === 1 && layoutRecognize && (
<FormField <FormField
control={form.control} control={form.control}
name="parser_config.task_page_size" name="parser_config.task_page_size"
@ -279,50 +312,60 @@ export function ChunkMethodDialog({
/> />
)} )}
</FormContainer> </FormContainer>
<FormContainer {parseType === 1 && (
show={showOne || showMaxTokenNumber} <>
className="space-y-3" <FormContainer
> show={showOne || showMaxTokenNumber}
{showOne && <LayoutRecognizeFormField></LayoutRecognizeFormField>} className="space-y-3"
{showMaxTokenNumber && ( >
<> {showOne && (
<MaxTokenNumberFormField <LayoutRecognizeFormField></LayoutRecognizeFormField>
max={ )}
selectedTag === DocumentParserType.KnowledgeGraph {showMaxTokenNumber && (
? 8192 * 2 <>
: 2048 <MaxTokenNumberFormField
} max={
></MaxTokenNumberFormField> selectedTag === DocumentParserType.KnowledgeGraph
<DelimiterFormField></DelimiterFormField> ? 8192 * 2
</> : 2048
)} }
</FormContainer> ></MaxTokenNumberFormField>
<FormContainer <DelimiterFormField></DelimiterFormField>
show={showAutoKeywords(selectedTag) || showExcelToHtml} </>
className="space-y-3" )}
>
{showAutoKeywords(selectedTag) && (
<>
<AutoKeywordsFormField></AutoKeywordsFormField>
<AutoQuestionsFormField></AutoQuestionsFormField>
</>
)}
{showExcelToHtml && <ExcelToHtmlFormField></ExcelToHtmlFormField>}
</FormContainer>
{showRaptorParseConfiguration(
selectedTag as DocumentParserType,
) && (
<FormContainer>
<RaptorFormFields></RaptorFormFields>
</FormContainer>
)}
{showGraphRagItems(selectedTag as DocumentParserType) &&
useGraphRag && (
<FormContainer>
<UseGraphRagFormField></UseGraphRagFormField>
</FormContainer> </FormContainer>
)} <FormContainer
{showEntityTypes && <EntityTypesFormField></EntityTypesFormField>} show={showAutoKeywords(selectedTag) || showExcelToHtml}
className="space-y-3"
>
{showAutoKeywords(selectedTag) && (
<>
<AutoKeywordsFormField></AutoKeywordsFormField>
<AutoQuestionsFormField></AutoQuestionsFormField>
</>
)}
{showExcelToHtml && (
<ExcelToHtmlFormField></ExcelToHtmlFormField>
)}
</FormContainer>
{/* {showRaptorParseConfiguration(
selectedTag as DocumentParserType,
) && (
<FormContainer>
<RaptorFormFields></RaptorFormFields>
</FormContainer>
)} */}
{/* {showGraphRagItems(selectedTag as DocumentParserType) &&
useGraphRag && (
<FormContainer>
<UseGraphRagFormField></UseGraphRagFormField>
</FormContainer>
)} */}
{showEntityTypes && (
<EntityTypesFormField></EntityTypesFormField>
)}
</>
)}
</form> </form>
</Form> </Form>
<DialogFooter> <DialogFooter>

View File

@ -1,7 +1,7 @@
import { IParserConfig } from '@/interfaces/database/document'; import { IParserConfig } from '@/interfaces/database/document';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { DocumentType } from '../layout-recognize-form-field'; import { ParseDocumentType } from '../layout-recognize-form-field';
export function useDefaultParserValues() { export function useDefaultParserValues() {
const { t } = useTranslation(); const { t } = useTranslation();
@ -9,23 +9,23 @@ export function useDefaultParserValues() {
const defaultParserValues = useMemo(() => { const defaultParserValues = useMemo(() => {
const defaultParserValues = { const defaultParserValues = {
task_page_size: 12, task_page_size: 12,
layout_recognize: DocumentType.DeepDOC, layout_recognize: ParseDocumentType.DeepDOC,
chunk_token_num: 512, chunk_token_num: 512,
delimiter: '\n', delimiter: '\n',
auto_keywords: 0, auto_keywords: 0,
auto_questions: 0, auto_questions: 0,
html4excel: false, html4excel: false,
raptor: { // raptor: {
use_raptor: false, // use_raptor: false,
prompt: t('knowledgeConfiguration.promptText'), // prompt: t('knowledgeConfiguration.promptText'),
max_token: 256, // max_token: 256,
threshold: 0.1, // threshold: 0.1,
max_cluster: 64, // max_cluster: 64,
random_seed: 0, // random_seed: 0,
}, // },
graphrag: { // graphrag: {
use_graphrag: false, // use_graphrag: false,
}, // },
entity_types: [], entity_types: [],
pages: [], pages: [],
}; };

View File

@ -8,7 +8,7 @@ import {
AlertDialogTitle, AlertDialogTitle,
AlertDialogTrigger, AlertDialogTrigger,
} from '@/components/ui/alert-dialog'; } from '@/components/ui/alert-dialog';
import { PropsWithChildren } from 'react'; import { DialogProps } from '@radix-ui/react-dialog';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
interface IProps { interface IProps {
@ -24,7 +24,10 @@ export function ConfirmDeleteDialog({
onOk, onOk,
onCancel, onCancel,
hidden = false, hidden = false,
}: IProps & PropsWithChildren) { onOpenChange,
open,
defaultOpen,
}: IProps & DialogProps) {
const { t } = useTranslation(); const { t } = useTranslation();
if (hidden) { if (hidden) {
@ -32,7 +35,11 @@ export function ConfirmDeleteDialog({
} }
return ( return (
<AlertDialog> <AlertDialog
onOpenChange={onOpenChange}
open={open}
defaultOpen={defaultOpen}
>
<AlertDialogTrigger asChild>{children}</AlertDialogTrigger> <AlertDialogTrigger asChild>{children}</AlertDialogTrigger>
<AlertDialogContent <AlertDialogContent
onSelect={(e) => e.preventDefault()} onSelect={(e) => e.preventDefault()}

View File

@ -22,7 +22,7 @@ const Languages = [
'Vietnamese', 'Vietnamese',
]; ];
const options = Languages.map((x) => ({ export const crossLanguageOptions = Languages.map((x) => ({
label: t('language.' + toLower(x)), label: t('language.' + toLower(x)),
value: x, value: x,
})); }));
@ -30,11 +30,13 @@ const options = Languages.map((x) => ({
type CrossLanguageItemProps = { type CrossLanguageItemProps = {
name?: string; name?: string;
vertical?: boolean; vertical?: boolean;
label?: string;
}; };
export const CrossLanguageFormField = ({ export const CrossLanguageFormField = ({
name = 'prompt_config.cross_languages', name = 'prompt_config.cross_languages',
vertical = true, vertical = true,
label,
}: CrossLanguageItemProps) => { }: CrossLanguageItemProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const form = useFormContext(); const form = useFormContext();
@ -53,11 +55,11 @@ export const CrossLanguageFormField = ({
})} })}
> >
<FormLabel tooltip={t('chat.crossLanguageTip')}> <FormLabel tooltip={t('chat.crossLanguageTip')}>
{t('chat.crossLanguage')} {label || t('chat.crossLanguage')}
</FormLabel> </FormLabel>
<FormControl> <FormControl>
<MultiSelect <MultiSelect
options={options} options={crossLanguageOptions}
placeholder={t('fileManager.pleaseSelect')} placeholder={t('fileManager.pleaseSelect')}
maxCount={100} maxCount={100}
{...field} {...field}

View File

@ -0,0 +1,120 @@
import { AgentCategory } from '@/constants/agent';
import { useTranslate } from '@/hooks/common-hooks';
import { useFetchAgentList } from '@/hooks/use-agent-request';
import { buildSelectOptions } from '@/utils/component-util';
import { ArrowUpRight } from 'lucide-react';
import { useEffect, useMemo } from 'react';
import { useFormContext } from 'react-hook-form';
import { SelectWithSearch } from '../originui/select-with-search';
import {
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from '../ui/form';
import { MultiSelect } from '../ui/multi-select';
export interface IDataPipelineSelectNode {
id?: string;
name?: string;
avatar?: string;
}
interface IProps {
toDataPipeline?: () => void;
formFieldName: string;
isMult?: boolean;
setDataList?: (data: IDataPipelineSelectNode[]) => void;
}
export function DataFlowSelect(props: IProps) {
const { toDataPipeline, formFieldName, isMult = false, setDataList } = props;
const { t } = useTranslate('knowledgeConfiguration');
const form = useFormContext();
const toDataPipLine = () => {
toDataPipeline?.();
};
const { data: dataPipelineOptions } = useFetchAgentList({
canvas_category: AgentCategory.DataflowCanvas,
});
const options = useMemo(() => {
const option = buildSelectOptions(
dataPipelineOptions?.canvas,
'id',
'title',
);
return option || [];
}, [dataPipelineOptions]);
const nodes = useMemo(() => {
return (
dataPipelineOptions?.canvas?.map((item) => {
return {
id: item?.id,
name: item?.title,
avatar: item?.avatar,
};
}) || []
);
}, [dataPipelineOptions]);
useEffect(() => {
setDataList?.(nodes);
}, [nodes, setDataList]);
return (
<FormField
control={form.control}
name={formFieldName}
render={({ field }) => (
<FormItem className=" items-center space-y-0 ">
<div className="flex flex-col gap-1">
<div className="flex gap-2 justify-between ">
<FormLabel
tooltip={t('dataFlowTip')}
className="text-sm text-text-primary whitespace-wrap "
>
{t('dataPipeline')}
</FormLabel>
{toDataPipeline && (
<div
className="text-sm flex text-text-primary cursor-pointer"
onClick={toDataPipLine}
>
{t('buildItFromScratch')}
<ArrowUpRight size={14} />
</div>
)}
</div>
<div className="text-muted-foreground">
<FormControl>
<>
{!isMult && (
<SelectWithSearch
{...field}
placeholder={t('dataFlowPlaceholder')}
options={options}
/>
)}
{isMult && (
<MultiSelect
{...field}
onValueChange={field.onChange}
placeholder={t('dataFlowPlaceholder')}
options={options}
/>
)}
</>
</FormControl>
</div>
</div>
<div className="flex pt-1">
<FormMessage />
</div>
</FormItem>
)}
/>
);
}

View File

@ -16,11 +16,17 @@ interface IProps {
} }
export const DelimiterInput = forwardRef<HTMLInputElement, InputProps & IProps>( export const DelimiterInput = forwardRef<HTMLInputElement, InputProps & IProps>(
({ value, onChange, maxLength, defaultValue }, ref) => { ({ value, onChange, maxLength, defaultValue, ...props }, ref) => {
const nextValue = value?.replaceAll('\n', '\\n'); const nextValue = value
?.replaceAll('\n', '\\n')
.replaceAll('\t', '\\t')
.replaceAll('\r', '\\r');
const handleInputChange = (e: React.ChangeEvent<HTMLInputElement>) => { const handleInputChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const val = e.target.value; const val = e.target.value;
const nextValue = val.replaceAll('\\n', '\n'); const nextValue = val
.replaceAll('\\n', '\n')
.replaceAll('\\t', '\t')
.replaceAll('\\r', '\r');
onChange?.(nextValue); onChange?.(nextValue);
}; };
return ( return (
@ -30,6 +36,7 @@ export const DelimiterInput = forwardRef<HTMLInputElement, InputProps & IProps>(
maxLength={maxLength} maxLength={maxLength}
defaultValue={defaultValue} defaultValue={defaultValue}
ref={ref} ref={ref}
{...props}
></Input> ></Input>
); );
}, },

View File

@ -26,7 +26,7 @@ export function EntityTypesFormField({
return ( return (
<FormItem className=" items-center space-y-0 "> <FormItem className=" items-center space-y-0 ">
<div className="flex items-center"> <div className="flex items-center">
<FormLabel className="text-sm text-muted-foreground whitespace-nowrap w-1/4"> <FormLabel className="text-sm whitespace-nowrap w-1/4">
<span className="text-red-600">*</span> {t('entityTypes')} <span className="text-red-600">*</span> {t('entityTypes')}
</FormLabel> </FormLabel>
<div className="w-3/4"> <div className="w-3/4">

View File

@ -1,24 +1,29 @@
// src/pages/dataset/file-logs/file-status-badge.tsx // src/pages/dataset/file-logs/file-status-badge.tsx
import { RunningStatus } from '@/pages/dataset/dataset/constant';
import { FC } from 'react'; import { FC } from 'react';
/**
* params: status: 0 not run yet 1 running, 2 cancel, 3 success, 4 fail
*/
interface StatusBadgeProps { interface StatusBadgeProps {
status: 'Success' | 'Failed' | 'Running' | 'Pending'; // status: 'Success' | 'Failed' | 'Running' | 'Pending';
status: RunningStatus;
name?: string;
} }
const FileStatusBadge: FC<StatusBadgeProps> = ({ status }) => { const FileStatusBadge: FC<StatusBadgeProps> = ({ status, name }) => {
const getStatusColor = () => { const getStatusColor = () => {
// #3ba05c → rgb(59, 160, 92) // state-success // #3ba05c → rgb(59, 160, 92) // state-success
// #d8494b → rgb(216, 73, 75) // state-error // #d8494b → rgb(216, 73, 75) // state-error
// #00beb4 → rgb(0, 190, 180) // accent-primary // #00beb4 → rgb(0, 190, 180) // accent-primary
// #faad14 → rgb(250, 173, 20) // state-warning // #faad14 → rgb(250, 173, 20) // state-warning
switch (status) { switch (status) {
case 'Success': case RunningStatus.DONE:
return `bg-[rgba(59,160,92,0.1)] text-state-success`; return `bg-[rgba(59,160,92,0.1)] text-state-success`;
case 'Failed': case RunningStatus.FAIL:
return `bg-[rgba(216,73,75,0.1)] text-state-error`; return `bg-[rgba(216,73,75,0.1)] text-state-error`;
case 'Running': case RunningStatus.RUNNING:
return `bg-[rgba(0,190,180,0.1)] text-accent-primary`; return `bg-[rgba(0,190,180,0.1)] text-accent-primary`;
case 'Pending': case RunningStatus.UNSTART:
return `bg-[rgba(250,173,20,0.1)] text-state-warning`; return `bg-[rgba(250,173,20,0.1)] text-state-warning`;
default: default:
return 'bg-gray-500/10 text-white'; return 'bg-gray-500/10 text-white';
@ -31,13 +36,13 @@ const FileStatusBadge: FC<StatusBadgeProps> = ({ status }) => {
// #00beb4 → rgb(0, 190, 180) // accent-primary // #00beb4 → rgb(0, 190, 180) // accent-primary
// #faad14 → rgb(250, 173, 20) // state-warning // #faad14 → rgb(250, 173, 20) // state-warning
switch (status) { switch (status) {
case 'Success': case RunningStatus.DONE:
return `bg-[rgba(59,160,92,1)] text-state-success`; return `bg-[rgba(59,160,92,1)] text-state-success`;
case 'Failed': case RunningStatus.FAIL:
return `bg-[rgba(216,73,75,1)] text-state-error`; return `bg-[rgba(216,73,75,1)] text-state-error`;
case 'Running': case RunningStatus.RUNNING:
return `bg-[rgba(0,190,180,1)] text-accent-primary`; return `bg-[rgba(0,190,180,1)] text-accent-primary`;
case 'Pending': case RunningStatus.UNSTART:
return `bg-[rgba(250,173,20,1)] text-state-warning`; return `bg-[rgba(250,173,20,1)] text-state-warning`;
default: default:
return 'bg-gray-500/10 text-white'; return 'bg-gray-500/10 text-white';
@ -46,10 +51,10 @@ const FileStatusBadge: FC<StatusBadgeProps> = ({ status }) => {
return ( return (
<span <span
className={`inline-flex items-center w-[75px] px-2 py-1 rounded-full text-xs font-medium ${getStatusColor(0.1)}`} className={`inline-flex items-center w-[75px] px-2 py-1 rounded-full text-xs font-medium ${getStatusColor()}`}
> >
<div className={`w-1 h-1 mr-1 rounded-full ${getBgStatusColor()}`}></div> <div className={`w-1 h-1 mr-1 rounded-full ${getBgStatusColor()}`}></div>
{status} {name || ''}
</span> </span>
); );
}; };

View File

@ -13,8 +13,15 @@ interface IProps {
onClick?: () => void; onClick?: () => void;
moreDropdown: React.ReactNode; moreDropdown: React.ReactNode;
sharedBadge?: ReactNode; sharedBadge?: ReactNode;
icon?: React.ReactNode;
} }
export function HomeCard({ data, onClick, moreDropdown, sharedBadge }: IProps) { export function HomeCard({
data,
onClick,
moreDropdown,
sharedBadge,
icon,
}: IProps) {
return ( return (
<Card <Card
className="bg-bg-card border-colors-outline-neutral-standard" className="bg-bg-card border-colors-outline-neutral-standard"
@ -32,10 +39,13 @@ export function HomeCard({ data, onClick, moreDropdown, sharedBadge }: IProps) {
/> />
</div> </div>
<div className="flex flex-col justify-between gap-1 flex-1 h-full w-[calc(100%-50px)]"> <div className="flex flex-col justify-between gap-1 flex-1 h-full w-[calc(100%-50px)]">
<section className="flex justify-between"> <section className="flex justify-between w-full">
<div className="text-[20px] font-bold w-80% leading-5 text-ellipsis overflow-hidden"> <section className="flex gap-1 items-center w-full">
{data.name} <div className="text-[20px] font-bold w-80% leading-5 text-ellipsis overflow-hidden">
</div> {data.name}
</div>
{icon}
</section>
{moreDropdown} {moreDropdown}
</section> </section>

View File

@ -4,6 +4,7 @@ import { getExtension } from '@/utils/document-util';
type IconFontType = { type IconFontType = {
name: string; name: string;
className?: string; className?: string;
}; };
@ -13,6 +14,23 @@ export const IconFont = ({ name, className }: IconFontType) => (
</svg> </svg>
); );
export function IconFontFill({
name,
className,
isFill = true,
}: IconFontType & { isFill?: boolean }) {
return (
<span className={cn('size-4', className)}>
<svg
className={cn('size-4', className)}
style={{ fill: isFill ? 'currentColor' : '' }}
>
<use xlinkHref={`#icon-${name}`} />
</svg>
</span>
);
}
export function FileIcon({ export function FileIcon({
name, name,
className, className,

View File

@ -1,9 +1,11 @@
import { LlmModelType } from '@/constants/knowledge'; import { LlmModelType } from '@/constants/knowledge';
import { useTranslate } from '@/hooks/common-hooks'; import { useTranslate } from '@/hooks/common-hooks';
import { useSelectLlmOptionsByModelType } from '@/hooks/llm-hooks'; import { useSelectLlmOptionsByModelType } from '@/hooks/llm-hooks';
import { cn } from '@/lib/utils';
import { camelCase } from 'lodash'; import { camelCase } from 'lodash';
import { useMemo } from 'react'; import { ReactNode, useMemo } from 'react';
import { useFormContext } from 'react-hook-form'; import { useFormContext } from 'react-hook-form';
import { SelectWithSearch } from './originui/select-with-search';
import { import {
FormControl, FormControl,
FormField, FormField,
@ -11,24 +13,36 @@ import {
FormLabel, FormLabel,
FormMessage, FormMessage,
} from './ui/form'; } from './ui/form';
import { RAGFlowSelect } from './ui/select';
export const enum DocumentType { export const enum ParseDocumentType {
DeepDOC = 'DeepDOC', DeepDOC = 'DeepDOC',
PlainText = 'Plain Text', PlainText = 'Plain Text',
} }
export function LayoutRecognizeFormField() { export function LayoutRecognizeFormField({
name = 'parser_config.layout_recognize',
horizontal = true,
optionsWithoutLLM,
label,
}: {
name?: string;
horizontal?: boolean;
optionsWithoutLLM?: { value: string; label: string }[];
label?: ReactNode;
}) {
const form = useFormContext(); const form = useFormContext();
const { t } = useTranslate('knowledgeDetails'); const { t } = useTranslate('knowledgeDetails');
const allOptions = useSelectLlmOptionsByModelType(); const allOptions = useSelectLlmOptionsByModelType();
const options = useMemo(() => { const options = useMemo(() => {
const list = [DocumentType.DeepDOC, DocumentType.PlainText].map((x) => ({ const list = optionsWithoutLLM
label: x === DocumentType.PlainText ? t(camelCase(x)) : 'DeepDoc', ? optionsWithoutLLM
value: x, : [ParseDocumentType.DeepDOC, ParseDocumentType.PlainText].map((x) => ({
})); label:
x === ParseDocumentType.PlainText ? t(camelCase(x)) : 'DeepDoc',
value: x,
}));
const image2TextList = allOptions[LlmModelType.Image2text].map((x) => { const image2TextList = allOptions[LlmModelType.Image2text].map((x) => {
return { return {
@ -48,38 +62,40 @@ export function LayoutRecognizeFormField() {
}); });
return [...list, ...image2TextList]; return [...list, ...image2TextList];
}, [allOptions, t]); }, [allOptions, optionsWithoutLLM, t]);
return ( return (
<FormField <FormField
control={form.control} control={form.control}
name="parser_config.layout_recognize" name={name}
render={({ field }) => { render={({ field }) => {
if (typeof field.value === 'undefined') {
// default value set
form.setValue(
'parser_config.layout_recognize',
form.formState.defaultValues?.parser_config?.layout_recognize ??
'DeepDOC',
);
}
return ( return (
<FormItem className=" items-center space-y-0 "> <FormItem className={'items-center space-y-0 '}>
<div className="flex items-center"> <div
className={cn('flex', {
'flex-col ': !horizontal,
'items-center': horizontal,
})}
>
<FormLabel <FormLabel
tooltip={t('layoutRecognizeTip')} tooltip={t('layoutRecognizeTip')}
className="text-sm text-muted-foreground whitespace-wrap w-1/4" className={cn('text-sm text-muted-foreground whitespace-wrap', {
['w-1/4']: horizontal,
})}
> >
{t('layoutRecognize')} {label || t('layoutRecognize')}
</FormLabel> </FormLabel>
<div className="w-3/4"> <div className={horizontal ? 'w-3/4' : 'w-full'}>
<FormControl> <FormControl>
<RAGFlowSelect {...field} options={options}></RAGFlowSelect> <SelectWithSearch
{...field}
options={options}
></SelectWithSearch>
</FormControl> </FormControl>
</div> </div>
</div> </div>
<div className="flex pt-1"> <div className="flex pt-1">
<div className="w-1/4"></div> <div className={horizontal ? 'w-1/4' : 'w-full'}></div>
<FormMessage /> <FormMessage />
</div> </div>
</FormItem> </FormItem>

View File

@ -0,0 +1,25 @@
import { LlmModelType } from '@/constants/knowledge';
import { useComposeLlmOptionsByModelTypes } from '@/hooks/llm-hooks';
import { useTranslation } from 'react-i18next';
import { SelectWithSearch } from '../originui/select-with-search';
import { RAGFlowFormItem } from '../ragflow-form';
type LLMFormFieldProps = {
options?: any[];
name?: string;
};
export function LLMFormField({ options, name }: LLMFormFieldProps) {
const { t } = useTranslation();
const modelOptions = useComposeLlmOptionsByModelTypes([
LlmModelType.Chat,
LlmModelType.Image2text,
]);
return (
<RAGFlowFormItem name={name || 'llm_id'} label={t('chat.model')}>
<SelectWithSearch options={options || modelOptions}></SelectWithSearch>
</RAGFlowFormItem>
);
}

View File

@ -1,11 +1,9 @@
import { LlmModelType, ModelVariableType } from '@/constants/knowledge'; import { ModelVariableType } from '@/constants/knowledge';
import { useTranslate } from '@/hooks/common-hooks'; import { useTranslate } from '@/hooks/common-hooks';
import { useComposeLlmOptionsByModelTypes } from '@/hooks/llm-hooks';
import { camelCase } from 'lodash'; import { camelCase } from 'lodash';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useFormContext } from 'react-hook-form'; import { useFormContext } from 'react-hook-form';
import { z } from 'zod'; import { z } from 'zod';
import { SelectWithSearch } from '../originui/select-with-search';
import { import {
FormControl, FormControl,
FormField, FormField,
@ -20,6 +18,7 @@ import {
SelectTrigger, SelectTrigger,
SelectValue, SelectValue,
} from '../ui/select'; } from '../ui/select';
import { LLMFormField } from './llm-form-field';
import { SliderInputSwitchFormField } from './slider'; import { SliderInputSwitchFormField } from './slider';
import { useHandleFreedomChange } from './use-watch-change'; import { useHandleFreedomChange } from './use-watch-change';
@ -61,11 +60,6 @@ export function LlmSettingFieldItems({
const form = useFormContext(); const form = useFormContext();
const { t } = useTranslate('chat'); const { t } = useTranslate('chat');
const modelOptions = useComposeLlmOptionsByModelTypes([
LlmModelType.Chat,
LlmModelType.Image2text,
]);
const getFieldWithPrefix = useCallback( const getFieldWithPrefix = useCallback(
(name: string) => { (name: string) => {
return prefix ? `${prefix}.${name}` : name; return prefix ? `${prefix}.${name}` : name;
@ -82,22 +76,7 @@ export function LlmSettingFieldItems({
return ( return (
<div className="space-y-5"> <div className="space-y-5">
<FormField <LLMFormField options={options}></LLMFormField>
control={form.control}
name={'llm_id'}
render={({ field }) => (
<FormItem>
<FormLabel>{t('model')}</FormLabel>
<FormControl>
<SelectWithSearch
options={options || modelOptions}
{...field}
></SelectWithSearch>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField <FormField
control={form.control} control={form.control}
name={'parameter'} name={'parameter'}

View File

@ -45,8 +45,26 @@ export type SelectWithSearchFlagProps = {
onChange?(value: string): void; onChange?(value: string): void;
triggerClassName?: string; triggerClassName?: string;
allowClear?: boolean; allowClear?: boolean;
disabled?: boolean;
placeholder?: string;
}; };
function findLabelWithoutOptions(
options: SelectWithSearchFlagOptionType[],
value: string,
) {
return options.find((opt) => opt.value === value)?.label || '';
}
function findLabelWithOptions(
options: SelectWithSearchFlagOptionType[],
value: string,
) {
return options
.map((group) => group?.options?.find((item) => item.value === value))
.filter(Boolean)[0]?.label;
}
export const SelectWithSearch = forwardRef< export const SelectWithSearch = forwardRef<
React.ElementRef<typeof Button>, React.ElementRef<typeof Button>,
SelectWithSearchFlagProps SelectWithSearchFlagProps
@ -58,6 +76,8 @@ export const SelectWithSearch = forwardRef<
options = [], options = [],
triggerClassName, triggerClassName,
allowClear = false, allowClear = false,
disabled = false,
placeholder = t('common.selectPlaceholder'),
}, },
ref, ref,
) => { ) => {
@ -65,6 +85,28 @@ export const SelectWithSearch = forwardRef<
const [open, setOpen] = useState<boolean>(false); const [open, setOpen] = useState<boolean>(false);
const [value, setValue] = useState<string>(''); const [value, setValue] = useState<string>('');
const selectLabel = useMemo(() => {
if (options.every((x) => x.options === undefined)) {
return findLabelWithoutOptions(options, value);
} else if (options.every((x) => Array.isArray(x.options))) {
return findLabelWithOptions(options, value);
} else {
// Some have options, some don't
const optionsWithOptions = options.filter((x) =>
Array.isArray(x.options),
);
const optionsWithoutOptions = options.filter(
(x) => x.options === undefined,
);
const label = findLabelWithOptions(optionsWithOptions, value);
if (label) {
return label;
}
return findLabelWithoutOptions(optionsWithoutOptions, value);
}
}, [options, value]);
const handleSelect = useCallback( const handleSelect = useCallback(
(val: string) => { (val: string) => {
setValue(val); setValue(val);
@ -86,16 +128,7 @@ export const SelectWithSearch = forwardRef<
useEffect(() => { useEffect(() => {
setValue(val); setValue(val);
}, [val]); }, [val]);
const selectLabel = useMemo(() => {
const optionTemp = options[0];
if (optionTemp?.options) {
return options
.map((group) => group?.options?.find((item) => item.value === value))
.filter(Boolean)[0]?.label;
} else {
return options.find((opt) => opt.value === value)?.label || '';
}
}, [options, value]);
return ( return (
<Popover open={open} onOpenChange={setOpen}> <Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger asChild> <PopoverTrigger asChild>
@ -105,6 +138,7 @@ export const SelectWithSearch = forwardRef<
role="combobox" role="combobox"
aria-expanded={open} aria-expanded={open}
ref={ref} ref={ref}
disabled={disabled}
className={cn( className={cn(
'bg-background hover:bg-background border-input w-full justify-between px-3 font-normal outline-offset-0 outline-none focus-visible:outline-[3px] [&_svg]:pointer-events-auto', 'bg-background hover:bg-background border-input w-full justify-between px-3 font-normal outline-offset-0 outline-none focus-visible:outline-[3px] [&_svg]:pointer-events-auto',
triggerClassName, triggerClassName,
@ -115,9 +149,7 @@ export const SelectWithSearch = forwardRef<
<span className="leading-none truncate">{selectLabel}</span> <span className="leading-none truncate">{selectLabel}</span>
</span> </span>
) : ( ) : (
<span className="text-muted-foreground"> <span className="text-muted-foreground">{placeholder}</span>
{t('common.selectPlaceholder')}
</span>
)} )}
<div className="flex items-center justify-between"> <div className="flex items-center justify-between">
{value && allowClear && ( {value && allowClear && (

View File

@ -1,7 +1,8 @@
'use client'; 'use client';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import { parseColorToRGBA } from '@/utils/common-util'; import { TimelineNodeType } from '@/pages/dataflow-result/constant';
import { parseColorToRGB } from '@/utils/common-util';
import { Slot } from '@radix-ui/react-slot'; import { Slot } from '@radix-ui/react-slot';
import * as React from 'react'; import * as React from 'react';
@ -220,6 +221,8 @@ interface TimelineNode
completed?: boolean; completed?: boolean;
clickable?: boolean; clickable?: boolean;
activeStyle?: TimelineIndicatorNodeProps; activeStyle?: TimelineIndicatorNodeProps;
detail?: any;
type?: TimelineNodeType;
} }
interface CustomTimelineProps extends React.HTMLAttributes<HTMLDivElement> { interface CustomTimelineProps extends React.HTMLAttributes<HTMLDivElement> {
@ -243,7 +246,7 @@ const CustomTimeline = ({
orientation = 'horizontal', orientation = 'horizontal',
lineStyle = 'solid', lineStyle = 'solid',
lineColor = 'var(--text-secondary)', lineColor = 'var(--text-secondary)',
indicatorColor = 'var(--accent-primary)', indicatorColor = 'rgb(var(--accent-primary))',
defaultValue = 1, defaultValue = 1,
className, className,
activeStyle, activeStyle,
@ -251,8 +254,7 @@ const CustomTimeline = ({
}: CustomTimelineProps) => { }: CustomTimelineProps) => {
const [internalActiveStep, setInternalActiveStep] = const [internalActiveStep, setInternalActiveStep] =
React.useState(defaultValue); React.useState(defaultValue);
const _lineColor = `rgb(${parseColorToRGBA(lineColor)})`; const _lineColor = `rgb(${parseColorToRGB(lineColor)})`;
console.log(lineColor, _lineColor);
const currentActiveStep = activeStep ?? internalActiveStep; const currentActiveStep = activeStep ?? internalActiveStep;
const handleStepChange = (step: number, id: string | number) => { const handleStepChange = (step: number, id: string | number) => {
@ -261,7 +263,7 @@ const CustomTimeline = ({
} }
onStepChange?.(step, id); onStepChange?.(step, id);
}; };
const [r, g, b] = parseColorToRGBA(indicatorColor); const [r, g, b] = parseColorToRGB(indicatorColor);
return ( return (
<Timeline <Timeline
value={currentActiveStep} value={currentActiveStep}
@ -284,8 +286,6 @@ const CustomTimeline = ({
typeof _nodeSizeTemp === 'number' typeof _nodeSizeTemp === 'number'
? `${_nodeSizeTemp}px` ? `${_nodeSizeTemp}px`
: _nodeSizeTemp; : _nodeSizeTemp;
console.log('icon-size', nodeSize, node.nodeSize, _nodeSize);
// const activeStyle = _activeStyle || {};
return ( return (
<TimelineItem <TimelineItem
@ -372,11 +372,10 @@ const CustomTimeline = ({
)} )}
</TimelineIndicator> </TimelineIndicator>
<TimelineHeader> <TimelineHeader className="transform -translate-x-[40%] text-center">
{node.date && <TimelineDate>{node.date}</TimelineDate>}
<TimelineTitle <TimelineTitle
className={cn( className={cn(
'text-sm font-medium', 'text-sm font-medium -ml-1',
isActive && _activeStyle.textColor isActive && _activeStyle.textColor
? `text-${_activeStyle.textColor}` ? `text-${_activeStyle.textColor}`
: '', : '',
@ -387,6 +386,7 @@ const CustomTimeline = ({
> >
{node.title} {node.title}
</TimelineTitle> </TimelineTitle>
{node.date && <TimelineDate>{node.date}</TimelineDate>}
</TimelineHeader> </TimelineHeader>
{node.content && <TimelineContent>{node.content}</TimelineContent>} {node.content && <TimelineContent>{node.content}</TimelineContent>}
</TimelineItem> </TimelineItem>

View File

@ -1,6 +1,11 @@
import { DocumentParserType } from '@/constants/knowledge'; import { DocumentParserType } from '@/constants/knowledge';
import { useTranslate } from '@/hooks/common-hooks'; import { useTranslate } from '@/hooks/common-hooks';
import { cn } from '@/lib/utils'; import { cn } from '@/lib/utils';
import {
GenerateLogButton,
GenerateType,
IGenerateLogButtonProps,
} from '@/pages/dataset/dataset/generate-button/generate';
import { upperFirst } from 'lodash'; import { upperFirst } from 'lodash';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useFormContext, useWatch } from 'react-hook-form'; import { useFormContext, useWatch } from 'react-hook-form';
@ -47,9 +52,17 @@ export const showGraphRagItems = (parserId: DocumentParserType | undefined) => {
type GraphRagItemsProps = { type GraphRagItemsProps = {
marginBottom?: boolean; marginBottom?: boolean;
className?: string; className?: string;
data: IGenerateLogButtonProps;
onDelete?: () => void;
}; };
export function UseGraphRagFormField() { export function UseGraphRagFormField({
data,
onDelete,
}: {
data: IGenerateLogButtonProps;
onDelete?: () => void;
}) {
const form = useFormContext(); const form = useFormContext();
const { t } = useTranslate('knowledgeConfiguration'); const { t } = useTranslate('knowledgeConfiguration');
@ -62,16 +75,23 @@ export function UseGraphRagFormField() {
<div className="flex items-center gap-1"> <div className="flex items-center gap-1">
<FormLabel <FormLabel
tooltip={t('useGraphRagTip')} tooltip={t('useGraphRagTip')}
className="text-sm text-muted-foreground whitespace-break-spaces w-1/4" className="text-sm whitespace-break-spaces w-1/4"
> >
{t('useGraphRag')} {t('useGraphRag')}
</FormLabel> </FormLabel>
<div className="w-3/4"> <div className="w-3/4">
<FormControl> <FormControl>
<Switch {/* <Switch
checked={field.value} checked={field.value}
onCheckedChange={field.onChange} onCheckedChange={field.onChange}
></Switch> ></Switch> */}
<GenerateLogButton
{...data}
onDelete={onDelete}
className="w-full text-text-secondary"
status={1}
type={GenerateType.KnowledgeGraph}
/>
</FormControl> </FormControl>
</div> </div>
</div> </div>
@ -89,6 +109,8 @@ export function UseGraphRagFormField() {
const GraphRagItems = ({ const GraphRagItems = ({
marginBottom = false, marginBottom = false,
className = 'p-10', className = 'p-10',
data,
onDelete,
}: GraphRagItemsProps) => { }: GraphRagItemsProps) => {
const { t } = useTranslate('knowledgeConfiguration'); const { t } = useTranslate('knowledgeConfiguration');
const form = useFormContext(); const form = useFormContext();
@ -114,7 +136,10 @@ const GraphRagItems = ({
return ( return (
<FormContainer className={cn({ 'mb-4': marginBottom }, className)}> <FormContainer className={cn({ 'mb-4': marginBottom }, className)}>
<UseGraphRagFormField></UseGraphRagFormField> <UseGraphRagFormField
data={data}
onDelete={onDelete}
></UseGraphRagFormField>
{useRaptor && ( {useRaptor && (
<> <>
<EntityTypesFormField name="parser_config.graphrag.entity_types"></EntityTypesFormField> <EntityTypesFormField name="parser_config.graphrag.entity_types"></EntityTypesFormField>
@ -125,7 +150,7 @@ const GraphRagItems = ({
<FormItem className=" items-center space-y-0 "> <FormItem className=" items-center space-y-0 ">
<div className="flex items-center"> <div className="flex items-center">
<FormLabel <FormLabel
className="text-sm text-muted-foreground whitespace-nowrap w-1/4" className="text-sm whitespace-nowrap w-1/4"
tooltip={renderWideTooltip( tooltip={renderWideTooltip(
<div <div
dangerouslySetInnerHTML={{ dangerouslySetInnerHTML={{
@ -161,7 +186,7 @@ const GraphRagItems = ({
<div className="flex items-center"> <div className="flex items-center">
<FormLabel <FormLabel
tooltip={renderWideTooltip('resolutionTip')} tooltip={renderWideTooltip('resolutionTip')}
className="text-sm text-muted-foreground whitespace-nowrap w-1/4" className="text-sm whitespace-nowrap w-1/4"
> >
{t('resolution')} {t('resolution')}
</FormLabel> </FormLabel>
@ -190,7 +215,7 @@ const GraphRagItems = ({
<div className="flex items-center"> <div className="flex items-center">
<FormLabel <FormLabel
tooltip={renderWideTooltip('communityTip')} tooltip={renderWideTooltip('communityTip')}
className="text-sm text-muted-foreground whitespace-nowrap w-1/4" className="text-sm whitespace-nowrap w-1/4"
> >
{t('community')} {t('community')}
</FormLabel> </FormLabel>
@ -210,6 +235,18 @@ const GraphRagItems = ({
</FormItem> </FormItem>
)} )}
/> />
{/* {showGenerateItem && (
<div className="w-full flex items-center">
<div className="text-sm whitespace-nowrap w-1/4">
{t('extractKnowledgeGraph')}
</div>
<GenerateLogButton
className="w-3/4 text-text-secondary"
status={1}
type={GenerateType.KnowledgeGraph}
/>
</div>
)} */}
</> </>
)} )}
</FormContainer> </FormContainer>

View File

@ -1,12 +1,16 @@
import { FormLayout } from '@/constants/form'; import { FormLayout } from '@/constants/form';
import { DocumentParserType } from '@/constants/knowledge'; import { DocumentParserType } from '@/constants/knowledge';
import { useTranslate } from '@/hooks/common-hooks'; import { useTranslate } from '@/hooks/common-hooks';
import {
GenerateLogButton,
GenerateType,
IGenerateLogButtonProps,
} from '@/pages/dataset/dataset/generate-button/generate';
import random from 'lodash/random'; import random from 'lodash/random';
import { Plus } from 'lucide-react'; import { Shuffle } from 'lucide-react';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useFormContext, useWatch } from 'react-hook-form'; import { useFormContext, useWatch } from 'react-hook-form';
import { SliderInputFormField } from '../slider-input-form-field'; import { SliderInputFormField } from '../slider-input-form-field';
import { Button } from '../ui/button';
import { import {
FormControl, FormControl,
FormField, FormField,
@ -14,8 +18,7 @@ import {
FormLabel, FormLabel,
FormMessage, FormMessage,
} from '../ui/form'; } from '../ui/form';
import { Input } from '../ui/input'; import { ExpandedInput } from '../ui/input';
import { Switch } from '../ui/switch';
import { Textarea } from '../ui/textarea'; import { Textarea } from '../ui/textarea';
export const excludedParseMethods = [ export const excludedParseMethods = [
@ -53,7 +56,13 @@ const Prompt = 'parser_config.raptor.prompt';
// The three types "table", "resume" and "one" do not display this configuration. // The three types "table", "resume" and "one" do not display this configuration.
const RaptorFormFields = () => { const RaptorFormFields = ({
data,
onDelete,
}: {
data: IGenerateLogButtonProps;
onDelete: () => void;
}) => {
const form = useFormContext(); const form = useFormContext();
const { t } = useTranslate('knowledgeConfiguration'); const { t } = useTranslate('knowledgeConfiguration');
const useRaptor = useWatch({ name: UseRaptorField }); const useRaptor = useWatch({ name: UseRaptorField });
@ -93,7 +102,7 @@ const RaptorFormFields = () => {
<div className="flex items-center gap-1"> <div className="flex items-center gap-1">
<FormLabel <FormLabel
tooltip={t('useRaptorTip')} tooltip={t('useRaptorTip')}
className="text-sm text-muted-foreground w-1/4 whitespace-break-spaces" className="text-sm w-1/4 whitespace-break-spaces"
> >
<div className="w-auto xl:w-20 2xl:w-24 3xl:w-28 4xl:w-auto "> <div className="w-auto xl:w-20 2xl:w-24 3xl:w-28 4xl:w-auto ">
{t('useRaptor')} {t('useRaptor')}
@ -101,13 +110,13 @@ const RaptorFormFields = () => {
</FormLabel> </FormLabel>
<div className="w-3/4"> <div className="w-3/4">
<FormControl> <FormControl>
<Switch <GenerateLogButton
checked={field.value} {...data}
onCheckedChange={(e) => { onDelete={onDelete}
changeRaptor(e); className="w-full text-text-secondary"
field.onChange(e); status={1}
}} type={GenerateType.Raptor}
></Switch> />
</FormControl> </FormControl>
</div> </div>
</div> </div>
@ -130,7 +139,7 @@ const RaptorFormFields = () => {
<div className="flex items-start"> <div className="flex items-start">
<FormLabel <FormLabel
tooltip={t('promptTip')} tooltip={t('promptTip')}
className="text-sm text-muted-foreground whitespace-nowrap w-1/4" className="text-sm whitespace-nowrap w-1/4"
> >
{t('prompt')} {t('prompt')}
</FormLabel> </FormLabel>
@ -185,21 +194,23 @@ const RaptorFormFields = () => {
render={({ field }) => ( render={({ field }) => (
<FormItem className=" items-center space-y-0 "> <FormItem className=" items-center space-y-0 ">
<div className="flex items-center"> <div className="flex items-center">
<FormLabel className="text-sm text-muted-foreground whitespace-wrap w-1/4"> <FormLabel className="text-sm whitespace-wrap w-1/4">
{t('randomSeed')} {t('randomSeed')}
</FormLabel> </FormLabel>
<div className="w-3/4"> <div className="w-3/4">
<FormControl defaultValue={0}> <FormControl defaultValue={0}>
<div className="flex gap-4 items-center"> <ExpandedInput
<Input {...field} defaultValue={0} type="number" /> {...field}
<Button className="w-full"
size={'sm'} defaultValue={0}
onClick={handleGenerate} type="number"
type={'button'} suffix={
> <Shuffle
<Plus /> className="size-3.5 cursor-pointer"
</Button> onClick={handleGenerate}
</div> />
}
/>
</FormControl> </FormControl>
</div> </div>
</div> </div>

View File

@ -11,11 +11,12 @@ import { ControllerRenderProps, useFormContext } from 'react-hook-form';
type RAGFlowFormItemProps = { type RAGFlowFormItemProps = {
name: string; name: string;
label: ReactNode; label?: ReactNode;
tooltip?: ReactNode; tooltip?: ReactNode;
children: ReactNode | ((field: ControllerRenderProps) => ReactNode); children: ReactNode | ((field: ControllerRenderProps) => ReactNode);
horizontal?: boolean; horizontal?: boolean;
required?: boolean; required?: boolean;
labelClassName?: string;
}; };
export function RAGFlowFormItem({ export function RAGFlowFormItem({
@ -25,6 +26,7 @@ export function RAGFlowFormItem({
children, children,
horizontal = false, horizontal = false,
required = false, required = false,
labelClassName,
}: RAGFlowFormItemProps) { }: RAGFlowFormItemProps) {
const form = useFormContext(); const form = useFormContext();
return ( return (
@ -37,13 +39,15 @@ export function RAGFlowFormItem({
'flex items-center': horizontal, 'flex items-center': horizontal,
})} })}
> >
<FormLabel {label && (
required={required} <FormLabel
tooltip={tooltip} required={required}
className={cn({ 'w-1/4': horizontal })} tooltip={tooltip}
> className={cn({ 'w-1/4': horizontal }, labelClassName)}
{label} >
</FormLabel> {label}
</FormLabel>
)}
<FormControl> <FormControl>
{typeof children === 'function' {typeof children === 'function'
? children(field) ? children(field)

View File

@ -54,8 +54,7 @@ export function SliderInputFormField({
<FormLabel <FormLabel
tooltip={tooltip} tooltip={tooltip}
className={cn({ className={cn({
'text-sm text-muted-foreground whitespace-break-spaces w-1/4': 'text-sm whitespace-break-spaces w-1/4': isHorizontal,
isHorizontal,
})} })}
> >
{label} {label}

View File

@ -28,7 +28,7 @@ const DualRangeSlider = React.forwardRef<
)} )}
{...props} {...props}
> >
<SliderPrimitive.Track className="relative h-2 w-full grow overflow-hidden rounded-full bg-secondary"> <SliderPrimitive.Track className="relative h-2 w-full grow overflow-hidden rounded-full bg-border-button">
<SliderPrimitive.Range className="absolute h-full bg-accent-primary" /> <SliderPrimitive.Range className="absolute h-full bg-accent-primary" />
</SliderPrimitive.Track> </SliderPrimitive.Track>
{initialValue.map((value, index) => ( {initialValue.map((value, index) => (

View File

@ -31,6 +31,7 @@ export interface ModalProps {
export interface ModalType extends FC<ModalProps> { export interface ModalType extends FC<ModalProps> {
show: typeof modalIns.show; show: typeof modalIns.show;
hide: typeof modalIns.hide; hide: typeof modalIns.hide;
destroy: typeof modalIns.destroy;
} }
const Modal: ModalType = ({ const Modal: ModalType = ({
@ -76,20 +77,20 @@ const Modal: ModalType = ({
const handleCancel = useCallback(() => { const handleCancel = useCallback(() => {
onOpenChange?.(false); onOpenChange?.(false);
onCancel?.(); onCancel?.();
}, [onOpenChange, onCancel]); }, [onCancel, onOpenChange]);
const handleOk = useCallback(() => { const handleOk = useCallback(() => {
onOpenChange?.(true); onOpenChange?.(true);
onOk?.(); onOk?.();
}, [onOpenChange, onOk]); }, [onOk, onOpenChange]);
const handleChange = (open: boolean) => { const handleChange = (open: boolean) => {
onOpenChange?.(open); onOpenChange?.(open);
console.log('open', open, onOpenChange); console.log('open', open, onOpenChange);
if (open) { if (open) {
handleOk(); onOk?.();
} }
if (!open) { if (!open) {
handleCancel(); onCancel?.();
} }
}; };
const footEl = useMemo(() => { const footEl = useMemo(() => {
@ -177,7 +178,7 @@ const Modal: ModalType = ({
<DialogPrimitive.Close asChild> <DialogPrimitive.Close asChild>
<button <button
type="button" type="button"
className="flex h-7 w-7 items-center justify-center rounded-full hover:bg-muted" className="flex h-7 w-7 items-center justify-center rounded-full hover:bg-muted focus-visible:outline-none"
> >
{closeIcon} {closeIcon}
</button> </button>
@ -187,7 +188,7 @@ const Modal: ModalType = ({
)} )}
{/* content */} {/* content */}
<div className="py-2 px-6 overflow-y-auto max-h-[80vh] focus-visible:!outline-none"> <div className="py-2 px-6 overflow-y-auto scrollbar-auto max-h-[80vh] focus-visible:!outline-none">
{destroyOnClose && !open ? null : children} {destroyOnClose && !open ? null : children}
</div> </div>
@ -208,5 +209,6 @@ Modal.show = modalIns
return modalIns.show; return modalIns.show;
}; };
Modal.hide = modalIns.hide; Modal.hide = modalIns.hide;
Modal.destroy = modalIns.destroy;
export { Modal }; export { Modal };

View File

@ -49,7 +49,7 @@ function Radio({ value, checked, disabled, onChange, children }: RadioProps) {
> >
<span <span
className={cn( className={cn(
'flex h-4 w-4 items-center justify-center rounded-full border border-input transition-colors', 'flex h-4 w-4 items-center justify-center rounded-full border border-border transition-colors',
'peer ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2', 'peer ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2',
isChecked && 'border-primary bg-primary/10', isChecked && 'border-primary bg-primary/10',
mergedDisabled && 'border-muted', mergedDisabled && 'border-muted',

View File

@ -33,7 +33,12 @@ export { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger };
export const FormTooltip = ({ tooltip }: { tooltip: React.ReactNode }) => { export const FormTooltip = ({ tooltip }: { tooltip: React.ReactNode }) => {
return ( return (
<Tooltip> <Tooltip>
<TooltipTrigger tabIndex={-1}> <TooltipTrigger
tabIndex={-1}
onClick={(e) => {
e.preventDefault(); // Prevent clicking the tooltip from triggering form save
}}
>
<Info className="size-3 ml-2" /> <Info className="size-3 ml-2" />
</TooltipTrigger> </TooltipTrigger>
<TooltipContent> <TooltipContent>
@ -107,7 +112,7 @@ export const AntToolTip: React.FC<AntToolTipProps> = ({
{visible && title && ( {visible && title && (
<div <div
className={cn( className={cn(
'absolute z-50 px-2.5 py-2 text-xs text-text-primary bg-muted rounded-sm shadow-sm whitespace-wrap', 'absolute z-50 px-2.5 py-2 text-xs text-text-primary bg-muted rounded-sm shadow-sm whitespace-wrap w-max',
getPlacementClasses(), getPlacementClasses(),
className, className,
)} )}

View File

@ -1,3 +1,6 @@
import { setInitialChatVariableEnabledFieldValue } from '@/utils/chat';
import { ChatVariableEnabledField, variableEnabledFieldMap } from './chat';
export enum ProgrammingLanguage { export enum ProgrammingLanguage {
Python = 'python', Python = 'python',
Javascript = 'javascript', Javascript = 'javascript',
@ -26,3 +29,26 @@ export enum AgentGlobals {
} }
export const AgentGlobalsSysQueryWithBrace = `{${AgentGlobals.SysQuery}}`; export const AgentGlobalsSysQueryWithBrace = `{${AgentGlobals.SysQuery}}`;
export const variableCheckBoxFieldMap = Object.keys(
variableEnabledFieldMap,
).reduce<Record<string, boolean>>((pre, cur) => {
pre[cur] = setInitialChatVariableEnabledFieldValue(
cur as ChatVariableEnabledField,
);
return pre;
}, {});
export const initialLlmBaseValues = {
...variableCheckBoxFieldMap,
temperature: 0.1,
top_p: 0.3,
frequency_penalty: 0.7,
presence_penalty: 0.4,
max_tokens: 256,
};
export enum AgentCategory {
AgentCanvas = 'agent_canvas',
DataflowCanvas = 'dataflow_canvas',
}

View File

@ -15,6 +15,14 @@ export enum RunningStatus {
FAIL = '4', // need to refresh FAIL = '4', // need to refresh
} }
export const RunningStatusMap = {
[RunningStatus.UNSTART]: 'Pending',
[RunningStatus.RUNNING]: 'Running',
[RunningStatus.CANCEL]: 'Cancel',
[RunningStatus.DONE]: 'Success',
[RunningStatus.FAIL]: 'Failed',
};
export enum ModelVariableType { export enum ModelVariableType {
Improvise = 'Improvise', Improvise = 'Improvise',
Precise = 'Precise', Precise = 'Precise',
@ -57,6 +65,7 @@ export enum LlmModelType {
export enum KnowledgeSearchParams { export enum KnowledgeSearchParams {
DocumentId = 'doc_id', DocumentId = 'doc_id',
KnowledgeId = 'id', KnowledgeId = 'id',
Type = 'type',
} }
export enum DocumentType { export enum DocumentType {

View File

@ -1,79 +1,14 @@
import { ResponseType } from '@/interfaces/database/base'; import { DSL, IFlow } from '@/interfaces/database/flow';
import { DSL, IFlow, IFlowTemplate } from '@/interfaces/database/flow';
import { IDebugSingleRequestBody } from '@/interfaces/request/flow'; import { IDebugSingleRequestBody } from '@/interfaces/request/flow';
import i18n from '@/locales/config'; import i18n from '@/locales/config';
import { useGetSharedChatSearchParams } from '@/pages/chat/shared-hooks'; import { useGetSharedChatSearchParams } from '@/pages/chat/shared-hooks';
import { BeginId } from '@/pages/flow/constant';
import flowService from '@/services/flow-service'; import flowService from '@/services/flow-service';
import { buildMessageListWithUuid } from '@/utils/chat'; import { buildMessageListWithUuid } from '@/utils/chat';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { message } from 'antd'; import { message } from 'antd';
import { set } from 'lodash'; import { set } from 'lodash';
import get from 'lodash/get'; import get from 'lodash/get';
import { useTranslation } from 'react-i18next';
import { useParams } from 'umi'; import { useParams } from 'umi';
import { v4 as uuid } from 'uuid';
export const EmptyDsl = {
graph: {
nodes: [
{
id: BeginId,
type: 'beginNode',
position: {
x: 50,
y: 200,
},
data: {
label: 'Begin',
name: 'begin',
},
sourcePosition: 'left',
targetPosition: 'right',
},
],
edges: [],
},
components: {
begin: {
obj: {
component_name: 'Begin',
params: {},
},
downstream: ['Answer:China'], // other edge target is downstream, edge source is current node id
upstream: [], // edge source is upstream, edge target is current node id
},
},
messages: [],
reference: [],
history: [],
path: [],
answer: [],
};
export const useFetchFlowTemplates = (): ResponseType<IFlowTemplate[]> => {
const { t } = useTranslation();
const { data } = useQuery({
queryKey: ['fetchFlowTemplates'],
initialData: [],
queryFn: async () => {
const { data } = await flowService.listTemplates();
if (Array.isArray(data?.data)) {
data.data.unshift({
id: uuid(),
title: t('flow.blank'),
description: t('flow.createFromNothing'),
dsl: EmptyDsl,
});
}
return data;
},
});
return data;
};
export const useFetchFlowList = (): { data: IFlow[]; loading: boolean } => { export const useFetchFlowList = (): { data: IFlow[]; loading: boolean } => {
const { data, isFetching: loading } = useQuery({ const { data, isFetching: loading } = useQuery({

View File

@ -1,3 +1,4 @@
import { NavigateToDataflowResultProps } from '@/pages/dataflow-result/interface';
import { Routes } from '@/routes'; import { Routes } from '@/routes';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useNavigate, useParams, useSearchParams } from 'umi'; import { useNavigate, useParams, useSearchParams } from 'umi';
@ -18,7 +19,14 @@ export const useNavigatePage = () => {
const navigateToDataset = useCallback( const navigateToDataset = useCallback(
(id: string) => () => { (id: string) => () => {
navigate(`${Routes.Dataset}/${id}`); navigate(`${Routes.DatasetBase}${Routes.DataSetOverview}/${id}`);
},
[navigate],
);
const navigateToDataFile = useCallback(
(id: string) => () => {
navigate(`${Routes.DatasetBase}${Routes.DatasetBase}/${id}`);
}, },
[navigate], [navigate],
); );
@ -61,6 +69,13 @@ export const useNavigatePage = () => {
[navigate], [navigate],
); );
const navigateToDataflow = useCallback(
(id: string) => () => {
navigate(`${Routes.DataFlow}/${id}`);
},
[navigate],
);
const navigateToAgentLogs = useCallback( const navigateToAgentLogs = useCallback(
(id: string) => () => { (id: string) => () => {
navigate(`${Routes.AgentLogPage}/${id}`); navigate(`${Routes.AgentLogPage}/${id}`);
@ -86,8 +101,8 @@ export const useNavigatePage = () => {
const navigateToChunkParsedResult = useCallback( const navigateToChunkParsedResult = useCallback(
(id: string, knowledgeId?: string) => () => { (id: string, knowledgeId?: string) => () => {
navigate( navigate(
// `${Routes.ParsedResult}/${id}?${QueryStringMap.KnowledgeId}=${knowledgeId}`,
`${Routes.ParsedResult}/chunks?id=${knowledgeId}&doc_id=${id}`, `${Routes.ParsedResult}/chunks?id=${knowledgeId}&doc_id=${id}`,
// `${Routes.DataflowResult}?id=${knowledgeId}&doc_id=${id}&type=chunk`,
); );
}, },
[navigate], [navigate],
@ -126,10 +141,16 @@ export const useNavigatePage = () => {
); );
const navigateToDataflowResult = useCallback( const navigateToDataflowResult = useCallback(
(id: string, knowledgeId?: string) => () => { (props: NavigateToDataflowResultProps) => () => {
let params: string[] = [];
Object.keys(props).forEach((key) => {
if (props[key]) {
params.push(`${key}=${props[key]}`);
}
});
navigate( navigate(
// `${Routes.ParsedResult}/${id}?${QueryStringMap.KnowledgeId}=${knowledgeId}`, // `${Routes.ParsedResult}/${id}?${QueryStringMap.KnowledgeId}=${knowledgeId}`,
`${Routes.DataflowResult}/${id}`, `${Routes.DataflowResult}?${params.join('&')}`,
); );
}, },
[navigate], [navigate],
@ -155,5 +176,7 @@ export const useNavigatePage = () => {
navigateToAgentList, navigateToAgentList,
navigateToOldProfile, navigateToOldProfile,
navigateToDataflowResult, navigateToDataflowResult,
navigateToDataflow,
navigateToDataFile,
}; };
}; };

View File

@ -29,6 +29,7 @@ export const useGetKnowledgeSearchParams = () => {
const [currentQueryParameters] = useSearchParams(); const [currentQueryParameters] = useSearchParams();
return { return {
type: currentQueryParameters.get(KnowledgeSearchParams.Type) || '',
documentId: documentId:
currentQueryParameters.get(KnowledgeSearchParams.DocumentId) || '', currentQueryParameters.get(KnowledgeSearchParams.DocumentId) || '',
knowledgeId: knowledgeId:

View File

@ -1,4 +1,5 @@
import { FileUploadProps } from '@/components/file-upload'; import { FileUploadProps } from '@/components/file-upload';
import { useHandleFilterSubmit } from '@/components/list-filter-bar/use-handle-filter-submit';
import message from '@/components/ui/message'; import message from '@/components/ui/message';
import { AgentGlobals } from '@/constants/agent'; import { AgentGlobals } from '@/constants/agent';
import { import {
@ -7,6 +8,7 @@ import {
IAgentLogsResponse, IAgentLogsResponse,
IFlow, IFlow,
IFlowTemplate, IFlowTemplate,
IPipeLineListRequest,
ITraceData, ITraceData,
} from '@/interfaces/database/agent'; } from '@/interfaces/database/agent';
import { IDebugSingleRequestBody } from '@/interfaces/request/agent'; import { IDebugSingleRequestBody } from '@/interfaces/request/agent';
@ -16,6 +18,7 @@ import { IInputs } from '@/pages/agent/interface';
import { useGetSharedChatSearchParams } from '@/pages/chat/shared-hooks'; import { useGetSharedChatSearchParams } from '@/pages/chat/shared-hooks';
import agentService, { import agentService, {
fetchAgentLogsByCanvasId, fetchAgentLogsByCanvasId,
fetchPipeLineList,
fetchTrace, fetchTrace,
} from '@/services/agent-service'; } from '@/services/agent-service';
import api from '@/utils/api'; import api from '@/utils/api';
@ -31,6 +34,7 @@ import {
} from './logic-hooks'; } from './logic-hooks';
export const enum AgentApiAction { export const enum AgentApiAction {
FetchAgentListByPage = 'fetchAgentListByPage',
FetchAgentList = 'fetchAgentList', FetchAgentList = 'fetchAgentList',
UpdateAgentSetting = 'updateAgentSetting', UpdateAgentSetting = 'updateAgentSetting',
DeleteAgent = 'deleteAgent', DeleteAgent = 'deleteAgent',
@ -50,6 +54,7 @@ export const enum AgentApiAction {
FetchExternalAgentInputs = 'fetchExternalAgentInputs', FetchExternalAgentInputs = 'fetchExternalAgentInputs',
SetAgentSetting = 'setAgentSetting', SetAgentSetting = 'setAgentSetting',
FetchPrompt = 'fetchPrompt', FetchPrompt = 'fetchPrompt',
CancelDataflow = 'cancelDataflow',
} }
export const EmptyDsl = { export const EmptyDsl = {
@ -111,28 +116,47 @@ export const useFetchAgentListByPage = () => {
const { searchString, handleInputChange } = useHandleSearchChange(); const { searchString, handleInputChange } = useHandleSearchChange();
const { pagination, setPagination } = useGetPaginationWithRouter(); const { pagination, setPagination } = useGetPaginationWithRouter();
const debouncedSearchString = useDebounce(searchString, { wait: 500 }); const debouncedSearchString = useDebounce(searchString, { wait: 500 });
const { filterValue, handleFilterSubmit } = useHandleFilterSubmit();
const canvasCategory = Array.isArray(filterValue.canvasCategory)
? filterValue.canvasCategory
: [];
const owner = filterValue.owner;
const requestParams: Record<string, any> = {
keywords: debouncedSearchString,
page_size: pagination.pageSize,
page: pagination.current,
canvas_category:
canvasCategory.length === 1 ? canvasCategory[0] : undefined,
};
if (Array.isArray(owner) && owner.length > 0) {
requestParams.owner_ids = owner.join(',');
}
const { data, isFetching: loading } = useQuery<{ const { data, isFetching: loading } = useQuery<{
canvas: IFlow[]; canvas: IFlow[];
total: number; total: number;
}>({ }>({
queryKey: [ queryKey: [
AgentApiAction.FetchAgentList, AgentApiAction.FetchAgentListByPage,
{ {
debouncedSearchString, debouncedSearchString,
...pagination, ...pagination,
filterValue,
}, },
], ],
initialData: { canvas: [], total: 0 }, placeholderData: (previousData) => {
if (previousData === undefined) {
return { canvas: [], total: 0 };
}
return previousData;
},
gcTime: 0, gcTime: 0,
queryFn: async () => { queryFn: async () => {
const { data } = await agentService.listCanvasTeam( const { data } = await agentService.listCanvas(
{ {
params: { params: requestParams,
keywords: debouncedSearchString,
page_size: pagination.pageSize,
page: pagination.current,
},
}, },
true, true,
); );
@ -150,12 +174,14 @@ export const useFetchAgentListByPage = () => {
); );
return { return {
data: data.canvas, data: data?.canvas ?? [],
loading, loading,
searchString, searchString,
handleInputChange: onInputChange, handleInputChange: onInputChange,
pagination: { ...pagination, total: data?.total }, pagination: { ...pagination, total: data?.total },
setPagination, setPagination,
filterValue,
handleFilterSubmit,
}; };
}; };
@ -173,7 +199,7 @@ export const useUpdateAgentSetting = () => {
if (ret?.data?.code === 0) { if (ret?.data?.code === 0) {
message.success('success'); message.success('success');
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: [AgentApiAction.FetchAgentList], queryKey: [AgentApiAction.FetchAgentListByPage],
}); });
} else { } else {
message.error(ret?.data?.data); message.error(ret?.data?.data);
@ -197,7 +223,7 @@ export const useDeleteAgent = () => {
const { data } = await agentService.removeCanvas({ canvasIds }); const { data } = await agentService.removeCanvas({ canvasIds });
if (data.code === 0) { if (data.code === 0) {
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: [AgentApiAction.FetchAgentList], queryKey: [AgentApiAction.FetchAgentListByPage],
}); });
} }
return data?.data ?? []; return data?.data ?? [];
@ -271,6 +297,7 @@ export const useSetAgent = (showMessage: boolean = true) => {
title?: string; title?: string;
dsl?: DSL; dsl?: DSL;
avatar?: string; avatar?: string;
canvas_category?: string;
}) => { }) => {
const { data = {} } = await agentService.setCanvas(params); const { data = {} } = await agentService.setCanvas(params);
if (data.code === 0) { if (data.code === 0) {
@ -280,7 +307,7 @@ export const useSetAgent = (showMessage: boolean = true) => {
); );
} }
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: [AgentApiAction.FetchAgentList], queryKey: [AgentApiAction.FetchAgentListByPage],
}); });
} }
return data; return data;
@ -379,7 +406,7 @@ export const useUploadCanvasFileWithProgress = (
files.forEach((file) => { files.forEach((file) => {
onError(file, error as Error); onError(file, error as Error);
}); });
message.error(error?.message); message.error((error as Error)?.message || 'Upload failed');
} }
}, },
}); });
@ -387,13 +414,11 @@ export const useUploadCanvasFileWithProgress = (
return { data, loading, uploadCanvasFile: mutateAsync }; return { data, loading, uploadCanvasFile: mutateAsync };
}; };
export const useFetchMessageTrace = ( export const useFetchMessageTrace = (canvasId?: string) => {
isStopFetchTrace: boolean,
canvasId?: string,
) => {
const { id } = useParams(); const { id } = useParams();
const queryId = id || canvasId; const queryId = id || canvasId;
const [messageId, setMessageId] = useState(''); const [messageId, setMessageId] = useState('');
const [isStopFetchTrace, setISStopFetchTrace] = useState(false);
const { const {
data, data,
@ -413,11 +438,19 @@ export const useFetchMessageTrace = (
message_id: messageId, message_id: messageId,
}); });
return data?.data ?? []; return Array.isArray(data?.data) ? data?.data : [];
}, },
}); });
return { data, loading, refetch, setMessageId }; return {
data,
loading,
refetch,
setMessageId,
messageId,
isStopFetchTrace,
setISStopFetchTrace,
};
}; };
export const useTestDbConnect = () => { export const useTestDbConnect = () => {
@ -563,7 +596,6 @@ export const useFetchAgentLog = (searchParams: IAgentLogsRequest) => {
initialData: {} as IAgentLogsResponse, initialData: {} as IAgentLogsResponse,
gcTime: 0, gcTime: 0,
queryFn: async () => { queryFn: async () => {
console.log('useFetchAgentLog', searchParams);
const { data } = await fetchAgentLogsByCanvasId(id as string, { const { data } = await fetchAgentLogsByCanvasId(id as string, {
...searchParams, ...searchParams,
}); });
@ -647,3 +679,59 @@ export const useFetchPrompt = () => {
return { data, loading, refetch }; return { data, loading, refetch };
}; };
export const useFetchAgentList = ({
canvas_category,
}: IPipeLineListRequest) => {
const { data, isFetching: loading } = useQuery<{
canvas: IFlow[];
total: number;
}>({
queryKey: [AgentApiAction.FetchAgentList],
initialData: { canvas: [], total: 0 },
gcTime: 0,
queryFn: async () => {
const { data } = await fetchPipeLineList({ canvas_category });
return data?.data ?? [];
},
});
return { data, loading };
};
export const useCancelDataflow = () => {
const {
data,
isPending: loading,
mutateAsync,
} = useMutation({
mutationKey: [AgentApiAction.CancelDataflow],
mutationFn: async (taskId: string) => {
const ret = await agentService.cancelDataflow(taskId);
if (ret?.data?.code === 0) {
message.success('success');
} else {
message.error(ret?.data?.data);
}
return ret?.data?.code;
},
});
return { data, loading, cancelDataflow: mutateAsync };
};
// export const useFetchKnowledgeList = () => {
// const { data, isFetching: loading } = useQuery<IFlow[]>({
// queryKey: [AgentApiAction.FetchAgentList],
// initialData: [],
// gcTime: 0, // https://tanstack.com/query/latest/docs/framework/react/guides/caching?from=reactQueryV3
// queryFn: async () => {
// const { data } = await agentService.listCanvas();
// return data?.data ?? [];
// },
// });
// return { list: data, loading };
// };

View File

@ -13,7 +13,9 @@ import {
} from './logic-hooks'; } from './logic-hooks';
import { useGetKnowledgeSearchParams } from './route-hook'; import { useGetKnowledgeSearchParams } from './route-hook';
export const useFetchNextChunkList = (): ResponseGetType<{ export const useFetchNextChunkList = (
enabled = true,
): ResponseGetType<{
data: IChunk[]; data: IChunk[];
total: number; total: number;
documentInfo: IKnowledgeFile; documentInfo: IKnowledgeFile;
@ -37,6 +39,7 @@ export const useFetchNextChunkList = (): ResponseGetType<{
placeholderData: (previousData: any) => placeholderData: (previousData: any) =>
previousData ?? { data: [], total: 0, documentInfo: {} }, // https://github.com/TanStack/query/issues/8183 previousData ?? { data: [], total: 0, documentInfo: {} }, // https://github.com/TanStack/query/issues/8183
gcTime: 0, gcTime: 0,
enabled,
queryFn: async () => { queryFn: async () => {
const { data } = await kbService.chunk_list({ const { data } = await kbService.chunk_list({
doc_id: documentId, doc_id: documentId,

View File

@ -0,0 +1,91 @@
import message from '@/components/ui/message';
import { IFlow } from '@/interfaces/database/agent';
import dataflowService from '@/services/dataflow-service';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { useTranslation } from 'react-i18next';
import { useParams } from 'umi';
export const enum DataflowApiAction {
ListDataflow = 'listDataflow',
RemoveDataflow = 'removeDataflow',
FetchDataflow = 'fetchDataflow',
RunDataflow = 'runDataflow',
SetDataflow = 'setDataflow',
}
export const useRemoveDataflow = () => {
const queryClient = useQueryClient();
const { t } = useTranslation();
const {
data,
isPending: loading,
mutateAsync,
} = useMutation({
mutationKey: [DataflowApiAction.RemoveDataflow],
mutationFn: async (ids: string[]) => {
const { data } = await dataflowService.removeDataflow({
canvas_ids: ids,
});
if (data.code === 0) {
queryClient.invalidateQueries({
queryKey: [DataflowApiAction.ListDataflow],
});
message.success(t('message.deleted'));
}
return data.code;
},
});
return { data, loading, removeDataflow: mutateAsync };
};
export const useSetDataflow = () => {
const queryClient = useQueryClient();
const { t } = useTranslation();
const {
data,
isPending: loading,
mutateAsync,
} = useMutation({
mutationKey: [DataflowApiAction.SetDataflow],
mutationFn: async (params: Partial<IFlow>) => {
const { data } = await dataflowService.setDataflow(params);
if (data.code === 0) {
queryClient.invalidateQueries({
queryKey: [DataflowApiAction.FetchDataflow],
});
message.success(t(`message.${params.id ? 'modified' : 'created'}`));
}
return data?.code;
},
});
return { data, loading, setDataflow: mutateAsync };
};
export const useFetchDataflow = () => {
const { id } = useParams();
const {
data,
isFetching: loading,
refetch,
} = useQuery<IFlow>({
queryKey: [DataflowApiAction.FetchDataflow, id],
gcTime: 0,
initialData: {} as IFlow,
enabled: !!id,
refetchOnWindowFocus: false,
queryFn: async () => {
const { data } = await dataflowService.fetchDataflow(id);
return data?.data ?? ({} as IFlow);
},
});
return { data, loading, refetch };
};

View File

@ -335,15 +335,18 @@ export const useSetDocumentParser = () => {
mutationKey: [DocumentApiAction.SetDocumentParser], mutationKey: [DocumentApiAction.SetDocumentParser],
mutationFn: async ({ mutationFn: async ({
parserId, parserId,
pipelineId,
documentId, documentId,
parserConfig, parserConfig,
}: { }: {
parserId: string; parserId: string;
pipelineId: string;
documentId: string; documentId: string;
parserConfig: IChangeParserConfigRequestBody; parserConfig: IChangeParserConfigRequestBody;
}) => { }) => {
const { data } = await kbService.document_change_parser({ const { data } = await kbService.document_change_parser({
parser_id: parserId, parser_id: parserId,
pipeline_id: pipelineId,
doc_id: documentId, doc_id: documentId,
parser_config: parserConfig, parser_config: parserConfig,
}); });

View File

@ -31,6 +31,7 @@ export const enum KnowledgeApiAction {
FetchKnowledgeDetail = 'fetchKnowledgeDetail', FetchKnowledgeDetail = 'fetchKnowledgeDetail',
FetchKnowledgeGraph = 'fetchKnowledgeGraph', FetchKnowledgeGraph = 'fetchKnowledgeGraph',
FetchMetadata = 'fetchMetadata', FetchMetadata = 'fetchMetadata',
FetchKnowledgeList = 'fetchKnowledgeList',
RemoveKnowledgeGraph = 'removeKnowledgeGraph', RemoveKnowledgeGraph = 'removeKnowledgeGraph',
} }
@ -238,7 +239,11 @@ export const useUpdateKnowledge = (shouldFetchList = false) => {
return { data, loading, saveKnowledgeConfiguration: mutateAsync }; return { data, loading, saveKnowledgeConfiguration: mutateAsync };
}; };
export const useFetchKnowledgeBaseConfiguration = (refreshCount?: number) => { export const useFetchKnowledgeBaseConfiguration = (props?: {
isEdit?: boolean;
refreshCount?: number;
}) => {
const { isEdit = true, refreshCount } = props || { isEdit: true };
const { id } = useParams(); const { id } = useParams();
const [searchParams] = useSearchParams(); const [searchParams] = useSearchParams();
const knowledgeBaseId = searchParams.get('id') || id; const knowledgeBaseId = searchParams.get('id') || id;
@ -255,10 +260,14 @@ export const useFetchKnowledgeBaseConfiguration = (refreshCount?: number) => {
initialData: {} as IKnowledge, initialData: {} as IKnowledge,
gcTime: 0, gcTime: 0,
queryFn: async () => { queryFn: async () => {
const { data } = await kbService.get_kb_detail({ if (isEdit) {
kb_id: knowledgeBaseId, const { data } = await kbService.get_kb_detail({
}); kb_id: knowledgeBaseId,
return data?.data ?? {}; });
return data?.data ?? {};
} else {
return {};
}
}, },
}); });
@ -323,3 +332,25 @@ export const useRemoveKnowledgeGraph = () => {
return { data, loading, removeKnowledgeGraph: mutateAsync }; return { data, loading, removeKnowledgeGraph: mutateAsync };
}; };
export const useFetchKnowledgeList = (
shouldFilterListWithoutDocument: boolean = false,
): {
list: IKnowledge[];
loading: boolean;
} => {
const { data, isFetching: loading } = useQuery({
queryKey: [KnowledgeApiAction.FetchKnowledgeList],
initialData: [],
gcTime: 0, // https://tanstack.com/query/latest/docs/framework/react/guides/caching?from=reactQueryV3
queryFn: async () => {
const { data } = await listDataset();
const list = data?.data?.kbs ?? [];
return shouldFilterListWithoutDocument
? list.filter((x: IKnowledge) => x.chunk_num > 0)
: list;
},
});
return { list: data, loading };
};

View File

@ -30,6 +30,7 @@ export interface ISwitchForm {
no: string; no: string;
} }
import { AgentCategory } from '@/constants/agent';
import { Edge, Node } from '@xyflow/react'; import { Edge, Node } from '@xyflow/react';
import { IReference, Message } from './chat'; import { IReference, Message } from './chat';
@ -74,6 +75,7 @@ export declare interface IFlow {
permission: string; permission: string;
nickname: string; nickname: string;
operator_permission: number; operator_permission: number;
canvas_category: string;
} }
export interface IFlowTemplate { export interface IFlowTemplate {
@ -265,3 +267,12 @@ export interface IAgentLogMessage {
role: 'user' | 'assistant'; role: 'user' | 'assistant';
id: string; id: string;
} }
export interface IPipeLineListRequest {
page?: number;
page_size?: number;
keywords?: string;
orderby?: string;
desc?: boolean;
canvas_category?: AgentCategory;
}

View File

@ -5,12 +5,15 @@ export interface IDocumentInfo {
create_date: string; create_date: string;
create_time: number; create_time: number;
created_by: string; created_by: string;
nickname: string;
id: string; id: string;
kb_id: string; kb_id: string;
location: string; location: string;
name: string; name: string;
parser_config: IParserConfig; parser_config: IParserConfig;
parser_id: string; parser_id: string;
pipeline_id: string;
pipeline_name: string;
process_begin_at?: string; process_begin_at?: string;
process_duration: number; process_duration: number;
progress: number; progress: number;
@ -19,6 +22,7 @@ export interface IDocumentInfo {
size: number; size: number;
source_type: string; source_type: string;
status: string; status: string;
suffix: string;
thumbnail: string; thumbnail: string;
token_num: number; token_num: number;
type: string; type: string;

View File

@ -14,6 +14,9 @@ export interface IKnowledge {
name: string; name: string;
parser_config: ParserConfig; parser_config: ParserConfig;
parser_id: string; parser_id: string;
pipeline_id: string;
pipeline_name: string;
pipeline_avatar: string;
permission: string; permission: string;
similarity_threshold: number; similarity_threshold: number;
status: string; status: string;
@ -26,6 +29,10 @@ export interface IKnowledge {
nickname: string; nickname: string;
operator_permission: number; operator_permission: number;
size: number; size: number;
raptor_task_finish_at?: string;
raptor_task_id?: string;
mindmap_task_finish_at?: string;
mindmap_task_id?: string;
} }
export interface IKnowledgeResult { export interface IKnowledgeResult {

View File

@ -7,6 +7,7 @@ export interface IChangeParserConfigRequestBody {
export interface IChangeParserRequestBody { export interface IChangeParserRequestBody {
parser_id: string; parser_id: string;
pipeline_id: string;
doc_id: string; doc_id: string;
parser_config: IChangeParserConfigRequestBody; parser_config: IChangeParserConfigRequestBody;
} }

Some files were not shown because too many files have changed in this diff Show More