mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-26 17:16:52 +08:00
Fix: judge retrieval from (#12223)
### What problem does this PR solve? Judge retrieval from in retrieval component, and fix bug in message component ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -427,7 +427,7 @@ class Message(ComponentBase):
|
||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||
|
||||
async def _save_to_memory(self, content):
|
||||
if not self._param.memory_ids:
|
||||
if hasattr(self._param, "memory_ids") and not self._param.memory_ids:
|
||||
return True, "No memory selected."
|
||||
|
||||
message_dict = {
|
||||
|
||||
@ -282,7 +282,11 @@ class Retrieval(ToolBase, ABC):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
if self._param.kb_ids:
|
||||
if hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "dataset":
|
||||
return await self._retrieve_kb(kwargs["query"])
|
||||
elif hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "memory":
|
||||
return await self._retrieve_memory(kwargs["query"])
|
||||
elif self._param.kb_ids:
|
||||
return await self._retrieve_kb(kwargs["query"])
|
||||
elif hasattr(self._param, "memory_ids") and self._param.memory_ids:
|
||||
return await self._retrieve_memory(kwargs["query"])
|
||||
|
||||
@ -21,11 +21,13 @@ from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
||||
not_allowed_parameters
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from memory.services.messages import MessageService
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
|
||||
|
||||
@ -68,7 +70,7 @@ async def create_memory():
|
||||
|
||||
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
|
||||
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type")
|
||||
async def update_memory(memory_id):
|
||||
req = await get_request_json()
|
||||
update_dict = {}
|
||||
@ -88,6 +90,14 @@ async def update_memory(memory_id):
|
||||
update_dict["permissions"] = req["permissions"]
|
||||
if req.get("llm_id"):
|
||||
update_dict["llm_id"] = req["llm_id"]
|
||||
if req.get("embd_id"):
|
||||
update_dict["embd_id"] = req["embd_id"]
|
||||
if req.get("memory_type"):
|
||||
memory_type = set(req["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||
update_dict["memory_type"] = list(memory_type)
|
||||
# check memory_size valid
|
||||
if req.get("memory_size"):
|
||||
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
|
||||
@ -123,6 +133,15 @@ async def update_memory(memory_id):
|
||||
|
||||
if not to_update:
|
||||
return get_json_result(message=True, data=memory_dict)
|
||||
# check memory empty when update embd_id, memory_type
|
||||
memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id)
|
||||
not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0]
|
||||
if not_allowed_update:
|
||||
return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.")
|
||||
if "memory_type" in to_update:
|
||||
if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type):
|
||||
# update old default prompt, assemble a new one
|
||||
to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]})
|
||||
|
||||
try:
|
||||
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
|
||||
|
||||
@ -231,3 +231,8 @@ def init_memory_size_cache():
|
||||
memory_size = memory_size_map.get(memory.id, 0)
|
||||
set_memory_size_cache(memory.id, memory_size)
|
||||
logging.info("Memory size cache init done.")
|
||||
|
||||
|
||||
def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]):
|
||||
memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type)
|
||||
return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list})
|
||||
|
||||
@ -149,6 +149,8 @@ class MemoryService(CommonService):
|
||||
return 0
|
||||
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
|
||||
update_dict["temperature"] = float(update_dict["temperature"])
|
||||
if "memory_type" in update_dict and isinstance(update_dict["memory_type"], list):
|
||||
update_dict["memory_type"] = calculate_memory_type(update_dict["memory_type"])
|
||||
if "name" in update_dict:
|
||||
update_dict["name"] = duplicate_name(
|
||||
cls.query,
|
||||
|
||||
Reference in New Issue
Block a user