Fix: use async task to save memory (#12308)

### What problem does this PR solve?

Use async task to save memory.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Lynn
2025-12-30 11:41:38 +08:00
committed by GitHub
parent 731e2d5f26
commit 4a6d37f0e8
4 changed files with 196 additions and 39 deletions

View File

@ -26,6 +26,7 @@ import time
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from api.db.joint_services.memory_message_service import handle_save_to_memory_task
from common.connection_utils import timeout
from common.metadata_utils import update_metadata_to, metadata_schema
from rag.utils.base64_image import image2id
@ -96,6 +97,7 @@ TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
"raptor": PipelineTaskType.RAPTOR,
"graphrag": PipelineTaskType.GRAPH_RAG,
"mindmap": PipelineTaskType.MINDMAP,
"memory": PipelineTaskType.MEMORY,
}
UNACKED_ITERATOR = None
@ -197,6 +199,9 @@ async def collect():
if task:
task["doc_id"] = msg["doc_id"]
task["doc_ids"] = msg.get("doc_ids", []) or []
elif msg.get("task_type") == PipelineTaskType.MEMORY.lower():
_, task_obj = TaskService.get_by_id(msg["id"])
task = task_obj.to_dict()
else:
task = TaskService.get_task(msg["id"])
@ -215,6 +220,10 @@ async def collect():
task["tenant_id"] = msg["tenant_id"]
task["dataflow_id"] = msg["dataflow_id"]
task["kb_id"] = msg.get("kb_id", "")
if task_type[:6] == "memory":
task["memory_id"] = msg["memory_id"]
task["source_id"] = msg["source_id"]
task["message_dict"] = msg["message_dict"]
return redis_msg, task
@ -866,6 +875,10 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
async def do_handle_task(task):
task_type = task.get("task_type", "")
if task_type == "memory":
await handle_save_to_memory_task(task)
return
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
await run_dataflow(task)
return