From 4a6d37f0e888535f8fac68e6eef0c5d0f9044320 Mon Sep 17 00:00:00 2001 From: Lynn Date: Tue, 30 Dec 2025 11:41:38 +0800 Subject: [PATCH] Fix: use async task to save memory (#12308) ### What problem does this PR solve? Use async task to save memory. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Jin Hai --- agent/component/message.py | 17 +- .../joint_services/memory_message_service.py | 204 +++++++++++++++--- common/constants.py | 1 + rag/svr/task_executor.py | 13 ++ 4 files changed, 196 insertions(+), 39 deletions(-) diff --git a/agent/component/message.py b/agent/component/message.py index 7c888d886..bf393f541 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -33,7 +33,7 @@ from common.connection_utils import timeout from common.misc_utils import get_uuid from common import settings -from api.db.joint_services.memory_message_service import save_to_memory +from api.db.joint_services.memory_message_service import queue_save_to_memory_task class MessageParam(ComponentParamBase): @@ -437,17 +437,4 @@ class Message(ComponentBase): "user_input": self._canvas.get_sys_query(), "agent_response": content } - res = [] - for memory_id in self._param.memory_ids: - success, msg = await save_to_memory(memory_id, message_dict) - res.append({ - "memory_id": memory_id, - "success": success, - "msg": msg - }) - if all([r["success"] for r in res]): - return True, "Successfully added to memories." - - error_text = "Some messages failed to add. " + " ".join([f"Add to memory {r['memory_id']} failed, detail: {r['msg']}" for r in res if not r["success"]]) - logging.error(error_text) - return False, error_text + return await queue_save_to_memory_task(self._param.memory_ids, message_dict) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 2d581cebf..79848cad5 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -16,9 +16,14 @@ import logging from typing import List +from api.db.services.task_service import TaskService +from common import settings from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms from common.constants import MemoryType, LLMType from common.doc_store.doc_store_base import FusionExpr +from common.misc_utils import get_uuid +from api.db.db_utils import bulk_insert_into_db +from api.db.db_models import Task from api.db.services.memory_service import MemoryService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.llm_service import LLMBundle @@ -82,32 +87,44 @@ async def save_to_memory(memory_id: str, message_dict: dict): "forget_at": None, "status": True } for content in extracted_content]] - embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) - vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list]) - for idx, msg in enumerate(message_list): - msg["content_embed"] = vector_list[idx] - vector_dimension = len(vector_list[0]) - if not MessageService.has_index(tenant_id, memory_id): - created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension) - if not created: - return False, "Failed to create message index." + return await embed_and_save(memory, message_list) - new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list]) - current_memory_size = get_memory_size_cache(memory_id, tenant_id) - if new_msg_size + current_memory_size > memory.memory_size: - size_to_delete = current_memory_size + new_msg_size - memory.memory_size - if memory.forgetting_policy == "FIFO": - message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete) - MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id) - decrease_memory_size_cache(memory_id, delete_size) - else: - return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete." - fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id) - if fail_cases: - return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases) - increase_memory_size_cache(memory_id, new_msg_size) - return True, "Message saved successfully." +async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return False, f"Memory '{memory_id}' not found." + + if memory.memory_type == MemoryType.RAW.value: + return True, f"Memory '{memory_id}' don't need to extract." + + tenant_id = memory.tenant_id + extracted_content = await extract_by_llm( + tenant_id, + memory.llm_id, + {"temperature": memory.temperature}, + get_memory_type_human(memory.memory_type), + message_dict.get("user_input", ""), + message_dict.get("agent_response", "") + ) + message_list = [{ + "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), + "message_type": content["message_type"], + "source_id": source_message_id, + "memory_id": memory_id, + "user_id": "", + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": content["content"], + "valid_at": content["valid_at"], + "invalid_at": content["invalid_at"] if content["invalid_at"] else None, + "forget_at": None, + "status": True + } for content in extracted_content] + if not message_list: + return True, "No memory extracted from raw message." + + return await embed_and_save(memory, message_list) async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str, @@ -136,6 +153,36 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory } for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list] +async def embed_and_save(memory, message_list: list[dict]): + embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) + vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list]) + for idx, msg in enumerate(message_list): + msg["content_embed"] = vector_list[idx] + vector_dimension = len(vector_list[0]) + if not MessageService.has_index(memory.tenant_id, memory.id): + created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension) + if not created: + return False, "Failed to create message index." + + new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list]) + current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id) + if new_msg_size + current_memory_size > memory.memory_size: + size_to_delete = current_memory_size + new_msg_size - memory.memory_size + if memory.forgetting_policy == "FIFO": + message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id, + size_to_delete) + MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id) + decrease_memory_size_cache(memory.id, delete_size) + else: + return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete." + fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id) + if fail_cases: + return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases) + + increase_memory_size_cache(memory.id, new_msg_size) + return True, "Message saved successfully." + + def query_message(filter_dict: dict, params: dict): """ :param filter_dict: { @@ -231,3 +278,112 @@ def init_memory_size_cache(): def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]): memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type) return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list}) + + +async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict): + """ + :param memory_ids: + :param message_dict: { + "user_id": str, + "agent_id": str, + "session_id": str, + "user_input": str, + "agent_response": str + } + """ + def new_task(_memory_id: str, _source_id: int): + return { + "id": get_uuid(), + "doc_id": _memory_id, + "task_type": "memory", + "progress": 0.0, + "digest": str(_source_id) + } + + not_found_memory = [] + failed_memory = [] + for memory_id in memory_ids: + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + not_found_memory.append(memory_id) + continue + + raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory") + raw_message = { + "message_id": raw_message_id, + "message_type": MemoryType.RAW.name.lower(), + "source_id": 0, + "memory_id": memory_id, + "user_id": "", + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}", + "valid_at": timestamp_to_date(current_timestamp()), + "invalid_at": None, + "forget_at": None, + "status": True + } + res, msg = await embed_and_save(memory, [raw_message]) + if not res: + failed_memory.append({"memory_id": memory_id, "fail_msg": msg}) + continue + + task = new_task(memory_id, raw_message_id) + bulk_insert_into_db(Task, [task], replace_on_conflict=True) + task_message = { + "id": task["id"], + "task_id": task["id"], + "task_type": task["task_type"], + "memory_id": memory_id, + "source_id": raw_message_id, + "message_dict": message_dict + } + if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message): + failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."}) + + error_msg = "" + if not_found_memory: + error_msg = f"Memory {not_found_memory} not found." + if failed_memory: + error_msg += "".join([f"Memory {fm['memory_id']} failed. Detail: {fm['fail_msg']}" for fm in failed_memory]) + + if error_msg: + return False, error_msg + + return True, "All add to task." + + +async def handle_save_to_memory_task(task_param: dict): + """ + :param task_param: { + "id": task_id + "memory_id": id + "source_id": id + "message_dict": { + "user_id": str, + "agent_id": str, + "session_id": str, + "user_input": str, + "agent_response": str + } + } + """ + _, task = TaskService.get_by_id(task_param["id"]) + if not task: + return False, f"Task {task_param['id']} is not found." + if task.progress == -1: + return False, f"Task {task_param['id']} is already failed." + now_time = current_timestamp() + TaskService.update_by_id(task_param["id"], {"begin_at": timestamp_to_date(now_time)}) + + memory_id = task_param["memory_id"] + source_id = task_param["source_id"] + message_dict = task_param["message_dict"] + success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id) + if success: + TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg}) + return True, msg + + logging.error(msg) + TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None}) + return False, msg diff --git a/common/constants.py b/common/constants.py index 776e27447..8ddf3c4b4 100644 --- a/common/constants.py +++ b/common/constants.py @@ -138,6 +138,7 @@ class PipelineTaskType(StrEnum): RAPTOR = "RAPTOR" GRAPH_RAG = "GraphRAG" MINDMAP = "Mindmap" + MEMORY = "Memory" VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index b793d9c35..cf339b15a 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -26,6 +26,7 @@ import time from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService +from api.db.joint_services.memory_message_service import handle_save_to_memory_task from common.connection_utils import timeout from common.metadata_utils import update_metadata_to, metadata_schema from rag.utils.base64_image import image2id @@ -96,6 +97,7 @@ TASK_TYPE_TO_PIPELINE_TASK_TYPE = { "raptor": PipelineTaskType.RAPTOR, "graphrag": PipelineTaskType.GRAPH_RAG, "mindmap": PipelineTaskType.MINDMAP, + "memory": PipelineTaskType.MEMORY, } UNACKED_ITERATOR = None @@ -197,6 +199,9 @@ async def collect(): if task: task["doc_id"] = msg["doc_id"] task["doc_ids"] = msg.get("doc_ids", []) or [] + elif msg.get("task_type") == PipelineTaskType.MEMORY.lower(): + _, task_obj = TaskService.get_by_id(msg["id"]) + task = task_obj.to_dict() else: task = TaskService.get_task(msg["id"]) @@ -215,6 +220,10 @@ async def collect(): task["tenant_id"] = msg["tenant_id"] task["dataflow_id"] = msg["dataflow_id"] task["kb_id"] = msg.get("kb_id", "") + if task_type[:6] == "memory": + task["memory_id"] = msg["memory_id"] + task["source_id"] = msg["source_id"] + task["message_dict"] = msg["message_dict"] return redis_msg, task @@ -866,6 +875,10 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c async def do_handle_task(task): task_type = task.get("task_type", "") + if task_type == "memory": + await handle_save_to_memory_task(task) + return + if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID: await run_dataflow(task) return