mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-31 01:01:30 +08:00
Fix: use async task to save memory (#12308)
### What problem does this PR solve? Use async task to save memory. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -33,7 +33,7 @@ from common.connection_utils import timeout
|
||||
from common.misc_utils import get_uuid
|
||||
from common import settings
|
||||
|
||||
from api.db.joint_services.memory_message_service import save_to_memory
|
||||
from api.db.joint_services.memory_message_service import queue_save_to_memory_task
|
||||
|
||||
|
||||
class MessageParam(ComponentParamBase):
|
||||
@ -437,17 +437,4 @@ class Message(ComponentBase):
|
||||
"user_input": self._canvas.get_sys_query(),
|
||||
"agent_response": content
|
||||
}
|
||||
res = []
|
||||
for memory_id in self._param.memory_ids:
|
||||
success, msg = await save_to_memory(memory_id, message_dict)
|
||||
res.append({
|
||||
"memory_id": memory_id,
|
||||
"success": success,
|
||||
"msg": msg
|
||||
})
|
||||
if all([r["success"] for r in res]):
|
||||
return True, "Successfully added to memories."
|
||||
|
||||
error_text = "Some messages failed to add. " + " ".join([f"Add to memory {r['memory_id']} failed, detail: {r['msg']}" for r in res if not r["success"]])
|
||||
logging.error(error_text)
|
||||
return False, error_text
|
||||
return await queue_save_to_memory_task(self._param.memory_ids, message_dict)
|
||||
|
||||
@ -16,9 +16,14 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from api.db.services.task_service import TaskService
|
||||
from common import settings
|
||||
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
||||
from common.constants import MemoryType, LLMType
|
||||
from common.doc_store.doc_store_base import FusionExpr
|
||||
from common.misc_utils import get_uuid
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.db_models import Task
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
@ -82,32 +87,44 @@ async def save_to_memory(memory_id: str, message_dict: dict):
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
} for content in extracted_content]]
|
||||
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||
for idx, msg in enumerate(message_list):
|
||||
msg["content_embed"] = vector_list[idx]
|
||||
vector_dimension = len(vector_list[0])
|
||||
if not MessageService.has_index(tenant_id, memory_id):
|
||||
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
|
||||
if not created:
|
||||
return False, "Failed to create message index."
|
||||
return await embed_and_save(memory, message_list)
|
||||
|
||||
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
|
||||
if new_msg_size + current_memory_size > memory.memory_size:
|
||||
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
|
||||
if memory.forgetting_policy == "FIFO":
|
||||
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
|
||||
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
|
||||
decrease_memory_size_cache(memory_id, delete_size)
|
||||
else:
|
||||
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
|
||||
if fail_cases:
|
||||
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
|
||||
increase_memory_size_cache(memory_id, new_msg_size)
|
||||
return True, "Message saved successfully."
|
||||
async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return False, f"Memory '{memory_id}' not found."
|
||||
|
||||
if memory.memory_type == MemoryType.RAW.value:
|
||||
return True, f"Memory '{memory_id}' don't need to extract."
|
||||
|
||||
tenant_id = memory.tenant_id
|
||||
extracted_content = await extract_by_llm(
|
||||
tenant_id,
|
||||
memory.llm_id,
|
||||
{"temperature": memory.temperature},
|
||||
get_memory_type_human(memory.memory_type),
|
||||
message_dict.get("user_input", ""),
|
||||
message_dict.get("agent_response", "")
|
||||
)
|
||||
message_list = [{
|
||||
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||
"message_type": content["message_type"],
|
||||
"source_id": source_message_id,
|
||||
"memory_id": memory_id,
|
||||
"user_id": "",
|
||||
"agent_id": message_dict["agent_id"],
|
||||
"session_id": message_dict["session_id"],
|
||||
"content": content["content"],
|
||||
"valid_at": content["valid_at"],
|
||||
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
} for content in extracted_content]
|
||||
if not message_list:
|
||||
return True, "No memory extracted from raw message."
|
||||
|
||||
return await embed_and_save(memory, message_list)
|
||||
|
||||
|
||||
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||
@ -136,6 +153,36 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
|
||||
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
||||
|
||||
|
||||
async def embed_and_save(memory, message_list: list[dict]):
|
||||
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||
for idx, msg in enumerate(message_list):
|
||||
msg["content_embed"] = vector_list[idx]
|
||||
vector_dimension = len(vector_list[0])
|
||||
if not MessageService.has_index(memory.tenant_id, memory.id):
|
||||
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
|
||||
if not created:
|
||||
return False, "Failed to create message index."
|
||||
|
||||
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
|
||||
if new_msg_size + current_memory_size > memory.memory_size:
|
||||
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
|
||||
if memory.forgetting_policy == "FIFO":
|
||||
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id,
|
||||
size_to_delete)
|
||||
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
|
||||
decrease_memory_size_cache(memory.id, delete_size)
|
||||
else:
|
||||
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
|
||||
if fail_cases:
|
||||
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
|
||||
increase_memory_size_cache(memory.id, new_msg_size)
|
||||
return True, "Message saved successfully."
|
||||
|
||||
|
||||
def query_message(filter_dict: dict, params: dict):
|
||||
"""
|
||||
:param filter_dict: {
|
||||
@ -231,3 +278,112 @@ def init_memory_size_cache():
|
||||
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
|
||||
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
|
||||
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
|
||||
|
||||
|
||||
async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict):
|
||||
"""
|
||||
:param memory_ids:
|
||||
:param message_dict: {
|
||||
"user_id": str,
|
||||
"agent_id": str,
|
||||
"session_id": str,
|
||||
"user_input": str,
|
||||
"agent_response": str
|
||||
}
|
||||
"""
|
||||
def new_task(_memory_id: str, _source_id: int):
|
||||
return {
|
||||
"id": get_uuid(),
|
||||
"doc_id": _memory_id,
|
||||
"task_type": "memory",
|
||||
"progress": 0.0,
|
||||
"digest": str(_source_id)
|
||||
}
|
||||
|
||||
not_found_memory = []
|
||||
failed_memory = []
|
||||
for memory_id in memory_ids:
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
not_found_memory.append(memory_id)
|
||||
continue
|
||||
|
||||
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
|
||||
raw_message = {
|
||||
"message_id": raw_message_id,
|
||||
"message_type": MemoryType.RAW.name.lower(),
|
||||
"source_id": 0,
|
||||
"memory_id": memory_id,
|
||||
"user_id": "",
|
||||
"agent_id": message_dict["agent_id"],
|
||||
"session_id": message_dict["session_id"],
|
||||
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
|
||||
"valid_at": timestamp_to_date(current_timestamp()),
|
||||
"invalid_at": None,
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
}
|
||||
res, msg = await embed_and_save(memory, [raw_message])
|
||||
if not res:
|
||||
failed_memory.append({"memory_id": memory_id, "fail_msg": msg})
|
||||
continue
|
||||
|
||||
task = new_task(memory_id, raw_message_id)
|
||||
bulk_insert_into_db(Task, [task], replace_on_conflict=True)
|
||||
task_message = {
|
||||
"id": task["id"],
|
||||
"task_id": task["id"],
|
||||
"task_type": task["task_type"],
|
||||
"memory_id": memory_id,
|
||||
"source_id": raw_message_id,
|
||||
"message_dict": message_dict
|
||||
}
|
||||
if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message):
|
||||
failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."})
|
||||
|
||||
error_msg = ""
|
||||
if not_found_memory:
|
||||
error_msg = f"Memory {not_found_memory} not found."
|
||||
if failed_memory:
|
||||
error_msg += "".join([f"Memory {fm['memory_id']} failed. Detail: {fm['fail_msg']}" for fm in failed_memory])
|
||||
|
||||
if error_msg:
|
||||
return False, error_msg
|
||||
|
||||
return True, "All add to task."
|
||||
|
||||
|
||||
async def handle_save_to_memory_task(task_param: dict):
|
||||
"""
|
||||
:param task_param: {
|
||||
"id": task_id
|
||||
"memory_id": id
|
||||
"source_id": id
|
||||
"message_dict": {
|
||||
"user_id": str,
|
||||
"agent_id": str,
|
||||
"session_id": str,
|
||||
"user_input": str,
|
||||
"agent_response": str
|
||||
}
|
||||
}
|
||||
"""
|
||||
_, task = TaskService.get_by_id(task_param["id"])
|
||||
if not task:
|
||||
return False, f"Task {task_param['id']} is not found."
|
||||
if task.progress == -1:
|
||||
return False, f"Task {task_param['id']} is already failed."
|
||||
now_time = current_timestamp()
|
||||
TaskService.update_by_id(task_param["id"], {"begin_at": timestamp_to_date(now_time)})
|
||||
|
||||
memory_id = task_param["memory_id"]
|
||||
source_id = task_param["source_id"]
|
||||
message_dict = task_param["message_dict"]
|
||||
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id)
|
||||
if success:
|
||||
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg})
|
||||
return True, msg
|
||||
|
||||
logging.error(msg)
|
||||
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None})
|
||||
return False, msg
|
||||
|
||||
@ -138,6 +138,7 @@ class PipelineTaskType(StrEnum):
|
||||
RAPTOR = "RAPTOR"
|
||||
GRAPH_RAG = "GraphRAG"
|
||||
MINDMAP = "Mindmap"
|
||||
MEMORY = "Memory"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
|
||||
|
||||
@ -26,6 +26,7 @@ import time
|
||||
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.joint_services.memory_message_service import handle_save_to_memory_task
|
||||
from common.connection_utils import timeout
|
||||
from common.metadata_utils import update_metadata_to, metadata_schema
|
||||
from rag.utils.base64_image import image2id
|
||||
@ -96,6 +97,7 @@ TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
||||
"raptor": PipelineTaskType.RAPTOR,
|
||||
"graphrag": PipelineTaskType.GRAPH_RAG,
|
||||
"mindmap": PipelineTaskType.MINDMAP,
|
||||
"memory": PipelineTaskType.MEMORY,
|
||||
}
|
||||
|
||||
UNACKED_ITERATOR = None
|
||||
@ -197,6 +199,9 @@ async def collect():
|
||||
if task:
|
||||
task["doc_id"] = msg["doc_id"]
|
||||
task["doc_ids"] = msg.get("doc_ids", []) or []
|
||||
elif msg.get("task_type") == PipelineTaskType.MEMORY.lower():
|
||||
_, task_obj = TaskService.get_by_id(msg["id"])
|
||||
task = task_obj.to_dict()
|
||||
else:
|
||||
task = TaskService.get_task(msg["id"])
|
||||
|
||||
@ -215,6 +220,10 @@ async def collect():
|
||||
task["tenant_id"] = msg["tenant_id"]
|
||||
task["dataflow_id"] = msg["dataflow_id"]
|
||||
task["kb_id"] = msg.get("kb_id", "")
|
||||
if task_type[:6] == "memory":
|
||||
task["memory_id"] = msg["memory_id"]
|
||||
task["source_id"] = msg["source_id"]
|
||||
task["message_dict"] = msg["message_dict"]
|
||||
return redis_msg, task
|
||||
|
||||
|
||||
@ -866,6 +875,10 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
async def do_handle_task(task):
|
||||
task_type = task.get("task_type", "")
|
||||
|
||||
if task_type == "memory":
|
||||
await handle_save_to_memory_task(task)
|
||||
return
|
||||
|
||||
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
|
||||
await run_dataflow(task)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user