diff --git a/memory/services/messages.py b/memory/services/messages.py index 18b774432..8687f51a3 100644 --- a/memory/services/messages.py +++ b/memory/services/messages.py @@ -71,7 +71,7 @@ class MessageService: filter_dict["session_id"] = keywords order_by = OrderByExpr() order_by.desc("valid_at") - res = settings.msgStoreConn.search( + res, total_count = settings.msgStoreConn.search( select_fields=[ "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "invalid_at", "forget_at", "status" @@ -82,13 +82,12 @@ class MessageService: offset=(page-1)*page_size, limit=page_size, index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False ) - if not res: + if not total_count: return { "message_list": [], "total_count": 0 } - total_count = settings.msgStoreConn.get_total(res) doc_mapping = settings.msgStoreConn.get_fields(res, [ "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "invalid_at", "forget_at", "status" @@ -107,7 +106,7 @@ class MessageService: } order_by = OrderByExpr() order_by.desc("valid_at") - res = settings.msgStoreConn.search( + res, total_count = settings.msgStoreConn.search( select_fields=[ "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", "invalid_at", "forget_at", "status", "content" @@ -118,7 +117,7 @@ class MessageService: offset=0, limit=limit, index_names=index_names, memory_ids=memory_ids, agg_fields=[] ) - if not res: + if not total_count: return [] doc_mapping = settings.msgStoreConn.get_fields(res, [ @@ -136,7 +135,7 @@ class MessageService: order_by = OrderByExpr() order_by.desc("valid_at") - res = settings.msgStoreConn.search( + res, total_count = settings.msgStoreConn.search( select_fields=[ "message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at", @@ -149,7 +148,7 @@ class MessageService: offset=0, limit=top_n, index_names=index_names, memory_ids=memory_ids, agg_fields=[] ) - if not res: + if not total_count: return [] docs = settings.msgStoreConn.get_fields(res, [ @@ -211,7 +210,7 @@ class MessageService: order_by = OrderByExpr() order_by.asc("valid_at") - res = settings.msgStoreConn.search( + res, total_count = settings.msgStoreConn.search( select_fields=select_fields, highlight_fields=[], condition={}, @@ -240,7 +239,7 @@ class MessageService: order_by = OrderByExpr() order_by.desc("message_id") index_names = [index_name(uid) for uid in uid_list] - res = settings.msgStoreConn.search( + res, total_count = settings.msgStoreConn.search( select_fields=["message_id"], highlight_fields=[], condition={}, @@ -250,7 +249,7 @@ class MessageService: index_names=index_names, memory_ids=memory_ids, agg_fields=[], hide_forgotten=False ) - if not res: + if not total_count: return 1 docs = settings.msgStoreConn.get_fields(res, ["message_id"]) diff --git a/memory/utils/es_conn.py b/memory/utils/es_conn.py index 2c635ac51..77d16dc41 100644 --- a/memory/utils/es_conn.py +++ b/memory/utils/es_conn.py @@ -130,7 +130,7 @@ class ESConnection(ESConnectionBase): exist_index_list = [idx for idx in index_names if self.index_exist(idx)] if not exist_index_list: - return None + return None, 0 bool_query = Q("bool", must=[], must_not=[]) if hide_forgotten: