From 7498bc63a382bc0e07c4185216992f0b21b8f053 Mon Sep 17 00:00:00 2001 From: Lynn Date: Fri, 26 Dec 2025 13:01:46 +0800 Subject: [PATCH] Fix: judge retrieval from (#12223) ### What problem does this PR solve? Judge retrieval from in retrieval component, and fix bug in message component ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/component/message.py | 2 +- agent/tools/retrieval.py | 6 +++++- api/apps/memories_app.py | 21 ++++++++++++++++++- .../joint_services/memory_message_service.py | 5 +++++ api/db/services/memory_service.py | 2 ++ 5 files changed, 33 insertions(+), 3 deletions(-) diff --git a/agent/component/message.py b/agent/component/message.py index b4f21018c..15acf77f3 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -427,7 +427,7 @@ class Message(ComponentBase): logging.error(f"Error converting content to {self._param.output_format}: {e}") async def _save_to_memory(self, content): - if not self._param.memory_ids: + if hasattr(self._param, "memory_ids") and not self._param.memory_ids: return True, "No memory selected." message_dict = { diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index fe69e90f8..21df960be 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -282,7 +282,11 @@ class Retrieval(ToolBase, ABC): self.set_output("formalized_content", self._param.empty_response) return - if self._param.kb_ids: + if hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "dataset": + return await self._retrieve_kb(kwargs["query"]) + elif hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "memory": + return await self._retrieve_memory(kwargs["query"]) + elif self._param.kb_ids: return await self._retrieve_kb(kwargs["query"]) elif hasattr(self._param, "memory_ids") and self._param.memory_ids: return await self._retrieve_memory(kwargs["query"]) diff --git a/api/apps/memories_app.py b/api/apps/memories_app.py index 4882b9526..746abad7d 100644 --- a/api/apps/memories_app.py +++ b/api/apps/memories_app.py @@ -21,11 +21,13 @@ from api.db import TenantPermission from api.db.services.memory_service import MemoryService from api.db.services.user_service import UserTenantService from api.db.services.canvas_service import UserCanvasService +from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \ not_allowed_parameters from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from memory.services.messages import MessageService +from memory.utils.prompt_util import PromptAssembler from common.constants import MemoryType, RetCode, ForgettingPolicy @@ -68,7 +70,7 @@ async def create_memory(): @manager.route("/", methods=["PUT"]) # noqa: F821 @login_required -@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id") +@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type") async def update_memory(memory_id): req = await get_request_json() update_dict = {} @@ -88,6 +90,14 @@ async def update_memory(memory_id): update_dict["permissions"] = req["permissions"] if req.get("llm_id"): update_dict["llm_id"] = req["llm_id"] + if req.get("embd_id"): + update_dict["embd_id"] = req["embd_id"] + if req.get("memory_type"): + memory_type = set(req["memory_type"]) + invalid_type = memory_type - {e.name.lower() for e in MemoryType} + if invalid_type: + return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") + update_dict["memory_type"] = list(memory_type) # check memory_size valid if req.get("memory_size"): if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT: @@ -123,6 +133,15 @@ async def update_memory(memory_id): if not to_update: return get_json_result(message=True, data=memory_dict) + # check memory empty when update embd_id, memory_type + memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id) + not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0] + if not_allowed_update: + return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.") + if "memory_type" in to_update: + if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type): + # update old default prompt, assemble a new one + to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]}) try: MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 97e41b366..0b1016bc6 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -231,3 +231,8 @@ def init_memory_size_cache(): memory_size = memory_size_map.get(memory.id, 0) set_memory_size_cache(memory.id, memory_size) logging.info("Memory size cache init done.") + + +def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]): + memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type) + return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list}) diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py index da75adcb8..8a65d15e2 100644 --- a/api/db/services/memory_service.py +++ b/api/db/services/memory_service.py @@ -149,6 +149,8 @@ class MemoryService(CommonService): return 0 if "temperature" in update_dict and isinstance(update_dict["temperature"], str): update_dict["temperature"] = float(update_dict["temperature"]) + if "memory_type" in update_dict and isinstance(update_dict["memory_type"], list): + update_dict["memory_type"] = calculate_memory_type(update_dict["memory_type"]) if "name" in update_dict: update_dict["name"] = duplicate_name( cls.query,