mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 16:45:08 +08:00
Compare commits
5 Commits
c7efaab30e
...
6b9b785b5c
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b9b785b5c | |||
| 4c0a89f262 | |||
| 76b1ee2a00 | |||
| 771a38434f | |||
| 886d38620e |
@ -29,7 +29,7 @@ from api.db.services.canvas_service import CanvasTemplateService, UserCanvasServ
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import queue_dataflow
|
||||
from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.settings import RetCode
|
||||
@ -41,6 +41,7 @@ from api.db.db_models import APIToken
|
||||
import time
|
||||
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@ -145,6 +146,7 @@ def run():
|
||||
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
|
||||
@ -1,353 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import trio
|
||||
from flask import request
|
||||
from flask_login import current_user, login_required
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from agent.component import LLM
|
||||
from api.db import CanvasCategory, FileType
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.task_service import queue_dataflow
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
|
||||
@manager.route("/templates", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def templates():
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.DataFlow)])
|
||||
|
||||
|
||||
@manager.route("/list", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def canvas_list():
|
||||
return get_json_result(data=sorted([c.to_dict() for c in UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.DataFlow)], key=lambda x: x["update_time"] * -1))
|
||||
|
||||
|
||||
@manager.route("/rm", methods=["POST"]) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.delete_by_id(i)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
def save():
|
||||
req = request.json
|
||||
if not isinstance(req["dsl"], str):
|
||||
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
|
||||
req["dsl"] = json.loads(req["dsl"])
|
||||
req["canvas_category"] = CanvasCategory.DataFlow
|
||||
if "id" not in req:
|
||||
req["user_id"] = current_user.id
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.DataFlow):
|
||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||
req["id"] = get_uuid()
|
||||
|
||||
if not UserCanvasService.save(**req):
|
||||
return get_data_error_result(message="Fail to save canvas.")
|
||||
else:
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.update_by_id(req["id"], req)
|
||||
# save version
|
||||
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||
UserCanvasVersionService.delete_all_versions(req["id"])
|
||||
return get_json_result(data=req)
|
||||
|
||||
|
||||
@manager.route("/get/<canvas_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@manager.route("/run", methods=["POST"]) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def run():
|
||||
req = request.json
|
||||
flow_id = req.get("id", "")
|
||||
doc_id = req.get("doc_id", "")
|
||||
if not all([flow_id, doc_id]):
|
||||
return get_data_error_result(message="id and doc_id are required.")
|
||||
|
||||
if not DocumentService.get_by_id(doc_id):
|
||||
return get_data_error_result(message=f"Document for {doc_id} not found.")
|
||||
|
||||
user_id = req.get("user_id", current_user.id)
|
||||
if not UserCanvasService.accessible(flow_id, current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(flow_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
task_id = get_uuid()
|
||||
|
||||
ok, error_message = queue_dataflow(dsl=cvs.dsl, tenant_id=user_id, doc_id=doc_id, task_id=task_id, flow_id=flow_id, priority=0)
|
||||
if not ok:
|
||||
return server_error_response(error_message)
|
||||
|
||||
return get_json_result(data={"task_id": task_id, "flow_id": flow_id})
|
||||
|
||||
|
||||
@manager.route("/reset", methods=["POST"]) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
def reset():
|
||||
req = request.json
|
||||
flow_id = req.get("id", "")
|
||||
if not flow_id:
|
||||
return get_data_error_result(message="id is required.")
|
||||
|
||||
if not UserCanvasService.accessible(flow_id, current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
task_id = req.get("task_id", "")
|
||||
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id=task_id)
|
||||
dataflow.reset()
|
||||
req["dsl"] = json.loads(str(dataflow))
|
||||
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
||||
return get_json_result(data=req["dsl"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
user_id = cvs["user_id"]
|
||||
|
||||
def structured(filename, filetype, blob, content_type):
|
||||
nonlocal user_id
|
||||
if filetype == FileType.PDF.value:
|
||||
blob = read_potential_broken_pdf(blob)
|
||||
|
||||
location = get_uuid()
|
||||
FileService.put_blob(user_id, location, blob)
|
||||
|
||||
return {
|
||||
"id": location,
|
||||
"name": filename,
|
||||
"size": sys.getsizeof(blob),
|
||||
"extension": filename.split(".")[-1].lower(),
|
||||
"mime_type": content_type,
|
||||
"created_by": user_id,
|
||||
"created_at": time.time(),
|
||||
"preview_url": None,
|
||||
}
|
||||
|
||||
if request.args.get("url"):
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CrawlResult, DefaultMarkdownGenerator, PruningContentFilter
|
||||
|
||||
try:
|
||||
url = request.args.get("url")
|
||||
filename = re.sub(r"\?.*", "", url.split("/")[-1])
|
||||
|
||||
async def adownload():
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False,
|
||||
)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(markdown_generator=DefaultMarkdownGenerator(content_filter=PruningContentFilter()), pdf=True, screenshot=False)
|
||||
result: CrawlResult = await crawler.arun(url=url, config=crawler_config)
|
||||
return result
|
||||
|
||||
page = trio.run(adownload())
|
||||
if page.pdf:
|
||||
if filename.split(".")[-1].lower() != "pdf":
|
||||
filename += ".pdf"
|
||||
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
|
||||
|
||||
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
file = request.files["file"]
|
||||
try:
|
||||
DocumentService.check_doc_health(user_id, file.filename)
|
||||
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/input_form", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def input_form():
|
||||
flow_id = request.args.get("id")
|
||||
cpn_id = request.args.get("component_id")
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(flow_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=flow_id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
dataflow = Pipeline(dsl=json.dumps(user_canvas.dsl), tenant_id=current_user.id, flow_id=flow_id, task_id="")
|
||||
|
||||
return get_json_result(data=dataflow.get_component_input_form(cpn_id))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/debug", methods=["POST"]) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
def debug():
|
||||
req = request.json
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
canvas.reset()
|
||||
canvas.message_id = get_uuid()
|
||||
component = canvas.get_component(req["component_id"])["obj"]
|
||||
component.reset()
|
||||
|
||||
if isinstance(component, LLM):
|
||||
component.set_debug_inputs(req["params"])
|
||||
component.invoke(**{k: o["value"] for k, o in req["params"].items()})
|
||||
outputs = component.output()
|
||||
for k in outputs.keys():
|
||||
if isinstance(outputs[k], partial):
|
||||
txt = ""
|
||||
for c in outputs[k]():
|
||||
txt += c
|
||||
outputs[k] = txt
|
||||
return get_json_result(data=outputs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
# api get list version dsl of canvas
|
||||
@manager.route("/getlistversion/<canvas_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
list = sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"] * -1)
|
||||
return get_json_result(data=list)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
|
||||
# api get version dsl of canvas
|
||||
@manager.route("/getversion/<version_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def getversion(version_id):
|
||||
try:
|
||||
e, version = UserCanvasVersionService.get_by_id(version_id)
|
||||
if version:
|
||||
return get_json_result(data=version.to_dict())
|
||||
except Exception as e:
|
||||
return get_json_result(data=f"Error getting history file: {e}")
|
||||
|
||||
|
||||
@manager.route("/listteam", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_canvas():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 150))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
desc = request.args.get("desc", True)
|
||||
try:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
[m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.DataFlow
|
||||
)
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/setting", methods=["POST"]) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
def setting():
|
||||
req = request.json
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(data=False, message="Only owner of canvas authorized for this operation.", code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, flow = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
flow = flow.to_dict()
|
||||
flow["title"] = req["title"]
|
||||
for key in ("description", "permission", "avatar"):
|
||||
if value := req.get(key):
|
||||
flow[key] = value
|
||||
|
||||
num = UserCanvasService.update_by_id(req["id"], flow)
|
||||
return get_json_result(data=num)
|
||||
|
||||
|
||||
@manager.route("/trace", methods=["GET"]) # noqa: F821
|
||||
def trace():
|
||||
dataflow_id = request.args.get("dataflow_id")
|
||||
task_id = request.args.get("task_id")
|
||||
if not all([dataflow_id, task_id]):
|
||||
return get_data_error_result(message="dataflow_id and task_id are required.")
|
||||
|
||||
e, dataflow_canvas = UserCanvasService.get_by_id(dataflow_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="dataflow not found.")
|
||||
|
||||
dsl_str = json.dumps(dataflow_canvas.dsl, ensure_ascii=False)
|
||||
dataflow = Pipeline(dsl=dsl_str, tenant_id=dataflow_canvas.user_id, flow_id=dataflow_id, task_id=task_id)
|
||||
log = dataflow.fetch_logs()
|
||||
|
||||
return get_json_result(data=log)
|
||||
@ -32,7 +32,7 @@ from api.db.services.document_service import DocumentService, doc_upload_and_par
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks
|
||||
from api.db.services.task_service import TaskService, cancel_all_task_of, queue_tasks, queue_dataflow
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
@ -480,8 +480,11 @@ def run():
|
||||
kb_table_num_map[kb_id] = count
|
||||
if kb_table_num_map[kb_id] <= 0:
|
||||
KnowledgebaseService.delete_field_map(kb_id)
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
if doc.get("pipeline_id", ""):
|
||||
queue_dataflow(tenant_id, flow_id=doc["pipeline_id"], task_id=get_uuid(), doc_id=id)
|
||||
else:
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
|
||||
@ -417,8 +417,10 @@ def list_pipeline_logs():
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
create_date_from = request.args.get("create_date_from", "")
|
||||
create_date_to = request.args.get("create_date_to", "")
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
|
||||
@ -437,17 +439,7 @@ def list_pipeline_logs():
|
||||
suffix = req.get("suffix", [])
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
for doc in logs:
|
||||
doc_create_time = doc.get("create_time", 0)
|
||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||
filtered_docs.append(doc)
|
||||
logs = filtered_docs
|
||||
|
||||
|
||||
logs, tol = PipelineOperationLogService.get_file_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -467,8 +459,10 @@ def list_pipeline_dataset_logs():
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
create_time_from = int(request.args.get("create_time_from", 0))
|
||||
create_time_to = int(request.args.get("create_time_to", 0))
|
||||
create_date_from = request.args.get("create_date_from", "")
|
||||
create_date_to = request.args.get("create_date_to", "")
|
||||
if create_date_to > create_date_from:
|
||||
return get_data_error_result(message="Create data filter is abnormal.")
|
||||
|
||||
req = request.get_json()
|
||||
|
||||
@ -479,17 +473,7 @@ def list_pipeline_dataset_logs():
|
||||
return get_data_error_result(message=f"Invalid filter operation_status status conditions: {', '.join(invalid_status)}")
|
||||
|
||||
try:
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status)
|
||||
|
||||
if create_time_from or create_time_to:
|
||||
filtered_docs = []
|
||||
for doc in logs:
|
||||
doc_create_time = doc.get("create_time", 0)
|
||||
if (create_time_from == 0 or doc_create_time >= create_time_from) and (create_time_to == 0 or doc_create_time <= create_time_to):
|
||||
filtered_docs.append(doc)
|
||||
logs = filtered_docs
|
||||
|
||||
|
||||
logs, tol = PipelineOperationLogService.get_dataset_logs_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from, create_date_to)
|
||||
return get_json_result(data={"total": tol, "logs": logs})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -538,12 +522,13 @@ def run_graphrag():
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
@ -583,7 +568,7 @@ def trace_graphrag():
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_error_data_result(message="GraphRAG Task ID Not Found")
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
@ -606,12 +591,13 @@ def run_raptor():
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}")
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
@ -651,10 +637,79 @@ def trace_raptor():
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_error_data_result(message="RAPTOR Task ID Not Found")
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def run_mindmap():
|
||||
req = request.json
|
||||
|
||||
kb_id = req.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.mindmap_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=kb_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}):
|
||||
logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}")
|
||||
|
||||
return get_json_result(data={"mindmap_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def trace_mindmap():
|
||||
kb_id = request.args.get("kb_id", "")
|
||||
if not kb_id:
|
||||
return get_error_data_result(message='Lack of "KB ID"')
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Knowledgebase ID")
|
||||
|
||||
task_id = kb.mindmap_task_id
|
||||
if not task_id:
|
||||
return get_json_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Mindmap Task Not Found or Error Occurred")
|
||||
|
||||
return get_json_result(data=task.to_dict())
|
||||
|
||||
@ -127,9 +127,10 @@ class PipelineTaskType(StrEnum):
|
||||
DOWNLOAD = "Download"
|
||||
RAPTOR = "RAPTOR"
|
||||
GRAPH_RAG = "GraphRAG"
|
||||
MINDMAP = "Mindmap"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG}
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||
|
||||
|
||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||
|
||||
@ -651,7 +651,11 @@ class Knowledgebase(DataBaseModel):
|
||||
pagerank = IntegerField(default=0, index=False)
|
||||
|
||||
graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True)
|
||||
graphrag_task_finish_at = DateTimeField(null=True)
|
||||
raptor_task_id = CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)
|
||||
raptor_task_finish_at = DateTimeField(null=True)
|
||||
mindmap_task_id = CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)
|
||||
mindmap_task_finish_at = DateTimeField(null=True)
|
||||
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
|
||||
|
||||
@ -1084,4 +1088,20 @@ def migrate_db():
|
||||
migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@ -636,8 +636,6 @@ class DocumentService(CommonService):
|
||||
prg = 0
|
||||
finished = True
|
||||
bad = 0
|
||||
has_raptor = False
|
||||
has_graphrag = False
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
priority = 0
|
||||
@ -649,25 +647,14 @@ class DocumentService(CommonService):
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
if t.progress_msg.strip():
|
||||
msg.append(t.progress_msg)
|
||||
if t.task_type == "raptor":
|
||||
has_raptor = True
|
||||
elif t.task_type == "graphrag":
|
||||
has_graphrag = True
|
||||
priority = max(priority, t.priority)
|
||||
prg /= len(tsks)
|
||||
if finished and bad:
|
||||
prg = -1
|
||||
status = TaskStatus.FAIL.value
|
||||
elif finished:
|
||||
if (d["parser_config"].get("raptor") or {}).get("use_raptor") and not has_raptor:
|
||||
queue_raptor_o_graphrag_tasks(d, "raptor", priority)
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif (d["parser_config"].get("graphrag") or {}).get("use_graphrag") and not has_graphrag:
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag", priority)
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
else:
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
@ -679,7 +666,7 @@ class DocumentService(CommonService):
|
||||
info["progress"] = prg
|
||||
if msg:
|
||||
info["progress_msg"] = msg
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor"):
|
||||
if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"):
|
||||
info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
else:
|
||||
info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority)
|
||||
@ -770,7 +757,8 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[])
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty,
|
||||
"begin_at": datetime.now(),
|
||||
}
|
||||
|
||||
task = new_task()
|
||||
@ -780,7 +768,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[])
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
|
||||
if ty in ["graphrag", "raptor"]:
|
||||
if ty in ["graphrag", "raptor", "mindmap"]:
|
||||
task["doc_ids"] = doc_ids
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
@ -15,10 +15,10 @@
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import fn
|
||||
from peewee import fn, JOIN
|
||||
|
||||
from api.db import StatusEnum, TenantPermission
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant
|
||||
from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant, UserCanvas
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
|
||||
@ -226,13 +226,23 @@ class KnowledgebaseService(CommonService):
|
||||
cls.model.chunk_num,
|
||||
cls.model.parser_id,
|
||||
cls.model.pipeline_id,
|
||||
UserCanvas.title,
|
||||
UserCanvas.avatar.alias("pipeline_avatar"),
|
||||
cls.model.parser_config,
|
||||
cls.model.pagerank,
|
||||
cls.model.graphrag_task_id,
|
||||
cls.model.graphrag_task_finish_at,
|
||||
cls.model.raptor_task_id,
|
||||
cls.model.raptor_task_finish_at,
|
||||
cls.model.mindmap_task_id,
|
||||
cls.model.mindmap_task_finish_at,
|
||||
cls.model.create_time,
|
||||
cls.model.update_time
|
||||
]
|
||||
kbs = cls.model.select(*fields).join(Tenant, on=(
|
||||
(Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
|
||||
kbs = cls.model.select(*fields)\
|
||||
.join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value)))\
|
||||
.join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\
|
||||
.where(
|
||||
(cls.model.id == kb_id),
|
||||
(cls.model.status == StatusEnum.VALID.value)
|
||||
)
|
||||
|
||||
@ -15,12 +15,12 @@
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from api.db import VALID_PIPELINE_TASK_TYPES
|
||||
from api.db.db_models import DB, PipelineOperationLog, Document
|
||||
from api.db import VALID_PIPELINE_TASK_TYPES, PipelineTaskType
|
||||
from api.db.db_models import DB, Document, PipelineOperationLog
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
@ -83,10 +83,7 @@ class PipelineOperationLogService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[]):
|
||||
from rag.flow.pipeline import Pipeline
|
||||
|
||||
dsl = ""
|
||||
def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl:str="{}"):
|
||||
referred_document_id = document_id
|
||||
|
||||
if referred_document_id == GRAPH_RAPTOR_FAKE_DOC_ID and fake_document_ids:
|
||||
@ -108,13 +105,9 @@ class PipelineOperationLogService(CommonService):
|
||||
ok, user_pipeline = UserCanvasService.get_by_id(pipeline_id)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
pipeline = Pipeline(dsl=json.dumps(user_pipeline.dsl), tenant_id=user_pipeline.user_id, doc_id=referred_document_id, task_id="", flow_id=pipeline_id)
|
||||
|
||||
tenant_id = user_pipeline.user_id
|
||||
title = user_pipeline.title
|
||||
avatar = user_pipeline.avatar
|
||||
dsl = json.loads(str(pipeline))
|
||||
else:
|
||||
ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id)
|
||||
if not ok:
|
||||
@ -127,6 +120,24 @@ class PipelineOperationLogService(CommonService):
|
||||
if task_type not in VALID_PIPELINE_TASK_TYPES:
|
||||
raise ValueError(f"Invalid task type: {task_type}")
|
||||
|
||||
if task_type in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]:
|
||||
finish_at = document.process_begin_at + timedelta(seconds=document.process_duration)
|
||||
if task_type == PipelineTaskType.GRAPH_RAG:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"graphrag_task_finish_at": finish_at},
|
||||
)
|
||||
elif task_type == PipelineTaskType.RAPTOR:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"raptor_task_finish_at": finish_at},
|
||||
)
|
||||
elif task_type == PipelineTaskType.MINDMAP:
|
||||
KnowledgebaseService.update_by_id(
|
||||
document.kb_id,
|
||||
{"mindmap_task_finish_at": finish_at},
|
||||
)
|
||||
|
||||
log = dict(
|
||||
id=get_uuid(),
|
||||
document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id
|
||||
@ -143,7 +154,7 @@ class PipelineOperationLogService(CommonService):
|
||||
progress_msg=document.progress_msg,
|
||||
process_begin_at=document.process_begin_at,
|
||||
process_duration=document.process_duration,
|
||||
dsl=dsl,
|
||||
dsl=json.loads(dsl),
|
||||
task_type=task_type,
|
||||
operation_status=operation_status,
|
||||
avatar=avatar,
|
||||
@ -162,7 +173,7 @@ class PipelineOperationLogService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix):
|
||||
def get_file_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, operation_status, types, suffix, create_date_from=None, create_date_to=None):
|
||||
fields = cls.get_file_logs_fields()
|
||||
if keywords:
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.document_name).contains(keywords.lower())))
|
||||
@ -177,6 +188,10 @@ class PipelineOperationLogService(CommonService):
|
||||
logs = logs.where(cls.model.document_type.in_(types))
|
||||
if suffix:
|
||||
logs = logs.where(cls.model.document_suffix.in_(suffix))
|
||||
if create_date_from:
|
||||
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||
if create_date_to:
|
||||
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
@ -192,25 +207,30 @@ class PipelineOperationLogService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_documents_info(cls, id):
|
||||
fields = [
|
||||
Document.id,
|
||||
Document.name,
|
||||
Document.progress
|
||||
]
|
||||
return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where(
|
||||
cls.model.id == id,
|
||||
Document.progress > 0,
|
||||
Document.progress < 1
|
||||
).dicts()
|
||||
|
||||
fields = [Document.id, Document.name, Document.progress]
|
||||
return (
|
||||
cls.model.select(*fields)
|
||||
.join(Document, on=(cls.model.document_id == Document.id))
|
||||
.where(
|
||||
cls.model.id == id,
|
||||
Document.progress > 0,
|
||||
Document.progress < 1,
|
||||
)
|
||||
.dicts()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status):
|
||||
def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None):
|
||||
fields = cls.get_dataset_logs_fields()
|
||||
logs = cls.model.select(*fields).where((cls.model.kb_id == kb_id), (cls.model.document_id == GRAPH_RAPTOR_FAKE_DOC_ID))
|
||||
|
||||
if operation_status:
|
||||
logs = logs.where(cls.model.operation_status.in_(operation_status))
|
||||
if create_date_from:
|
||||
logs = logs.where(cls.model.create_date >= create_date_from)
|
||||
if create_date_to:
|
||||
logs = logs.where(cls.model.create_date <= create_date_to)
|
||||
|
||||
count = logs.count()
|
||||
if desc:
|
||||
@ -222,4 +242,3 @@ class PipelineOperationLogService(CommonService):
|
||||
logs = logs.paginate(page_number, items_per_page)
|
||||
|
||||
return list(logs.dicts()), count
|
||||
|
||||
|
||||
@ -298,21 +298,23 @@ class TaskService(CommonService):
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
return
|
||||
else:
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
prog = info["progress"]
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
|
||||
with DB.lock("update_progress", -1):
|
||||
if info["progress_msg"]:
|
||||
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000)
|
||||
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
|
||||
if "progress" in info:
|
||||
prog = info["progress"]
|
||||
cls.model.update(progress=prog).where(
|
||||
(cls.model.id == id) &
|
||||
(
|
||||
(cls.model.progress != -1) &
|
||||
((prog == -1) | (prog > cls.model.progress))
|
||||
)
|
||||
).execute()
|
||||
process_duration = (datetime.now() - task.begin_at).total_seconds()
|
||||
cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute()
|
||||
|
||||
|
||||
def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
@ -336,7 +338,14 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int):
|
||||
- Previous task chunks may be reused if available
|
||||
"""
|
||||
def new_task():
|
||||
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000}
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": doc["id"],
|
||||
"progress": 0.0,
|
||||
"from_page": 0,
|
||||
"to_page": 100000000,
|
||||
"begin_at": datetime.now(),
|
||||
}
|
||||
|
||||
parse_task_array = []
|
||||
|
||||
@ -487,9 +496,11 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE
|
||||
to_page=100000000,
|
||||
task_type="dataflow" if not rerun else "dataflow_rerun",
|
||||
priority=priority,
|
||||
begin_at=datetime.now(),
|
||||
)
|
||||
|
||||
TaskService.model.delete().where(TaskService.model.id == task["id"]).execute()
|
||||
if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]:
|
||||
TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute()
|
||||
DocumentService.begin2parse(doc_id)
|
||||
bulk_insert_into_db(model=Task, data_source=[task], replace_on_conflict=True)
|
||||
|
||||
task["kb_id"] = DocumentService.get_knowledgebase_id(doc_id)
|
||||
|
||||
@ -1127,7 +1127,7 @@ class RAGFlowPdfParser:
|
||||
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
|
||||
pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
|
||||
left, right, top, bottom = float(left), float(right), float(top), float(bottom)
|
||||
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
|
||||
poss.append(([int(p) - 1 for p in pn.split("-")], int(left), int(right), int(top), int(bottom)))
|
||||
return poss
|
||||
|
||||
def crop(self, text, ZM=3, need_position=False):
|
||||
|
||||
@ -31,6 +31,7 @@ class Extractor(ProcessBase, LLM):
|
||||
component_name = "Extractor"
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to generate.")
|
||||
inputs = self.get_input_elements()
|
||||
chunks = []
|
||||
@ -50,7 +51,8 @@ class Extractor(ProcessBase, LLM):
|
||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||
ck[self._param.field_name] = self._generate(msg)
|
||||
prog += 1./len(chunks)
|
||||
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||
if i % (len(chunks)//100+1) == 1:
|
||||
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||
self.set_output("chunks", chunks)
|
||||
else:
|
||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||
|
||||
@ -25,7 +25,7 @@ class ExtractorFromUpstream(BaseModel):
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
|
||||
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
|
||||
@ -53,6 +53,7 @@ class HierarchicalMerger(ProcessBase):
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to merge hierarchically.")
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
if from_upstream.output_format == "markdown":
|
||||
|
||||
@ -25,7 +25,7 @@ class HierarchicalMergerFromUpstream(BaseModel):
|
||||
file: dict | None = Field(default=None)
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
output_format: Literal["json", "chunks"] | None = Field(default=None)
|
||||
json_result: list[dict[str, Any]] | None = Field(default=None, alias="json")
|
||||
markdown_result: str | None = Field(default=None, alias="markdown")
|
||||
text_result: str | None = Field(default=None, alias="text")
|
||||
|
||||
@ -148,7 +148,7 @@ class ParserParam(ProcessParamBase):
|
||||
self.check_empty(pdf_parse_method, "Parse method abnormal.")
|
||||
|
||||
if pdf_parse_method.lower() not in ["deepdoc", "plain_text"]:
|
||||
self.check_empty(pdf_config.get("lang", ""), "Language")
|
||||
self.check_empty(pdf_config.get("lang", ""), "PDF VLM language")
|
||||
|
||||
pdf_output_format = pdf_config.get("output_format", "")
|
||||
self.check_valid_value(pdf_output_format, "PDF output format abnormal.", self.allowed_output_format["pdf"])
|
||||
@ -172,7 +172,7 @@ class ParserParam(ProcessParamBase):
|
||||
if image_config:
|
||||
image_parse_method = image_config.get("parse_method", "")
|
||||
if image_parse_method not in ["ocr"]:
|
||||
self.check_empty(image_config.get("lang", ""), "Language")
|
||||
self.check_empty(image_config.get("lang", ""), "Image VLM language")
|
||||
|
||||
text_config = self.setups.get("text&markdown", "")
|
||||
if text_config:
|
||||
@ -181,7 +181,7 @@ class ParserParam(ProcessParamBase):
|
||||
|
||||
audio_config = self.setups.get("audio", "")
|
||||
if audio_config:
|
||||
self.check_empty(audio_config.get("llm_id"), "VLM")
|
||||
self.check_empty(audio_config.get("llm_id"), "Audio VLM")
|
||||
audio_language = audio_config.get("lang", "")
|
||||
self.check_empty(audio_language, "Language")
|
||||
|
||||
|
||||
@ -76,22 +76,23 @@ class Pipeline(Graph):
|
||||
}
|
||||
]
|
||||
REDIS_CONN.set_obj(log_key, obj, 60 * 30)
|
||||
if self._doc_id and self.task_id:
|
||||
if component_name != "END" and self._doc_id and self.task_id:
|
||||
percentage = 1.0 / len(self.components.items())
|
||||
msg = ""
|
||||
finished = 0.0
|
||||
for o in obj:
|
||||
if o["component_id"] == "END":
|
||||
continue
|
||||
msg += f"\n[{o['component_id']}]:\n"
|
||||
for t in o["trace"]:
|
||||
msg += "%s: %s\n" % (t["datetime"], t["message"])
|
||||
if t["progress"] < 0:
|
||||
finished = -1
|
||||
break
|
||||
if finished < 0:
|
||||
break
|
||||
finished += o["trace"][-1]["progress"] * percentage
|
||||
|
||||
msg = ""
|
||||
if len(obj[-1]["trace"]) == 1:
|
||||
msg += f"\n-------------------------------------\n[{self.get_component_name(o['component_id'])}]:\n"
|
||||
t = obj[-1]["trace"][-1]
|
||||
msg += "%s: %s\n" % (t["datetime"], t["message"])
|
||||
TaskService.update_progress(self.task_id, {"progress": finished, "progress_msg": msg})
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
@ -59,6 +59,7 @@ class Splitter(ProcessBase):
|
||||
else:
|
||||
deli += d
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to split into chunks.")
|
||||
if from_upstream.output_format in ["markdown", "text", "html"]:
|
||||
if from_upstream.output_format == "markdown":
|
||||
@ -99,7 +100,7 @@ class Splitter(ProcessBase):
|
||||
{
|
||||
"text": RAGFlowPdfParser.remove_tag(c),
|
||||
"image": img,
|
||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||
"positions": [[pos[0][-1], *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)],
|
||||
}
|
||||
for c, img in zip(chunks, images)
|
||||
]
|
||||
|
||||
@ -24,7 +24,7 @@ class TokenizerFromUpstream(BaseModel):
|
||||
name: str = ""
|
||||
file: dict | None = Field(default=None)
|
||||
|
||||
output_format: Literal["json", "markdown", "text", "html"] | None = Field(default=None)
|
||||
output_format: Literal["json", "markdown", "text", "html", "chunks"] | None = Field(default=None)
|
||||
|
||||
chunks: list[dict[str, Any]] | None = Field(default=None)
|
||||
|
||||
|
||||
@ -108,6 +108,7 @@ class Tokenizer(ProcessBase):
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
return
|
||||
|
||||
self.set_output("output_format", "chunks")
|
||||
parts = sum(["full_text" in self._param.search_method, "embedding" in self._param.search_method])
|
||||
if "full_text" in self._param.search_method:
|
||||
self.callback(random.randint(1, 5) / 100.0, "Start to tokenize.")
|
||||
@ -117,11 +118,13 @@ class Tokenizer(ProcessBase):
|
||||
ck["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", from_upstream.name))
|
||||
ck["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(ck["title_tks"])
|
||||
if ck.get("questions"):
|
||||
ck["question_tks"] = rag_tokenizer.tokenize("\n".join(ck["questions"]))
|
||||
ck["question_kwd"] = ck["questions"].split("\n")
|
||||
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
|
||||
if ck.get("keywords"):
|
||||
ck["important_tks"] = rag_tokenizer.tokenize(",".join(ck["keywords"]))
|
||||
ck["important_kwd"] = ck["keywords"].split(",")
|
||||
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
|
||||
if ck.get("summary"):
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["summary"])
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
else:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(ck["text"])
|
||||
|
||||
@ -20,6 +20,9 @@ import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import json_repair
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
@ -57,7 +60,7 @@ from api.versions import get_ragflow_version
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
email, tag
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.nlp import search, rag_tokenizer, add_positions
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, DOC_BULK_SIZE, EMBEDDING_BATCH_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
@ -90,6 +93,7 @@ TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
||||
"dataflow" : PipelineTaskType.PARSE,
|
||||
"raptor": PipelineTaskType.RAPTOR,
|
||||
"graphrag": PipelineTaskType.GRAPH_RAG,
|
||||
"mindmap": PipelineTaskType.MINDMAP,
|
||||
}
|
||||
|
||||
UNACKED_ITERATOR = None
|
||||
@ -224,7 +228,7 @@ async def collect():
|
||||
canceled = False
|
||||
if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]:
|
||||
task = msg
|
||||
if task["task_type"] in ["graphrag", "raptor"] and msg.get("doc_ids", []):
|
||||
if task["task_type"] in ["graphrag", "raptor", "mindmap"] and msg.get("doc_ids", []):
|
||||
task = TaskService.get_task(msg["id"], msg["doc_ids"])
|
||||
task["doc_ids"] = msg["doc_ids"]
|
||||
else:
|
||||
@ -477,6 +481,8 @@ async def run_dataflow(task: dict):
|
||||
dataflow_id = task["dataflow_id"]
|
||||
doc_id = task["doc_id"]
|
||||
task_id = task["id"]
|
||||
task_dataset_id = task["kb_id"]
|
||||
|
||||
if task["task_type"] == "dataflow":
|
||||
e, cvs = UserCanvasService.get_by_id(dataflow_id)
|
||||
assert e, "User pipeline not found."
|
||||
@ -486,12 +492,12 @@ async def run_dataflow(task: dict):
|
||||
assert e, "Pipeline log not found."
|
||||
dsl = pipeline_log.dsl
|
||||
pipeline = Pipeline(dsl, tenant_id=task["tenant_id"], doc_id=doc_id, task_id=task_id, flow_id=dataflow_id)
|
||||
chunks = await pipeline.run(file=task["file"]) if task.get("file") else pipeline.run()
|
||||
chunks = await pipeline.run(file=task["file"]) if task.get("file") else await pipeline.run()
|
||||
if doc_id == CANVAS_DEBUG_DOC_ID:
|
||||
return
|
||||
|
||||
if not chunks:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
embedding_token_consumption = chunks.get("embedding_token_consumption", 0)
|
||||
@ -508,7 +514,7 @@ async def run_dataflow(task: dict):
|
||||
|
||||
keys = [k for o in chunks for k in list(o.keys())]
|
||||
if not any([re.match(r"q_[0-9]+_vec", k) for k in keys]):
|
||||
set_progress(task_id, prog=0.82, msg="Start to embedding...")
|
||||
set_progress(task_id, prog=0.82, msg="\n-------------------------------------\nStart to embedding...")
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
embedding_id = kb.embd_id
|
||||
embedding_model = LLMBundle(task["tenant_id"], LLMType.EMBEDDING, llm_name=embedding_id)
|
||||
@ -518,7 +524,7 @@ async def run_dataflow(task: dict):
|
||||
return embedding_model.encode([truncate(c, embedding_model.max_length - 10) for c in txts])
|
||||
vects = np.array([])
|
||||
texts = [o.get("questions", o.get("summary", o["text"])) for o in chunks]
|
||||
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE)
|
||||
delta = 0.20/(len(texts)//EMBEDDING_BATCH_SIZE+1)
|
||||
prog = 0.8
|
||||
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
||||
async with embed_limiter:
|
||||
@ -529,7 +535,8 @@ async def run_dataflow(task: dict):
|
||||
vects = np.concatenate((vects, vts), axis=0)
|
||||
embedding_token_consumption += c
|
||||
prog += delta
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||
if i % (len(texts)//EMBEDDING_BATCH_SIZE/100+1) == 1:
|
||||
set_progress(task_id, prog=prog, msg=f"{i+1} / {len(texts)//EMBEDDING_BATCH_SIZE}")
|
||||
|
||||
assert len(vects) == len(chunks)
|
||||
for i, ck in enumerate(chunks):
|
||||
@ -539,9 +546,23 @@ async def run_dataflow(task: dict):
|
||||
metadata = {}
|
||||
def dict_update(meta):
|
||||
nonlocal metadata
|
||||
if not meta or not isinstance(meta, dict):
|
||||
if not meta:
|
||||
return
|
||||
for k,v in meta.items():
|
||||
if isinstance(meta, str):
|
||||
try:
|
||||
meta = json_repair.loads(meta)
|
||||
except Exception:
|
||||
logging.error("Meta data format error.")
|
||||
return
|
||||
if not isinstance(meta, dict):
|
||||
return
|
||||
for k, v in meta.items():
|
||||
if isinstance(v, list):
|
||||
v = [vv for vv in v if isinstance(vv, str)]
|
||||
if not v:
|
||||
continue
|
||||
if not isinstance(v, list) and not isinstance(v, str):
|
||||
continue
|
||||
if k not in metadata:
|
||||
metadata[k] = v
|
||||
continue
|
||||
@ -561,15 +582,29 @@ async def run_dataflow(task: dict):
|
||||
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||
if "questions" in ck:
|
||||
if "question_tks" not in ck:
|
||||
ck["question_kwd"] = ck["questions"].split("\n")
|
||||
ck["question_tks"] = rag_tokenizer.tokenize(str(ck["questions"]))
|
||||
del ck["questions"]
|
||||
if "keywords" in ck:
|
||||
if "important_tks" not in ck:
|
||||
ck["important_kwd"] = ck["keywords"].split(",")
|
||||
ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"]))
|
||||
del ck["keywords"]
|
||||
if "summary" in ck:
|
||||
if "content_ltks" not in ck:
|
||||
ck["content_ltks"] = rag_tokenizer.tokenize(str(ck["summary"]))
|
||||
ck["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(ck["content_ltks"])
|
||||
del ck["summary"]
|
||||
if "metadata" in ck:
|
||||
dict_update(ck["metadata"])
|
||||
del ck["metadata"]
|
||||
if "content_with_weight" not in ck:
|
||||
ck["content_with_weight"] = ck["text"]
|
||||
del ck["text"]
|
||||
if "positions" in ck:
|
||||
add_positions(ck, ck["positions"])
|
||||
del ck["positions"]
|
||||
|
||||
if metadata:
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
@ -580,59 +615,18 @@ async def run_dataflow(task: dict):
|
||||
DocumentService.update_by_id(doc_id, {"meta_fields": metadata})
|
||||
|
||||
start_ts = timer()
|
||||
set_progress(task_id, prog=0.82, msg="Start to index...")
|
||||
set_progress(task_id, prog=0.82, msg="[DOC Engine]:\nStart to index...")
|
||||
e = await insert_es(task_id, task["tenant_id"], task["kb_id"], chunks, partial(set_progress, task_id, 0, 100000000))
|
||||
if not e:
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
return
|
||||
|
||||
time_cost = timer() - start_ts
|
||||
task_time_cost = timer() - task_start_ts
|
||||
set_progress(task_id, prog=1., msg="Indexing done ({:.2f}s). Task done ({:.2f}s)".format(time_cost, task_time_cost))
|
||||
DocumentService.increment_chunk_num(doc_id, task_dataset_id, embedding_token_consumption, len(chunks), task_time_cost)
|
||||
logging.info("[Done], chunks({}), token({}), elapsed:{:.2f}".format(len(chunks), embedding_token_consumption, task_time_cost))
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE)
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
|
||||
chunks = []
|
||||
vctr_nm = "q_%d_vec"%vector_size
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm]):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
raptor = Raptor(
|
||||
row["parser_config"]["raptor"].get("max_cluster", 64),
|
||||
chat_mdl,
|
||||
embd_mdl,
|
||||
row["parser_config"]["raptor"]["prompt"],
|
||||
row["parser_config"]["raptor"]["max_token"],
|
||||
row["parser_config"]["raptor"]["threshold"]
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = await raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback)
|
||||
doc = {
|
||||
"doc_id": row["doc_id"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": row["name"],
|
||||
"title_tks": rag_tokenizer.tokenize(row["name"])
|
||||
}
|
||||
if row["pagerank"]:
|
||||
doc[PAGERANK_FLD] = int(row["pagerank"])
|
||||
res = []
|
||||
tk_count = 0
|
||||
for content, vctr in chunks[original_length:]:
|
||||
d = copy.deepcopy(doc)
|
||||
d["id"] = xxhash.xxh64((content + str(d["doc_id"])).encode("utf-8")).hexdigest()
|
||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
d[vctr_nm] = vctr.tolist()
|
||||
d["content_with_weight"] = content
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(content)
|
||||
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
||||
res.append(d)
|
||||
tk_count += num_tokens_from_string(content)
|
||||
return res, tk_count
|
||||
PipelineOperationLogService.create(document_id=doc_id, pipeline_id=dataflow_id, task_type=PipelineTaskType.PARSE, dsl=str(pipeline))
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
@ -787,7 +781,6 @@ async def do_handle_task(task):
|
||||
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
|
||||
# run RAPTOR
|
||||
async with kg_limiter:
|
||||
# chunks, token_count = await run_raptor(task, chat_model, embedding_model, vector_size, progress_callback)
|
||||
chunks, token_count = await run_raptor_for_kb(
|
||||
row=task,
|
||||
kb_parser_config=kb_parser_config,
|
||||
@ -830,6 +823,10 @@ async def do_handle_task(task):
|
||||
logging.info(f"GraphRAG task result for task {task}:\n{result}")
|
||||
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
|
||||
return
|
||||
elif task_type == "mindmap":
|
||||
progress_callback(1, "place holder")
|
||||
pass
|
||||
return
|
||||
else:
|
||||
# Standard chunking methods
|
||||
start_ts = timer()
|
||||
@ -906,10 +903,10 @@ async def handle_task():
|
||||
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
||||
finally:
|
||||
task_document_ids = []
|
||||
if task_type in ["graphrag", "raptor"]:
|
||||
if task_type in ["graphrag", "raptor", "mindmap"]:
|
||||
task_document_ids = task["doc_ids"]
|
||||
if task["doc_id"] != CANVAS_DEBUG_DOC_ID:
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id=task.get("dataflow_id", "") or "", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
if not task.get("dataflow_id", ""):
|
||||
PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)
|
||||
|
||||
redis_msg.ack()
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ import {
|
||||
ParseTypeItem,
|
||||
} from '@/pages/dataset/dataset-setting/configuration/common-item';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import get from 'lodash/get';
|
||||
import omit from 'lodash/omit';
|
||||
import {} from 'module';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
@ -41,13 +40,6 @@ import { ExcelToHtmlFormField } from '../excel-to-html-form-field';
|
||||
import { FormContainer } from '../form-container';
|
||||
import { LayoutRecognizeFormField } from '../layout-recognize-form-field';
|
||||
import { MaxTokenNumberFormField } from '../max-token-number-from-field';
|
||||
import {
|
||||
UseGraphRagFormField,
|
||||
showGraphRagItems,
|
||||
} from '../parse-configuration/graph-rag-form-fields';
|
||||
import RaptorFormFields, {
|
||||
showRaptorParseConfiguration,
|
||||
} from '../parse-configuration/raptor-form-fields';
|
||||
import { ButtonLoading } from '../ui/button';
|
||||
import { Input } from '../ui/input';
|
||||
import { DynamicPageRange } from './dynamic-page-range';
|
||||
@ -121,19 +113,19 @@ export function ChunkMethodDialog({
|
||||
auto_keywords: z.coerce.number().optional(),
|
||||
auto_questions: z.coerce.number().optional(),
|
||||
html4excel: z.boolean().optional(),
|
||||
raptor: z
|
||||
.object({
|
||||
use_raptor: z.boolean().optional(),
|
||||
prompt: z.string().optional().optional(),
|
||||
max_token: z.coerce.number().optional(),
|
||||
threshold: z.coerce.number().optional(),
|
||||
max_cluster: z.coerce.number().optional(),
|
||||
random_seed: z.coerce.number().optional(),
|
||||
})
|
||||
.optional(),
|
||||
graphrag: z.object({
|
||||
use_graphrag: z.boolean().optional(),
|
||||
}),
|
||||
// raptor: z
|
||||
// .object({
|
||||
// use_raptor: z.boolean().optional(),
|
||||
// prompt: z.string().optional().optional(),
|
||||
// max_token: z.coerce.number().optional(),
|
||||
// threshold: z.coerce.number().optional(),
|
||||
// max_cluster: z.coerce.number().optional(),
|
||||
// random_seed: z.coerce.number().optional(),
|
||||
// })
|
||||
// .optional(),
|
||||
// graphrag: z.object({
|
||||
// use_graphrag: z.boolean().optional(),
|
||||
// }),
|
||||
entity_types: z.array(z.string()).optional(),
|
||||
pages: z
|
||||
.array(z.object({ from: z.coerce.number(), to: z.coerce.number() }))
|
||||
@ -223,13 +215,13 @@ export function ChunkMethodDialog({
|
||||
parser_config: fillDefaultParserValue({
|
||||
pages: pages.length > 0 ? pages : [{ from: 1, to: 1024 }],
|
||||
...omit(parserConfig, 'pages'),
|
||||
graphrag: {
|
||||
use_graphrag: get(
|
||||
parserConfig,
|
||||
'graphrag.use_graphrag',
|
||||
useGraphRag,
|
||||
),
|
||||
},
|
||||
// graphrag: {
|
||||
// use_graphrag: get(
|
||||
// parserConfig,
|
||||
// 'graphrag.use_graphrag',
|
||||
// useGraphRag,
|
||||
// ),
|
||||
// },
|
||||
}),
|
||||
});
|
||||
}
|
||||
@ -351,19 +343,19 @@ export function ChunkMethodDialog({
|
||||
<ExcelToHtmlFormField></ExcelToHtmlFormField>
|
||||
)}
|
||||
</FormContainer>
|
||||
{showRaptorParseConfiguration(
|
||||
{/* {showRaptorParseConfiguration(
|
||||
selectedTag as DocumentParserType,
|
||||
) && (
|
||||
<FormContainer>
|
||||
<RaptorFormFields></RaptorFormFields>
|
||||
</FormContainer>
|
||||
)}
|
||||
{showGraphRagItems(selectedTag as DocumentParserType) &&
|
||||
)} */}
|
||||
{/* {showGraphRagItems(selectedTag as DocumentParserType) &&
|
||||
useGraphRag && (
|
||||
<FormContainer>
|
||||
<UseGraphRagFormField></UseGraphRagFormField>
|
||||
</FormContainer>
|
||||
)}
|
||||
)} */}
|
||||
{showEntityTypes && (
|
||||
<EntityTypesFormField></EntityTypesFormField>
|
||||
)}
|
||||
|
||||
@ -15,17 +15,17 @@ export function useDefaultParserValues() {
|
||||
auto_keywords: 0,
|
||||
auto_questions: 0,
|
||||
html4excel: false,
|
||||
raptor: {
|
||||
use_raptor: false,
|
||||
prompt: t('knowledgeConfiguration.promptText'),
|
||||
max_token: 256,
|
||||
threshold: 0.1,
|
||||
max_cluster: 64,
|
||||
random_seed: 0,
|
||||
},
|
||||
graphrag: {
|
||||
use_graphrag: false,
|
||||
},
|
||||
// raptor: {
|
||||
// use_raptor: false,
|
||||
// prompt: t('knowledgeConfiguration.promptText'),
|
||||
// max_token: 256,
|
||||
// threshold: 0.1,
|
||||
// max_cluster: 64,
|
||||
// random_seed: 0,
|
||||
// },
|
||||
// graphrag: {
|
||||
// use_graphrag: false,
|
||||
// },
|
||||
entity_types: [],
|
||||
pages: [],
|
||||
};
|
||||
|
||||
@ -8,7 +8,7 @@ import {
|
||||
AlertDialogTitle,
|
||||
AlertDialogTrigger,
|
||||
} from '@/components/ui/alert-dialog';
|
||||
import { PropsWithChildren } from 'react';
|
||||
import { DialogProps } from '@radix-ui/react-dialog';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
interface IProps {
|
||||
@ -24,7 +24,10 @@ export function ConfirmDeleteDialog({
|
||||
onOk,
|
||||
onCancel,
|
||||
hidden = false,
|
||||
}: IProps & PropsWithChildren) {
|
||||
onOpenChange,
|
||||
open,
|
||||
defaultOpen,
|
||||
}: IProps & DialogProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (hidden) {
|
||||
@ -32,7 +35,11 @@ export function ConfirmDeleteDialog({
|
||||
}
|
||||
|
||||
return (
|
||||
<AlertDialog>
|
||||
<AlertDialog
|
||||
onOpenChange={onOpenChange}
|
||||
open={open}
|
||||
defaultOpen={defaultOpen}
|
||||
>
|
||||
<AlertDialogTrigger asChild>{children}</AlertDialogTrigger>
|
||||
<AlertDialogContent
|
||||
onSelect={(e) => e.preventDefault()}
|
||||
|
||||
@ -2,7 +2,7 @@ import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { useFetchAgentList } from '@/hooks/use-agent-request';
|
||||
import { buildSelectOptions } from '@/utils/component-util';
|
||||
import { ArrowUpRight } from 'lucide-react';
|
||||
import { useMemo } from 'react';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { SelectWithSearch } from '../originui/select-with-search';
|
||||
import {
|
||||
@ -13,15 +13,21 @@ import {
|
||||
FormMessage,
|
||||
} from '../ui/form';
|
||||
import { MultiSelect } from '../ui/multi-select';
|
||||
export interface IDataPipelineSelectNode {
|
||||
id?: string;
|
||||
name?: string;
|
||||
avatar?: string;
|
||||
}
|
||||
|
||||
interface IProps {
|
||||
toDataPipeline?: () => void;
|
||||
formFieldName: string;
|
||||
isMult?: boolean;
|
||||
setDataList?: (data: IDataPipelineSelectNode[]) => void;
|
||||
}
|
||||
|
||||
export function DataFlowSelect(props: IProps) {
|
||||
const { toDataPipeline, formFieldName, isMult = true } = props;
|
||||
const { toDataPipeline, formFieldName, isMult = false, setDataList } = props;
|
||||
const { t } = useTranslate('knowledgeConfiguration');
|
||||
const form = useFormContext();
|
||||
const toDataPipLine = () => {
|
||||
@ -36,8 +42,26 @@ export function DataFlowSelect(props: IProps) {
|
||||
'id',
|
||||
'title',
|
||||
);
|
||||
|
||||
return option || [];
|
||||
}, [dataPipelineOptions]);
|
||||
|
||||
const nodes = useMemo(() => {
|
||||
return (
|
||||
dataPipelineOptions?.canvas?.map((item) => {
|
||||
return {
|
||||
id: item?.id,
|
||||
name: item?.title,
|
||||
avatar: item?.avatar,
|
||||
};
|
||||
}) || []
|
||||
);
|
||||
}, [dataPipelineOptions]);
|
||||
|
||||
useEffect(() => {
|
||||
setDataList?.(nodes);
|
||||
}, [nodes, setDataList]);
|
||||
|
||||
return (
|
||||
<FormField
|
||||
control={form.control}
|
||||
|
||||
@ -4,6 +4,7 @@ import { cn } from '@/lib/utils';
|
||||
import {
|
||||
GenerateLogButton,
|
||||
GenerateType,
|
||||
IGenerateLogButtonProps,
|
||||
} from '@/pages/dataset/dataset/generate-button/generate';
|
||||
import { upperFirst } from 'lodash';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
@ -51,10 +52,14 @@ export const showGraphRagItems = (parserId: DocumentParserType | undefined) => {
|
||||
type GraphRagItemsProps = {
|
||||
marginBottom?: boolean;
|
||||
className?: string;
|
||||
showGenerateItem?: boolean;
|
||||
data: IGenerateLogButtonProps;
|
||||
};
|
||||
|
||||
export function UseGraphRagFormField() {
|
||||
export function UseGraphRagFormField({
|
||||
data,
|
||||
}: {
|
||||
data: IGenerateLogButtonProps;
|
||||
}) {
|
||||
const form = useFormContext();
|
||||
const { t } = useTranslate('knowledgeConfiguration');
|
||||
|
||||
@ -73,10 +78,16 @@ export function UseGraphRagFormField() {
|
||||
</FormLabel>
|
||||
<div className="w-3/4">
|
||||
<FormControl>
|
||||
<Switch
|
||||
{/* <Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
></Switch>
|
||||
></Switch> */}
|
||||
<GenerateLogButton
|
||||
{...data}
|
||||
className="w-full text-text-secondary"
|
||||
status={1}
|
||||
type={GenerateType.KnowledgeGraph}
|
||||
/>
|
||||
</FormControl>
|
||||
</div>
|
||||
</div>
|
||||
@ -93,8 +104,8 @@ export function UseGraphRagFormField() {
|
||||
// The three types "table", "resume" and "one" do not display this configuration.
|
||||
const GraphRagItems = ({
|
||||
marginBottom = false,
|
||||
showGenerateItem = false,
|
||||
className = 'p-10',
|
||||
data,
|
||||
}: GraphRagItemsProps) => {
|
||||
const { t } = useTranslate('knowledgeConfiguration');
|
||||
const form = useFormContext();
|
||||
@ -120,7 +131,7 @@ const GraphRagItems = ({
|
||||
|
||||
return (
|
||||
<FormContainer className={cn({ 'mb-4': marginBottom }, className)}>
|
||||
<UseGraphRagFormField></UseGraphRagFormField>
|
||||
<UseGraphRagFormField data={data}></UseGraphRagFormField>
|
||||
{useRaptor && (
|
||||
<>
|
||||
<EntityTypesFormField name="parser_config.graphrag.entity_types"></EntityTypesFormField>
|
||||
@ -216,7 +227,7 @@ const GraphRagItems = ({
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
{showGenerateItem && (
|
||||
{/* {showGenerateItem && (
|
||||
<div className="w-full flex items-center">
|
||||
<div className="text-sm whitespace-nowrap w-1/4">
|
||||
{t('extractKnowledgeGraph')}
|
||||
@ -227,7 +238,7 @@ const GraphRagItems = ({
|
||||
type={GenerateType.KnowledgeGraph}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
)} */}
|
||||
</>
|
||||
)}
|
||||
</FormContainer>
|
||||
|
||||
@ -4,6 +4,7 @@ import { useTranslate } from '@/hooks/common-hooks';
|
||||
import {
|
||||
GenerateLogButton,
|
||||
GenerateType,
|
||||
IGenerateLogButtonProps,
|
||||
} from '@/pages/dataset/dataset/generate-button/generate';
|
||||
import random from 'lodash/random';
|
||||
import { Shuffle } from 'lucide-react';
|
||||
@ -18,7 +19,6 @@ import {
|
||||
FormMessage,
|
||||
} from '../ui/form';
|
||||
import { ExpandedInput } from '../ui/input';
|
||||
import { Switch } from '../ui/switch';
|
||||
import { Textarea } from '../ui/textarea';
|
||||
|
||||
export const excludedParseMethods = [
|
||||
@ -56,11 +56,7 @@ const Prompt = 'parser_config.raptor.prompt';
|
||||
|
||||
// The three types "table", "resume" and "one" do not display this configuration.
|
||||
|
||||
const RaptorFormFields = ({
|
||||
showGenerateItem = false,
|
||||
}: {
|
||||
showGenerateItem?: boolean;
|
||||
}) => {
|
||||
const RaptorFormFields = ({ data }: { data: IGenerateLogButtonProps }) => {
|
||||
const form = useFormContext();
|
||||
const { t } = useTranslate('knowledgeConfiguration');
|
||||
const useRaptor = useWatch({ name: UseRaptorField });
|
||||
@ -108,13 +104,12 @@ const RaptorFormFields = ({
|
||||
</FormLabel>
|
||||
<div className="w-3/4">
|
||||
<FormControl>
|
||||
<Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={(e) => {
|
||||
changeRaptor(e);
|
||||
field.onChange(e);
|
||||
}}
|
||||
></Switch>
|
||||
<GenerateLogButton
|
||||
{...data}
|
||||
className="w-full text-text-secondary"
|
||||
status={1}
|
||||
type={GenerateType.Raptor}
|
||||
/>
|
||||
</FormControl>
|
||||
</div>
|
||||
</div>
|
||||
@ -219,18 +214,6 @@ const RaptorFormFields = ({
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
{showGenerateItem && (
|
||||
<div className="w-full flex items-center">
|
||||
<div className="text-sm whitespace-nowrap w-1/4">
|
||||
{t('extractRaptor')}
|
||||
</div>
|
||||
<GenerateLogButton
|
||||
className="w-3/4 text-text-secondary"
|
||||
status={1}
|
||||
type={GenerateType.Raptor}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
|
||||
@ -14,6 +14,9 @@ export interface IKnowledge {
|
||||
name: string;
|
||||
parser_config: ParserConfig;
|
||||
parser_id: string;
|
||||
pipeline_id: string;
|
||||
pipeline_name: string;
|
||||
pipeline_avatar: string;
|
||||
permission: string;
|
||||
similarity_threshold: number;
|
||||
status: string;
|
||||
@ -26,6 +29,10 @@ export interface IKnowledge {
|
||||
nickname: string;
|
||||
operator_permission: number;
|
||||
size: number;
|
||||
raptor_task_finish_at?: string;
|
||||
raptor_task_id?: string;
|
||||
mindmap_task_finish_at?: string;
|
||||
mindmap_task_id?: string;
|
||||
}
|
||||
|
||||
export interface IKnowledgeResult {
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
.chunkText() {
|
||||
em {
|
||||
color: red;
|
||||
color: var(--accent-primary);
|
||||
font-style: normal;
|
||||
}
|
||||
table {
|
||||
|
||||
@ -1583,6 +1583,9 @@ This delimiter is used to split the input text into several text pieces echo of
|
||||
'Write your SQL query here. You can use variables, raw SQL, or mix both using variable syntax.',
|
||||
frameworkPrompts: 'Framework',
|
||||
release: 'Publish',
|
||||
createFromBlank: 'Create from Blank',
|
||||
createFromTemplate: 'Create from Template',
|
||||
importJsonFile: 'Import json file',
|
||||
},
|
||||
llmTools: {
|
||||
bad_calculator: {
|
||||
@ -1762,6 +1765,9 @@ Important structured information may include: names, dates, locations, events, k
|
||||
metadata: `Content: [INSERT CONTENT HERE]`,
|
||||
},
|
||||
},
|
||||
cancel: 'Cancel',
|
||||
swicthPromptMessage:
|
||||
'The prompt word will change. Please confirm whether to abandon the existing prompt word?',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@ -1494,6 +1494,9 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
'在此处编写您的 SQL 查询。您可以使用变量、原始 SQL,或使用变量语法混合使用两者。',
|
||||
frameworkPrompts: '框架',
|
||||
release: '发布',
|
||||
createFromBlank: '从空白创建',
|
||||
createFromTemplate: '从模板创建',
|
||||
importJsonFile: '导入 JSON 文件',
|
||||
},
|
||||
footer: {
|
||||
profile: 'All rights reserved @ React',
|
||||
@ -1680,6 +1683,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
metadata: `内容:[在此处插入内容]`,
|
||||
},
|
||||
},
|
||||
cancel: '取消',
|
||||
switchPromptMessage: '提示词将发生变化,请确认是否放弃已有提示词?',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@ -10,7 +10,6 @@ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext
|
||||
import {
|
||||
LexicalTypeaheadMenuPlugin,
|
||||
MenuOption,
|
||||
useBasicTypeaheadTriggerMatch,
|
||||
} from '@lexical/react/LexicalTypeaheadMenuPlugin';
|
||||
import {
|
||||
$createParagraphNode,
|
||||
@ -131,9 +130,23 @@ export default function VariablePickerMenuPlugin({
|
||||
baseOptions,
|
||||
}: VariablePickerMenuPluginProps): JSX.Element {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const checkForTriggerMatch = useBasicTypeaheadTriggerMatch('/', {
|
||||
minLength: 0,
|
||||
});
|
||||
|
||||
// const checkForTriggerMatch = useBasicTypeaheadTriggerMatch('/', {
|
||||
// minLength: 0,
|
||||
// });
|
||||
|
||||
const testTriggerFn = React.useCallback((text: string) => {
|
||||
const lastChar = text.slice(-1);
|
||||
if (lastChar === '/') {
|
||||
console.log('Found trigger character "/"');
|
||||
return {
|
||||
leadOffset: text.length - 1,
|
||||
matchingString: '',
|
||||
replaceableString: '/',
|
||||
};
|
||||
}
|
||||
return null;
|
||||
}, []);
|
||||
|
||||
const previousValue = useRef<string | undefined>();
|
||||
|
||||
@ -291,6 +304,21 @@ export default function VariablePickerMenuPlugin({
|
||||
}
|
||||
}, [parseTextToVariableNodes, editor, value]);
|
||||
|
||||
// Fixed the issue where the cursor would go to the end when changing its own data
|
||||
useEffect(() => {
|
||||
return editor.registerUpdateListener(({ editorState, tags }) => {
|
||||
// If we trigger the programmatic update ourselves, we should not write back to avoid an infinite loop.
|
||||
if (tags.has(ProgrammaticTag)) return;
|
||||
|
||||
editorState.read(() => {
|
||||
const text = $getRoot().getTextContent();
|
||||
if (text !== previousValue.current) {
|
||||
previousValue.current = text;
|
||||
}
|
||||
});
|
||||
});
|
||||
}, [editor]);
|
||||
|
||||
return (
|
||||
<LexicalTypeaheadMenuPlugin<VariableOption | VariableInnerOption>
|
||||
onQueryChange={setQueryString}
|
||||
@ -301,7 +329,7 @@ export default function VariablePickerMenuPlugin({
|
||||
closeMenu,
|
||||
)
|
||||
}
|
||||
triggerFn={checkForTriggerMatch}
|
||||
triggerFn={testTriggerFn}
|
||||
options={buildNextOptions()}
|
||||
menuRenderFn={(anchorElementRef, { selectOptionAndCleanUp }) => {
|
||||
const nextOptions = buildNextOptions();
|
||||
|
||||
@ -2,7 +2,11 @@ import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { EmptyDsl, useSetAgent } from '@/hooks/use-agent-request';
|
||||
import { DSL } from '@/interfaces/database/agent';
|
||||
import { AgentCategory } from '@/pages/agent/constant';
|
||||
import { BeginId, Operator } from '@/pages/data-flow/constant';
|
||||
import {
|
||||
BeginId,
|
||||
Operator,
|
||||
initialParserValues,
|
||||
} from '@/pages/data-flow/constant';
|
||||
import { useCallback } from 'react';
|
||||
import { FlowType } from '../constant';
|
||||
import { FormSchemaType } from '../create-agent-form';
|
||||
@ -24,8 +28,37 @@ export const DataflowEmptyDsl = {
|
||||
sourcePosition: 'left',
|
||||
targetPosition: 'right',
|
||||
},
|
||||
{
|
||||
data: {
|
||||
form: initialParserValues,
|
||||
label: 'Parser',
|
||||
name: 'Parser_0',
|
||||
},
|
||||
dragging: false,
|
||||
id: 'Parser:HipSignsRhyme',
|
||||
measured: {
|
||||
height: 57,
|
||||
width: 200,
|
||||
},
|
||||
position: {
|
||||
x: 316.99524094206413,
|
||||
y: 195.39629819663406,
|
||||
},
|
||||
selected: true,
|
||||
sourcePosition: 'right',
|
||||
targetPosition: 'left',
|
||||
type: 'parserNode',
|
||||
},
|
||||
],
|
||||
edges: [
|
||||
{
|
||||
id: 'xy-edge__Filestart-Parser:HipSignsRhymeend',
|
||||
source: BeginId,
|
||||
sourceHandle: 'start',
|
||||
target: 'Parser:HipSignsRhyme',
|
||||
targetHandle: 'end',
|
||||
},
|
||||
],
|
||||
edges: [],
|
||||
},
|
||||
components: {
|
||||
[Operator.Begin]: {
|
||||
|
||||
@ -79,21 +79,21 @@ export default function Agents() {
|
||||
onClick={showCreatingModal}
|
||||
>
|
||||
<Clipboard />
|
||||
Create from Blank
|
||||
{t('flow.createFromBlank')}
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
justifyBetween={false}
|
||||
onClick={navigateToAgentTemplates}
|
||||
>
|
||||
<ClipboardPlus />
|
||||
Create from Template
|
||||
{t('flow.createFromTemplate')}
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
justifyBetween={false}
|
||||
onClick={handleImportJson}
|
||||
>
|
||||
<FileInput />
|
||||
Import json file
|
||||
{t('flow.importJsonFile')}
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
|
||||
@ -268,11 +268,6 @@ export const initialParserValues = {
|
||||
fileFormat: FileType.PowerPoint,
|
||||
output_format: PptOutputFormat.Json,
|
||||
},
|
||||
{
|
||||
fileFormat: FileType.Audio,
|
||||
llm_id: '',
|
||||
output_format: AudioOutputFormat.Text,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { ConfirmDeleteDialog } from '@/components/confirm-delete-dialog';
|
||||
import { LargeModelFormField } from '@/components/large-model-form-field';
|
||||
import { LlmSettingSchema } from '@/components/llm-setting-items/next';
|
||||
import { SelectWithSearch } from '@/components/originui/select-with-search';
|
||||
@ -6,7 +7,7 @@ import { Form } from '@/components/ui/form';
|
||||
import { PromptEditor } from '@/pages/agent/form/components/prompt-editor';
|
||||
import { buildOptions } from '@/utils/form';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
@ -19,6 +20,7 @@ import { useFormValues } from '../../hooks/use-form-values';
|
||||
import { useWatchFormChange } from '../../hooks/use-watch-form-change';
|
||||
import { INextOperatorForm } from '../../interface';
|
||||
import { FormWrapper } from '../components/form-wrapper';
|
||||
import { useSwitchPrompt } from './use-switch-prompt';
|
||||
|
||||
export const FormSchema = z.object({
|
||||
field_name: z.string(),
|
||||
@ -43,25 +45,13 @@ const ExtractorForm = ({ node }: INextOperatorForm) => {
|
||||
|
||||
const options = buildOptions(ContextGeneratorFieldName, t, 'dataflow');
|
||||
|
||||
const setPromptValue = useCallback(
|
||||
(field: keyof ExtractorFormSchemaType, key: string, value: string) => {
|
||||
form.setValue(field, t(`dataflow.prompts.${key}.${value}`), {
|
||||
shouldDirty: true,
|
||||
shouldValidate: true,
|
||||
});
|
||||
},
|
||||
[form, t],
|
||||
);
|
||||
|
||||
const handleFieldNameChange = useCallback(
|
||||
(value: string) => {
|
||||
if (value) {
|
||||
setPromptValue('sys_prompt', 'system', value);
|
||||
setPromptValue('prompts', 'user', value);
|
||||
}
|
||||
},
|
||||
[setPromptValue],
|
||||
);
|
||||
const {
|
||||
handleFieldNameChange,
|
||||
confirmSwitch,
|
||||
hideModal,
|
||||
visible,
|
||||
cancelSwitch,
|
||||
} = useSwitchPrompt(form);
|
||||
|
||||
useWatchFormChange(node?.id, form);
|
||||
|
||||
@ -96,6 +86,15 @@ const ExtractorForm = ({ node }: INextOperatorForm) => {
|
||||
></PromptEditor>
|
||||
</RAGFlowFormItem>
|
||||
</FormWrapper>
|
||||
{visible && (
|
||||
<ConfirmDeleteDialog
|
||||
title={t('dataflow.switchPromptMessage')}
|
||||
open
|
||||
onOpenChange={hideModal}
|
||||
onOk={confirmSwitch}
|
||||
onCancel={cancelSwitch}
|
||||
></ConfirmDeleteDialog>
|
||||
)}
|
||||
</Form>
|
||||
);
|
||||
};
|
||||
|
||||
@ -0,0 +1,69 @@
|
||||
import { LlmSettingSchema } from '@/components/llm-setting-items/next';
|
||||
import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { useCallback, useRef } from 'react';
|
||||
import { UseFormReturn } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
|
||||
export const FormSchema = z.object({
|
||||
field_name: z.string(),
|
||||
sys_prompt: z.string(),
|
||||
prompts: z.string().optional(),
|
||||
...LlmSettingSchema,
|
||||
});
|
||||
|
||||
export type ExtractorFormSchemaType = z.infer<typeof FormSchema>;
|
||||
|
||||
export function useSwitchPrompt(form: UseFormReturn<ExtractorFormSchemaType>) {
|
||||
const { visible, showModal, hideModal } = useSetModalState();
|
||||
const { t } = useTranslation();
|
||||
const previousFieldNames = useRef<string[]>([form.getValues('field_name')]);
|
||||
|
||||
const setPromptValue = useCallback(
|
||||
(field: keyof ExtractorFormSchemaType, key: string, value: string) => {
|
||||
form.setValue(field, t(`dataflow.prompts.${key}.${value}`), {
|
||||
shouldDirty: true,
|
||||
shouldValidate: true,
|
||||
});
|
||||
},
|
||||
[form, t],
|
||||
);
|
||||
|
||||
const handleFieldNameChange = useCallback(
|
||||
(value: string) => {
|
||||
if (value) {
|
||||
const names = previousFieldNames.current;
|
||||
if (names.length > 1) {
|
||||
names.shift();
|
||||
}
|
||||
names.push(value);
|
||||
showModal();
|
||||
}
|
||||
},
|
||||
[showModal],
|
||||
);
|
||||
|
||||
const confirmSwitch = useCallback(() => {
|
||||
const value = form.getValues('field_name');
|
||||
setPromptValue('sys_prompt', 'system', value);
|
||||
setPromptValue('prompts', 'user', value);
|
||||
}, [form, setPromptValue]);
|
||||
|
||||
const cancelSwitch = useCallback(() => {
|
||||
const previousValue = previousFieldNames.current.at(-2);
|
||||
if (previousValue) {
|
||||
form.setValue('field_name', previousValue, {
|
||||
shouldDirty: true,
|
||||
shouldValidate: true,
|
||||
});
|
||||
}
|
||||
}, [form]);
|
||||
|
||||
return {
|
||||
handleFieldNameChange,
|
||||
confirmSwitch,
|
||||
hideModal,
|
||||
visible,
|
||||
cancelSwitch,
|
||||
};
|
||||
}
|
||||
@ -4,11 +4,9 @@ import { useCallback } from 'react';
|
||||
export function useCancelCurrentDataflow({
|
||||
messageId,
|
||||
setMessageId,
|
||||
hideLogSheet,
|
||||
}: {
|
||||
messageId: string;
|
||||
setMessageId: (messageId: string) => void;
|
||||
hideLogSheet(): void;
|
||||
}) {
|
||||
const { cancelDataflow } = useCancelDataflow();
|
||||
|
||||
@ -16,9 +14,8 @@ export function useCancelCurrentDataflow({
|
||||
const code = await cancelDataflow(messageId);
|
||||
if (code === 0) {
|
||||
setMessageId('');
|
||||
hideLogSheet();
|
||||
}
|
||||
}, [cancelDataflow, hideLogSheet, messageId, setMessageId]);
|
||||
}, [cancelDataflow, messageId, setMessageId]);
|
||||
|
||||
return { handleCancel };
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import message from '@/components/ui/message';
|
||||
import { useSendMessageBySSE } from '@/hooks/use-send-message';
|
||||
import api from '@/utils/api';
|
||||
import { get } from 'lodash';
|
||||
@ -38,6 +39,8 @@ export function useRunDataflow(
|
||||
}
|
||||
|
||||
return msgId;
|
||||
} else {
|
||||
message.error(get(res, 'data.message', ''));
|
||||
}
|
||||
},
|
||||
[hideRunOrChatDrawer, id, saveGraph, send, setMessageId],
|
||||
|
||||
@ -4,10 +4,6 @@ import useGraphStore from '../store';
|
||||
|
||||
export function useWatchFormChange(id?: string, form?: UseFormReturn<any>) {
|
||||
let values = useWatch({ control: form?.control });
|
||||
console.log(
|
||||
'🚀 ~ useWatchFormChange ~ values:',
|
||||
JSON.stringify(values, null, 2),
|
||||
);
|
||||
|
||||
const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
|
||||
|
||||
|
||||
@ -68,7 +68,8 @@ export default function DataFlow() {
|
||||
const { handleExportJson } = useHandleExportOrImportJsonFile();
|
||||
const { saveGraph, loading } = useSaveGraph();
|
||||
const { flowDetail: agentDetail } = useFetchDataOnMount();
|
||||
const { handleRun } = useSaveGraphBeforeOpeningDebugDrawer(showChatDrawer);
|
||||
const { handleRun, loading: running } =
|
||||
useSaveGraphBeforeOpeningDebugDrawer(showChatDrawer);
|
||||
|
||||
const {
|
||||
visible: versionDialogVisible,
|
||||
@ -102,7 +103,6 @@ export default function DataFlow() {
|
||||
const { handleCancel } = useCancelCurrentDataflow({
|
||||
messageId,
|
||||
setMessageId,
|
||||
hideLogSheet,
|
||||
});
|
||||
|
||||
const time = useWatchAgentChange(chatDrawerVisible);
|
||||
@ -136,14 +136,18 @@ export default function DataFlow() {
|
||||
>
|
||||
<LaptopMinimalCheck /> {t('flow.save')}
|
||||
</ButtonLoading>
|
||||
<Button
|
||||
<ButtonLoading
|
||||
variant={'secondary'}
|
||||
onClick={handleRunAgent}
|
||||
disabled={isParsing}
|
||||
loading={running}
|
||||
>
|
||||
<CirclePlay className={isParsing ? 'animate-spin' : ''} />
|
||||
{isParsing ? t('dataflow.running') : t('flow.run')}
|
||||
</Button>
|
||||
{running || (
|
||||
<CirclePlay className={isParsing ? 'animate-spin' : ''} />
|
||||
)}
|
||||
|
||||
{isParsing || running ? t('dataflow.running') : t('flow.run')}
|
||||
</ButtonLoading>
|
||||
<Button variant={'secondary'} onClick={showVersionDialog}>
|
||||
<History />
|
||||
{t('flow.historyversion')}
|
||||
|
||||
@ -57,16 +57,16 @@ export function LogSheet({
|
||||
</section>
|
||||
{isParsing ? (
|
||||
<Button
|
||||
className="w-full mt-8 bg-state-error/10 text-state-error"
|
||||
className="w-full mt-8 bg-state-error/10 text-state-error hover:bg-state-error hover:text-bg-base"
|
||||
onClick={handleCancel}
|
||||
>
|
||||
<CirclePause /> Cancel
|
||||
<CirclePause /> {t('dataflow.cancel')}
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
onClick={handleDownloadJson}
|
||||
disabled={isEndOutputEmpty(logs)}
|
||||
className="w-full mt-8"
|
||||
className="w-full mt-8 bg-accent-primary-5 text-text-secondary hover:bg-accent-primary-5 hover:text-accent-primary hover:border-accent-primary hover:border"
|
||||
>
|
||||
<SquareArrowOutUpRight />
|
||||
{t('dataflow.exportJson')}
|
||||
|
||||
@ -14,6 +14,7 @@ import {
|
||||
NodeHandleId,
|
||||
Operator,
|
||||
} from './constant';
|
||||
import { ExtractorFormSchemaType } from './form/extractor-form';
|
||||
import { HierarchicalMergerFormSchemaType } from './form/hierarchical-merger-form';
|
||||
import { ParserFormSchemaType } from './form/parser-form';
|
||||
import { SplitterFormSchemaType } from './form/splitter-form';
|
||||
@ -143,6 +144,10 @@ function transformHierarchicalMergerParams(
|
||||
return { ...params, hierarchy: Number(params.hierarchy), levels };
|
||||
}
|
||||
|
||||
function transformExtractorParams(params: ExtractorFormSchemaType) {
|
||||
return { ...params, prompts: [{ content: params.prompts, role: 'user' }] };
|
||||
}
|
||||
|
||||
// construct a dsl based on the node information of the graph
|
||||
export const buildDslComponentsByGraph = (
|
||||
nodes: RAGFlowNodeType[],
|
||||
@ -174,6 +179,9 @@ export const buildDslComponentsByGraph = (
|
||||
case Operator.HierarchicalMerger:
|
||||
params = transformHierarchicalMergerParams(params);
|
||||
break;
|
||||
case Operator.Extractor:
|
||||
params = transformExtractorParams(params);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
|
||||
@ -115,7 +115,7 @@ const FormatPreserveEditor = ({
|
||||
) : (
|
||||
<>
|
||||
{content.key === 'json' && */}
|
||||
{content.value.map((item, index) => (
|
||||
{content.value?.map((item, index) => (
|
||||
<section
|
||||
key={index}
|
||||
className={
|
||||
|
||||
@ -45,11 +45,17 @@ const useFetchFileLogList = () => {
|
||||
queryKey: [
|
||||
'fileLogList',
|
||||
knowledgeBaseId,
|
||||
pagination.current,
|
||||
pagination.pageSize,
|
||||
pagination,
|
||||
searchString,
|
||||
active,
|
||||
],
|
||||
placeholderData: (previousData) => {
|
||||
if (previousData === undefined) {
|
||||
return { logs: [], total: 0 };
|
||||
}
|
||||
return previousData;
|
||||
},
|
||||
enabled: true,
|
||||
queryFn: async () => {
|
||||
const { data: res = {} } = await fetchFunc({
|
||||
kb_id: knowledgeBaseId,
|
||||
@ -73,6 +79,7 @@ const useFetchFileLogList = () => {
|
||||
searchString,
|
||||
handleInputChange: onInputChange,
|
||||
pagination: { ...pagination, total: data?.total },
|
||||
setPagination,
|
||||
active,
|
||||
setActive,
|
||||
};
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import SvgIcon from '@/components/svg-icon';
|
||||
import { useIsDarkTheme } from '@/components/theme-provider';
|
||||
import { useFetchDocumentList } from '@/hooks/use-document-request';
|
||||
import { parseColorToRGBA } from '@/utils/common-util';
|
||||
import { CircleQuestionMark } from 'lucide-react';
|
||||
import { FC, useEffect, useMemo, useState } from 'react';
|
||||
@ -90,6 +91,9 @@ const FileLogsPage: FC = () => {
|
||||
});
|
||||
|
||||
const { data: topData } = useFetchOverviewTital();
|
||||
const {
|
||||
pagination: { total: fileTotal },
|
||||
} = useFetchDocumentList();
|
||||
console.log('topData --> ', topData);
|
||||
useEffect(() => {
|
||||
setTopAllData((prev) => {
|
||||
@ -104,11 +108,24 @@ const FileLogsPage: FC = () => {
|
||||
});
|
||||
}, [topData]);
|
||||
|
||||
useEffect(() => {
|
||||
setTopAllData((prev) => {
|
||||
return {
|
||||
...prev,
|
||||
totalFiles: {
|
||||
value: fileTotal || 0,
|
||||
precent: 0,
|
||||
},
|
||||
};
|
||||
});
|
||||
}, [fileTotal]);
|
||||
|
||||
const {
|
||||
data: tableOriginData,
|
||||
searchString,
|
||||
handleInputChange,
|
||||
pagination,
|
||||
setPagination,
|
||||
active,
|
||||
setActive,
|
||||
} = useFetchFileLogList();
|
||||
@ -131,6 +148,11 @@ const FileLogsPage: FC = () => {
|
||||
};
|
||||
const handlePaginationChange = (page: number, pageSize: number) => {
|
||||
console.log('Pagination changed:', { page, pageSize });
|
||||
setPagination({
|
||||
...pagination,
|
||||
page,
|
||||
pageSize: pageSize,
|
||||
});
|
||||
};
|
||||
|
||||
const isDark = useIsDarkTheme();
|
||||
|
||||
@ -1,23 +1,26 @@
|
||||
import { IDataPipelineSelectNode } from '@/components/data-pipeline-select';
|
||||
import { IconFont } from '@/components/icon-font';
|
||||
import { RAGFlowAvatar } from '@/components/ragflow-avatar';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Modal } from '@/components/ui/modal/modal';
|
||||
import { omit } from 'lodash';
|
||||
import { Link, Settings2, Unlink } from 'lucide-react';
|
||||
import { useState } from 'react';
|
||||
import { Link } from 'lucide-react';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
import { linkPiplineFormSchema } from '../form-schema';
|
||||
import LinkDataPipelineModal from './link-data-pipline-modal';
|
||||
|
||||
interface DataPipelineItemProps {
|
||||
id: string;
|
||||
name: string;
|
||||
avatar?: string;
|
||||
export interface IDataPipelineNodeProps extends IDataPipelineSelectNode {
|
||||
isDefault?: boolean;
|
||||
linked?: boolean;
|
||||
}
|
||||
|
||||
export interface ILinkDataPipelineProps {
|
||||
data?: IDataPipelineNodeProps;
|
||||
handleLinkOrEditSubmit?: (data: IDataPipelineNodeProps | undefined) => void;
|
||||
}
|
||||
|
||||
interface DataPipelineItemProps extends IDataPipelineNodeProps {
|
||||
openLinkModalFunc?: (open: boolean, data?: IDataPipelineNodeProps) => void;
|
||||
}
|
||||
|
||||
const DataPipelineItem = (props: DataPipelineItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const { name, avatar, isDefault, linked, openLinkModalFunc } = props;
|
||||
@ -57,17 +60,17 @@ const DataPipelineItem = (props: DataPipelineItemProps) => {
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-1 px-2 rounded-lg border">
|
||||
<div className="flex items-center justify-between gap-1 px-2 rounded-md border">
|
||||
<div className="flex items-center gap-1">
|
||||
<RAGFlowAvatar avatar={avatar} name={name} className="size-4" />
|
||||
<div>{name}</div>
|
||||
{isDefault && (
|
||||
{/* {isDefault && (
|
||||
<div className="text-xs bg-text-secondary text-bg-base px-2 py-1 rounded-md">
|
||||
{t('knowledgeConfiguration.default')}
|
||||
</div>
|
||||
)}
|
||||
)} */}
|
||||
</div>
|
||||
<div className="flex gap-1 items-center">
|
||||
{/* <div className="flex gap-1 items-center">
|
||||
<Button
|
||||
variant={'transparent'}
|
||||
className="border-none"
|
||||
@ -94,50 +97,29 @@ const DataPipelineItem = (props: DataPipelineItemProps) => {
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div> */}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export interface IDataPipelineNodeProps {
|
||||
id: string;
|
||||
name: string;
|
||||
avatar?: string;
|
||||
isDefault?: boolean;
|
||||
linked?: boolean;
|
||||
}
|
||||
const LinkDataPipeline = () => {
|
||||
const LinkDataPipeline = (props: ILinkDataPipelineProps) => {
|
||||
const { data, handleLinkOrEditSubmit: submit } = props;
|
||||
const { t } = useTranslation();
|
||||
const [openLinkModal, setOpenLinkModal] = useState(false);
|
||||
const [currentDataPipeline, setCurrentDataPipeline] =
|
||||
useState<IDataPipelineNodeProps>();
|
||||
const testNode = [
|
||||
{
|
||||
id: '1',
|
||||
name: 'Data Pipeline 1',
|
||||
avatar: 'https://avatars.githubusercontent.com/u/10656201?v=4',
|
||||
isDefault: true,
|
||||
linked: true,
|
||||
},
|
||||
{
|
||||
id: '2',
|
||||
name: 'Data Pipeline 2',
|
||||
avatar: 'https://avatars.githubusercontent.com/u/10656201?v=4',
|
||||
linked: false,
|
||||
},
|
||||
{
|
||||
id: '3',
|
||||
name: 'Data Pipeline 3',
|
||||
avatar: 'https://avatars.githubusercontent.com/u/10656201?v=4',
|
||||
linked: false,
|
||||
},
|
||||
{
|
||||
id: '4',
|
||||
name: 'Data Pipeline 4',
|
||||
avatar: 'https://avatars.githubusercontent.com/u/10656201?v=4',
|
||||
linked: true,
|
||||
},
|
||||
];
|
||||
const pipelineNode: IDataPipelineNodeProps[] = useMemo(
|
||||
() => [
|
||||
{
|
||||
id: data?.id,
|
||||
name: data?.name,
|
||||
avatar: data?.avatar,
|
||||
isDefault: data?.isDefault,
|
||||
linked: true,
|
||||
},
|
||||
],
|
||||
[data],
|
||||
);
|
||||
const openLinkModalFunc = (open: boolean, data?: IDataPipelineNodeProps) => {
|
||||
console.log('open', open, data);
|
||||
setOpenLinkModal(open);
|
||||
@ -148,9 +130,11 @@ const LinkDataPipeline = () => {
|
||||
}
|
||||
};
|
||||
const handleLinkOrEditSubmit = (
|
||||
data: z.infer<typeof linkPiplineFormSchema>,
|
||||
data: IDataPipelineSelectNode | undefined,
|
||||
) => {
|
||||
console.log('handleLinkOrEditSubmit', data);
|
||||
submit?.(data);
|
||||
setOpenLinkModal(false);
|
||||
};
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
@ -178,13 +162,20 @@ const LinkDataPipeline = () => {
|
||||
</div>
|
||||
</section>
|
||||
<section className="flex flex-col gap-2">
|
||||
{testNode.map((item) => (
|
||||
<DataPipelineItem
|
||||
key={item.name}
|
||||
openLinkModalFunc={openLinkModalFunc}
|
||||
{...item}
|
||||
/>
|
||||
))}
|
||||
{pipelineNode.map(
|
||||
(item) =>
|
||||
item.id && (
|
||||
<DataPipelineItem
|
||||
key={item.id}
|
||||
openLinkModalFunc={openLinkModalFunc}
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
avatar={item.avatar}
|
||||
isDefault={item.isDefault}
|
||||
linked={item.linked}
|
||||
/>
|
||||
),
|
||||
)}
|
||||
</section>
|
||||
<LinkDataPipelineModal
|
||||
data={currentDataPipeline}
|
||||
|
||||
@ -1,19 +1,14 @@
|
||||
import { DataFlowSelect } from '@/components/data-pipeline-select';
|
||||
import Input from '@/components/originui/input';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from '@/components/ui/form';
|
||||
DataFlowSelect,
|
||||
IDataPipelineSelectNode,
|
||||
} from '@/components/data-pipeline-select';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Form } from '@/components/ui/form';
|
||||
import { Modal } from '@/components/ui/modal/modal';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { t } from 'i18next';
|
||||
import { useState } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { z } from 'zod';
|
||||
import { pipelineFormSchema } from '../form-schema';
|
||||
@ -28,13 +23,14 @@ const LinkDataPipelineModal = ({
|
||||
data: IDataPipelineNodeProps | undefined;
|
||||
open: boolean;
|
||||
setOpen: (open: boolean) => void;
|
||||
onSubmit?: (data: any) => void;
|
||||
onSubmit?: (pipeline: IDataPipelineSelectNode | undefined) => void;
|
||||
}) => {
|
||||
const isEdit = !!data;
|
||||
const [list, setList] = useState<IDataPipelineSelectNode[]>();
|
||||
const form = useForm<z.infer<typeof pipelineFormSchema>>({
|
||||
resolver: zodResolver(pipelineFormSchema),
|
||||
defaultValues: {
|
||||
data_flow: [],
|
||||
pipeline_id: '',
|
||||
set_default: false,
|
||||
file_filter: '',
|
||||
},
|
||||
@ -43,11 +39,12 @@ const LinkDataPipelineModal = ({
|
||||
const { navigateToAgents } = useNavigatePage();
|
||||
const handleFormSubmit = (values: any) => {
|
||||
console.log(values, data);
|
||||
const param = {
|
||||
...data,
|
||||
...values,
|
||||
};
|
||||
onSubmit?.(param);
|
||||
// const param = {
|
||||
// ...data,
|
||||
// ...values,
|
||||
// };
|
||||
const pipeline = list?.find((item) => item.id === values.pipeline_id);
|
||||
onSubmit?.(pipeline);
|
||||
};
|
||||
return (
|
||||
<Modal
|
||||
@ -67,10 +64,11 @@ const LinkDataPipelineModal = ({
|
||||
{!isEdit && (
|
||||
<DataFlowSelect
|
||||
toDataPipeline={navigateToAgents}
|
||||
formFieldName="data_flow"
|
||||
formFieldName="pipeline_id"
|
||||
setDataList={setList}
|
||||
/>
|
||||
)}
|
||||
<FormField
|
||||
{/* <FormField
|
||||
control={form.control}
|
||||
name={'file_filter'}
|
||||
render={({ field }) => (
|
||||
@ -135,7 +133,7 @@ const LinkDataPipelineModal = ({
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
)} */}
|
||||
<div className="flex justify-end gap-1">
|
||||
<Button
|
||||
type="button"
|
||||
|
||||
@ -11,6 +11,9 @@ export const formSchema = z.object({
|
||||
avatar: z.any().nullish(),
|
||||
permission: z.string().optional(),
|
||||
parser_id: z.string(),
|
||||
pipeline_id: z.string().optional(),
|
||||
pipeline_name: z.string().optional(),
|
||||
pipeline_avatar: z.string().optional(),
|
||||
embd_id: z.string(),
|
||||
parser_config: z
|
||||
.object({
|
||||
@ -73,16 +76,16 @@ export const formSchema = z.object({
|
||||
});
|
||||
|
||||
export const pipelineFormSchema = z.object({
|
||||
data_flow: z.array(z.string()).optional(),
|
||||
pipeline_id: z.string().optional(),
|
||||
set_default: z.boolean().optional(),
|
||||
file_filter: z.string().optional(),
|
||||
});
|
||||
|
||||
export const linkPiplineFormSchema = pipelineFormSchema.pick({
|
||||
data_flow: true,
|
||||
file_filter: true,
|
||||
});
|
||||
export const editPiplineFormSchema = pipelineFormSchema.pick({
|
||||
set_default: true,
|
||||
file_filter: true,
|
||||
});
|
||||
// export const linkPiplineFormSchema = pipelineFormSchema.pick({
|
||||
// pipeline_id: true,
|
||||
// file_filter: true,
|
||||
// });
|
||||
// export const editPiplineFormSchema = pipelineFormSchema.pick({
|
||||
// set_default: true,
|
||||
// file_filter: true,
|
||||
// });
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { AvatarUpload } from '@/components/avatar-upload';
|
||||
import PageRankFormField from '@/components/page-rank-form-field';
|
||||
import {
|
||||
FormControl,
|
||||
FormField,
|
||||
@ -9,6 +10,7 @@ import {
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { TagItems } from './components/tag-item';
|
||||
import { EmbeddingModelItem } from './configuration/common-item';
|
||||
import { PermissionFormField } from './permission-form-field';
|
||||
|
||||
@ -87,6 +89,9 @@ export function GeneralForm() {
|
||||
/>
|
||||
<PermissionFormField></PermissionFormField>
|
||||
<EmbeddingModelItem></EmbeddingModelItem>
|
||||
<PageRankFormField></PageRankFormField>
|
||||
|
||||
<TagItems></TagItems>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@ -41,6 +41,16 @@ export const useFetchKnowledgeConfigurationOnMount = (
|
||||
const parser_config = {
|
||||
...form.formState?.defaultValues?.parser_config,
|
||||
...knowledgeDetails.parser_config,
|
||||
raptor: {
|
||||
...form.formState?.defaultValues?.parser_config?.raptor,
|
||||
...knowledgeDetails.parser_config?.raptor,
|
||||
use_raptor: true,
|
||||
},
|
||||
graphrag: {
|
||||
...form.formState?.defaultValues?.parser_config?.graphrag,
|
||||
...knowledgeDetails.parser_config?.graphrag,
|
||||
use_graphrag: true,
|
||||
},
|
||||
};
|
||||
const formValues = {
|
||||
...pick({ ...knowledgeDetails, parser_config: parser_config }, [
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import { IDataPipelineSelectNode } from '@/components/data-pipeline-select';
|
||||
import GraphRagItems from '@/components/parse-configuration/graph-rag-form-fields';
|
||||
import RaptorFormFields from '@/components/parse-configuration/raptor-form-fields';
|
||||
import { Button } from '@/components/ui/button';
|
||||
@ -6,11 +7,15 @@ import { Form } from '@/components/ui/form';
|
||||
import { DocumentParserType } from '@/constants/knowledge';
|
||||
import { PermissionRole } from '@/constants/permission';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
import { TopTitle } from '../dataset-title';
|
||||
import LinkDataPipeline from './components/link-data-pipeline';
|
||||
import { IGenerateLogButtonProps } from '../dataset/generate-button/generate';
|
||||
import LinkDataPipeline, {
|
||||
IDataPipelineNodeProps,
|
||||
} from './components/link-data-pipeline';
|
||||
import { MainContainer } from './configuration-form-container';
|
||||
import { formSchema } from './form-schema';
|
||||
import { GeneralForm } from './general-form';
|
||||
@ -51,24 +56,70 @@ export default function DatasetSettings() {
|
||||
html4excel: false,
|
||||
topn_tags: 3,
|
||||
raptor: {
|
||||
use_raptor: false,
|
||||
use_raptor: true,
|
||||
max_token: 256,
|
||||
threshold: 0.1,
|
||||
max_cluster: 64,
|
||||
random_seed: 0,
|
||||
prompt: t('knowledgeConfiguration.promptText'),
|
||||
},
|
||||
graphrag: {
|
||||
use_graphrag: false,
|
||||
use_graphrag: true,
|
||||
entity_types: initialEntityTypes,
|
||||
method: MethodValue.Light,
|
||||
},
|
||||
},
|
||||
pipeline_id: '',
|
||||
pagerank: 0,
|
||||
},
|
||||
});
|
||||
|
||||
useFetchKnowledgeConfigurationOnMount(form);
|
||||
const knowledgeDetails = useFetchKnowledgeConfigurationOnMount(form);
|
||||
|
||||
const [pipelineData, setPipelineData] = useState<IDataPipelineNodeProps>();
|
||||
const [graphRagGenerateData, setGraphRagGenerateData] =
|
||||
useState<IGenerateLogButtonProps>();
|
||||
const [raptorGenerateData, setRaptorGenerateData] =
|
||||
useState<IGenerateLogButtonProps>();
|
||||
useEffect(() => {
|
||||
console.log('🚀 ~ DatasetSettings ~ knowledgeDetails:', knowledgeDetails);
|
||||
if (knowledgeDetails) {
|
||||
const data: IDataPipelineNodeProps = {
|
||||
id: knowledgeDetails.pipeline_id,
|
||||
name: knowledgeDetails.pipeline_name,
|
||||
avatar: knowledgeDetails.pipeline_avatar,
|
||||
linked: true,
|
||||
};
|
||||
setPipelineData(data);
|
||||
setGraphRagGenerateData({
|
||||
finish_at: knowledgeDetails.mindmap_task_finish_at,
|
||||
task_id: knowledgeDetails.mindmap_task_id,
|
||||
} as IGenerateLogButtonProps);
|
||||
setRaptorGenerateData({
|
||||
finish_at: knowledgeDetails.raptor_task_finish_at,
|
||||
task_id: knowledgeDetails.raptor_task_id,
|
||||
} as IGenerateLogButtonProps);
|
||||
}
|
||||
}, [knowledgeDetails]);
|
||||
|
||||
async function onSubmit(data: z.infer<typeof formSchema>) {
|
||||
console.log('🚀 ~ DatasetSettings ~ data:', data);
|
||||
try {
|
||||
console.log('Form validation passed, submit data', data);
|
||||
} catch (error) {
|
||||
console.error('An error occurred during submission:', error);
|
||||
}
|
||||
}
|
||||
|
||||
const handleLinkOrEditSubmit = (
|
||||
data: IDataPipelineSelectNode | undefined,
|
||||
) => {
|
||||
console.log('🚀 ~ DatasetSettings ~ data:', data);
|
||||
if (data) {
|
||||
setPipelineData(data);
|
||||
form.setValue('pipeline_id', data.id || '');
|
||||
// form.setValue('pipeline_name', data.name || '');
|
||||
// form.setValue('pipeline_avatar', data.avatar || '');
|
||||
}
|
||||
};
|
||||
return (
|
||||
<section className="p-5 h-full flex flex-col">
|
||||
<TopTitle
|
||||
@ -88,12 +139,17 @@ export default function DatasetSettings() {
|
||||
|
||||
<GraphRagItems
|
||||
className="border-none p-0"
|
||||
showGenerateItem={true}
|
||||
data={graphRagGenerateData as IGenerateLogButtonProps}
|
||||
></GraphRagItems>
|
||||
<Divider />
|
||||
<RaptorFormFields showGenerateItem={true}></RaptorFormFields>
|
||||
<RaptorFormFields
|
||||
data={raptorGenerateData as IGenerateLogButtonProps}
|
||||
></RaptorFormFields>
|
||||
<Divider />
|
||||
<LinkDataPipeline />
|
||||
<LinkDataPipeline
|
||||
data={pipelineData}
|
||||
handleLinkOrEditSubmit={handleLinkOrEditSubmit}
|
||||
/>
|
||||
</MainContainer>
|
||||
</div>
|
||||
<div className="text-right items-center flex justify-end gap-3 w-[768px]">
|
||||
|
||||
@ -62,7 +62,7 @@ export function SavingButton() {
|
||||
if (beValid) {
|
||||
form.handleSubmit(async (values) => {
|
||||
console.log('saveKnowledgeConfiguration: ', values);
|
||||
delete values['avatar'];
|
||||
// delete values['avatar'];
|
||||
await saveKnowledgeConfiguration({
|
||||
kb_id,
|
||||
...values,
|
||||
|
||||
@ -9,6 +9,7 @@ import {
|
||||
import { Modal } from '@/components/ui/modal/modal';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { toFixed } from '@/utils/common-util';
|
||||
import { formatDate } from '@/utils/date';
|
||||
import { UseMutateAsyncFunction } from '@tanstack/react-query';
|
||||
import { t } from 'i18next';
|
||||
import { lowerFirst } from 'lodash';
|
||||
@ -29,7 +30,13 @@ export enum GenerateType {
|
||||
const MenuItem: React.FC<{
|
||||
name: GenerateType;
|
||||
data: ITraceInfo;
|
||||
pauseGenerate: () => void;
|
||||
pauseGenerate: ({
|
||||
task_id,
|
||||
type,
|
||||
}: {
|
||||
task_id: string;
|
||||
type: GenerateType;
|
||||
}) => void;
|
||||
runGenerate: UseMutateAsyncFunction<
|
||||
any,
|
||||
Error,
|
||||
@ -38,13 +45,12 @@ const MenuItem: React.FC<{
|
||||
},
|
||||
unknown
|
||||
>;
|
||||
}> = ({ name, runGenerate, data, pauseGenerate }) => {
|
||||
console.log(name, 'pppp', data);
|
||||
}> = ({ name: type, runGenerate, data, pauseGenerate }) => {
|
||||
const iconKeyMap = {
|
||||
KnowledgeGraph: 'knowledgegraph',
|
||||
Raptor: 'dataflow-01',
|
||||
};
|
||||
const type = useMemo(() => {
|
||||
const status = useMemo(() => {
|
||||
if (!data) {
|
||||
return generateStatus.start;
|
||||
}
|
||||
@ -60,9 +66,9 @@ const MenuItem: React.FC<{
|
||||
}, [data]);
|
||||
|
||||
const percent =
|
||||
type === generateStatus.failed
|
||||
status === generateStatus.failed
|
||||
? 100
|
||||
: type === generateStatus.running
|
||||
: status === generateStatus.running
|
||||
? data.progress * 100
|
||||
: 0;
|
||||
|
||||
@ -72,9 +78,9 @@ const MenuItem: React.FC<{
|
||||
'border cursor-pointer p-2 rounded-md focus:bg-transparent',
|
||||
{
|
||||
'hover:border-accent-primary hover:bg-[rgba(59,160,92,0.1)]':
|
||||
type === generateStatus.start,
|
||||
status === generateStatus.start,
|
||||
'hover:border-border hover:bg-[rgba(59,160,92,0)]':
|
||||
type !== generateStatus.start,
|
||||
status !== generateStatus.start,
|
||||
},
|
||||
)}
|
||||
onSelect={(e) => {
|
||||
@ -87,56 +93,65 @@ const MenuItem: React.FC<{
|
||||
<div
|
||||
className="flex items-start gap-2 flex-col w-full"
|
||||
onClick={() => {
|
||||
if (type === generateStatus.start) {
|
||||
runGenerate({ type: name });
|
||||
if (status === generateStatus.start) {
|
||||
runGenerate({ type });
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex justify-start text-text-primary items-center gap-2">
|
||||
<IconFontFill
|
||||
name={iconKeyMap[name]}
|
||||
name={iconKeyMap[type]}
|
||||
className="text-accent-primary"
|
||||
/>
|
||||
{t(`knowledgeDetails.${lowerFirst(name)}`)}
|
||||
{t(`knowledgeDetails.${lowerFirst(type)}`)}
|
||||
</div>
|
||||
{type === generateStatus.start && (
|
||||
{status === generateStatus.start && (
|
||||
<div className="text-text-secondary text-sm">
|
||||
{t(`knowledgeDetails.generate${name}`)}
|
||||
{t(`knowledgeDetails.generate${type}`)}
|
||||
</div>
|
||||
)}
|
||||
{(type === generateStatus.running ||
|
||||
type === generateStatus.failed) && (
|
||||
{(status === generateStatus.running ||
|
||||
status === generateStatus.failed) && (
|
||||
<div className="flex justify-between items-center w-full px-2.5 py-1">
|
||||
<div
|
||||
className={cn(' bg-border-button h-1 rounded-full', {
|
||||
'w-[calc(100%-100px)]': type === generateStatus.running,
|
||||
'w-[calc(100%-50px)]': type === generateStatus.failed,
|
||||
'w-[calc(100%-100px)]': status === generateStatus.running,
|
||||
'w-[calc(100%-50px)]': status === generateStatus.failed,
|
||||
})}
|
||||
>
|
||||
<div
|
||||
className={cn('h-1 rounded-full', {
|
||||
'bg-state-error': type === generateStatus.failed,
|
||||
'bg-accent-primary': type === generateStatus.running,
|
||||
'bg-state-error': status === generateStatus.failed,
|
||||
'bg-accent-primary': status === generateStatus.running,
|
||||
})}
|
||||
style={{ width: `${toFixed(percent)}%` }}
|
||||
></div>
|
||||
</div>
|
||||
{type === generateStatus.running && (
|
||||
{status === generateStatus.running && (
|
||||
<span>{(toFixed(percent) as string) + '%'}</span>
|
||||
)}
|
||||
<span
|
||||
className="text-state-error"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
pauseGenerate();
|
||||
}}
|
||||
>
|
||||
{type === generateStatus.failed ? (
|
||||
{status === generateStatus.failed && (
|
||||
<span
|
||||
className="text-state-error"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
runGenerate({ type });
|
||||
}}
|
||||
>
|
||||
<IconFontFill name="reparse" className="text-accent-primary" />
|
||||
) : (
|
||||
</span>
|
||||
)}
|
||||
{status !== generateStatus.failed && (
|
||||
<span
|
||||
className="text-state-error"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
pauseGenerate({ task_id: data.id, type });
|
||||
}}
|
||||
>
|
||||
<CirclePause />
|
||||
)}
|
||||
</span>
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<div className="w-full whitespace-pre-line text-wrap rounded-lg h-fit max-h-[350px] overflow-y-auto scrollbar-auto px-2.5 py-1">
|
||||
@ -202,7 +217,12 @@ const Generate: React.FC = () => {
|
||||
|
||||
export default Generate;
|
||||
|
||||
export type IGenerateLogProps = {
|
||||
export type IGenerateLogButtonProps = {
|
||||
finish_at: string;
|
||||
task_id: string;
|
||||
};
|
||||
|
||||
export type IGenerateLogProps = IGenerateLogButtonProps & {
|
||||
id?: string;
|
||||
status: 0 | 1;
|
||||
message?: string;
|
||||
@ -214,16 +234,7 @@ export type IGenerateLogProps = {
|
||||
};
|
||||
export const GenerateLogButton = (props: IGenerateLogProps) => {
|
||||
const { t } = useTranslation();
|
||||
const {
|
||||
id,
|
||||
status,
|
||||
message,
|
||||
created_at,
|
||||
updated_at,
|
||||
type,
|
||||
className,
|
||||
onDelete,
|
||||
} = props;
|
||||
const { task_id, message, finish_at, type, onDelete } = props;
|
||||
const handleDelete = () => {
|
||||
Modal.show({
|
||||
visible: true,
|
||||
@ -278,11 +289,11 @@ export const GenerateLogButton = (props: IGenerateLogProps) => {
|
||||
className={cn('flex bg-bg-card rounded-md py-1 px-3', props.className)}
|
||||
>
|
||||
<div className="flex items-center justify-between w-full">
|
||||
{status === 1 && (
|
||||
{finish_at && (
|
||||
<>
|
||||
<div>
|
||||
{message || t('knowledgeDetails.generatedOn')}
|
||||
{created_at}
|
||||
{formatDate(finish_at)}
|
||||
</div>
|
||||
<Trash2
|
||||
size={14}
|
||||
@ -295,7 +306,7 @@ export const GenerateLogButton = (props: IGenerateLogProps) => {
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{status === 0 && <div>{t('knowledgeDetails.notGenerated')}</div>}
|
||||
{!finish_at && <div>{t('knowledgeDetails.notGenerated')}</div>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import message from '@/components/ui/message';
|
||||
import agentService from '@/services/agent-service';
|
||||
import kbService from '@/services/knowledge-service';
|
||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useParams } from 'umi';
|
||||
import { GenerateType } from './generate';
|
||||
export const generateStatus = {
|
||||
@ -14,6 +15,7 @@ export const generateStatus = {
|
||||
|
||||
enum DatasetKey {
|
||||
generate = 'generate',
|
||||
pauseGenerate = 'pauseGenerate',
|
||||
}
|
||||
|
||||
export interface ITraceInfo {
|
||||
@ -126,9 +128,28 @@ export const useDatasetGenerate = () => {
|
||||
return data;
|
||||
},
|
||||
});
|
||||
const pauseGenerate = useCallback(() => {
|
||||
// TODO: pause generate
|
||||
console.log('pause generate');
|
||||
}, []);
|
||||
// const pauseGenerate = useCallback(() => {
|
||||
// // TODO: pause generate
|
||||
// console.log('pause generate');
|
||||
// }, []);
|
||||
const { mutateAsync: pauseGenerate } = useMutation({
|
||||
mutationKey: [DatasetKey.pauseGenerate],
|
||||
mutationFn: async ({
|
||||
task_id,
|
||||
type,
|
||||
}: {
|
||||
task_id: string;
|
||||
type: GenerateType;
|
||||
}) => {
|
||||
const { data } = await agentService.cancelDataflow(task_id);
|
||||
if (data.code === 0) {
|
||||
message.success(t('message.operated'));
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: [type],
|
||||
});
|
||||
}
|
||||
return data;
|
||||
},
|
||||
});
|
||||
return { runGenerate: mutateAsync, pauseGenerate, data, loading };
|
||||
};
|
||||
|
||||
@ -23,7 +23,9 @@ const IconMap = {
|
||||
[RunningStatus.UNSTART]: (
|
||||
<div className="w-0 h-0 border-l-[10px] border-l-accent-primary border-t-8 border-r-4 border-b-8 border-transparent"></div>
|
||||
),
|
||||
[RunningStatus.RUNNING]: <CircleX size={14} color="var(--state-error)" />,
|
||||
[RunningStatus.RUNNING]: (
|
||||
<CircleX size={14} color="rgba(var(--state-error))" />
|
||||
),
|
||||
[RunningStatus.CANCEL]: (
|
||||
<IconFontFill name="reparse" className="text-accent-primary" />
|
||||
),
|
||||
|
||||
@ -76,7 +76,10 @@ module.exports = {
|
||||
'border-default': 'var(--border-default)',
|
||||
'border-accent': 'var(--border-accent)',
|
||||
'border-button': 'var(--border-button)',
|
||||
'accent-primary': 'var(--accent-primary)',
|
||||
'accent-primary': {
|
||||
DEFAULT: 'rgb(var(--accent-primary) / <alpha-value>)',
|
||||
5: 'rgba(var(--accent-primary) / 0.05)', // 5%
|
||||
},
|
||||
'bg-accent': 'var(--bg-accent)',
|
||||
'state-success': 'var(--state-success)',
|
||||
'state-warning': 'var(--state-warning)',
|
||||
|
||||
@ -112,7 +112,7 @@
|
||||
--border-accent: #000000;
|
||||
--border-button: rgba(0, 0, 0, 0.1);
|
||||
/* Regulators, parsing, switches, variables */
|
||||
--accent-primary: #00beb4;
|
||||
--accent-primary: 0 190 180;
|
||||
/* Output Variables Box */
|
||||
--bg-accent: rgba(76, 164, 231, 0.05);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user