mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-24 23:46:52 +08:00
Feat: message manage (#12083)
### What problem does this PR solve? Message CRUD. Issue #4213 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
|
||||
from common.constants import LLMType
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
@ -169,6 +170,8 @@ def init_web_data():
|
||||
# init_superuser()
|
||||
|
||||
add_graph_templates()
|
||||
init_message_id_sequence()
|
||||
init_memory_size_cache()
|
||||
logging.info("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
|
||||
|
||||
233
api/db/joint_services/memory_message_service.py
Normal file
233
api/db/joint_services/memory_message_service.py
Normal file
@ -0,0 +1,233 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
||||
from common.constants import MemoryType, LLMType
|
||||
from common.doc_store.doc_store_base import FusionExpr
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils.memory_utils import get_memory_type_human
|
||||
from memory.services.messages import MessageService
|
||||
from memory.services.query import MsgTextQuery, get_vector
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
from memory.utils.msg_util import get_json_result_from_llm_response
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
async def save_to_memory(memory_id: str, message_dict: dict):
|
||||
"""
|
||||
:param memory_id:
|
||||
:param message_dict: {
|
||||
"user_id": str,
|
||||
"agent_id": str,
|
||||
"session_id": str,
|
||||
"user_input": str,
|
||||
"agent_response": str
|
||||
}
|
||||
"""
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return False, f"Memory '{memory_id}' not found."
|
||||
|
||||
tenant_id = memory.tenant_id
|
||||
extracted_content = await extract_by_llm(
|
||||
tenant_id,
|
||||
memory.llm_id,
|
||||
{"temperature": memory.temperature},
|
||||
get_memory_type_human(memory.memory_type),
|
||||
message_dict.get("user_input", ""),
|
||||
message_dict.get("agent_response", "")
|
||||
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
|
||||
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
|
||||
message_list = [{
|
||||
"message_id": raw_message_id,
|
||||
"message_type": MemoryType.RAW.name.lower(),
|
||||
"source_id": 0,
|
||||
"memory_id": memory_id,
|
||||
"user_id": "",
|
||||
"agent_id": message_dict["agent_id"],
|
||||
"session_id": message_dict["session_id"],
|
||||
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
|
||||
"valid_at": timestamp_to_date(current_timestamp()),
|
||||
"invalid_at": None,
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
}, *[{
|
||||
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||
"message_type": content["message_type"],
|
||||
"source_id": raw_message_id,
|
||||
"memory_id": memory_id,
|
||||
"user_id": "",
|
||||
"agent_id": message_dict["agent_id"],
|
||||
"session_id": message_dict["session_id"],
|
||||
"content": content["content"],
|
||||
"valid_at": content["valid_at"],
|
||||
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
} for content in extracted_content]]
|
||||
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||
for idx, msg in enumerate(message_list):
|
||||
msg["content_embed"] = vector_list[idx]
|
||||
vector_dimension = len(vector_list[0])
|
||||
if not MessageService.has_index(tenant_id, memory_id):
|
||||
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
|
||||
if not created:
|
||||
return False, "Failed to create message index."
|
||||
|
||||
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
|
||||
if new_msg_size + current_memory_size > memory.memory_size:
|
||||
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
|
||||
if memory.forgetting_policy == "fifo":
|
||||
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
|
||||
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
|
||||
decrease_memory_size_cache(memory_id, tenant_id, delete_size)
|
||||
else:
|
||||
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
|
||||
if fail_cases:
|
||||
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
|
||||
increase_memory_size_cache(memory_id, tenant_id, new_msg_size)
|
||||
return True, "Message saved successfully."
|
||||
|
||||
|
||||
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
|
||||
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
|
||||
if not llm_type:
|
||||
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
|
||||
if not system_prompt:
|
||||
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
|
||||
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
|
||||
conversation_time = timestamp_to_date(current_timestamp())
|
||||
user_prompts = []
|
||||
if user_prompt:
|
||||
user_prompts.append({"role": "user", "content": user_prompt})
|
||||
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
|
||||
else:
|
||||
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||
llm = LLMBundle(tenant_id, llm_type, llm_id)
|
||||
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
|
||||
res_json = get_json_result_from_llm_response(res)
|
||||
return [{
|
||||
"content": extracted_content["content"],
|
||||
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
|
||||
"invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "",
|
||||
"message_type": message_type
|
||||
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
||||
|
||||
|
||||
def query_message(filter_dict: dict, params: dict):
|
||||
"""
|
||||
:param filter_dict: {
|
||||
"memory_id": List[str],
|
||||
"agent_id": optional
|
||||
"session_id": optional
|
||||
}
|
||||
:param params: {
|
||||
"query": question str,
|
||||
"similarity_threshold": float,
|
||||
"keywords_similarity_weight": float,
|
||||
"top_n": int
|
||||
}
|
||||
"""
|
||||
memory_ids = filter_dict["memory_id"]
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
if not memory_list:
|
||||
return []
|
||||
|
||||
condition_dict = {k: v for k, v in filter_dict.items() if v}
|
||||
uids = [memory.tenant_id for memory in memory_list]
|
||||
|
||||
question = params["query"]
|
||||
question = question.strip()
|
||||
memory = memory_list[0]
|
||||
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
|
||||
match_text, _ = MsgTextQuery().question(question, min_match=0.3)
|
||||
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
|
||||
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])})
|
||||
|
||||
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
|
||||
|
||||
|
||||
def init_message_id_sequence():
|
||||
message_id_redis_key = "id_generator:memory"
|
||||
if REDIS_CONN.exist(message_id_redis_key):
|
||||
current_max_id = REDIS_CONN.get(message_id_redis_key)
|
||||
logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.")
|
||||
else:
|
||||
max_id = 1
|
||||
exist_memory_list = MemoryService.get_all_memory()
|
||||
if not exist_memory_list:
|
||||
REDIS_CONN.set(message_id_redis_key, max_id)
|
||||
else:
|
||||
max_id = MessageService.get_max_message_id(
|
||||
uid_list=[m.tenant_id for m in exist_memory_list],
|
||||
memory_ids=[m.id for m in exist_memory_list]
|
||||
)
|
||||
REDIS_CONN.set(message_id_redis_key, max_id)
|
||||
logging.info(f"Init message_id sequence done, current max id is {max_id}.")
|
||||
|
||||
|
||||
def get_memory_size_cache(memory_id: str, uid: str):
|
||||
redis_key = f"memory_{memory_id}"
|
||||
if REDIS_CONN.exists(redis_key):
|
||||
return REDIS_CONN.get(redis_key)
|
||||
else:
|
||||
memory_size_map = MessageService.calculate_memory_size(
|
||||
[memory_id],
|
||||
[uid]
|
||||
)
|
||||
memory_size = memory_size_map.get(memory_id, 0)
|
||||
set_memory_size_cache(memory_id, memory_size)
|
||||
return memory_size
|
||||
|
||||
|
||||
def set_memory_size_cache(memory_id: str, size: int):
|
||||
redis_key = f"memory_{memory_id}"
|
||||
return REDIS_CONN.set(redis_key, size)
|
||||
|
||||
|
||||
def increase_memory_size_cache(memory_id: str, uid: str, size: int):
|
||||
current_value = get_memory_size_cache(memory_id, uid)
|
||||
return set_memory_size_cache(memory_id, current_value + size)
|
||||
|
||||
|
||||
def decrease_memory_size_cache(memory_id: str, uid: str, size: int):
|
||||
current_value = get_memory_size_cache(memory_id, uid)
|
||||
return set_memory_size_cache(memory_id, max(current_value - size, 0))
|
||||
|
||||
|
||||
def init_memory_size_cache():
|
||||
memory_list = MemoryService.get_all_memory()
|
||||
if not memory_list:
|
||||
logging.info("No memory found, no need to init memory size.")
|
||||
else:
|
||||
memory_size_map = MessageService.calculate_memory_size(
|
||||
memory_ids=[m.id for m in memory_list],
|
||||
uid_list=[m.tenant_id for m in memory_list],
|
||||
)
|
||||
for memory in memory_list:
|
||||
memory_size = memory_size_map.get(memory.id, 0)
|
||||
set_memory_size_cache(memory.id, memory_size)
|
||||
logging.info("Memory size cache init done.")
|
||||
@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from memory.services.messages import MessageService
|
||||
from rag.nlp import search
|
||||
from common.constants import ActiveEnum
|
||||
from common import settings
|
||||
@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict:
|
||||
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||
# step1.3 delete own tenant
|
||||
# step1.3 delete memory and messages
|
||||
user_memory = MemoryService.get_by_tenant_id(tenant_id)
|
||||
if user_memory:
|
||||
for memory in user_memory:
|
||||
if MessageService.has_index(tenant_id, memory.id):
|
||||
MessageService.delete_index(tenant_id, memory.id)
|
||||
done_msg += " Deleted memory index."
|
||||
memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory])
|
||||
done_msg += f"Deleted {memory_delete_res} memory datasets."
|
||||
# step1.4 delete own tenant
|
||||
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||
# step2 delete user-tenant relation
|
||||
|
||||
@ -123,6 +123,19 @@ class UserCanvasService(CommonService):
|
||||
logging.exception(e)
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_basic_info_by_canvas_ids(cls, canvas_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.user_id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
|
||||
@ -38,7 +38,7 @@ from common.time_utils import current_timestamp, get_format_time
|
||||
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -345,7 +345,7 @@ class DocumentService(CommonService):
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@ -1230,8 +1230,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
|
||||
if not settings.docStoreConn.index_exist(idxnm, kb_id):
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
#
|
||||
from typing import List
|
||||
|
||||
from api.apps import current_user
|
||||
from api.db.db_models import DB, Memory, User
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type
|
||||
from api.constants import MEMORY_NAME_LIMIT
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import get_format_time, current_timestamp
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
|
||||
|
||||
class MemoryService(CommonService):
|
||||
@ -34,6 +34,17 @@ class MemoryService(CommonService):
|
||||
def get_by_memory_id(cls, memory_id: str):
|
||||
return cls.model.select().where(cls.model.id == memory_id).first()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, tenant_id: str):
|
||||
return cls.model.select().where(cls.model.tenant_id == tenant_id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_memory(cls):
|
||||
memory_list = cls.model.select()
|
||||
return list(memory_list)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_with_owner_name_by_id(cls, memory_id: str):
|
||||
@ -53,7 +64,9 @@ class MemoryService(CommonService):
|
||||
cls.model.forgetting_policy,
|
||||
cls.model.temperature,
|
||||
cls.model.system_prompt,
|
||||
cls.model.user_prompt
|
||||
cls.model.user_prompt,
|
||||
cls.model.create_date,
|
||||
cls.model.create_time
|
||||
]
|
||||
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
||||
cls.model.id == memory_id
|
||||
@ -72,7 +85,9 @@ class MemoryService(CommonService):
|
||||
cls.model.memory_type,
|
||||
cls.model.storage_type,
|
||||
cls.model.permissions,
|
||||
cls.model.description
|
||||
cls.model.description,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date
|
||||
]
|
||||
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
|
||||
if filter_dict.get("tenant_id"):
|
||||
@ -110,6 +125,7 @@ class MemoryService(CommonService):
|
||||
"tenant_id": tenant_id,
|
||||
"embd_id": embd_id,
|
||||
"llm_id": llm_id,
|
||||
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
|
||||
"create_time": current_timestamp(),
|
||||
"create_date": get_format_time(),
|
||||
"update_time": current_timestamp(),
|
||||
@ -126,7 +142,7 @@ class MemoryService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_memory(cls, memory_id: str, update_dict: dict):
|
||||
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
|
||||
if not update_dict:
|
||||
return 0
|
||||
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
|
||||
@ -135,7 +151,7 @@ class MemoryService(CommonService):
|
||||
update_dict["name"] = duplicate_name(
|
||||
cls.query,
|
||||
name=update_dict["name"],
|
||||
tenant_id=current_user.id
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
update_dict.update({
|
||||
"update_time": current_timestamp(),
|
||||
|
||||
Reference in New Issue
Block a user