mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-26 17:16:52 +08:00
Fix: judge retrieval from (#12223)
### What problem does this PR solve? Judge retrieval from in retrieval component, and fix bug in message component ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -21,11 +21,13 @@ from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
||||
not_allowed_parameters
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from memory.services.messages import MessageService
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
|
||||
|
||||
@ -68,7 +70,7 @@ async def create_memory():
|
||||
|
||||
@manager.route("/<memory_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type", "embd_id")
|
||||
@not_allowed_parameters("id", "tenant_id", "memory_type", "storage_type")
|
||||
async def update_memory(memory_id):
|
||||
req = await get_request_json()
|
||||
update_dict = {}
|
||||
@ -88,6 +90,14 @@ async def update_memory(memory_id):
|
||||
update_dict["permissions"] = req["permissions"]
|
||||
if req.get("llm_id"):
|
||||
update_dict["llm_id"] = req["llm_id"]
|
||||
if req.get("embd_id"):
|
||||
update_dict["embd_id"] = req["embd_id"]
|
||||
if req.get("memory_type"):
|
||||
memory_type = set(req["memory_type"])
|
||||
invalid_type = memory_type - {e.name.lower() for e in MemoryType}
|
||||
if invalid_type:
|
||||
return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.")
|
||||
update_dict["memory_type"] = list(memory_type)
|
||||
# check memory_size valid
|
||||
if req.get("memory_size"):
|
||||
if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT:
|
||||
@ -123,6 +133,15 @@ async def update_memory(memory_id):
|
||||
|
||||
if not to_update:
|
||||
return get_json_result(message=True, data=memory_dict)
|
||||
# check memory empty when update embd_id, memory_type
|
||||
memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id)
|
||||
not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0]
|
||||
if not_allowed_update:
|
||||
return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.")
|
||||
if "memory_type" in to_update:
|
||||
if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type):
|
||||
# update old default prompt, assemble a new one
|
||||
to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]})
|
||||
|
||||
try:
|
||||
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
|
||||
|
||||
Reference in New Issue
Block a user