From 4c0a89f262c2ce4d5e23f8580b52895928cc250e Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Fri, 26 Sep 2025 19:45:01 +0800 Subject: [PATCH] Feat: add initial support for Mindmap (#10310) ### What problem does this PR solve? Add initial support for Mindmap. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu --- api/apps/kb_app.py | 95 ++++++++++++++++--- api/db/__init__.py | 3 +- api/db/db_models.py | 20 ++++ api/db/services/document_service.py | 24 ++--- api/db/services/knowledgebase_service.py | 6 ++ .../pipeline_operation_log_service.py | 48 +++++++--- api/db/services/task_service.py | 40 +++++--- rag/svr/task_executor.py | 9 +- 8 files changed, 182 insertions(+), 63 deletions(-) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 10d0a6a7b..be7c819ac 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -522,12 +522,13 @@ def run_graphrag(): return get_error_data_result(message="Invalid Knowledgebase ID") task_id = kb.graphrag_task_id - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}") + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid GraphRAG task id is expected for kb {kb_id}") - if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") documents, _ = DocumentService.get_by_kb_id( kb_id=kb_id, @@ -567,7 +568,7 @@ def trace_graphrag(): task_id = kb.graphrag_task_id if not task_id: - return get_error_data_result(message="GraphRAG Task ID Not Found") + return get_json_result(data={}) ok, task = TaskService.get_by_id(task_id) if not ok: @@ -590,12 +591,13 @@ def run_raptor(): return get_error_data_result(message="Invalid Knowledgebase ID") task_id = kb.raptor_task_id - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}") + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid RAPTOR task id is expected for kb {kb_id}") - if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") documents, _ = DocumentService.get_by_kb_id( kb_id=kb_id, @@ -635,10 +637,79 @@ def trace_raptor(): task_id = kb.raptor_task_id if not task_id: - return get_error_data_result(message="RAPTOR Task ID Not Found") + return get_json_result(data={}) ok, task = TaskService.get_by_id(task_id) if not ok: return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") return get_json_result(data=task.to_dict()) + + +@manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 +@login_required +def run_mindmap(): + req = request.json + + kb_id = req.get("kb_id", "") + if not kb_id: + return get_error_data_result(message='Lack of "KB ID"') + + ok, kb = KnowledgebaseService.get_by_id(kb_id) + if not ok: + return get_error_data_result(message="Invalid Knowledgebase ID") + + task_id = kb.mindmap_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid Mindmap task id is expected for kb {kb_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Mindmap Task is already running.") + + documents, _ = DocumentService.get_by_kb_id( + kb_id=kb_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return get_error_data_result(message=f"No documents in Knowledgebase {kb_id}") + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(doc=sample_document, ty="mindmap", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"mindmap_task_id": task_id}): + logging.warning(f"Cannot save mindmap_task_id for kb {kb_id}") + + return get_json_result(data={"mindmap_task_id": task_id}) + + +@manager.route("/trace_mindmap", methods=["GET"]) # noqa: F821 +@login_required +def trace_mindmap(): + kb_id = request.args.get("kb_id", "") + if not kb_id: + return get_error_data_result(message='Lack of "KB ID"') + + ok, kb = KnowledgebaseService.get_by_id(kb_id) + if not ok: + return get_error_data_result(message="Invalid Knowledgebase ID") + + task_id = kb.mindmap_task_id + if not task_id: + return get_json_result(data={}) + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_error_data_result(message="Mindmap Task Not Found or Error Occurred") + + return get_json_result(data=task.to_dict()) diff --git a/api/db/__init__.py b/api/db/__init__.py index 8f2806419..c93932db8 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -127,9 +127,10 @@ class PipelineTaskType(StrEnum): DOWNLOAD = "Download" RAPTOR = "RAPTOR" GRAPH_RAG = "GraphRAG" + MINDMAP = "Mindmap" -VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG} +VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP} KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" diff --git a/api/db/db_models.py b/api/db/db_models.py index b63e9c900..3d88c8b88 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -651,7 +651,11 @@ class Knowledgebase(DataBaseModel): pagerank = IntegerField(default=0, index=False) graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True) + graphrag_task_finish_at = DateTimeField(null=True) raptor_task_id = CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True) + raptor_task_finish_at = DateTimeField(null=True) + mindmap_task_id = CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True) + mindmap_task_finish_at = DateTimeField(null=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) @@ -1084,4 +1088,20 @@ def migrate_db(): migrate(migrator.add_column("knowledgebase", "raptor_task_id", CharField(max_length=32, null=True, help_text="RAPTOR task ID", index=True))) except Exception: pass + try: + migrate(migrator.add_column("knowledgebase", "graphrag_task_finish_at", DateTimeField(null=True))) + except Exception: + pass + try: + migrate(migrator.add_column("knowledgebase", "raptor_task_finish_at", CharField(null=True))) + except Exception: + pass + try: + migrate(migrator.add_column("knowledgebase", "mindmap_task_id", CharField(max_length=32, null=True, help_text="Mindmap task ID", index=True))) + except Exception: + pass + try: + migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True))) + except Exception: + pass logging.disable(logging.NOTSET) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index d28110799..cdbe48c39 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -636,8 +636,6 @@ class DocumentService(CommonService): prg = 0 finished = True bad = 0 - has_raptor = False - has_graphrag = False e, doc = DocumentService.get_by_id(d["id"]) status = doc.run # TaskStatus.RUNNING.value priority = 0 @@ -649,25 +647,14 @@ class DocumentService(CommonService): prg += t.progress if t.progress >= 0 else 0 if t.progress_msg.strip(): msg.append(t.progress_msg) - if t.task_type == "raptor": - has_raptor = True - elif t.task_type == "graphrag": - has_graphrag = True priority = max(priority, t.priority) prg /= len(tsks) if finished and bad: prg = -1 status = TaskStatus.FAIL.value elif finished: - if (d["parser_config"].get("raptor") or {}).get("use_raptor") and not has_raptor: - queue_raptor_o_graphrag_tasks(d, "raptor", priority) - prg = 0.98 * len(tsks) / (len(tsks) + 1) - elif (d["parser_config"].get("graphrag") or {}).get("use_graphrag") and not has_graphrag: - queue_raptor_o_graphrag_tasks(d, "graphrag", priority) - prg = 0.98 * len(tsks) / (len(tsks) + 1) - else: - prg = 1 - status = TaskStatus.DONE.value + prg = 1 + status = TaskStatus.DONE.value msg = "\n".join(sorted(msg)) info = { @@ -679,7 +666,7 @@ class DocumentService(CommonService): info["progress"] = prg if msg: info["progress_msg"] = msg - if msg.endswith("created task graphrag") or msg.endswith("created task raptor"): + if msg.endswith("created task graphrag") or msg.endswith("created task raptor") or msg.endswith("created task mindmap"): info["progress_msg"] += "\n%d tasks are ahead in the queue..."%get_queue_length(priority) else: info["progress_msg"] = "%d tasks are ahead in the queue..."%get_queue_length(priority) @@ -770,7 +757,8 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]) "from_page": 100000000, "to_page": 100000000, "task_type": ty, - "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty + "progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty, + "begin_at": datetime.now(), } task = new_task() @@ -780,7 +768,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty, priority, fake_doc_id="", doc_ids=[]) task["digest"] = hasher.hexdigest() bulk_insert_into_db(Task, [task], True) - if ty in ["graphrag", "raptor"]: + if ty in ["graphrag", "raptor", "mindmap"]: task["doc_ids"] = doc_ids DocumentService.begin2parse(doc["id"]) assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index d858265b9..f80dca04f 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -230,6 +230,12 @@ class KnowledgebaseService(CommonService): UserCanvas.avatar.alias("pipeline_avatar"), cls.model.parser_config, cls.model.pagerank, + cls.model.graphrag_task_id, + cls.model.graphrag_task_finish_at, + cls.model.raptor_task_id, + cls.model.raptor_task_finish_at, + cls.model.mindmap_task_id, + cls.model.mindmap_task_finish_at, cls.model.create_time, cls.model.update_time ] diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py index f69ec4288..df95d2517 100644 --- a/api/db/services/pipeline_operation_log_service.py +++ b/api/db/services/pipeline_operation_log_service.py @@ -15,12 +15,12 @@ # import json import logging -from datetime import datetime +from datetime import datetime, timedelta from peewee import fn -from api.db import VALID_PIPELINE_TASK_TYPES -from api.db.db_models import DB, PipelineOperationLog, Document +from api.db import VALID_PIPELINE_TASK_TYPES, PipelineTaskType +from api.db.db_models import DB, Document, PipelineOperationLog from api.db.services.canvas_service import UserCanvasService from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService @@ -120,6 +120,24 @@ class PipelineOperationLogService(CommonService): if task_type not in VALID_PIPELINE_TASK_TYPES: raise ValueError(f"Invalid task type: {task_type}") + if task_type in [PipelineTaskType.GRAPH_RAG, PipelineTaskType.RAPTOR, PipelineTaskType.MINDMAP]: + finish_at = document.process_begin_at + timedelta(seconds=document.process_duration) + if task_type == PipelineTaskType.GRAPH_RAG: + KnowledgebaseService.update_by_id( + document.kb_id, + {"graphrag_task_finish_at": finish_at}, + ) + elif task_type == PipelineTaskType.RAPTOR: + KnowledgebaseService.update_by_id( + document.kb_id, + {"raptor_task_finish_at": finish_at}, + ) + elif task_type == PipelineTaskType.MINDMAP: + KnowledgebaseService.update_by_id( + document.kb_id, + {"mindmap_task_finish_at": finish_at}, + ) + log = dict( id=get_uuid(), document_id=document_id, # GRAPH_RAPTOR_FAKE_DOC_ID or real document_id @@ -189,17 +207,18 @@ class PipelineOperationLogService(CommonService): @classmethod @DB.connection_context() def get_documents_info(cls, id): - fields = [ - Document.id, - Document.name, - Document.progress - ] - return cls.model.select(*fields).join(Document, on=(cls.model.document_id == Document.id)).where( - cls.model.id == id, - Document.progress > 0, - Document.progress < 1 - ).dicts() - + fields = [Document.id, Document.name, Document.progress] + return ( + cls.model.select(*fields) + .join(Document, on=(cls.model.document_id == Document.id)) + .where( + cls.model.id == id, + Document.progress > 0, + Document.progress < 1, + ) + .dicts() + ) + @classmethod @DB.connection_context() def get_dataset_logs_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, operation_status, create_date_from=None, create_date_to=None): @@ -223,4 +242,3 @@ class PipelineOperationLogService(CommonService): logs = logs.paginate(page_number, items_per_page) return list(logs.dicts()), count - diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 3005e14de..077970fa7 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -298,21 +298,23 @@ class TaskService(CommonService): ((prog == -1) | (prog > cls.model.progress)) ) ).execute() - return + else: + with DB.lock("update_progress", -1): + if info["progress_msg"]: + progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000) + cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() + if "progress" in info: + prog = info["progress"] + cls.model.update(progress=prog).where( + (cls.model.id == id) & + ( + (cls.model.progress != -1) & + ((prog == -1) | (prog > cls.model.progress)) + ) + ).execute() - with DB.lock("update_progress", -1): - if info["progress_msg"]: - progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 3000) - cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() - if "progress" in info: - prog = info["progress"] - cls.model.update(progress=prog).where( - (cls.model.id == id) & - ( - (cls.model.progress != -1) & - ((prog == -1) | (prog > cls.model.progress)) - ) - ).execute() + process_duration = (datetime.now() - task.begin_at).total_seconds() + cls.model.update(process_duration=process_duration).where(cls.model.id == id).execute() def queue_tasks(doc: dict, bucket: str, name: str, priority: int): @@ -336,7 +338,14 @@ def queue_tasks(doc: dict, bucket: str, name: str, priority: int): - Previous task chunks may be reused if available """ def new_task(): - return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0, "from_page": 0, "to_page": 100000000} + return { + "id": get_uuid(), + "doc_id": doc["id"], + "progress": 0.0, + "from_page": 0, + "to_page": 100000000, + "begin_at": datetime.now(), + } parse_task_array = [] @@ -487,6 +496,7 @@ def queue_dataflow(tenant_id:str, flow_id:str, task_id:str, doc_id:str=CANVAS_DE to_page=100000000, task_type="dataflow" if not rerun else "dataflow_rerun", priority=priority, + begin_at=datetime.now(), ) if doc_id not in [CANVAS_DEBUG_DOC_ID, GRAPH_RAPTOR_FAKE_DOC_ID]: TaskService.model.delete().where(TaskService.model.doc_id == doc_id).execute() diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 3b78e08e1..f9b6f33a3 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -93,6 +93,7 @@ TASK_TYPE_TO_PIPELINE_TASK_TYPE = { "dataflow" : PipelineTaskType.PARSE, "raptor": PipelineTaskType.RAPTOR, "graphrag": PipelineTaskType.GRAPH_RAG, + "mindmap": PipelineTaskType.MINDMAP, } UNACKED_ITERATOR = None @@ -227,7 +228,7 @@ async def collect(): canceled = False if msg.get("doc_id", "") in [GRAPH_RAPTOR_FAKE_DOC_ID, CANVAS_DEBUG_DOC_ID]: task = msg - if task["task_type"] in ["graphrag", "raptor"] and msg.get("doc_ids", []): + if task["task_type"] in ["graphrag", "raptor", "mindmap"] and msg.get("doc_ids", []): task = TaskService.get_task(msg["id"], msg["doc_ids"]) task["doc_ids"] = msg["doc_ids"] else: @@ -822,6 +823,10 @@ async def do_handle_task(task): logging.info(f"GraphRAG task result for task {task}:\n{result}") progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) return + elif task_type == "mindmap": + progress_callback(1, "place holder") + pass + return else: # Standard chunking methods start_ts = timer() @@ -898,7 +903,7 @@ async def handle_task(): logging.exception(f"handle_task got exception for task {json.dumps(task)}") finally: task_document_ids = [] - if task_type in ["graphrag", "raptor"]: + if task_type in ["graphrag", "raptor", "mindmap"]: task_document_ids = task["doc_ids"] if not task.get("dataflow_id", ""): PipelineOperationLogService.record_pipeline_operation(document_id=task["doc_id"], pipeline_id="", task_type=pipeline_task_type, fake_document_ids=task_document_ids)