diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 0b1016bc6..2d581cebf 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -223,13 +223,8 @@ def init_memory_size_cache(): if not memory_list: logging.info("No memory found, no need to init memory size.") else: - memory_size_map = MessageService.calculate_memory_size( - memory_ids=[m.id for m in memory_list], - uid_list=[m.tenant_id for m in memory_list], - ) - for memory in memory_list: - memory_size = memory_size_map.get(memory.id, 0) - set_memory_size_cache(memory.id, memory_size) + for m in memory_list: + get_memory_size_cache(m.id, m.tenant_id) logging.info("Memory size cache init done.") diff --git a/memory/services/messages.py b/memory/services/messages.py index cfa722100..29f0f4da1 100644 --- a/memory/services/messages.py +++ b/memory/services/messages.py @@ -198,7 +198,7 @@ class MessageService: message_list = settings.msgStoreConn.get_fields(res, select_fields) current_size = 0 ids_to_remove = [] - for message in message_list: + for message in message_list.values(): if current_size < size_to_delete: current_size += cls.calculate_message_size(message) ids_to_remove.append(message["message_id"]) @@ -210,7 +210,7 @@ class MessageService: order_by = OrderByExpr() order_by.asc("valid_at") res = settings.msgStoreConn.search( - select_fields=["memory_id", "content", "content_embed"], + select_fields=select_fields, highlight_fields=[], condition={}, match_expressions=[], @@ -222,7 +222,7 @@ class MessageService: for doc in docs.values(): if current_size < size_to_delete: current_size += cls.calculate_message_size(doc) - ids_to_remove.append(doc["memory_id"]) + ids_to_remove.append(doc["message_id"]) else: return ids_to_remove, current_size return ids_to_remove, current_size diff --git a/memory/utils/es_conn.py b/memory/utils/es_conn.py index dc98e871d..b75b9df56 100644 --- a/memory/utils/es_conn.py +++ b/memory/utils/es_conn.py @@ -127,6 +127,11 @@ class ESConnection(ESConnectionBase): index_names = index_names.split(",") assert isinstance(index_names, list) and len(index_names) > 0 assert "_id" not in condition + + exist_index_list = [idx for idx in index_names if self.index_exist(idx)] + if not exist_index_list: + return None + bool_query = Q("bool", must=[], must_not=[]) if hide_forgotten: # filter not forget @@ -214,7 +219,7 @@ class ESConnection(ESConnectionBase): for i in range(ATTEMPT_TIME): try: #print(json.dumps(q, ensure_ascii=False)) - res = self.es.search(index=index_names, + res = self.es.search(index=exist_index_list, body=q, timeout="600s", # search_type="dfs_query_then_fetch", @@ -239,8 +244,8 @@ class ESConnection(ESConnectionBase): raise Exception("ESConnection.search timeout.") def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=512): - bool_query = Q("bool", must_not=[]) - bool_query.must_not.append(Q("term", forget_at=None)) + bool_query = Q("bool", must=[]) + bool_query.must.append(Q("exists", field="forget_at")) bool_query.filter.append(Q("term", memory_id=memory_id)) # from old to new order_by = OrderByExpr()