mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-31 09:05:30 +08:00
Fix: use async task to save memory (#12308)
### What problem does this PR solve? Use async task to save memory. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -33,7 +33,7 @@ from common.connection_utils import timeout
|
|||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
from api.db.joint_services.memory_message_service import save_to_memory
|
from api.db.joint_services.memory_message_service import queue_save_to_memory_task
|
||||||
|
|
||||||
|
|
||||||
class MessageParam(ComponentParamBase):
|
class MessageParam(ComponentParamBase):
|
||||||
@ -437,17 +437,4 @@ class Message(ComponentBase):
|
|||||||
"user_input": self._canvas.get_sys_query(),
|
"user_input": self._canvas.get_sys_query(),
|
||||||
"agent_response": content
|
"agent_response": content
|
||||||
}
|
}
|
||||||
res = []
|
return await queue_save_to_memory_task(self._param.memory_ids, message_dict)
|
||||||
for memory_id in self._param.memory_ids:
|
|
||||||
success, msg = await save_to_memory(memory_id, message_dict)
|
|
||||||
res.append({
|
|
||||||
"memory_id": memory_id,
|
|
||||||
"success": success,
|
|
||||||
"msg": msg
|
|
||||||
})
|
|
||||||
if all([r["success"] for r in res]):
|
|
||||||
return True, "Successfully added to memories."
|
|
||||||
|
|
||||||
error_text = "Some messages failed to add. " + " ".join([f"Add to memory {r['memory_id']} failed, detail: {r['msg']}" for r in res if not r["success"]])
|
|
||||||
logging.error(error_text)
|
|
||||||
return False, error_text
|
|
||||||
|
|||||||
@ -16,9 +16,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from api.db.services.task_service import TaskService
|
||||||
|
from common import settings
|
||||||
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
||||||
from common.constants import MemoryType, LLMType
|
from common.constants import MemoryType, LLMType
|
||||||
from common.doc_store.doc_store_base import FusionExpr
|
from common.doc_store.doc_store_base import FusionExpr
|
||||||
|
from common.misc_utils import get_uuid
|
||||||
|
from api.db.db_utils import bulk_insert_into_db
|
||||||
|
from api.db.db_models import Task
|
||||||
from api.db.services.memory_service import MemoryService
|
from api.db.services.memory_service import MemoryService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
@ -82,32 +87,44 @@ async def save_to_memory(memory_id: str, message_dict: dict):
|
|||||||
"forget_at": None,
|
"forget_at": None,
|
||||||
"status": True
|
"status": True
|
||||||
} for content in extracted_content]]
|
} for content in extracted_content]]
|
||||||
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
return await embed_and_save(memory, message_list)
|
||||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
|
||||||
for idx, msg in enumerate(message_list):
|
|
||||||
msg["content_embed"] = vector_list[idx]
|
|
||||||
vector_dimension = len(vector_list[0])
|
|
||||||
if not MessageService.has_index(tenant_id, memory_id):
|
|
||||||
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
|
|
||||||
if not created:
|
|
||||||
return False, "Failed to create message index."
|
|
||||||
|
|
||||||
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
|
||||||
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
|
|
||||||
if new_msg_size + current_memory_size > memory.memory_size:
|
|
||||||
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
|
|
||||||
if memory.forgetting_policy == "FIFO":
|
|
||||||
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
|
|
||||||
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
|
|
||||||
decrease_memory_size_cache(memory_id, delete_size)
|
|
||||||
else:
|
|
||||||
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
|
||||||
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
|
|
||||||
if fail_cases:
|
|
||||||
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
|
||||||
|
|
||||||
increase_memory_size_cache(memory_id, new_msg_size)
|
async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int):
|
||||||
return True, "Message saved successfully."
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
return False, f"Memory '{memory_id}' not found."
|
||||||
|
|
||||||
|
if memory.memory_type == MemoryType.RAW.value:
|
||||||
|
return True, f"Memory '{memory_id}' don't need to extract."
|
||||||
|
|
||||||
|
tenant_id = memory.tenant_id
|
||||||
|
extracted_content = await extract_by_llm(
|
||||||
|
tenant_id,
|
||||||
|
memory.llm_id,
|
||||||
|
{"temperature": memory.temperature},
|
||||||
|
get_memory_type_human(memory.memory_type),
|
||||||
|
message_dict.get("user_input", ""),
|
||||||
|
message_dict.get("agent_response", "")
|
||||||
|
)
|
||||||
|
message_list = [{
|
||||||
|
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||||
|
"message_type": content["message_type"],
|
||||||
|
"source_id": source_message_id,
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"user_id": "",
|
||||||
|
"agent_id": message_dict["agent_id"],
|
||||||
|
"session_id": message_dict["session_id"],
|
||||||
|
"content": content["content"],
|
||||||
|
"valid_at": content["valid_at"],
|
||||||
|
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
|
||||||
|
"forget_at": None,
|
||||||
|
"status": True
|
||||||
|
} for content in extracted_content]
|
||||||
|
if not message_list:
|
||||||
|
return True, "No memory extracted from raw message."
|
||||||
|
|
||||||
|
return await embed_and_save(memory, message_list)
|
||||||
|
|
||||||
|
|
||||||
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||||
@ -136,6 +153,36 @@ async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory
|
|||||||
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_and_save(memory, message_list: list[dict]):
|
||||||
|
embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||||
|
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||||
|
for idx, msg in enumerate(message_list):
|
||||||
|
msg["content_embed"] = vector_list[idx]
|
||||||
|
vector_dimension = len(vector_list[0])
|
||||||
|
if not MessageService.has_index(memory.tenant_id, memory.id):
|
||||||
|
created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension)
|
||||||
|
if not created:
|
||||||
|
return False, "Failed to create message index."
|
||||||
|
|
||||||
|
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||||
|
current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id)
|
||||||
|
if new_msg_size + current_memory_size > memory.memory_size:
|
||||||
|
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
|
||||||
|
if memory.forgetting_policy == "FIFO":
|
||||||
|
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id,
|
||||||
|
size_to_delete)
|
||||||
|
MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id)
|
||||||
|
decrease_memory_size_cache(memory.id, delete_size)
|
||||||
|
else:
|
||||||
|
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||||
|
fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id)
|
||||||
|
if fail_cases:
|
||||||
|
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||||
|
|
||||||
|
increase_memory_size_cache(memory.id, new_msg_size)
|
||||||
|
return True, "Message saved successfully."
|
||||||
|
|
||||||
|
|
||||||
def query_message(filter_dict: dict, params: dict):
|
def query_message(filter_dict: dict, params: dict):
|
||||||
"""
|
"""
|
||||||
:param filter_dict: {
|
:param filter_dict: {
|
||||||
@ -231,3 +278,112 @@ def init_memory_size_cache():
|
|||||||
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
|
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
|
||||||
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
|
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
|
||||||
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
|
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
|
||||||
|
|
||||||
|
|
||||||
|
async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict):
|
||||||
|
"""
|
||||||
|
:param memory_ids:
|
||||||
|
:param message_dict: {
|
||||||
|
"user_id": str,
|
||||||
|
"agent_id": str,
|
||||||
|
"session_id": str,
|
||||||
|
"user_input": str,
|
||||||
|
"agent_response": str
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
def new_task(_memory_id: str, _source_id: int):
|
||||||
|
return {
|
||||||
|
"id": get_uuid(),
|
||||||
|
"doc_id": _memory_id,
|
||||||
|
"task_type": "memory",
|
||||||
|
"progress": 0.0,
|
||||||
|
"digest": str(_source_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
not_found_memory = []
|
||||||
|
failed_memory = []
|
||||||
|
for memory_id in memory_ids:
|
||||||
|
memory = MemoryService.get_by_memory_id(memory_id)
|
||||||
|
if not memory:
|
||||||
|
not_found_memory.append(memory_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
|
||||||
|
raw_message = {
|
||||||
|
"message_id": raw_message_id,
|
||||||
|
"message_type": MemoryType.RAW.name.lower(),
|
||||||
|
"source_id": 0,
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"user_id": "",
|
||||||
|
"agent_id": message_dict["agent_id"],
|
||||||
|
"session_id": message_dict["session_id"],
|
||||||
|
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
|
||||||
|
"valid_at": timestamp_to_date(current_timestamp()),
|
||||||
|
"invalid_at": None,
|
||||||
|
"forget_at": None,
|
||||||
|
"status": True
|
||||||
|
}
|
||||||
|
res, msg = await embed_and_save(memory, [raw_message])
|
||||||
|
if not res:
|
||||||
|
failed_memory.append({"memory_id": memory_id, "fail_msg": msg})
|
||||||
|
continue
|
||||||
|
|
||||||
|
task = new_task(memory_id, raw_message_id)
|
||||||
|
bulk_insert_into_db(Task, [task], replace_on_conflict=True)
|
||||||
|
task_message = {
|
||||||
|
"id": task["id"],
|
||||||
|
"task_id": task["id"],
|
||||||
|
"task_type": task["task_type"],
|
||||||
|
"memory_id": memory_id,
|
||||||
|
"source_id": raw_message_id,
|
||||||
|
"message_dict": message_dict
|
||||||
|
}
|
||||||
|
if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message):
|
||||||
|
failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."})
|
||||||
|
|
||||||
|
error_msg = ""
|
||||||
|
if not_found_memory:
|
||||||
|
error_msg = f"Memory {not_found_memory} not found."
|
||||||
|
if failed_memory:
|
||||||
|
error_msg += "".join([f"Memory {fm['memory_id']} failed. Detail: {fm['fail_msg']}" for fm in failed_memory])
|
||||||
|
|
||||||
|
if error_msg:
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
return True, "All add to task."
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_save_to_memory_task(task_param: dict):
|
||||||
|
"""
|
||||||
|
:param task_param: {
|
||||||
|
"id": task_id
|
||||||
|
"memory_id": id
|
||||||
|
"source_id": id
|
||||||
|
"message_dict": {
|
||||||
|
"user_id": str,
|
||||||
|
"agent_id": str,
|
||||||
|
"session_id": str,
|
||||||
|
"user_input": str,
|
||||||
|
"agent_response": str
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
_, task = TaskService.get_by_id(task_param["id"])
|
||||||
|
if not task:
|
||||||
|
return False, f"Task {task_param['id']} is not found."
|
||||||
|
if task.progress == -1:
|
||||||
|
return False, f"Task {task_param['id']} is already failed."
|
||||||
|
now_time = current_timestamp()
|
||||||
|
TaskService.update_by_id(task_param["id"], {"begin_at": timestamp_to_date(now_time)})
|
||||||
|
|
||||||
|
memory_id = task_param["memory_id"]
|
||||||
|
source_id = task_param["source_id"]
|
||||||
|
message_dict = task_param["message_dict"]
|
||||||
|
success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id)
|
||||||
|
if success:
|
||||||
|
TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg})
|
||||||
|
return True, msg
|
||||||
|
|
||||||
|
logging.error(msg)
|
||||||
|
TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None})
|
||||||
|
return False, msg
|
||||||
|
|||||||
@ -138,6 +138,7 @@ class PipelineTaskType(StrEnum):
|
|||||||
RAPTOR = "RAPTOR"
|
RAPTOR = "RAPTOR"
|
||||||
GRAPH_RAG = "GraphRAG"
|
GRAPH_RAG = "GraphRAG"
|
||||||
MINDMAP = "Mindmap"
|
MINDMAP = "Mindmap"
|
||||||
|
MEMORY = "Memory"
|
||||||
|
|
||||||
|
|
||||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
|
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import time
|
|||||||
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES
|
from api.db import PIPELINE_SPECIAL_PROGRESS_FREEZE_TASK_TYPES
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||||
|
from api.db.joint_services.memory_message_service import handle_save_to_memory_task
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from common.metadata_utils import update_metadata_to, metadata_schema
|
from common.metadata_utils import update_metadata_to, metadata_schema
|
||||||
from rag.utils.base64_image import image2id
|
from rag.utils.base64_image import image2id
|
||||||
@ -96,6 +97,7 @@ TASK_TYPE_TO_PIPELINE_TASK_TYPE = {
|
|||||||
"raptor": PipelineTaskType.RAPTOR,
|
"raptor": PipelineTaskType.RAPTOR,
|
||||||
"graphrag": PipelineTaskType.GRAPH_RAG,
|
"graphrag": PipelineTaskType.GRAPH_RAG,
|
||||||
"mindmap": PipelineTaskType.MINDMAP,
|
"mindmap": PipelineTaskType.MINDMAP,
|
||||||
|
"memory": PipelineTaskType.MEMORY,
|
||||||
}
|
}
|
||||||
|
|
||||||
UNACKED_ITERATOR = None
|
UNACKED_ITERATOR = None
|
||||||
@ -197,6 +199,9 @@ async def collect():
|
|||||||
if task:
|
if task:
|
||||||
task["doc_id"] = msg["doc_id"]
|
task["doc_id"] = msg["doc_id"]
|
||||||
task["doc_ids"] = msg.get("doc_ids", []) or []
|
task["doc_ids"] = msg.get("doc_ids", []) or []
|
||||||
|
elif msg.get("task_type") == PipelineTaskType.MEMORY.lower():
|
||||||
|
_, task_obj = TaskService.get_by_id(msg["id"])
|
||||||
|
task = task_obj.to_dict()
|
||||||
else:
|
else:
|
||||||
task = TaskService.get_task(msg["id"])
|
task = TaskService.get_task(msg["id"])
|
||||||
|
|
||||||
@ -215,6 +220,10 @@ async def collect():
|
|||||||
task["tenant_id"] = msg["tenant_id"]
|
task["tenant_id"] = msg["tenant_id"]
|
||||||
task["dataflow_id"] = msg["dataflow_id"]
|
task["dataflow_id"] = msg["dataflow_id"]
|
||||||
task["kb_id"] = msg.get("kb_id", "")
|
task["kb_id"] = msg.get("kb_id", "")
|
||||||
|
if task_type[:6] == "memory":
|
||||||
|
task["memory_id"] = msg["memory_id"]
|
||||||
|
task["source_id"] = msg["source_id"]
|
||||||
|
task["message_dict"] = msg["message_dict"]
|
||||||
return redis_msg, task
|
return redis_msg, task
|
||||||
|
|
||||||
|
|
||||||
@ -866,6 +875,10 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
|||||||
async def do_handle_task(task):
|
async def do_handle_task(task):
|
||||||
task_type = task.get("task_type", "")
|
task_type = task.get("task_type", "")
|
||||||
|
|
||||||
|
if task_type == "memory":
|
||||||
|
await handle_save_to_memory_task(task)
|
||||||
|
return
|
||||||
|
|
||||||
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
|
if task_type == "dataflow" and task.get("doc_id", "") == CANVAS_DEBUG_DOC_ID:
|
||||||
await run_dataflow(task)
|
await run_dataflow(task)
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user