mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-06 02:25:05 +08:00
Compare commits
5 Commits
6ff7cfe005
...
63781bde3f
| Author | SHA1 | Date | |
|---|---|---|---|
| 63781bde3f | |||
| 91d6fb8061 | |||
| 45f52e85d7 | |||
| 9aa8cfb73a | |||
| 79ca25ec7e |
@ -418,12 +418,10 @@ def setting():
|
|||||||
return get_data_error_result(message="canvas not found.")
|
return get_data_error_result(message="canvas not found.")
|
||||||
flow = flow.to_dict()
|
flow = flow.to_dict()
|
||||||
flow["title"] = req["title"]
|
flow["title"] = req["title"]
|
||||||
if req["description"]:
|
|
||||||
flow["description"] = req["description"]
|
for key in ["description", "permission", "avatar"]:
|
||||||
if req["permission"]:
|
if value := req.get(key):
|
||||||
flow["permission"] = req["permission"]
|
flow[key] = value
|
||||||
if req["avatar"]:
|
|
||||||
flow["avatar"] = req["avatar"]
|
|
||||||
|
|
||||||
num= UserCanvasService.update_by_id(req["id"], flow)
|
num= UserCanvasService.update_by_id(req["id"], flow)
|
||||||
return get_json_result(data=num)
|
return get_json_result(data=num)
|
||||||
|
|||||||
353
api/apps/dataflow_app.py
Normal file
353
api/apps/dataflow_app.py
Normal file
@ -0,0 +1,353 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import trio
|
||||||
|
from flask import request
|
||||||
|
from flask_login import current_user, login_required
|
||||||
|
|
||||||
|
from agent.canvas import Canvas
|
||||||
|
from agent.component import LLM
|
||||||
|
from api.db import CanvasCategory, FileType
|
||||||
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.task_service import queue_dataflow
|
||||||
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from api.settings import RetCode
|
||||||
|
from api.utils import get_uuid
|
||||||
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
|
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||||
|
from rag.flow.pipeline import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/templates", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def templates():
|
||||||
|
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)])
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def canvas_list():
|
||||||
|
return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1))
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("canvas_ids")
|
||||||
|
@login_required
|
||||||
|
def rm():
|
||||||
|
for i in request.json["canvas_ids"]:
|
||||||
|
if not UserCanvasService.accessible(i, current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
UserCanvasService.delete_by_id(i)
|
||||||
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("dsl", "title")
|
||||||
|
@login_required
|
||||||
|
def save():
|
||||||
|
req = request.json
|
||||||
|
if not isinstance(req["dsl"], str):
|
||||||
|
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||||
|
req["dsl"] = json.loads(req["dsl"])
|
||||||
|
req["canvas_category"] = CanvasCategory.DataFlow
|
||||||
|
if "id" not in req:
|
||||||
|
req["user_id"] = current_user.id
|
||||||
|
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow):
|
||||||
|
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||||
|
req["id"] = get_uuid()
|
||||||
|
|
||||||
|
if not UserCanvasService.save(**req):
|
||||||
|
return get_data_error_result(message="Fail to save canvas.")
|
||||||
|
else:
|
||||||
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
UserCanvasService.update_by_id(req["id"], req)
|
||||||
|
# save version
|
||||||
|
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||||
|
UserCanvasVersionService.delete_all_versions(req["id"])
|
||||||
|
return get_json_result(data=req)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/get/<canvas_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def get(canvas_id):
|
||||||
|
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||||
|
return get_json_result(data=c)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/run", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("id")
|
||||||
|
@login_required
|
||||||
|
def run():
|
||||||
|
req = request.json
|
||||||
|
flow_id = req.get("id", "")
|
||||||
|
doc_id = req.get("doc_id", "")
|
||||||
|
if not all([flow_id, doc_id]):
|
||||||
|
return get_data_error_result(message="id and doc_id are required.")
|
||||||
|
|
||||||
|
if not DocumentService.get_by_id(doc_id):
|
||||||
|
return get_data_error_result(message=f"Document for {doc_id} not found.")
|
||||||
|
|
||||||
|
user_id = req.get("user_id", current_user.id)
|
||||||
|
if not UserCanvasService.accessible(flow_id, current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
e, cvs = UserCanvasService.get_by_id(flow_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
|
if not isinstance(cvs.dsl, str):
|
||||||
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
|
||||||
|
task_id = get_uuid()
|
||||||
|
|
||||||
|
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0)
|
||||||
|
if not ok:
|
||||||
|
return server_error_response(error_message)
|
||||||
|
|
||||||
|
return get_json_result(data={"task_id": task_id, "flow_id": flow_id})
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/reset", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("id")
|
||||||
|
@login_required
|
||||||
|
def reset():
|
||||||
|
req = request.json
|
||||||
|
flow_id = req.get("id", "")
|
||||||
|
if not flow_id:
|
||||||
|
return get_data_error_result(message="id is required.")
|
||||||
|
|
||||||
|
if not UserCanvasService.accessible(flow_id, current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
task_id = req.get("task_id", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
|
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id)
|
||||||
|
dataflow.reset()
|
||||||
|
req["dsl"] = json.loads(str(dataflow))
|
||||||
|
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
||||||
|
return get_json_result(data=req["dsl"])
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||||
|
def upload(canvas_id):
|
||||||
|
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
|
user_id = cvs["user_id"]
|
||||||
|
|
||||||
|
def structured(filename, filetype, blob, content_type):
|
||||||
|
nonlocal user_id
|
||||||
|
if filetype == FileType.PDF.value:
|
||||||
|
blob = read_potential_broken_pdf(blob)
|
||||||
|
|
||||||
|
location = get_uuid()
|
||||||
|
FileService.put_blob(user_id, location, blob)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": location,
|
||||||
|
"name": filename,
|
||||||
|
"size": sys.getsizeof(blob),
|
||||||
|
"extension": filename.split(".")[-1].lower(),
|
||||||
|
"mime_type": content_type,
|
||||||
|
"created_by": user_id,
|
||||||
|
"created_at": time.time(),
|
||||||
|
"preview_url": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.args.get("url"):
|
||||||
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = request.args.get("url")
|
||||||
|
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||||
|
|
||||||
|
async def adownload():
|
||||||
|
browser_config = BrowserConfig(
|
||||||
|
headless=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
|
crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False)
|
||||||
|
result: CrawlResult = await crawler.arun(url=url, config=crawler_config)
|
||||||
|
return result
|
||||||
|
|
||||||
|
page = trio.run(adownload())
|
||||||
|
if page.pdf:
|
||||||
|
if filename.split(".")[-1].lower() != "pdf":
|
||||||
|
filename += ".pdf"
|
||||||
|
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||||
|
|
||||||
|
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
file = request.files["file"]
|
||||||
|
try:
|
||||||
|
DocumentService.check_doc_health(user_id, file.filename)
|
||||||
|
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/input_form", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def input_form():
|
||||||
|
flow_id = request.args.get("id")
|
||||||
|
cpn_id = request.args.get("component_id")
|
||||||
|
try:
|
||||||
|
e, user_canvas = UserCanvasService.get_by_id(flow_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
if not UserCanvasService.query(user_id=current_user.id, id=flow_id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="")
|
||||||
|
|
||||||
|
return get_json_result(data=dataflow.get_component_input_form(cpn_id))
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/debug", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("id", "component_id", "params")
|
||||||
|
@login_required
|
||||||
|
def debug():
|
||||||
|
req = request.json
|
||||||
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
try:
|
||||||
|
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||||
|
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||||
|
canvas.reset()
|
||||||
|
canvas.message_id = get_uuid()
|
||||||
|
component = canvas.get_component(req["component_id"])["obj"]
|
||||||
|
component.reset()
|
||||||
|
|
||||||
|
if isinstance(component, LLM):
|
||||||
|
component.set_debug_inputs(req["params"])
|
||||||
|
component.invoke(**{k: o["value"] for k, o in req["params"].items()})
|
||||||
|
outputs = component.output()
|
||||||
|
for k in outputs.keys():
|
||||||
|
if isinstance(outputs[k], partial):
|
||||||
|
txt = ""
|
||||||
|
for c in outputs[k]():
|
||||||
|
txt += c
|
||||||
|
outputs[k] = txt
|
||||||
|
return get_json_result(data=outputs)
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
# api get list version dsl of canvas
|
||||||
|
@manager.route("/getlistversion/<canvas_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def getlistversion(canvas_id):
|
||||||
|
try:
|
||||||
|
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1)
|
||||||
|
return get_json_result(data=list)
|
||||||
|
except Exception as e:
|
||||||
|
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# api get version dsl of canvas
|
||||||
|
@manager.route("/getversion/<version_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def getversion(version_id):
|
||||||
|
try:
|
||||||
|
e, version = UserCanvasVersionService.get_by_id(version_id)
|
||||||
|
if version:
|
||||||
|
return get_json_result(data=version.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
return get_json_result(data=f"Error getting history file: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/listteam", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def list_canvas():
|
||||||
|
keywords = request.args.get("keywords", "")
|
||||||
|
page_number = int(request.args.get("page", 1))
|
||||||
|
items_per_page = int(request.args.get("page_size", 150))
|
||||||
|
orderby = request.args.get("orderby", "create_time")
|
||||||
|
desc = request.args.get("desc", True)
|
||||||
|
try:
|
||||||
|
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||||
|
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||||
|
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow
|
||||||
|
)
|
||||||
|
return get_json_result(data={"canvas": canvas, "total": total})
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||||
|
@validate_request("id", "title", "permission")
|
||||||
|
@login_required
|
||||||
|
def setting():
|
||||||
|
req = request.json
|
||||||
|
req["user_id"] = current_user.id
|
||||||
|
|
||||||
|
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||||
|
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
e, flow = UserCanvasService.get_by_id(req["id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
flow = flow.to_dict()
|
||||||
|
flow["title"] = req["title"]
|
||||||
|
for key in ("description", "permission", "avatar"):
|
||||||
|
if value := req.get(key):
|
||||||
|
flow[key] = value
|
||||||
|
|
||||||
|
num = UserCanvasService.update_by_id(req["id"], flow)
|
||||||
|
return get_json_result(data=num)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/trace", methods=["GET"]) # noqa: F821
|
||||||
|
def trace():
|
||||||
|
dataflow_id = request.args.get("dataflow_id")
|
||||||
|
task_id = request.args.get("task_id")
|
||||||
|
if not all([dataflow_id, task_id]):
|
||||||
|
return get_data_error_result(message="dataflow_id and task_id are required.")
|
||||||
|
|
||||||
|
e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="dataflow not found.")
|
||||||
|
|
||||||
|
dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False)
|
||||||
|
dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id)
|
||||||
|
log = dataflow.fetch_logs()
|
||||||
|
|
||||||
|
return get_json_result(data=log)
|
||||||
@ -24,7 +24,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from api.db.services.dialog_service import meta_filter
|
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
|
||||||
@ -101,19 +101,4 @@ def retrieval(tenant_id):
|
|||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
def convert_conditions(metadata_condition):
|
|
||||||
if metadata_condition is None:
|
|
||||||
metadata_condition = {}
|
|
||||||
op_mapping = {
|
|
||||||
"is": "=",
|
|
||||||
"not is": "≠"
|
|
||||||
}
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
|
||||||
"key": cond["name"],
|
|
||||||
"value": cond["value"]
|
|
||||||
}
|
|
||||||
for cond in metadata_condition.get("conditions", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|||||||
@ -35,8 +35,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.task_service import TaskService, queue_tasks
|
from api.db.services.task_service import TaskService, queue_tasks
|
||||||
from api.db.services.dialog_service import meta_filter
|
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||||
from api.apps.sdk.dify_retrieval import convert_conditions
|
|
||||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
|
|||||||
@ -21,11 +21,9 @@ from copy import deepcopy
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
from langfuse import Langfuse
|
from langfuse import Langfuse
|
||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from agentic_reasoning import DeepResearcher
|
from agentic_reasoning import DeepResearcher
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.db import LLMType, ParserType, StatusEnum
|
from api.db import LLMType, ParserType, StatusEnum
|
||||||
@ -255,6 +253,23 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
|||||||
return answer, idx
|
return answer, idx
|
||||||
|
|
||||||
|
|
||||||
|
def convert_conditions(metadata_condition):
|
||||||
|
if metadata_condition is None:
|
||||||
|
metadata_condition = {}
|
||||||
|
op_mapping = {
|
||||||
|
"is": "=",
|
||||||
|
"not is": "≠"
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||||
|
"key": cond["name"],
|
||||||
|
"value": cond["value"]
|
||||||
|
}
|
||||||
|
for cond in metadata_condition.get("conditions", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def meta_filter(metas: dict, filters: list[dict]):
|
def meta_filter(metas: dict, filters: list[dict]):
|
||||||
doc_ids = set([])
|
doc_ids = set([])
|
||||||
|
|
||||||
|
|||||||
@ -54,15 +54,15 @@ def trim_header_by_lines(text: str, max_length) -> str:
|
|||||||
|
|
||||||
class TaskService(CommonService):
|
class TaskService(CommonService):
|
||||||
"""Service class for managing document processing tasks.
|
"""Service class for managing document processing tasks.
|
||||||
|
|
||||||
This class extends CommonService to provide specialized functionality for document
|
This class extends CommonService to provide specialized functionality for document
|
||||||
processing task management, including task creation, progress tracking, and chunk
|
processing task management, including task creation, progress tracking, and chunk
|
||||||
management. It handles various document types (PDF, Excel, etc.) and manages their
|
management. It handles various document types (PDF, Excel, etc.) and manages their
|
||||||
processing lifecycle.
|
processing lifecycle.
|
||||||
|
|
||||||
The class implements a robust task queue system with retry mechanisms and progress
|
The class implements a robust task queue system with retry mechanisms and progress
|
||||||
tracking, supporting both synchronous and asynchronous task execution.
|
tracking, supporting both synchronous and asynchronous task execution.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
model: The Task model class for database operations.
|
model: The Task model class for database operations.
|
||||||
"""
|
"""
|
||||||
@ -72,14 +72,14 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_task(cls, task_id):
|
def get_task(cls, task_id):
|
||||||
"""Retrieve detailed task information by task ID.
|
"""Retrieve detailed task information by task ID.
|
||||||
|
|
||||||
This method fetches comprehensive task details including associated document,
|
This method fetches comprehensive task details including associated document,
|
||||||
knowledge base, and tenant information. It also handles task retry logic and
|
knowledge base, and tenant information. It also handles task retry logic and
|
||||||
progress updates.
|
progress updates.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_id (str): The unique identifier of the task to retrieve.
|
task_id (str): The unique identifier of the task to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Task details dictionary containing all task information and related metadata.
|
dict: Task details dictionary containing all task information and related metadata.
|
||||||
Returns None if task is not found or has exceeded retry limit.
|
Returns None if task is not found or has exceeded retry limit.
|
||||||
@ -139,13 +139,13 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_tasks(cls, doc_id: str):
|
def get_tasks(cls, doc_id: str):
|
||||||
"""Retrieve all tasks associated with a document.
|
"""Retrieve all tasks associated with a document.
|
||||||
|
|
||||||
This method fetches all processing tasks for a given document, ordered by page
|
This method fetches all processing tasks for a given document, ordered by page
|
||||||
number and creation time. It includes task progress and chunk information.
|
number and creation time. It includes task progress and chunk information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
doc_id (str): The unique identifier of the document.
|
doc_id (str): The unique identifier of the document.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[dict]: List of task dictionaries containing task details.
|
list[dict]: List of task dictionaries containing task details.
|
||||||
Returns None if no tasks are found.
|
Returns None if no tasks are found.
|
||||||
@ -170,10 +170,10 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
def update_chunk_ids(cls, id: str, chunk_ids: str):
|
||||||
"""Update the chunk IDs associated with a task.
|
"""Update the chunk IDs associated with a task.
|
||||||
|
|
||||||
This method updates the chunk_ids field of a task, which stores the IDs of
|
This method updates the chunk_ids field of a task, which stores the IDs of
|
||||||
processed document chunks in a space-separated string format.
|
processed document chunks in a space-separated string format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id (str): The unique identifier of the task.
|
id (str): The unique identifier of the task.
|
||||||
chunk_ids (str): Space-separated string of chunk identifiers.
|
chunk_ids (str): Space-separated string of chunk identifiers.
|
||||||
@ -184,11 +184,11 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_ongoing_doc_name(cls):
|
def get_ongoing_doc_name(cls):
|
||||||
"""Get names of documents that are currently being processed.
|
"""Get names of documents that are currently being processed.
|
||||||
|
|
||||||
This method retrieves information about documents that are in the processing state,
|
This method retrieves information about documents that are in the processing state,
|
||||||
including their locations and associated IDs. It uses database locking to ensure
|
including their locations and associated IDs. It uses database locking to ensure
|
||||||
thread safety when accessing the task information.
|
thread safety when accessing the task information.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
|
list[tuple]: A list of tuples, each containing (parent_id/kb_id, location)
|
||||||
for documents currently being processed. Returns empty list if
|
for documents currently being processed. Returns empty list if
|
||||||
@ -238,14 +238,14 @@ class TaskService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def do_cancel(cls, id):
|
def do_cancel(cls, id):
|
||||||
"""Check if a task should be cancelled based on its document status.
|
"""Check if a task should be cancelled based on its document status.
|
||||||
|
|
||||||
This method determines whether a task should be cancelled by checking the
|
This method determines whether a task should be cancelled by checking the
|
||||||
associated document's run status and progress. A task should be cancelled
|
associated document's run status and progress. A task should be cancelled
|
||||||
if its document is marked for cancellation or has negative progress.
|
if its document is marked for cancellation or has negative progress.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
id (str): The unique identifier of the task to check.
|
id (str): The unique identifier of the task to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the task should be cancelled, False otherwise.
|
bool: True if the task should be cancelled, False otherwise.
|
||||||
"""
|
"""
|
||||||
@ -311,18 +311,18 @@ class TaskService(CommonService):
|
|||||||
|
|
||||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||||
"""Create and queue document processing tasks.
|
"""Create and queue document processing tasks.
|
||||||
|
|
||||||
This function creates processing tasks for a document based on its type and configuration.
|
This function creates processing tasks for a document based on its type and configuration.
|
||||||
It handles different document types (PDF, Excel, etc.) differently and manages task
|
It handles different document types (PDF, Excel, etc.) differently and manages task
|
||||||
chunking and configuration. It also implements task reuse optimization by checking
|
chunking and configuration. It also implements task reuse optimization by checking
|
||||||
for previously completed tasks.
|
for previously completed tasks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
doc (dict): Document dictionary containing metadata and configuration.
|
doc (dict): Document dictionary containing metadata and configuration.
|
||||||
bucket (str): Storage bucket name where the document is stored.
|
bucket (str): Storage bucket name where the document is stored.
|
||||||
name (str): File name of the document.
|
name (str): File name of the document.
|
||||||
priority (int, optional): Priority level for task queueing (default is 0).
|
priority (int, optional): Priority level for task queueing (default is 0).
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
- For PDF documents, tasks are created per page range based on configuration
|
- For PDF documents, tasks are created per page range based on configuration
|
||||||
- For Excel documents, tasks are created per row range
|
- For Excel documents, tasks are created per row range
|
||||||
@ -410,19 +410,19 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
|||||||
|
|
||||||
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
|
||||||
"""Attempt to reuse chunks from previous tasks for optimization.
|
"""Attempt to reuse chunks from previous tasks for optimization.
|
||||||
|
|
||||||
This function checks if chunks from previously completed tasks can be reused for
|
This function checks if chunks from previously completed tasks can be reused for
|
||||||
the current task, which can significantly improve processing efficiency. It matches
|
the current task, which can significantly improve processing efficiency. It matches
|
||||||
tasks based on page ranges and configuration digests.
|
tasks based on page ranges and configuration digests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (dict): Current task dictionary to potentially reuse chunks for.
|
task (dict): Current task dictionary to potentially reuse chunks for.
|
||||||
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
|
prev_tasks (list[dict]): List of previous task dictionaries to check for reuse.
|
||||||
chunking_config (dict): Configuration dictionary for chunk processing.
|
chunking_config (dict): Configuration dictionary for chunk processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
|
int: Number of chunks successfully reused. Returns 0 if no chunks could be reused.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Chunks can only be reused if:
|
Chunks can only be reused if:
|
||||||
- A previous task exists with matching page range and configuration digest
|
- A previous task exists with matching page range and configuration digest
|
||||||
@ -470,3 +470,39 @@ def has_canceled(task_id):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def queue_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, priority: int, callback=None) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Returns a tuple (success: bool, error_message: str).
|
||||||
|
"""
|
||||||
|
_ = callback
|
||||||
|
|
||||||
|
task = dict(
|
||||||
|
id=get_uuid() if not task_id else task_id,
|
||||||
|
doc_id=doc_id,
|
||||||
|
from_page=0,
|
||||||
|
to_page=100000000,
|
||||||
|
task_type="dataflow",
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
|
||||||
|
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
||||||
|
|
||||||
|
kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||||
|
if not kb_id:
|
||||||
|
return False, f"Can't find KB of this document: {doc_id}"
|
||||||
|
|
||||||
|
task["kb_id"] = kb_id
|
||||||
|
task["tenant_id"] = tenant_id
|
||||||
|
task["task_type"] = "dataflow"
|
||||||
|
task["dsl"] = dsl
|
||||||
|
task["dataflow_id"] = get_uuid() if not flow_id else flow_id
|
||||||
|
|
||||||
|
if not REDIS_CONN.queue_product(
|
||||||
|
get_svr_queue_name(priority), message=task
|
||||||
|
):
|
||||||
|
return False, "Can't access Redis. Please check the Redis' status."
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|||||||
@ -2690,21 +2690,21 @@
|
|||||||
"status": "1",
|
"status": "1",
|
||||||
"llm": [
|
"llm": [
|
||||||
{
|
{
|
||||||
"llm_name": "Qwen3-Embedding-8B",
|
"llm_name": "Qwen/Qwen3-Embedding-8B",
|
||||||
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
"model_type": "embedding",
|
"model_type": "embedding",
|
||||||
"is_tools": false
|
"is_tools": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "Qwen3-Embedding-4B",
|
"llm_name": "Qwen/Qwen3-Embedding-4B",
|
||||||
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
"model_type": "embedding",
|
"model_type": "embedding",
|
||||||
"is_tools": false
|
"is_tools": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"llm_name": "Qwen3-Embedding-0.6B",
|
"llm_name": "Qwen/Qwen3-Embedding-0.6B",
|
||||||
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
"tags": "TEXT EMBEDDING,TEXT RE-RANK,32k",
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
"model_type": "embedding",
|
"model_type": "embedding",
|
||||||
|
|||||||
@ -14,36 +14,45 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import os
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import pkgutil
|
||||||
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Dict, Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
_package_path = os.path.dirname(__file__)
|
|
||||||
__all_classes: Dict[str, Type] = {}
|
__all_classes: Dict[str, Type] = {}
|
||||||
|
|
||||||
def _import_submodules() -> None:
|
_pkg_dir = Path(__file__).resolve().parent
|
||||||
for filename in os.listdir(_package_path): # noqa: F821
|
_pkg_name = __name__
|
||||||
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
|
|
||||||
continue
|
|
||||||
module_name = filename[:-3]
|
|
||||||
|
|
||||||
|
|
||||||
|
def _should_skip_module(mod_name: str) -> bool:
|
||||||
|
leaf = mod_name.rsplit(".", 1)[-1]
|
||||||
|
return leaf in {"__init__"} or leaf.startswith("__") or leaf.startswith("_") or leaf.startswith("base")
|
||||||
|
|
||||||
|
|
||||||
|
def _import_submodules() -> None:
|
||||||
|
for modinfo in pkgutil.walk_packages([str(_pkg_dir)], prefix=_pkg_name + "."): # noqa: F821
|
||||||
|
mod_name = modinfo.name
|
||||||
|
if _should_skip_module(mod_name): # noqa: F821
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(f".{module_name}", package=__name__)
|
module = importlib.import_module(mod_name)
|
||||||
_extract_classes_from_module(module) # noqa: F821
|
_extract_classes_from_module(module) # noqa: F821
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"Warning: Failed to import module {module_name}: {str(e)}")
|
print(f"Warning: Failed to import module {mod_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _extract_classes_from_module(module: ModuleType) -> None:
|
def _extract_classes_from_module(module: ModuleType) -> None:
|
||||||
for name, obj in inspect.getmembers(module):
|
for name, obj in inspect.getmembers(module):
|
||||||
if (inspect.isclass(obj) and
|
if inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_"):
|
||||||
obj.__module__ == module.__name__ and not name.startswith("_")):
|
|
||||||
__all_classes[name] = obj
|
__all_classes[name] = obj
|
||||||
globals()[name] = obj
|
globals()[name] = obj
|
||||||
|
|
||||||
|
|
||||||
_import_submodules()
|
_import_submodules()
|
||||||
|
|
||||||
__all__ = list(__all_classes.keys()) + ["__all_classes"]
|
__all__ = list(__all_classes.keys()) + ["__all_classes"]
|
||||||
|
|
||||||
del _package_path, _import_submodules, _extract_classes_from_module
|
del _pkg_dir, _pkg_name, _import_submodules, _extract_classes_from_module
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -13,13 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import time
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
from agent.component.base import ComponentParamBase, ComponentBase
|
|
||||||
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
|
|
||||||
|
|
||||||
@ -31,14 +33,16 @@ class ProcessParamBase(ComponentParamBase):
|
|||||||
|
|
||||||
|
|
||||||
class ProcessBase(ComponentBase):
|
class ProcessBase(ComponentBase):
|
||||||
|
|
||||||
def __init__(self, pipeline, id, param: ProcessParamBase):
|
def __init__(self, pipeline, id, param: ProcessParamBase):
|
||||||
super().__init__(pipeline, id, param)
|
super().__init__(pipeline, id, param)
|
||||||
self.callback = partial(self._canvas.callback, self.component_name)
|
if hasattr(self._canvas, "callback"):
|
||||||
|
self.callback = partial(self._canvas.callback, self.component_name)
|
||||||
|
else:
|
||||||
|
self.callback = partial(lambda *args, **kwargs: None, self.component_name)
|
||||||
|
|
||||||
async def invoke(self, **kwargs) -> dict[str, Any]:
|
async def invoke(self, **kwargs) -> dict[str, Any]:
|
||||||
self.set_output("_created_time", time.perf_counter())
|
self.set_output("_created_time", time.perf_counter())
|
||||||
for k,v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
self.set_output(k, v)
|
self.set_output(k, v)
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(self._param.timeout):
|
with trio.fail_after(self._param.timeout):
|
||||||
@ -54,6 +58,6 @@ class ProcessBase(ComponentBase):
|
|||||||
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
|
||||||
return self.output()
|
return self.output()
|
||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
|
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
15
rag/flow/chunker/__init__.py
Normal file
15
rag/flow/chunker/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -13,12 +13,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||||
from graphrag.utils import get_llm_cache, chat_limiter, set_llm_cache
|
from graphrag.utils import chat_limiter, get_llm_cache, set_llm_cache
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.flow.chunker.schema import ChunkerFromUpstream
|
||||||
from rag.nlp import naive_merge, naive_merge_with_images
|
from rag.nlp import naive_merge, naive_merge_with_images
|
||||||
from rag.prompts.prompts import keyword_extraction, question_proposal
|
from rag.prompts.prompts import keyword_extraction, question_proposal
|
||||||
|
|
||||||
@ -26,7 +29,23 @@ from rag.prompts.prompts import keyword_extraction, question_proposal
|
|||||||
class ChunkerParam(ProcessParamBase):
|
class ChunkerParam(ProcessParamBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.method_options = ["general", "q&a", "resume", "manual", "table", "paper", "book", "laws", "presentation", "one"]
|
self.method_options = [
|
||||||
|
# General
|
||||||
|
"general",
|
||||||
|
"onetable",
|
||||||
|
# Customer Service
|
||||||
|
"q&a",
|
||||||
|
"manual",
|
||||||
|
# Recruitment
|
||||||
|
"resume",
|
||||||
|
# Education & Research
|
||||||
|
"book",
|
||||||
|
"paper",
|
||||||
|
"laws",
|
||||||
|
"presentation",
|
||||||
|
# Other
|
||||||
|
# "Tag" # TODO: Other method
|
||||||
|
]
|
||||||
self.method = "general"
|
self.method = "general"
|
||||||
self.chunk_token_size = 512
|
self.chunk_token_size = 512
|
||||||
self.delimiter = "\n"
|
self.delimiter = "\n"
|
||||||
@ -35,10 +54,7 @@ class ChunkerParam(ProcessParamBase):
|
|||||||
self.auto_keywords = 0
|
self.auto_keywords = 0
|
||||||
self.auto_questions = 0
|
self.auto_questions = 0
|
||||||
self.tag_sets = []
|
self.tag_sets = []
|
||||||
self.llm_setting = {
|
self.llm_setting = {"llm_name": "", "lang": "Chinese"}
|
||||||
"llm_name": "",
|
|
||||||
"lang": "Chinese"
|
|
||||||
}
|
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
|
self.check_valid_value(self.method.lower(), "Chunk method abnormal.", self.method_options)
|
||||||
@ -48,53 +64,79 @@ class ChunkerParam(ProcessParamBase):
|
|||||||
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
|
self.check_nonnegative_number(self.auto_questions, "Auto-question value: (0, 10]")
|
||||||
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class Chunker(ProcessBase):
|
class Chunker(ProcessBase):
|
||||||
component_name = "Chunker"
|
component_name = "Chunker"
|
||||||
|
|
||||||
def _general(self, **kwargs):
|
def _general(self, from_upstream: ChunkerFromUpstream):
|
||||||
self.callback(random.randint(1,5)/100., "Start to chunk via `General`.")
|
self.callback(random.randint(1, 5) / 100.0, "Start to chunk via `General`.")
|
||||||
if kwargs.get("output_format") in ["markdown", "text"]:
|
if from_upstream.output_format in ["markdown", "text"]:
|
||||||
cks = naive_merge(kwargs.get(kwargs["output_format"]), self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
|
if from_upstream.output_format == "markdown":
|
||||||
|
payload = from_upstream.markdown_result
|
||||||
|
else: # == "text"
|
||||||
|
payload = from_upstream.text_result
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
payload = ""
|
||||||
|
|
||||||
|
cks = naive_merge(
|
||||||
|
payload,
|
||||||
|
self._param.chunk_token_size,
|
||||||
|
self._param.delimiter,
|
||||||
|
self._param.overlapped_percent,
|
||||||
|
)
|
||||||
return [{"text": c} for c in cks]
|
return [{"text": c} for c in cks]
|
||||||
|
|
||||||
sections, section_images = [], []
|
sections, section_images = [], []
|
||||||
for o in kwargs["json"]:
|
for o in from_upstream.json_result or []:
|
||||||
sections.append((o["text"], o.get("position_tag","")))
|
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||||
section_images.append(o.get("image"))
|
section_images.append(o.get("image"))
|
||||||
|
|
||||||
chunks, images = naive_merge_with_images(sections, section_images,self._param.chunk_token_size, self._param.delimiter, self._param.overlapped_percent)
|
chunks, images = naive_merge_with_images(
|
||||||
return [{
|
sections,
|
||||||
"text": RAGFlowPdfParser.remove_tag(c),
|
section_images,
|
||||||
"image": img,
|
self._param.chunk_token_size,
|
||||||
"positions": RAGFlowPdfParser.extract_positions(c)
|
self._param.delimiter,
|
||||||
} for c,img in zip(chunks,images)]
|
self._param.overlapped_percent,
|
||||||
|
)
|
||||||
|
|
||||||
def _q_and_a(self, **kwargs):
|
return [
|
||||||
|
{
|
||||||
|
"text": RAGFlowPdfParser.remove_tag(c),
|
||||||
|
"image": img,
|
||||||
|
"positions": RAGFlowPdfParser.extract_positions(c),
|
||||||
|
}
|
||||||
|
for c, img in zip(chunks, images)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _q_and_a(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _resume(self, **kwargs):
|
def _resume(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _manual(self, **kwargs):
|
def _manual(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _table(self, **kwargs):
|
def _table(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _paper(self, **kwargs):
|
def _paper(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _book(self, **kwargs):
|
def _book(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _laws(self, **kwargs):
|
def _laws(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _presentation(self, **kwargs):
|
def _presentation(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _one(self, **kwargs):
|
def _one(self, from_upstream: ChunkerFromUpstream):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
@ -110,7 +152,14 @@ class Chunker(ProcessBase):
|
|||||||
"presentation": self._presentation,
|
"presentation": self._presentation,
|
||||||
"one": self._one,
|
"one": self._one,
|
||||||
}
|
}
|
||||||
chunks = function_map[self._param.method](**kwargs)
|
|
||||||
|
try:
|
||||||
|
from_upstream = ChunkerFromUpstream.model_validate(kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
chunks = function_map[self._param.method](from_upstream)
|
||||||
llm_setting = self._param.llm_setting
|
llm_setting = self._param.llm_setting
|
||||||
|
|
||||||
async def auto_keywords():
|
async def auto_keywords():
|
||||||
37
rag/flow/chunker/schema.py
Normal file
37
rag/flow/chunker/schema.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkerFromUpstream(BaseModel):
|
||||||
|
created_time: float | None = Field(default=None, alias="_created_time")
|
||||||
|
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||||
|
|
||||||
|
name: str
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||||
|
|
||||||
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
|
html_result: str | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
# def to_dict(self, *, exclude_none: bool = True) -> dict:
|
||||||
|
# return self.model_dump(by_alias=True, exclude_none=exclude_none)
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -27,6 +27,9 @@ class FileParam(ProcessParamBase):
|
|||||||
def check(self):
|
def check(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class File(ProcessBase):
|
class File(ProcessBase):
|
||||||
component_name = "File"
|
component_name = "File"
|
||||||
@ -1,107 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import random
|
|
||||||
import trio
|
|
||||||
from api.db import LLMType
|
|
||||||
from api.db.services.llm_service import LLMBundle
|
|
||||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser, PlainParser, VisionParser
|
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
|
||||||
from rag.llm.cv_model import Base as VLM
|
|
||||||
from deepdoc.parser import ExcelParser
|
|
||||||
|
|
||||||
|
|
||||||
class ParserParam(ProcessParamBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.setups = {
|
|
||||||
"pdf": {
|
|
||||||
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
|
||||||
"vlm_name": "",
|
|
||||||
"lang": "Chinese",
|
|
||||||
"suffix": ["pdf"],
|
|
||||||
"output_format": "json"
|
|
||||||
},
|
|
||||||
"excel": {
|
|
||||||
"output_format": "html"
|
|
||||||
},
|
|
||||||
"ppt": {},
|
|
||||||
"image": {
|
|
||||||
"parse_method": "ocr"
|
|
||||||
},
|
|
||||||
"email": {},
|
|
||||||
"text": {},
|
|
||||||
"audio": {},
|
|
||||||
"video": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
def check(self):
|
|
||||||
if self.setups["pdf"].get("parse_method") not in ["deepdoc", "plain_text"]:
|
|
||||||
assert self.setups["pdf"].get("vlm_name"), "No VLM specified."
|
|
||||||
assert self.setups["pdf"].get("lang"), "No language specified."
|
|
||||||
|
|
||||||
|
|
||||||
class Parser(ProcessBase):
|
|
||||||
component_name = "Parser"
|
|
||||||
|
|
||||||
def _pdf(self, blob):
|
|
||||||
self.callback(random.randint(1,5)/100., "Start to work on a PDF.")
|
|
||||||
conf = self._param.setups["pdf"]
|
|
||||||
self.set_output("output_format", conf["output_format"])
|
|
||||||
if conf.get("parse_method") == "deepdoc":
|
|
||||||
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
|
||||||
elif conf.get("parse_method") == "plain_text":
|
|
||||||
lines,_ = PlainParser()(blob)
|
|
||||||
bboxes = [{"text": t} for t,_ in lines]
|
|
||||||
else:
|
|
||||||
assert conf.get("vlm_name")
|
|
||||||
vision_model = LLMBundle(self._canvas.tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self.setups["pdf"].get("lang"))
|
|
||||||
lines, _ = VisionParser(vision_model=vision_model)(bin, callback=self.callback)
|
|
||||||
bboxes = []
|
|
||||||
for t, poss in lines:
|
|
||||||
pn, x0, x1, top, bott = poss.split(" ")
|
|
||||||
bboxes.append({"page_number": int(pn), "x0": int(x0), "x1": int(x1), "top": int(top), "bottom": int(bott), "text": t})
|
|
||||||
|
|
||||||
self.set_output("json", bboxes)
|
|
||||||
mkdn = ""
|
|
||||||
for b in bboxes:
|
|
||||||
if b.get("layout_type", "") == "title":
|
|
||||||
mkdn += "\n## "
|
|
||||||
if b.get("layout_type", "") == "figure":
|
|
||||||
mkdn += "\n".format(VLM.image2base64(b["image"]))
|
|
||||||
continue
|
|
||||||
mkdn += b.get("text", "") + "\n"
|
|
||||||
self.set_output("markdown", mkdn)
|
|
||||||
|
|
||||||
def _excel(self, blob):
|
|
||||||
self.callback(random.randint(1,5)/100., "Start to work on a Excel.")
|
|
||||||
conf = self._param.setups["excel"]
|
|
||||||
excel_parser = ExcelParser()
|
|
||||||
if conf.get("output_format") == "html":
|
|
||||||
html = excel_parser.html(blob,1000000000)
|
|
||||||
self.set_output("html", html)
|
|
||||||
elif conf.get("output_format") == "json":
|
|
||||||
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
|
|
||||||
elif conf.get("output_format") == "markdown":
|
|
||||||
self.set_output("markdown", excel_parser.markdown(blob))
|
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
|
||||||
function_map = {
|
|
||||||
"pdf": self._pdf,
|
|
||||||
}
|
|
||||||
for p_type, conf in self._param.setups.items():
|
|
||||||
if kwargs.get("name", "").split(".")[-1].lower() not in conf.get("suffix", []):
|
|
||||||
continue
|
|
||||||
await trio.to_thread.run_sync(function_map[p_type], kwargs["blob"])
|
|
||||||
break
|
|
||||||
14
rag/flow/parser/__init__.py
Normal file
14
rag/flow/parser/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
154
rag/flow/parser/parser.py
Normal file
154
rag/flow/parser/parser.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import random
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from api.db import LLMType
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from deepdoc.parser import ExcelParser
|
||||||
|
from deepdoc.parser.pdf_parser import PlainParser, RAGFlowPdfParser, VisionParser
|
||||||
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.flow.parser.schema import ParserFromUpstream
|
||||||
|
from rag.llm.cv_model import Base as VLM
|
||||||
|
|
||||||
|
|
||||||
|
class ParserParam(ProcessParamBase):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.allowed_output_format = {
|
||||||
|
"pdf": ["json", "markdown"],
|
||||||
|
"excel": ["json", "markdown", "html"],
|
||||||
|
"ppt": [],
|
||||||
|
"image": [],
|
||||||
|
"email": [],
|
||||||
|
"text": [],
|
||||||
|
"audio": [],
|
||||||
|
"video": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.setups = {
|
||||||
|
"pdf": {
|
||||||
|
"parse_method": "deepdoc", # deepdoc/plain_text/vlm
|
||||||
|
"vlm_name": "",
|
||||||
|
"lang": "Chinese",
|
||||||
|
"suffix": ["pdf"],
|
||||||
|
"output_format": "json",
|
||||||
|
},
|
||||||
|
"excel": {
|
||||||
|
"output_format": "html",
|
||||||
|
"suffix": ["xls", "xlsx", "csv"],
|
||||||
|
},
|
||||||
|
"ppt": {},
|
||||||
|
"image": {
|
||||||
|
"parse_method": "ocr",
|
||||||
|
},
|
||||||
|
"email": {},
|
||||||
|
"text": {},
|
||||||
|
"audio": {},
|
||||||
|
"video": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
pdf_config = self.setups.get("pdf", {})
|
||||||
|
if pdf_config:
|
||||||
|
pdf_parse_method = pdf_config.get("parse_method", "")
|
||||||
|
self.check_valid_value(pdf_parse_method.lower(), "Parse method abnormal.", ["deepdoc", "plain_text", "vlm"])
|
||||||
|
|
||||||
|
if pdf_parse_method not in ["deepdoc", "plain_text"]:
|
||||||
|
self.check_empty(pdf_config.get("vlm_name"), "VLM")
|
||||||
|
|
||||||
|
pdf_language = pdf_config.get("lang", "")
|
||||||
|
self.check_empty(pdf_language, "Language")
|
||||||
|
|
||||||
|
pdf_output_format = pdf_config.get("output_format", "")
|
||||||
|
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
|
||||||
|
|
||||||
|
excel_config = self.setups.get("excel", "")
|
||||||
|
if excel_config:
|
||||||
|
excel_output_format = excel_config.get("output_format", "")
|
||||||
|
self.check_valid_value(excel_output_format, "Excel output format abnormal.", self.allowed_output_format["excel"])
|
||||||
|
|
||||||
|
image_config = self.setups.get("image", "")
|
||||||
|
if image_config:
|
||||||
|
image_parse_method = image_config.get("parse_method", "")
|
||||||
|
self.check_valid_value(image_parse_method.lower(), "Parse method abnormal.", ["ocr"])
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class Parser(ProcessBase):
|
||||||
|
component_name = "Parser"
|
||||||
|
|
||||||
|
def _pdf(self, blob):
|
||||||
|
self.callback(random.randint(1, 5) / 100.0, "Start to work on a PDF.")
|
||||||
|
conf = self._param.setups["pdf"]
|
||||||
|
self.set_output("output_format", conf["output_format"])
|
||||||
|
if conf.get("parse_method") == "deepdoc":
|
||||||
|
bboxes = RAGFlowPdfParser().parse_into_bboxes(blob, callback=self.callback)
|
||||||
|
elif conf.get("parse_method") == "plain_text":
|
||||||
|
lines, _ = PlainParser()(blob)
|
||||||
|
bboxes = [{"text": t} for t, _ in lines]
|
||||||
|
else:
|
||||||
|
assert conf.get("vlm_name")
|
||||||
|
vision_model = LLMBundle(self._canvas._tenant_id, LLMType.IMAGE2TEXT, llm_name=conf.get("vlm_name"), lang=self._param.setups["pdf"].get("lang"))
|
||||||
|
lines, _ = VisionParser(vision_model=vision_model)(blob, callback=self.callback)
|
||||||
|
bboxes = []
|
||||||
|
for t, poss in lines:
|
||||||
|
pn, x0, x1, top, bott = poss.split(" ")
|
||||||
|
bboxes.append({"page_number": int(pn), "x0": float(x0), "x1": float(x1), "top": float(top), "bottom": float(bott), "text": t})
|
||||||
|
if conf.get("output_format") == "json":
|
||||||
|
self.set_output("json", bboxes)
|
||||||
|
if conf.get("output_format") == "markdown":
|
||||||
|
mkdn = ""
|
||||||
|
for b in bboxes:
|
||||||
|
if b.get("layout_type", "") == "title":
|
||||||
|
mkdn += "\n## "
|
||||||
|
if b.get("layout_type", "") == "figure":
|
||||||
|
mkdn += "\n".format(VLM.image2base64(b["image"]))
|
||||||
|
continue
|
||||||
|
mkdn += b.get("text", "") + "\n"
|
||||||
|
self.set_output("markdown", mkdn)
|
||||||
|
|
||||||
|
def _excel(self, blob):
|
||||||
|
self.callback(random.randint(1, 5) / 100.0, "Start to work on a Excel.")
|
||||||
|
conf = self._param.setups["excel"]
|
||||||
|
self.set_output("output_format", conf["output_format"])
|
||||||
|
excel_parser = ExcelParser()
|
||||||
|
if conf.get("output_format") == "html":
|
||||||
|
html = excel_parser.html(blob, 1000000000)
|
||||||
|
self.set_output("html", html)
|
||||||
|
elif conf.get("output_format") == "json":
|
||||||
|
self.set_output("json", [{"text": txt} for txt in excel_parser(blob) if txt])
|
||||||
|
elif conf.get("output_format") == "markdown":
|
||||||
|
self.set_output("markdown", excel_parser.markdown(blob))
|
||||||
|
|
||||||
|
async def _invoke(self, **kwargs):
|
||||||
|
function_map = {
|
||||||
|
"pdf": self._pdf,
|
||||||
|
"excel": self._excel,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
from_upstream = ParserFromUpstream.model_validate(kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for p_type, conf in self._param.setups.items():
|
||||||
|
if from_upstream.name.split(".")[-1].lower() not in conf.get("suffix", []):
|
||||||
|
continue
|
||||||
|
await trio.to_thread.run_sync(function_map[p_type], from_upstream.blob)
|
||||||
|
break
|
||||||
25
rag/flow/parser/schema.py
Normal file
25
rag/flow/parser/schema.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ParserFromUpstream(BaseModel):
|
||||||
|
created_time: float | None = Field(default=None, alias="_created_time")
|
||||||
|
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||||
|
|
||||||
|
name: str
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -18,14 +18,15 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from agent.canvas import Graph
|
from agent.canvas import Graph
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
|
|
||||||
|
|
||||||
class Pipeline(Graph):
|
class Pipeline(Graph):
|
||||||
|
|
||||||
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
def __init__(self, dsl: str, tenant_id=None, doc_id=None, task_id=None, flow_id=None):
|
||||||
super().__init__(dsl, tenant_id, task_id)
|
super().__init__(dsl, tenant_id, task_id)
|
||||||
self._doc_id = doc_id
|
self._doc_id = doc_id
|
||||||
@ -35,7 +36,7 @@ class Pipeline(Graph):
|
|||||||
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
self._kb_id = DocumentService.get_knowledgebase_id(doc_id)
|
||||||
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
|
assert self._kb_id, f"Can't find KB of this document: {doc_id}"
|
||||||
|
|
||||||
def callback(self, component_name: str, progress: float|int|None=None, message: str = "") -> None:
|
def callback(self, component_name: str, progress: float | int | None = None, message: str = "") -> None:
|
||||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||||
try:
|
try:
|
||||||
bin = REDIS_CONN.get(log_key)
|
bin = REDIS_CONN.get(log_key)
|
||||||
@ -44,16 +45,10 @@ class Pipeline(Graph):
|
|||||||
if obj[-1]["component_name"] == component_name:
|
if obj[-1]["component_name"] == component_name:
|
||||||
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
|
obj[-1]["trace"].append({"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")})
|
||||||
else:
|
else:
|
||||||
obj.append({
|
obj.append({"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]})
|
||||||
"component_name": component_name,
|
|
||||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
obj = [{
|
obj = [{"component_name": component_name, "trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]}]
|
||||||
"component_name": component_name,
|
REDIS_CONN.set_obj(log_key, obj, 60 * 10)
|
||||||
"trace": [{"progress": progress, "message": message, "datetime": datetime.datetime.now().strftime("%H:%M:%S")}]
|
|
||||||
}]
|
|
||||||
REDIS_CONN.set_obj(log_key, obj, 60*10)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
|
||||||
@ -71,21 +66,19 @@ class Pipeline(Graph):
|
|||||||
super().reset()
|
super().reset()
|
||||||
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
log_key = f"{self._flow_id}-{self.task_id}-logs"
|
||||||
try:
|
try:
|
||||||
REDIS_CONN.set_obj(log_key, [], 60*10)
|
REDIS_CONN.set_obj(log_key, [], 60 * 10)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
|
||||||
async def run(self, **kwargs):
|
async def run(self, **kwargs):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
if not self.path:
|
if not self.path:
|
||||||
self.path.append("begin")
|
self.path.append("File")
|
||||||
|
|
||||||
if self._doc_id:
|
if self._doc_id:
|
||||||
DocumentService.update_by_id(self._doc_id, {
|
DocumentService.update_by_id(
|
||||||
"progress": random.randint(0,5)/100.,
|
self._doc_id, {"progress": random.randint(0, 5) / 100.0, "progress_msg": "Start the pipeline...", "process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||||
"progress_msg": "Start the pipeline...",
|
)
|
||||||
"process_begin_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
})
|
|
||||||
|
|
||||||
self.error = ""
|
self.error = ""
|
||||||
idx = len(self.path) - 1
|
idx = len(self.path) - 1
|
||||||
@ -99,23 +92,21 @@ class Pipeline(Graph):
|
|||||||
self.path.extend(cpn_obj.get_downstream())
|
self.path.extend(cpn_obj.get_downstream())
|
||||||
|
|
||||||
while idx < len(self.path) and not self.error:
|
while idx < len(self.path) and not self.error:
|
||||||
last_cpn = self.get_component_obj(self.path[idx-1])
|
last_cpn = self.get_component_obj(self.path[idx - 1])
|
||||||
cpn_obj = self.get_component_obj(self.path[idx])
|
cpn_obj = self.get_component_obj(self.path[idx])
|
||||||
|
|
||||||
async def invoke():
|
async def invoke():
|
||||||
nonlocal last_cpn, cpn_obj
|
nonlocal last_cpn, cpn_obj
|
||||||
await cpn_obj.invoke(**last_cpn.output())
|
await cpn_obj.invoke(**last_cpn.output())
|
||||||
|
|
||||||
async with trio.open_nursery() as nursery:
|
async with trio.open_nursery() as nursery:
|
||||||
nursery.start_soon(invoke)
|
nursery.start_soon(invoke)
|
||||||
if cpn_obj.error():
|
if cpn_obj.error():
|
||||||
self.error = "[ERROR]" + cpn_obj.error()
|
self.error = "[ERROR]" + cpn_obj.error()
|
||||||
|
self.callback(cpn_obj.component_name, -1, self.error)
|
||||||
break
|
break
|
||||||
idx += 1
|
idx += 1
|
||||||
self.path.extend(cpn_obj.get_downstream())
|
self.path.extend(cpn_obj.get_downstream())
|
||||||
|
|
||||||
if self._doc_id:
|
if self._doc_id:
|
||||||
DocumentService.update_by_id(self._doc_id, {
|
DocumentService.update_by_id(self._doc_id, {"progress": 1 if not self.error else -1, "progress_msg": "Pipeline finished...\n" + self.error, "process_duration": time.perf_counter() - st})
|
||||||
"progress": 1 if not self.error else -1,
|
|
||||||
"progress_msg": "Pipeline finished...\n" + self.error,
|
|
||||||
"process_duration": time.perf_counter() - st
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -18,12 +18,14 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from rag.flow.pipeline import Pipeline
|
from rag.flow.pipeline import Pipeline
|
||||||
|
|
||||||
|
|
||||||
def print_logs(pipeline):
|
def print_logs(pipeline: Pipeline):
|
||||||
last_logs = "[]"
|
last_logs = "[]"
|
||||||
while True:
|
while True:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
@ -34,16 +36,16 @@ def print_logs(pipeline):
|
|||||||
last_logs = logs_str
|
last_logs = logs_str
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
dsl_default_path = os.path.join(
|
dsl_default_path = os.path.join(
|
||||||
os.path.dirname(os.path.realpath(__file__)),
|
os.path.dirname(os.path.realpath(__file__)),
|
||||||
"dsl_examples",
|
"dsl_examples",
|
||||||
"general_pdf_all.json",
|
"general_pdf_all.json",
|
||||||
)
|
)
|
||||||
parser.add_argument('-s', '--dsl', default=dsl_default_path, help="input dsl", action='store', required=True)
|
parser.add_argument("-s", "--dsl", default=dsl_default_path, help="input dsl", action="store", required=False)
|
||||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
parser.add_argument("-d", "--doc_id", default=False, help="Document ID", action="store", required=True)
|
||||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
parser.add_argument("-t", "--tenant_id", default=False, help="Tenant ID", action="store", required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
@ -53,5 +55,7 @@ if __name__ == '__main__':
|
|||||||
exe = ThreadPoolExecutor(max_workers=5)
|
exe = ThreadPoolExecutor(max_workers=5)
|
||||||
thr = exe.submit(print_logs, pipeline)
|
thr = exe.submit(print_logs, pipeline)
|
||||||
|
|
||||||
|
# queue_dataflow(dsl=open(args.dsl, "r").read(), tenant_id=args.tenant_id, doc_id=args.doc_id, task_id="xxxx", flow_id="xxx", priority=0)
|
||||||
|
|
||||||
trio.run(pipeline.run)
|
trio.run(pipeline.run)
|
||||||
thr.result()
|
thr.result()
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
{
|
{
|
||||||
"components": {
|
"components": {
|
||||||
"begin": {
|
"File": {
|
||||||
"obj":{
|
"obj":{
|
||||||
"component_name": "File",
|
"component_name": "File",
|
||||||
"params": {
|
"params": {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"downstream": ["parser:0"],
|
"downstream": ["Parser:0"],
|
||||||
"upstream": []
|
"upstream": []
|
||||||
},
|
},
|
||||||
"parser:0": {
|
"Parser:0": {
|
||||||
"obj": {
|
"obj": {
|
||||||
"component_name": "Parser",
|
"component_name": "Parser",
|
||||||
"params": {
|
"params": {
|
||||||
@ -22,14 +22,22 @@
|
|||||||
"pdf"
|
"pdf"
|
||||||
],
|
],
|
||||||
"output_format": "json"
|
"output_format": "json"
|
||||||
|
},
|
||||||
|
"excel": {
|
||||||
|
"output_format": "html",
|
||||||
|
"suffix": [
|
||||||
|
"xls",
|
||||||
|
"xlsx",
|
||||||
|
"csv"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"downstream": ["chunker:0"],
|
"downstream": ["Chunker:0"],
|
||||||
"upstream": ["begin"]
|
"upstream": ["Begin"]
|
||||||
},
|
},
|
||||||
"chunker:0": {
|
"Chunker:0": {
|
||||||
"obj": {
|
"obj": {
|
||||||
"component_name": "Chunker",
|
"component_name": "Chunker",
|
||||||
"params": {
|
"params": {
|
||||||
@ -37,18 +45,19 @@
|
|||||||
"auto_keywords": 5
|
"auto_keywords": 5
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"downstream": ["tokenizer:0"],
|
"downstream": ["Tokenizer:0"],
|
||||||
"upstream": ["chunker:0"]
|
"upstream": ["Parser:0"]
|
||||||
},
|
},
|
||||||
"tokenizer:0": {
|
"Tokenizer:0": {
|
||||||
"obj": {
|
"obj": {
|
||||||
"component_name": "Tokenizer",
|
"component_name": "Tokenizer",
|
||||||
"params": {
|
"params": {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"downstream": [],
|
"downstream": [],
|
||||||
"upstream": ["chunker:0"]
|
"upstream": ["Chunker:0"]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"path": []
|
"path": []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
14
rag/flow/tokenizer/__init__.py
Normal file
14
rag/flow/tokenizer/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
51
rag/flow/tokenizer/schema.py
Normal file
51
rag/flow/tokenizer/schema.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerFromUpstream(BaseModel):
|
||||||
|
created_time: float | None = Field(default=None, alias="_created_time")
|
||||||
|
elapsed_time: float | None = Field(default=None, alias="_elapsed_time")
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||||
|
|
||||||
|
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||||
|
|
||||||
|
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||||
|
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||||
|
text_result: str | None = Field(default=None, alias="text")
|
||||||
|
html_result: str | None = Field(default=None, alias="html")
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _check_payloads(self) -> "TokenizerFromUpstream":
|
||||||
|
if self.chunks:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self.output_format in {"markdown", "text"}:
|
||||||
|
if self.output_format == "markdown" and not self.markdown_result:
|
||||||
|
raise ValueError("output_format=markdown requires a markdown payload (field: 'markdown' or 'markdown_result').")
|
||||||
|
if self.output_format == "text" and not self.text_result:
|
||||||
|
raise ValueError("output_format=text requires a text payload (field: 'text' or 'text_result').")
|
||||||
|
else:
|
||||||
|
if not self.json_result:
|
||||||
|
raise ValueError("When no chunks are provided and output_format is not markdown/text, a JSON list payload is required (field: 'json' or 'json_result').")
|
||||||
|
return self
|
||||||
@ -1,5 +1,5 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -24,6 +25,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||||
|
from rag.flow.tokenizer.schema import TokenizerFromUpstream
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
from rag.settings import EMBEDDING_BATCH_SIZE
|
from rag.settings import EMBEDDING_BATCH_SIZE
|
||||||
from rag.svr.task_executor import embed_limiter
|
from rag.svr.task_executor import embed_limiter
|
||||||
@ -40,6 +42,9 @@ class TokenizerParam(ProcessParamBase):
|
|||||||
for v in self.search_method:
|
for v in self.search_method:
|
||||||
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
|
self.check_valid_value(v.lower(), "Chunk method abnormal.", ["full_text", "embedding"])
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer(ProcessBase):
|
class Tokenizer(ProcessBase):
|
||||||
component_name = "Tokenizer"
|
component_name = "Tokenizer"
|
||||||
@ -67,19 +72,19 @@ class Tokenizer(ProcessBase):
|
|||||||
@timeout(60)
|
@timeout(60)
|
||||||
def batch_encode(txts):
|
def batch_encode(txts):
|
||||||
nonlocal embedding_model
|
nonlocal embedding_model
|
||||||
return embedding_model.encode([truncate(c, embedding_model.max_length-10) for c in txts])
|
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||||
|
|
||||||
cnts_ = np.array([])
|
cnts_ = np.array([])
|
||||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||||
async with embed_limiter:
|
async with embed_limiter:
|
||||||
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i: i + EMBEDDING_BATCH_SIZE]))
|
vts, c = await trio.to_thread.run_sync(lambda: batch_encode(texts[i : i + EMBEDDING_BATCH_SIZE]))
|
||||||
if len(cnts_) == 0:
|
if len(cnts_) == 0:
|
||||||
cnts_ = vts
|
cnts_ = vts
|
||||||
else:
|
else:
|
||||||
cnts_ = np.concatenate((cnts_, vts), axis=0)
|
cnts_ = np.concatenate((cnts_, vts), axis=0)
|
||||||
token_count += c
|
token_count += c
|
||||||
if i % 33 == 32:
|
if i % 33 == 32:
|
||||||
self.callback(i*1./len(texts)/parts/EMBEDDING_BATCH_SIZE + 0.5*(parts-1))
|
self.callback(i * 1.0 / len(texts) / parts / EMBEDDING_BATCH_SIZE + 0.5 * (parts - 1))
|
||||||
|
|
||||||
cnts = cnts_
|
cnts = cnts_
|
||||||
title_w = float(self._param.filename_embd_weight)
|
title_w = float(self._param.filename_embd_weight)
|
||||||
@ -92,11 +97,17 @@ class Tokenizer(ProcessBase):
|
|||||||
return chunks, token_count
|
return chunks, token_count
|
||||||
|
|
||||||
async def _invoke(self, **kwargs):
|
async def _invoke(self, **kwargs):
|
||||||
|
try:
|
||||||
|
from_upstream = TokenizerFromUpstream.model_validate(kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||||
if "full_text" in self._param.search_method:
|
if "full_text" in self._param.search_method:
|
||||||
self.callback(random.randint(1,5)/100., "Start to tokenize.")
|
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
|
||||||
if kwargs.get("chunks"):
|
if from_upstream.chunks:
|
||||||
chunks = kwargs["chunks"]
|
chunks = from_upstream.chunks
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
if ck.get("questions"):
|
if ck.get("questions"):
|
||||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||||
@ -105,30 +116,40 @@ class Tokenizer(ProcessBase):
|
|||||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
if i % 100 == 99:
|
if i % 100 == 99:
|
||||||
self.callback(i*1./len(chunks)/parts)
|
self.callback(i * 1.0 / len(chunks) / parts)
|
||||||
elif kwargs.get("output_format") in ["markdown", "text"]:
|
elif from_upstream.output_format in ["markdown", "text"]:
|
||||||
ck = {
|
if from_upstream.output_format == "markdown":
|
||||||
"text": kwargs.get(kwargs["output_format"], "")
|
payload = from_upstream.markdown_result
|
||||||
}
|
else: # == "text"
|
||||||
if "full_text" in self._param.search_method:
|
payload = from_upstream.text_result
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
ck = {"text": payload}
|
||||||
|
if "full_text" in self._param.search_method:
|
||||||
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
|
ck["content_ltks"] = rag_tokenizer.tokenize(kwargs.get(kwargs["output_format"], ""))
|
||||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
chunks = [ck]
|
chunks = [ck]
|
||||||
else:
|
else:
|
||||||
chunks = kwargs["json"]
|
chunks = from_upstream.json_result
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||||
if i % 100 == 99:
|
if i % 100 == 99:
|
||||||
self.callback(i*1./len(chunks)/parts)
|
self.callback(i * 1.0 / len(chunks) / parts)
|
||||||
|
|
||||||
self.callback(1./parts, "Finish tokenizing.")
|
self.callback(1.0 / parts, "Finish tokenizing.")
|
||||||
|
|
||||||
if "embedding" in self._param.search_method:
|
if "embedding" in self._param.search_method:
|
||||||
self.callback(random.randint(1,5)/100. + 0.5*(parts-1), "Start embedding inference.")
|
self.callback(random.randint(1, 5) / 100.0 + 0.5 * (parts - 1), "Start embedding inference.")
|
||||||
chunks, token_count = await self._embedding(kwargs.get("name", ""), chunks)
|
|
||||||
|
if from_upstream.name.strip() == "":
|
||||||
|
logging.warning("Tokenizer: empty name provided from upstream, embedding may be not accurate.")
|
||||||
|
|
||||||
|
chunks, token_count = await self._embedding(from_upstream.name, chunks)
|
||||||
self.set_output("embedding_token_consumption", token_count)
|
self.set_output("embedding_token_consumption", token_count)
|
||||||
|
|
||||||
self.callback(1., "Finish embedding.")
|
self.callback(1.0, "Finish embedding.")
|
||||||
|
|
||||||
self.set_output("chunks", chunks)
|
self.set_output("chunks", chunks)
|
||||||
@ -374,7 +374,7 @@ class Base(ABC):
|
|||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens = tol
|
||||||
|
|
||||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
||||||
if finish_reason == "length":
|
if finish_reason == "length":
|
||||||
@ -410,7 +410,7 @@ class Base(ABC):
|
|||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens = tol
|
||||||
answer += resp.choices[0].delta.content
|
answer += resp.choices[0].delta.content
|
||||||
yield resp.choices[0].delta.content
|
yield resp.choices[0].delta.content
|
||||||
|
|
||||||
|
|||||||
@ -21,10 +21,12 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
from api.utils.log_utils import init_root_logger, get_project_base_directory
|
||||||
from graphrag.general.index import run_graphrag
|
from graphrag.general.index import run_graphrag
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
||||||
|
from rag.flow.pipeline import Pipeline
|
||||||
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
from rag.prompts import keyword_extraction, question_proposal, content_tagging
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -223,7 +225,14 @@ async def collect():
|
|||||||
logging.warning(f"collect task {msg['id']} {state}")
|
logging.warning(f"collect task {msg['id']} {state}")
|
||||||
redis_msg.ack()
|
redis_msg.ack()
|
||||||
return None, None
|
return None, None
|
||||||
task["task_type"] = msg.get("task_type", "")
|
|
||||||
|
task_type = msg.get("task_type", "")
|
||||||
|
task["task_type"] = task_type
|
||||||
|
if task_type == "dataflow":
|
||||||
|
task["tenant_id"]=msg.get("tenant_id", "")
|
||||||
|
task["dsl"] = msg.get("dsl", "")
|
||||||
|
task["dataflow_id"] = msg.get("dataflow_id", get_uuid())
|
||||||
|
task["kb_id"] = msg.get("kb_id", "")
|
||||||
return redis_msg, task
|
return redis_msg, task
|
||||||
|
|
||||||
|
|
||||||
@ -473,6 +482,15 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
|||||||
return tk_count, vector_size
|
return tk_count, vector_size
|
||||||
|
|
||||||
|
|
||||||
|
async def run_dataflow(dsl:str, tenant_id:str, doc_id:str, task_id:str, flow_id:str, callback=None):
|
||||||
|
_ = callback
|
||||||
|
|
||||||
|
pipeline = Pipeline(dsl=dsl, tenant_id=tenant_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id)
|
||||||
|
pipeline.reset()
|
||||||
|
|
||||||
|
await pipeline.run()
|
||||||
|
|
||||||
|
|
||||||
@timeout(3600)
|
@timeout(3600)
|
||||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||||
chunks = []
|
chunks = []
|
||||||
@ -558,15 +576,20 @@ async def do_handle_task(task):
|
|||||||
|
|
||||||
init_kb(task, vector_size)
|
init_kb(task, vector_size)
|
||||||
|
|
||||||
# Either using RAPTOR or Standard chunking methods
|
task_type = task.get("task_type", "")
|
||||||
if task.get("task_type", "") == "raptor":
|
if task_type == "dataflow":
|
||||||
|
task_dataflow_dsl = task["dsl"]
|
||||||
|
task_dataflow_id = task["dataflow_id"]
|
||||||
|
await run_dataflow(dsl=task_dataflow_dsl, tenant_id=task_tenant_id, doc_id=task_doc_id, task_id=task_id, flow_id=task_dataflow_id, callback=None)
|
||||||
|
return
|
||||||
|
elif task_type == "raptor":
|
||||||
# bind LLM for raptor
|
# bind LLM for raptor
|
||||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||||
# run RAPTOR
|
# run RAPTOR
|
||||||
async with kg_limiter:
|
async with kg_limiter:
|
||||||
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||||
# Either using graphrag or Standard chunking methods
|
# Either using graphrag or Standard chunking methods
|
||||||
elif task.get("task_type", "") == "graphrag":
|
elif task_type == "graphrag":
|
||||||
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
if not task_parser_config.get("graphrag", {}).get("use_graphrag", False):
|
||||||
progress_callback(prog=-1.0, msg="Internal configuration error.")
|
progress_callback(prog=-1.0, msg="Internal configuration error.")
|
||||||
return
|
return
|
||||||
|
|||||||
@ -1,28 +1,34 @@
|
|||||||
import { toast } from 'sonner';
|
import { toast } from 'sonner';
|
||||||
|
|
||||||
|
const duration = { duration: 1500 };
|
||||||
|
|
||||||
const message = {
|
const message = {
|
||||||
success: (msg: string) => {
|
success: (msg: string) => {
|
||||||
toast.success(msg, {
|
toast.success(msg, {
|
||||||
position: 'top-center',
|
position: 'top-center',
|
||||||
closeButton: false,
|
closeButton: false,
|
||||||
|
...duration,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
error: (msg: string) => {
|
error: (msg: string) => {
|
||||||
toast.error(msg, {
|
toast.error(msg, {
|
||||||
position: 'top-center',
|
position: 'top-center',
|
||||||
closeButton: false,
|
closeButton: false,
|
||||||
|
...duration,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
warning: (msg: string) => {
|
warning: (msg: string) => {
|
||||||
toast.warning(msg, {
|
toast.warning(msg, {
|
||||||
position: 'top-center',
|
position: 'top-center',
|
||||||
closeButton: false,
|
closeButton: false,
|
||||||
|
...duration,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
info: (msg: string) => {
|
info: (msg: string) => {
|
||||||
toast.info(msg, {
|
toast.info(msg, {
|
||||||
position: 'top-center',
|
position: 'top-center',
|
||||||
closeButton: false,
|
closeButton: false,
|
||||||
|
...duration,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
|
import message from '@/components/ui/message';
|
||||||
import { ResponseType } from '@/interfaces/database/base';
|
import { ResponseType } from '@/interfaces/database/base';
|
||||||
import { IFolder } from '@/interfaces/database/file-manager';
|
import { IFolder } from '@/interfaces/database/file-manager';
|
||||||
import { IConnectRequestBody } from '@/interfaces/request/file-manager';
|
import { IConnectRequestBody } from '@/interfaces/request/file-manager';
|
||||||
import fileManagerService from '@/services/file-manager-service';
|
import fileManagerService from '@/services/file-manager-service';
|
||||||
import { downloadFileFromBlob } from '@/utils/file-util';
|
import { downloadFileFromBlob } from '@/utils/file-util';
|
||||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||||
import { PaginationProps, UploadFile, message } from 'antd';
|
import { PaginationProps, UploadFile } from 'antd';
|
||||||
import React, { useCallback } from 'react';
|
import React, { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useSearchParams } from 'umi';
|
import { useSearchParams } from 'umi';
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import message from '@/components/ui/message';
|
||||||
import { Authorization } from '@/constants/authorization';
|
import { Authorization } from '@/constants/authorization';
|
||||||
import userService, {
|
import userService, {
|
||||||
getLoginChannels,
|
getLoginChannels,
|
||||||
@ -5,7 +6,7 @@ import userService, {
|
|||||||
} from '@/services/user-service';
|
} from '@/services/user-service';
|
||||||
import authorizationUtil, { redirectToLogin } from '@/utils/authorization-util';
|
import authorizationUtil, { redirectToLogin } from '@/utils/authorization-util';
|
||||||
import { useMutation, useQuery } from '@tanstack/react-query';
|
import { useMutation, useQuery } from '@tanstack/react-query';
|
||||||
import { Form, message } from 'antd';
|
import { Form } from 'antd';
|
||||||
import { FormInstance } from 'antd/lib';
|
import { FormInstance } from 'antd/lib';
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -50,8 +51,6 @@ export const useLoginWithChannel = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const useLogin = () => {
|
export const useLogin = () => {
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
data,
|
data,
|
||||||
isPending: loading,
|
isPending: loading,
|
||||||
@ -62,7 +61,6 @@ export const useLogin = () => {
|
|||||||
const { data: res = {}, response } = await userService.login(params);
|
const { data: res = {}, response } = await userService.login(params);
|
||||||
if (res.code === 0) {
|
if (res.code === 0) {
|
||||||
const { data } = res;
|
const { data } = res;
|
||||||
message.success(t('message.logged'));
|
|
||||||
const authorization = response.headers.get(Authorization);
|
const authorization = response.headers.get(Authorization);
|
||||||
const token = data.access_token;
|
const token = data.access_token;
|
||||||
const userInfo = {
|
const userInfo = {
|
||||||
|
|||||||
@ -51,6 +51,7 @@ export const enum AgentApiAction {
|
|||||||
FetchAgentAvatar = 'fetchAgentAvatar',
|
FetchAgentAvatar = 'fetchAgentAvatar',
|
||||||
FetchExternalAgentInputs = 'fetchExternalAgentInputs',
|
FetchExternalAgentInputs = 'fetchExternalAgentInputs',
|
||||||
SetAgentSetting = 'setAgentSetting',
|
SetAgentSetting = 'setAgentSetting',
|
||||||
|
FetchPrompt = 'fetchPrompt',
|
||||||
}
|
}
|
||||||
|
|
||||||
export const EmptyDsl = {
|
export const EmptyDsl = {
|
||||||
@ -637,3 +638,24 @@ export const useSetAgentSetting = () => {
|
|||||||
|
|
||||||
return { data, loading, setAgentSetting: mutateAsync };
|
return { data, loading, setAgentSetting: mutateAsync };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const useFetchPrompt = () => {
|
||||||
|
const {
|
||||||
|
data,
|
||||||
|
isFetching: loading,
|
||||||
|
refetch,
|
||||||
|
} = useQuery<Record<string, string>>({
|
||||||
|
queryKey: [AgentApiAction.FetchPrompt],
|
||||||
|
refetchOnReconnect: false,
|
||||||
|
refetchOnMount: false,
|
||||||
|
refetchOnWindowFocus: false,
|
||||||
|
gcTime: 0,
|
||||||
|
queryFn: async () => {
|
||||||
|
const { data } = await agentService.fetchPrompt();
|
||||||
|
|
||||||
|
return data?.data ?? {};
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return { data, loading, refetch };
|
||||||
|
};
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import { useHandleFilterSubmit } from '@/components/list-filter-bar/use-handle-filter-submit';
|
import { useHandleFilterSubmit } from '@/components/list-filter-bar/use-handle-filter-submit';
|
||||||
|
import message from '@/components/ui/message';
|
||||||
import { ResponseType } from '@/interfaces/database/base';
|
import { ResponseType } from '@/interfaces/database/base';
|
||||||
import {
|
import {
|
||||||
IDocumentInfo,
|
IDocumentInfo,
|
||||||
@ -12,7 +13,6 @@ import i18n from '@/locales/config';
|
|||||||
import kbService, { listDocument } from '@/services/knowledge-service';
|
import kbService, { listDocument } from '@/services/knowledge-service';
|
||||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||||
import { useDebounce } from 'ahooks';
|
import { useDebounce } from 'ahooks';
|
||||||
import { message } from 'antd';
|
|
||||||
import { get } from 'lodash';
|
import { get } from 'lodash';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
import { useParams } from 'umi';
|
import { useParams } from 'umi';
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import message from '@/components/ui/message';
|
||||||
import {
|
import {
|
||||||
IFetchFileListResult,
|
IFetchFileListResult,
|
||||||
IFolder,
|
IFolder,
|
||||||
@ -5,7 +6,7 @@ import {
|
|||||||
import fileManagerService from '@/services/file-manager-service';
|
import fileManagerService from '@/services/file-manager-service';
|
||||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||||
import { useDebounce } from 'ahooks';
|
import { useDebounce } from 'ahooks';
|
||||||
import { PaginationProps, message } from 'antd';
|
import { PaginationProps } from 'antd';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useSearchParams } from 'umi';
|
import { useSearchParams } from 'umi';
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import message from '@/components/ui/message';
|
||||||
import { LanguageTranslationMap } from '@/constants/common';
|
import { LanguageTranslationMap } from '@/constants/common';
|
||||||
import { ResponseGetType } from '@/interfaces/database/base';
|
import { ResponseGetType } from '@/interfaces/database/base';
|
||||||
import { IToken } from '@/interfaces/database/chat';
|
import { IToken } from '@/interfaces/database/chat';
|
||||||
@ -18,7 +19,7 @@ import userService, {
|
|||||||
listTenantUser,
|
listTenantUser,
|
||||||
} from '@/services/user-service';
|
} from '@/services/user-service';
|
||||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||||
import { Modal, message } from 'antd';
|
import { Modal } from 'antd';
|
||||||
import DOMPurify from 'dompurify';
|
import DOMPurify from 'dompurify';
|
||||||
import { isEmpty } from 'lodash';
|
import { isEmpty } from 'lodash';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
|
|||||||
@ -1518,6 +1518,7 @@ This delimiter is used to split the input text into several text pieces echo of
|
|||||||
sqlStatement: 'SQL Statement',
|
sqlStatement: 'SQL Statement',
|
||||||
sqlStatementTip:
|
sqlStatementTip:
|
||||||
'Write your SQL query here. You can use variables, raw SQL, or mix both using variable syntax.',
|
'Write your SQL query here. You can use variables, raw SQL, or mix both using variable syntax.',
|
||||||
|
frameworkPrompts: 'Framework Prompts',
|
||||||
},
|
},
|
||||||
llmTools: {
|
llmTools: {
|
||||||
bad_calculator: {
|
bad_calculator: {
|
||||||
|
|||||||
@ -1433,6 +1433,7 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
|||||||
sqlStatement: 'SQL 语句',
|
sqlStatement: 'SQL 语句',
|
||||||
sqlStatementTip:
|
sqlStatementTip:
|
||||||
'在此处编写您的 SQL 查询。您可以使用变量、原始 SQL,或使用变量语法混合使用两者。',
|
'在此处编写您的 SQL 查询。您可以使用变量、原始 SQL,或使用变量语法混合使用两者。',
|
||||||
|
frameworkPrompts: '框架提示词',
|
||||||
},
|
},
|
||||||
footer: {
|
footer: {
|
||||||
profile: 'All rights reserved @ React',
|
profile: 'All rights reserved @ React',
|
||||||
|
|||||||
@ -65,7 +65,7 @@ const FormSheet = ({
|
|||||||
return (
|
return (
|
||||||
<Sheet open={visible} modal={false}>
|
<Sheet open={visible} modal={false}>
|
||||||
<SheetContent
|
<SheetContent
|
||||||
className={cn('top-20 p-0 flex flex-col pb-20 ', {
|
className={cn('top-20 p-0 flex flex-col pb-20', {
|
||||||
'right-[620px]': chatVisible,
|
'right-[620px]': chatVisible,
|
||||||
})}
|
})}
|
||||||
closeIcon={false}
|
closeIcon={false}
|
||||||
|
|||||||
@ -39,6 +39,7 @@ import { Output } from '../components/output';
|
|||||||
import { PromptEditor } from '../components/prompt-editor';
|
import { PromptEditor } from '../components/prompt-editor';
|
||||||
import { QueryVariable } from '../components/query-variable';
|
import { QueryVariable } from '../components/query-variable';
|
||||||
import { AgentTools, Agents } from './agent-tools';
|
import { AgentTools, Agents } from './agent-tools';
|
||||||
|
import { useBuildPromptExtraPromptOptions } from './use-build-prompt-options';
|
||||||
import { useValues } from './use-values';
|
import { useValues } from './use-values';
|
||||||
import { useWatchFormChange } from './use-watch-change';
|
import { useWatchFormChange } from './use-watch-change';
|
||||||
|
|
||||||
@ -85,6 +86,8 @@ function AgentForm({ node }: INextOperatorForm) {
|
|||||||
|
|
||||||
const defaultValues = useValues(node);
|
const defaultValues = useValues(node);
|
||||||
|
|
||||||
|
const { extraOptions } = useBuildPromptExtraPromptOptions();
|
||||||
|
|
||||||
const ExceptionMethodOptions = Object.values(AgentExceptionMethod).map(
|
const ExceptionMethodOptions = Object.values(AgentExceptionMethod).map(
|
||||||
(x) => ({
|
(x) => ({
|
||||||
label: t(`flow.${x}`),
|
label: t(`flow.${x}`),
|
||||||
@ -150,6 +153,7 @@ function AgentForm({ node }: INextOperatorForm) {
|
|||||||
{...field}
|
{...field}
|
||||||
placeholder={t('flow.messagePlaceholder')}
|
placeholder={t('flow.messagePlaceholder')}
|
||||||
showToolbar={false}
|
showToolbar={false}
|
||||||
|
extraOptions={extraOptions}
|
||||||
></PromptEditor>
|
></PromptEditor>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</FormItem>
|
</FormItem>
|
||||||
|
|||||||
@ -0,0 +1,30 @@
|
|||||||
|
import { useFetchPrompt } from '@/hooks/use-agent-request';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
export const PromptIdentity = 'RAGFlow-Prompt';
|
||||||
|
|
||||||
|
function wrapPromptWithTag(text: string, tag: string) {
|
||||||
|
const capitalTag = tag.toUpperCase();
|
||||||
|
return `<${capitalTag}>
|
||||||
|
${text}
|
||||||
|
</${capitalTag}>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useBuildPromptExtraPromptOptions() {
|
||||||
|
const { data: prompts } = useFetchPrompt();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const options = useMemo(() => {
|
||||||
|
return Object.entries(prompts || {}).map(([key, value]) => ({
|
||||||
|
label: key,
|
||||||
|
value: wrapPromptWithTag(value, key),
|
||||||
|
}));
|
||||||
|
}, [prompts]);
|
||||||
|
|
||||||
|
const extraOptions = [
|
||||||
|
{ label: PromptIdentity, title: t('flow.frameworkPrompts'), options },
|
||||||
|
];
|
||||||
|
|
||||||
|
return { extraOptions };
|
||||||
|
}
|
||||||
@ -29,7 +29,9 @@ import { PasteHandlerPlugin } from './paste-handler-plugin';
|
|||||||
import theme from './theme';
|
import theme from './theme';
|
||||||
import { VariableNode } from './variable-node';
|
import { VariableNode } from './variable-node';
|
||||||
import { VariableOnChangePlugin } from './variable-on-change-plugin';
|
import { VariableOnChangePlugin } from './variable-on-change-plugin';
|
||||||
import VariablePickerMenuPlugin from './variable-picker-plugin';
|
import VariablePickerMenuPlugin, {
|
||||||
|
VariablePickerMenuPluginProps,
|
||||||
|
} from './variable-picker-plugin';
|
||||||
|
|
||||||
// Catch any errors that occur during Lexical updates and log them
|
// Catch any errors that occur during Lexical updates and log them
|
||||||
// or throw them as needed. If you don't throw them, Lexical will
|
// or throw them as needed. If you don't throw them, Lexical will
|
||||||
@ -52,7 +54,8 @@ type IProps = {
|
|||||||
value?: string;
|
value?: string;
|
||||||
onChange?: (value?: string) => void;
|
onChange?: (value?: string) => void;
|
||||||
placeholder?: ReactNode;
|
placeholder?: ReactNode;
|
||||||
} & PromptContentProps;
|
} & PromptContentProps &
|
||||||
|
Pick<VariablePickerMenuPluginProps, 'extraOptions'>;
|
||||||
|
|
||||||
function PromptContent({
|
function PromptContent({
|
||||||
showToolbar = true,
|
showToolbar = true,
|
||||||
@ -122,6 +125,7 @@ export function PromptEditor({
|
|||||||
placeholder,
|
placeholder,
|
||||||
showToolbar,
|
showToolbar,
|
||||||
multiLine = true,
|
multiLine = true,
|
||||||
|
extraOptions,
|
||||||
}: IProps) {
|
}: IProps) {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const initialConfig: InitialConfigType = {
|
const initialConfig: InitialConfigType = {
|
||||||
@ -170,7 +174,10 @@ export function PromptEditor({
|
|||||||
}
|
}
|
||||||
ErrorBoundary={LexicalErrorBoundary}
|
ErrorBoundary={LexicalErrorBoundary}
|
||||||
/>
|
/>
|
||||||
<VariablePickerMenuPlugin value={value}></VariablePickerMenuPlugin>
|
<VariablePickerMenuPlugin
|
||||||
|
value={value}
|
||||||
|
extraOptions={extraOptions}
|
||||||
|
></VariablePickerMenuPlugin>
|
||||||
<PasteHandlerPlugin />
|
<PasteHandlerPlugin />
|
||||||
<VariableOnChangePlugin
|
<VariableOnChangePlugin
|
||||||
onChange={onValueChange}
|
onChange={onValueChange}
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import { BeginId } from '@/pages/flow/constant';
|
|
||||||
import { DecoratorNode, LexicalNode, NodeKey } from 'lexical';
|
import { DecoratorNode, LexicalNode, NodeKey } from 'lexical';
|
||||||
import { ReactNode } from 'react';
|
import { ReactNode } from 'react';
|
||||||
const prefix = BeginId + '@';
|
|
||||||
|
|
||||||
export class VariableNode extends DecoratorNode<ReactNode> {
|
export class VariableNode extends DecoratorNode<ReactNode> {
|
||||||
__value: string;
|
__value: string;
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import { EditorState, LexicalEditor } from 'lexical';
|
|||||||
import { useEffect } from 'react';
|
import { useEffect } from 'react';
|
||||||
import { ProgrammaticTag } from './constant';
|
import { ProgrammaticTag } from './constant';
|
||||||
|
|
||||||
interface IProps {
|
interface VariableOnChangePluginProps {
|
||||||
onChange: (
|
onChange: (
|
||||||
editorState: EditorState,
|
editorState: EditorState,
|
||||||
editor?: LexicalEditor,
|
editor?: LexicalEditor,
|
||||||
@ -11,7 +11,9 @@ interface IProps {
|
|||||||
) => void;
|
) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function VariableOnChangePlugin({ onChange }: IProps) {
|
export function VariableOnChangePlugin({
|
||||||
|
onChange,
|
||||||
|
}: VariableOnChangePluginProps) {
|
||||||
// Access the editor through the LexicalComposerContext
|
// Access the editor through the LexicalComposerContext
|
||||||
const [editor] = useLexicalComposerContext();
|
const [editor] = useLexicalComposerContext();
|
||||||
// Wrap our listener in useEffect to handle the teardown and avoid stale references.
|
// Wrap our listener in useEffect to handle the teardown and avoid stale references.
|
||||||
|
|||||||
@ -32,6 +32,7 @@ import * as ReactDOM from 'react-dom';
|
|||||||
import { $createVariableNode } from './variable-node';
|
import { $createVariableNode } from './variable-node';
|
||||||
|
|
||||||
import { useBuildQueryVariableOptions } from '@/pages/agent/hooks/use-get-begin-query';
|
import { useBuildQueryVariableOptions } from '@/pages/agent/hooks/use-get-begin-query';
|
||||||
|
import { PromptIdentity } from '../../agent-form/use-build-prompt-options';
|
||||||
import { ProgrammaticTag } from './constant';
|
import { ProgrammaticTag } from './constant';
|
||||||
import './index.css';
|
import './index.css';
|
||||||
class VariableInnerOption extends MenuOption {
|
class VariableInnerOption extends MenuOption {
|
||||||
@ -108,11 +109,18 @@ function VariablePickerMenuItem({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type VariablePickerMenuPluginProps = {
|
||||||
|
value?: string;
|
||||||
|
extraOptions?: Array<{
|
||||||
|
label: string;
|
||||||
|
title: string;
|
||||||
|
options: Array<{ label: string; value: string; icon?: ReactNode }>;
|
||||||
|
}>;
|
||||||
|
};
|
||||||
export default function VariablePickerMenuPlugin({
|
export default function VariablePickerMenuPlugin({
|
||||||
value,
|
value,
|
||||||
}: {
|
extraOptions,
|
||||||
value?: string;
|
}: VariablePickerMenuPluginProps): JSX.Element {
|
||||||
}): JSX.Element {
|
|
||||||
const [editor] = useLexicalComposerContext();
|
const [editor] = useLexicalComposerContext();
|
||||||
const isFirstRender = useRef(true);
|
const isFirstRender = useRef(true);
|
||||||
|
|
||||||
@ -122,10 +130,10 @@ export default function VariablePickerMenuPlugin({
|
|||||||
|
|
||||||
const [queryString, setQueryString] = React.useState<string | null>('');
|
const [queryString, setQueryString] = React.useState<string | null>('');
|
||||||
|
|
||||||
const options = useBuildQueryVariableOptions();
|
let options = useBuildQueryVariableOptions();
|
||||||
|
|
||||||
const buildNextOptions = useCallback(() => {
|
const buildNextOptions = useCallback(() => {
|
||||||
let filteredOptions = options;
|
let filteredOptions = [...options, ...(extraOptions ?? [])];
|
||||||
if (queryString) {
|
if (queryString) {
|
||||||
const lowerQuery = queryString.toLowerCase();
|
const lowerQuery = queryString.toLowerCase();
|
||||||
filteredOptions = options
|
filteredOptions = options
|
||||||
@ -140,7 +148,7 @@ export default function VariablePickerMenuPlugin({
|
|||||||
.filter((x) => x.options.length > 0);
|
.filter((x) => x.options.length > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
const nextOptions: VariableOption[] = filteredOptions.map(
|
const finalOptions: VariableOption[] = filteredOptions.map(
|
||||||
(x) =>
|
(x) =>
|
||||||
new VariableOption(
|
new VariableOption(
|
||||||
x.label,
|
x.label,
|
||||||
@ -150,8 +158,8 @@ export default function VariablePickerMenuPlugin({
|
|||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
return nextOptions;
|
return finalOptions;
|
||||||
}, [options, queryString]);
|
}, [extraOptions, options, queryString]);
|
||||||
|
|
||||||
const findItemByValue = useCallback(
|
const findItemByValue = useCallback(
|
||||||
(value: string) => {
|
(value: string) => {
|
||||||
@ -173,7 +181,7 @@ export default function VariablePickerMenuPlugin({
|
|||||||
|
|
||||||
const onSelectOption = useCallback(
|
const onSelectOption = useCallback(
|
||||||
(
|
(
|
||||||
selectedOption: VariableOption | VariableInnerOption,
|
selectedOption: VariableInnerOption,
|
||||||
nodeToRemove: TextNode | null,
|
nodeToRemove: TextNode | null,
|
||||||
closeMenu: () => void,
|
closeMenu: () => void,
|
||||||
) => {
|
) => {
|
||||||
@ -193,7 +201,11 @@ export default function VariablePickerMenuPlugin({
|
|||||||
selectedOption.parentLabel as string | ReactNode,
|
selectedOption.parentLabel as string | ReactNode,
|
||||||
selectedOption.icon as ReactNode,
|
selectedOption.icon as ReactNode,
|
||||||
);
|
);
|
||||||
selection.insertNodes([variableNode]);
|
if (selectedOption.parentLabel === PromptIdentity) {
|
||||||
|
selection.insertText(selectedOption.value);
|
||||||
|
} else {
|
||||||
|
selection.insertNodes([variableNode]);
|
||||||
|
}
|
||||||
|
|
||||||
closeMenu();
|
closeMenu();
|
||||||
});
|
});
|
||||||
@ -269,7 +281,13 @@ export default function VariablePickerMenuPlugin({
|
|||||||
return (
|
return (
|
||||||
<LexicalTypeaheadMenuPlugin<VariableOption | VariableInnerOption>
|
<LexicalTypeaheadMenuPlugin<VariableOption | VariableInnerOption>
|
||||||
onQueryChange={setQueryString}
|
onQueryChange={setQueryString}
|
||||||
onSelectOption={onSelectOption}
|
onSelectOption={(option, textNodeContainingQuery, closeMenu) =>
|
||||||
|
onSelectOption(
|
||||||
|
option as VariableInnerOption, // Only the second level menu can be selected
|
||||||
|
textNodeContainingQuery,
|
||||||
|
closeMenu,
|
||||||
|
)
|
||||||
|
}
|
||||||
triggerFn={checkForTriggerMatch}
|
triggerFn={checkForTriggerMatch}
|
||||||
options={buildNextOptions()}
|
options={buildNextOptions()}
|
||||||
menuRenderFn={(anchorElementRef, { selectOptionAndCleanUp }) => {
|
menuRenderFn={(anchorElementRef, { selectOptionAndCleanUp }) => {
|
||||||
|
|||||||
8
web/src/pages/agent/hooks/use-calculate-sheet-right.ts
Normal file
8
web/src/pages/agent/hooks/use-calculate-sheet-right.ts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import { useSize } from 'ahooks';
|
||||||
|
|
||||||
|
export function useCalculateSheetRight() {
|
||||||
|
const size = useSize(document.querySelector('body'));
|
||||||
|
const bodyWidth = size?.width ?? 0;
|
||||||
|
|
||||||
|
return bodyWidth > 1800 ? 'right-[620px]' : `right-1/3`;
|
||||||
|
}
|
||||||
@ -5,6 +5,7 @@ import {
|
|||||||
SheetTitle,
|
SheetTitle,
|
||||||
} from '@/components/ui/sheet';
|
} from '@/components/ui/sheet';
|
||||||
import { IModalProps } from '@/interfaces/common';
|
import { IModalProps } from '@/interfaces/common';
|
||||||
|
import { cn } from '@/lib/utils';
|
||||||
import { NotebookText } from 'lucide-react';
|
import { NotebookText } from 'lucide-react';
|
||||||
import 'react18-json-view/src/style.css';
|
import 'react18-json-view/src/style.css';
|
||||||
import { useCacheChatLog } from '../hooks/use-cache-chat-log';
|
import { useCacheChatLog } from '../hooks/use-cache-chat-log';
|
||||||
@ -24,7 +25,7 @@ export function LogSheet({
|
|||||||
}: LogSheetProps) {
|
}: LogSheetProps) {
|
||||||
return (
|
return (
|
||||||
<Sheet open onOpenChange={hideModal} modal={false}>
|
<Sheet open onOpenChange={hideModal} modal={false}>
|
||||||
<SheetContent className="top-20 right-[620px]">
|
<SheetContent className={cn('top-20 right-[620px]')}>
|
||||||
<SheetHeader>
|
<SheetHeader>
|
||||||
<SheetTitle className="flex items-center gap-1">
|
<SheetTitle className="flex items-center gap-1">
|
||||||
<NotebookText className="size-4" />
|
<NotebookText className="size-4" />
|
||||||
|
|||||||
@ -25,6 +25,7 @@ const {
|
|||||||
fetchAgentAvatar,
|
fetchAgentAvatar,
|
||||||
fetchAgentLogs,
|
fetchAgentLogs,
|
||||||
fetchExternalAgentInputs,
|
fetchExternalAgentInputs,
|
||||||
|
prompt,
|
||||||
} = api;
|
} = api;
|
||||||
|
|
||||||
const methods = {
|
const methods = {
|
||||||
@ -112,6 +113,10 @@ const methods = {
|
|||||||
url: fetchExternalAgentInputs,
|
url: fetchExternalAgentInputs,
|
||||||
method: 'get',
|
method: 'get',
|
||||||
},
|
},
|
||||||
|
fetchPrompt: {
|
||||||
|
url: prompt,
|
||||||
|
method: 'get',
|
||||||
|
},
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
const agentService = registerNextServer<keyof typeof methods>(methods);
|
const agentService = registerNextServer<keyof typeof methods>(methods);
|
||||||
|
|||||||
@ -164,6 +164,7 @@ export default {
|
|||||||
`${api_host}/canvas/${canvasId}/sessions`,
|
`${api_host}/canvas/${canvasId}/sessions`,
|
||||||
fetchExternalAgentInputs: (canvasId: string) =>
|
fetchExternalAgentInputs: (canvasId: string) =>
|
||||||
`${ExternalApi}${api_host}/agentbots/${canvasId}/inputs`,
|
`${ExternalApi}${api_host}/agentbots/${canvasId}/inputs`,
|
||||||
|
prompt: `${api_host}/canvas/prompts`,
|
||||||
|
|
||||||
// mcp server
|
// mcp server
|
||||||
listMcpServer: `${api_host}/mcp_server/list`,
|
listMcpServer: `${api_host}/mcp_server/list`,
|
||||||
|
|||||||
Reference in New Issue
Block a user