Feat: message manage (#12083)

### What problem does this PR solve?

Message CRUD.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Lynn
2025-12-23 21:16:25 +08:00
committed by GitHub
parent bab6a4a219
commit 17b8bb62b6
49 changed files with 3480 additions and 1031 deletions

View File

@ -192,7 +192,7 @@ async def rerun():
if 0 < doc["progress"] < 1:
return get_data_error_result(message=f"`{doc['name']}` is processing...")
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
doc["progress_msg"] = ""
doc["chunk_num"] = 0

View File

@ -564,7 +564,7 @@ async def run():
DocumentService.update_by_id(id, info)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
@ -615,7 +615,7 @@ async def rename():
"title_tks": title_tks,
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
}
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.update(
{"doc_id": req["doc_id"]},
es_body,
@ -696,7 +696,7 @@ async def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return None

View File

@ -39,9 +39,9 @@ from api.utils.api_utils import get_json_result
from rag.nlp import search
from api.constants import DATASET_NAME_LIMIT
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
from common import settings
from common.doc_store.doc_store_base import OrderByExpr
from api.apps import login_required, current_user
@ -285,7 +285,7 @@ async def rm():
message="Database error (Knowledgebase removal)!")
for kb in kbs:
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
settings.STORAGE_IMPL.remove_bucket(kb.id)
return get_json_result(data=True)
@ -386,7 +386,7 @@ def knowledge_graph(kb_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
return get_json_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
if not len(sres.ids):
@ -858,11 +858,11 @@ async def check_embedding():
index_nm = search.index_name(tenant_id)
res0 = docStoreConn.search(
selectFields=[], highlightFields=[],
select_fields=[], highlight_fields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
match_expressions=[], order_by=OrderByExpr(),
offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
index_names=index_nm, knowledgebase_ids=[kb_id]
)
total = docStoreConn.get_total(res0)
if total <= 0:
@ -874,14 +874,14 @@ async def check_embedding():
for off in offsets:
res1 = docStoreConn.search(
selectFields=list(base_fields),
highlightFields=[],
select_fields=list(base_fields),
highlight_fields=[],
condition={"kb_id": kb_id, "available_int": 1},
matchExprs=[], orderBy=OrderByExpr(),
match_expressions=[], order_by=OrderByExpr(),
offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id]
index_names=index_nm, knowledgebase_ids=[kb_id]
)
ids = docStoreConn.get_chunk_ids(res1)
ids = docStoreConn.get_doc_ids(res1)
if not ids:
continue

View File

@ -20,10 +20,12 @@ from api.apps import login_required, current_user
from api.db import TenantPermission
from api.db.services.memory_service import MemoryService
from api.db.services.user_service import UserTenantService
from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
not_allowed_parameters
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
from memory.services.messages import MessageService
from common.constants import MemoryType, RetCode, ForgettingPolicy
@ -57,7 +59,6 @@ async def create_memory():
if res:
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
else:
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
@ -124,7 +125,7 @@ async def update_memory(memory_id):
return get_json_result(message=True, data=memory_dict)
try:
MemoryService.update_memory(memory_id, to_update)
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
updated_memory = MemoryService.get_by_memory_id(memory_id)
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
@ -133,7 +134,7 @@ async def update_memory(memory_id):
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_memory(memory_id):
memory = MemoryService.get_by_memory_id(memory_id)
@ -141,13 +142,14 @@ async def delete_memory(memory_id):
return get_json_result(message=True, code=RetCode.NOT_FOUND)
try:
MemoryService.delete_memory(memory_id)
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
return get_json_result(message=True)
except Exception as e:
logging.error(e)
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
@manager.route("", methods=["GET"]) # noqa: F821
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def list_memory():
args = request.args
@ -183,3 +185,26 @@ async def get_memory_config(memory_id):
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821
@login_required
async def get_memory_detail(memory_id):
args = request.args
agent_ids = args.getlist("agent_id")
keywords = args.get("keywords", "")
keywords = keywords.strip()
page = int(args.get("page", 1))
page_size = int(args.get("page_size", 50))
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
messages = MessageService.list_message(
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
agent_name_mapping = {}
if messages["message_list"]:
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
for message in messages["message_list"]:
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)

169
api/apps/messages_app.py Normal file
View File

@ -0,0 +1,169 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from quart import request
from api.apps import login_required
from api.db.services.memory_service import MemoryService
from common.time_utils import current_timestamp, timestamp_to_date
from memory.services.messages import MessageService
from api.db.joint_services import memory_message_service
from api.db.joint_services.memory_message_service import query_message
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
from common.constants import RetCode
@manager.route("", methods=["POST"]) # noqa: F821
@login_required
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
async def add_message():
req = await get_request_json()
memory_ids = req["memory_id"]
agent_id = req["agent_id"]
session_id = req["session_id"]
user_id = req["user_id"] if req.get("user_id") else ""
user_input = req["user_input"]
agent_response = req["agent_response"]
res = []
for memory_id in memory_ids:
success, msg = await memory_message_service.save_to_memory(
memory_id,
{
"user_id": user_id,
"agent_id": agent_id,
"session_id": session_id,
"user_input": user_input,
"agent_response": agent_response
}
)
res.append({
"memory_id": memory_id,
"success": success,
"message": msg
})
if all([r["success"] for r in res]):
return get_json_result(message="Successfully added to memories.")
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def forget_message(memory_id: str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
forget_time = timestamp_to_date(current_timestamp())
update_succeed = MessageService.update_message(
{"memory_id": memory_id, "message_id": int(message_id)},
{"forget_at": forget_time},
memory.tenant_id, memory_id)
if update_succeed:
return get_json_result(message=update_succeed)
else:
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
@login_required
@validate_request("status")
async def update_message(memory_id: str, message_id: int):
req = await get_request_json()
status = req["status"]
if not isinstance(status, bool):
return get_error_argument_result("Status must be a boolean.")
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id)
if update_succeed:
return get_json_result(message=update_succeed)
else:
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
@manager.route("/search", methods=["GET"]) # noqa: F821
@login_required
async def search_message():
args = request.args
print(args, flush=True)
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
if empty_fields:
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
memory_ids = args.getlist("memory_id")
query = args.get("query")
similarity_threshold = float(args.get("similarity_threshold", 0.2))
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
top_n = int(args.get("top_n", 5))
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
filter_dict = {
"memory_id": memory_ids,
"agent_id": agent_id,
"session_id": session_id
}
params = {
"query": query,
"similarity_threshold": similarity_threshold,
"keywords_similarity_weight": keywords_similarity_weight,
"top_n": top_n
}
res = query_message(filter_dict, params)
return get_json_result(message=True, data=res)
@manager.route("", methods=["GET"]) # noqa: F821
@login_required
async def get_messages():
args = request.args
memory_ids = args.getlist("memory_id")
agent_id = args.get("agent_id", "")
session_id = args.get("session_id", "")
limit = int(args.get("limit", 10))
if not memory_ids:
return get_error_argument_result("memory_ids is required.")
memory_list = MemoryService.get_by_ids(memory_ids)
uids = [memory.tenant_id for memory in memory_list]
res = MessageService.get_recent_messages(
uids,
memory_ids,
agent_id,
session_id,
limit
)
return get_json_result(message=True, data=res)
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
@login_required
async def get_message_content(memory_id:str, message_id: int):
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
if res:
return get_json_result(message=True, data=res)
else:
return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.")

View File

@ -495,7 +495,7 @@ def knowledge_graph(tenant_id, dataset_id):
}
obj = {"graph": {}, "mind_map": {}}
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
return get_result(data=obj)
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
if not len(sres.ids):

View File

@ -1080,7 +1080,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
res["chunks"].append(final_chunk)
_ = Chunk(**final_chunk)
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
res["total"] = sres.total
for id in sres.ids:

View File

@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
from common.constants import LLMType
from common.file_utils import get_project_base_directory
from common import settings
@ -169,6 +170,8 @@ def init_web_data():
# init_superuser()
add_graph_templates()
init_message_id_sequence()
init_memory_size_cache()
logging.info("init web data success:{}".format(time.time() - start_time))

View File

@ -0,0 +1,233 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import List
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
from common.constants import MemoryType, LLMType
from common.doc_store.doc_store_base import FusionExpr
from api.db.services.memory_service import MemoryService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.utils.memory_utils import get_memory_type_human
from memory.services.messages import MessageService
from memory.services.query import MsgTextQuery, get_vector
from memory.utils.prompt_util import PromptAssembler
from memory.utils.msg_util import get_json_result_from_llm_response
from rag.utils.redis_conn import REDIS_CONN
async def save_to_memory(memory_id: str, message_dict: dict):
"""
:param memory_id:
:param message_dict: {
"user_id": str,
"agent_id": str,
"session_id": str,
"user_input": str,
"agent_response": str
}
"""
memory = MemoryService.get_by_memory_id(memory_id)
if not memory:
return False, f"Memory '{memory_id}' not found."
tenant_id = memory.tenant_id
extracted_content = await extract_by_llm(
tenant_id,
memory.llm_id,
{"temperature": memory.temperature},
get_memory_type_human(memory.memory_type),
message_dict.get("user_input", ""),
message_dict.get("agent_response", "")
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
message_list = [{
"message_id": raw_message_id,
"message_type": MemoryType.RAW.name.lower(),
"source_id": 0,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
"valid_at": timestamp_to_date(current_timestamp()),
"invalid_at": None,
"forget_at": None,
"status": True
}, *[{
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
"message_type": content["message_type"],
"source_id": raw_message_id,
"memory_id": memory_id,
"user_id": "",
"agent_id": message_dict["agent_id"],
"session_id": message_dict["session_id"],
"content": content["content"],
"valid_at": content["valid_at"],
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
"forget_at": None,
"status": True
} for content in extracted_content]]
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
for idx, msg in enumerate(message_list):
msg["content_embed"] = vector_list[idx]
vector_dimension = len(vector_list[0])
if not MessageService.has_index(tenant_id, memory_id):
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
if not created:
return False, "Failed to create message index."
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
if new_msg_size + current_memory_size > memory.memory_size:
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
if memory.forgetting_policy == "fifo":
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
decrease_memory_size_cache(memory_id, tenant_id, delete_size)
else:
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
if fail_cases:
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
increase_memory_size_cache(memory_id, tenant_id, new_msg_size)
return True, "Message saved successfully."
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
if not llm_type:
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
if not system_prompt:
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
conversation_time = timestamp_to_date(current_timestamp())
user_prompts = []
if user_prompt:
user_prompts.append({"role": "user", "content": user_prompt})
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
else:
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
llm = LLMBundle(tenant_id, llm_type, llm_id)
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
res_json = get_json_result_from_llm_response(res)
return [{
"content": extracted_content["content"],
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
"invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "",
"message_type": message_type
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
def query_message(filter_dict: dict, params: dict):
"""
:param filter_dict: {
"memory_id": List[str],
"agent_id": optional
"session_id": optional
}
:param params: {
"query": question str,
"similarity_threshold": float,
"keywords_similarity_weight": float,
"top_n": int
}
"""
memory_ids = filter_dict["memory_id"]
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
return []
condition_dict = {k: v for k, v in filter_dict.items() if v}
uids = [memory.tenant_id for memory in memory_list]
question = params["query"]
question = question.strip()
memory = memory_list[0]
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
match_text, _ = MsgTextQuery().question(question, min_match=0.3)
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])})
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
def init_message_id_sequence():
message_id_redis_key = "id_generator:memory"
if REDIS_CONN.exist(message_id_redis_key):
current_max_id = REDIS_CONN.get(message_id_redis_key)
logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.")
else:
max_id = 1
exist_memory_list = MemoryService.get_all_memory()
if not exist_memory_list:
REDIS_CONN.set(message_id_redis_key, max_id)
else:
max_id = MessageService.get_max_message_id(
uid_list=[m.tenant_id for m in exist_memory_list],
memory_ids=[m.id for m in exist_memory_list]
)
REDIS_CONN.set(message_id_redis_key, max_id)
logging.info(f"Init message_id sequence done, current max id is {max_id}.")
def get_memory_size_cache(memory_id: str, uid: str):
redis_key = f"memory_{memory_id}"
if REDIS_CONN.exists(redis_key):
return REDIS_CONN.get(redis_key)
else:
memory_size_map = MessageService.calculate_memory_size(
[memory_id],
[uid]
)
memory_size = memory_size_map.get(memory_id, 0)
set_memory_size_cache(memory_id, memory_size)
return memory_size
def set_memory_size_cache(memory_id: str, size: int):
redis_key = f"memory_{memory_id}"
return REDIS_CONN.set(redis_key, size)
def increase_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, current_value + size)
def decrease_memory_size_cache(memory_id: str, uid: str, size: int):
current_value = get_memory_size_cache(memory_id, uid)
return set_memory_size_cache(memory_id, max(current_value - size, 0))
def init_memory_size_cache():
memory_list = MemoryService.get_all_memory()
if not memory_list:
logging.info("No memory found, no need to init memory size.")
else:
memory_size_map = MessageService.calculate_memory_size(
memory_ids=[m.id for m in memory_list],
uid_list=[m.tenant_id for m in memory_list],
)
for memory in memory_list:
memory_size = memory_size_map.get(memory.id, 0)
set_memory_size_cache(memory.id, memory_size)
logging.info("Memory size cache init done.")

View File

@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.db.services.user_service import TenantService, UserService, UserTenantService
from api.db.services.memory_service import MemoryService
from memory.services.messages import MessageService
from rag.nlp import search
from common.constants import ActiveEnum
from common import settings
@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict:
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
# step1.3 delete own tenant
# step1.3 delete memory and messages
user_memory = MemoryService.get_by_tenant_id(tenant_id)
if user_memory:
for memory in user_memory:
if MessageService.has_index(tenant_id, memory.id):
MessageService.delete_index(tenant_id, memory.id)
done_msg += " Deleted memory index."
memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory])
done_msg += f"Deleted {memory_delete_res} memory datasets."
# step1.4 delete own tenant
tenant_delete_res = TenantService.delete_by_id(tenant_id)
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
# step2 delete user-tenant relation

View File

@ -123,6 +123,19 @@ class UserCanvasService(CommonService):
logging.exception(e)
return False, None
@classmethod
@DB.connection_context()
def get_basic_info_by_canvas_ids(cls, canvas_id):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.user_id,
cls.model.title,
cls.model.permission,
cls.model.canvas_category
]
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,

View File

@ -38,7 +38,7 @@ from common.time_utils import current_timestamp, get_format_time
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import REDIS_CONN
from rag.utils.doc_store_conn import OrderByExpr
from common.doc_store.doc_store_base import OrderByExpr
from common import settings
@ -345,7 +345,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id])
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
if not chunk_ids:
break
all_chunk_ids.extend(chunk_ids)
@ -1230,8 +1230,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
try_create_idx = False
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)

View File

@ -15,7 +15,6 @@
#
from typing import List
from api.apps import current_user
from api.db.db_models import DB, Memory, User
from api.db.services import duplicate_name
from api.db.services.common_service import CommonService
@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type
from api.constants import MEMORY_NAME_LIMIT
from common.misc_utils import get_uuid
from common.time_utils import get_format_time, current_timestamp
from memory.utils.prompt_util import PromptAssembler
class MemoryService(CommonService):
@ -34,6 +34,17 @@ class MemoryService(CommonService):
def get_by_memory_id(cls, memory_id: str):
return cls.model.select().where(cls.model.id == memory_id).first()
@classmethod
@DB.connection_context()
def get_by_tenant_id(cls, tenant_id: str):
return cls.model.select().where(cls.model.tenant_id == tenant_id)
@classmethod
@DB.connection_context()
def get_all_memory(cls):
memory_list = cls.model.select()
return list(memory_list)
@classmethod
@DB.connection_context()
def get_with_owner_name_by_id(cls, memory_id: str):
@ -53,7 +64,9 @@ class MemoryService(CommonService):
cls.model.forgetting_policy,
cls.model.temperature,
cls.model.system_prompt,
cls.model.user_prompt
cls.model.user_prompt,
cls.model.create_date,
cls.model.create_time
]
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
cls.model.id == memory_id
@ -72,7 +85,9 @@ class MemoryService(CommonService):
cls.model.memory_type,
cls.model.storage_type,
cls.model.permissions,
cls.model.description
cls.model.description,
cls.model.create_time,
cls.model.create_date
]
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
if filter_dict.get("tenant_id"):
@ -110,6 +125,7 @@ class MemoryService(CommonService):
"tenant_id": tenant_id,
"embd_id": embd_id,
"llm_id": llm_id,
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
"create_time": current_timestamp(),
"create_date": get_format_time(),
"update_time": current_timestamp(),
@ -126,7 +142,7 @@ class MemoryService(CommonService):
@classmethod
@DB.connection_context()
def update_memory(cls, memory_id: str, update_dict: dict):
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
if not update_dict:
return 0
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
@ -135,7 +151,7 @@ class MemoryService(CommonService):
update_dict["name"] = duplicate_name(
cls.query,
name=update_dict["name"],
tenant_id=current_user.id
tenant_id=tenant_id
)
update_dict.update({
"update_time": current_timestamp(),