Feat: message manage (#12083)

### What problem does this PR solve?

Message CRUD.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Lynn
2025-12-23 21:16:25 +08:00
committed by GitHub
parent bab6a4a219
commit 17b8bb62b6
49 changed files with 3480 additions and 1031 deletions

0
memory/utils/__init__.py Normal file
View File

494
memory/utils/es_conn.py Normal file
View File

@ -0,0 +1,494 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import time
import copy
from elasticsearch import NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
@singleton
class ESConnection(ESConnectionBase):
@staticmethod
def convert_field_name(field_name: str) -> str:
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "content":
return "content_ltks"
case _:
return field_name
@staticmethod
def map_message_to_es_fields(message: dict) -> dict:
"""
Map message dictionary fields to Elasticsearch document/Infinity fields.
:param message: A dictionary containing message details.
:return: A dictionary formatted for Elasticsearch/Infinity indexing.
"""
storage_doc = {
"id": message.get("id"),
"message_id": message["message_id"],
"message_type_kwd": message["message_type"],
"source_id": message["source_id"],
"memory_id": message["memory_id"],
"user_id": message["user_id"],
"agent_id": message["agent_id"],
"session_id": message["session_id"],
"valid_at": message["valid_at"],
"invalid_at": message["invalid_at"],
"forget_at": message["forget_at"],
"status_int": 1 if message["status"] else 0,
"zone_id": message.get("zone_id", 0),
"content_ltks": message["content"],
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
}
return storage_doc
@staticmethod
def get_message_from_es_doc(doc: dict) -> dict:
"""
Convert an Elasticsearch/Infinity document back to a message dictionary.
:param doc: A dictionary representing the Elasticsearch/Infinity document.
:return: A dictionary formatted as a message.
"""
embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None)
message = {
"message_id": doc["message_id"],
"message_type": doc["message_type_kwd"],
"source_id": doc["source_id"] if doc["source_id"] else None,
"memory_id": doc["memory_id"],
"user_id": doc.get("user_id", ""),
"agent_id": doc["agent_id"],
"session_id": doc["session_id"],
"zone_id": doc.get("zone_id", 0),
"valid_at": doc["valid_at"],
"invalid_at": doc.get("invalid_at", "-"),
"forget_at": doc.get("forget_at", "-"),
"status": bool(int(doc["status_int"])),
"content": doc.get("content_ltks", ""),
"content_embed": doc.get(embd_field_name, []) if embd_field_name else [],
}
if doc.get("id"):
message["id"] = doc["id"]
return message
"""
CRUD operations
"""
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
memory_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
hide_forgotten: bool = True
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bool_query = Q("bool", must=[], must_not=[])
if hide_forgotten:
# filter not forget
bool_query.must_not.append(Q("exists", field="forget_at"))
condition["memory_id"] = memory_ids
for k, v in condition.items():
if k == "session_id" and v:
bool_query.filter.append(Q("query_string", **{"query": f"*{v}*", "fields": ["session_id"], "analyze_wildcard": True}))
continue
if not v:
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
s = s.knn(self.convert_field_name(m.vector_column_name),
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if order_by:
orders = list()
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field.endswith("_int") or field.endswith("_flt"):
order_info = {"order": order, "unmapped_type": "float"}
else:
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=index_names,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
bool_query = Q("bool", must_not=[])
bool_query.must_not.append(Q("term", forget_at=None))
bool_query.filter.append(Q("term", memory_id=memory_id))
# from old to new
order_by = OrderByExpr()
order_by.asc("forget_at")
# build search
s = Search()
s = s.query(bool_query)
s = s.sort(order_by)
s = s[:limit]
q = s.to_dict()
# search
for i in range(ATTEMPT_TIME):
try:
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
raise e
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
message = res["_source"]
message["id"] = doc_id
return self.get_message_from_es_doc(message)
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy_raw = copy.deepcopy(d)
d_copy = self.map_message_to_es_fields(d_copy_raw)
d_copy["memory_id"] = memory_id
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for item in r["items"]:
for action in ["create", "delete", "index", "update"]:
if action in item and "error" in item[action]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
doc = copy.deepcopy(new_value)
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
update_dict.pop("id", None)
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id
if "id" in condition_dict and isinstance(condition_dict["id"], str):
# update specific single document
message_id = condition_dict["id"]
for i in range(ATTEMPT_TIME):
for k in update_dict.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=index_name, id=message_id, doc=update_dict)
return True
except Exception as e:
self.logger.exception(
f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
break
return False
# update unspecific maybe-multiple documents
bool_query = Q("bool")
for k, v in condition_dict.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in update_dict.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
params[f"p_{kk}"] = vv
continue
if k == "add":
if isinstance(v, dict):
for kk, vv in v.items():
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
params[f"pp_{kk}"] = vv.strip()
continue
if (not isinstance(k, str) or not v) and k != "status_int":
continue
if isinstance(v, str):
v = re.sub(r"(['\n\r]|\\.)", " ", v)
params[f"pp_{k}"] = v
scripts.append(f"ctx._source.{k}=params.pp_{k};")
elif isinstance(v, int) or isinstance(v, float):
scripts.append(f"ctx._source.{k}={v};")
elif isinstance(v, list):
scripts.append(f"ctx._source.{k}=params.pp_{k};")
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
else:
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for _ in range(ATTEMPT_TIME):
try:
_ = ubq.execute()
return True
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, index_name: str, memory_id: str) -> int:
assert "_id" not in condition
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
condition_dict["memory_id"] = memory_id
if "id" in condition_dict:
message_ids = condition_dict["id"]
if not isinstance(message_ids, list):
message_ids = [message_ids]
if not message_ids: # when message_ids is empty, delete all
qry = Q("match_all")
else:
qry = Q("ids", values=message_ids)
else:
qry = Q("bool")
for k, v in condition_dict.items():
if k == "exists":
qry.filter.append(Q("exists", field=v))
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
qry.must_not.append(Q("exists", field=vv))
elif isinstance(v, list):
qry.must.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
"""
Helper functions for search result
"""
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for doc in self._get_source(res):
message = self.get_message_from_es_doc(doc)
m = {}
for n, v in message.items():
if n not in fields:
continue
if isinstance(v, list):
m[n] = v
continue
if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)):
m[n] = v
continue
if not isinstance(v, str):
m[n] = str(v)
else:
m[n] = v
if m:
res_fields[doc["id"]] = m
return res_fields

View File

@ -0,0 +1,467 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import copy
from infinity.common import InfinityException, SortType
from infinity.errors import ErrorCode
from common.decorator import singleton
import pandas as pd
from common.constants import PAGERANK_FLD, TAG_FLD
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.infinity_conn_base import InfinityConnectionBase
from common.time_utils import date_string_to_timestamp
@singleton
class InfinityConnection(InfinityConnectionBase):
def __init__(self):
super().__init__()
self.mapping_file_name = "message_infinity_mapping.json"
"""
Dataframe and fields convert
"""
@staticmethod
def field_keyword(field_name: str):
# no keywords right now
return False
@staticmethod
def convert_message_field_to_infinity(field_name: str):
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case _:
return field_name
@staticmethod
def convert_infinity_field_to_message(field_name: str):
if field_name.startswith("message_type"):
return "message_type"
if field_name.startswith("status"):
return "status"
if re.match(r"q_\d+_vec", field_name):
return "content_embed"
return field_name
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
return list({self.convert_message_field_to_infinity(f) for f in output_fields})
@staticmethod
def convert_matching_field(field_weight_str: str) -> str:
tokens = field_weight_str.split("^")
field = tokens[0]
if field == "content":
field = "content@ft_contentm_rag_fine"
tokens[0] = field
return "^".join(tokens)
@staticmethod
def convert_condition_and_order_field(field_name: str):
match field_name:
case "message_type":
return "message_type_kwd"
case "status":
return "status_int"
case "valid_at":
return "valid_at_flt"
case "invalid_at":
return "invalid_at_flt"
case "forget_at":
return "forget_at_flt"
case _:
return field_name
"""
CRUD operations
"""
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
memory_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
hide_forgotten: bool = True,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
"""
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
if hide_forgotten:
condition.update({"must_not": {"exists": "forget_at_flt"}})
output = select_fields.copy()
output = self.convert_select_fields(output)
if agg_fields is None:
agg_fields = []
for essential_field in ["id"] + agg_fields:
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
score_func = "score()"
score_column = "SCORE"
break
if not score_func:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchDenseExpr):
score_func = "similarity()"
score_column = "SIMILARITY"
break
if match_expressions:
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
output.append(PAGERANK_FLD)
output = [f for f in output if f != "_score"]
if limit <= 0:
# ElasticSearch default limit is 10000
limit = 10000
# Prepare expressions common to all tables
filter_cond = None
filter_fulltext = ""
if condition:
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
table_found = False
for indexName in index_names:
for mem_id in memory_ids:
table_name = f"{indexName}_{mem_id}"
try:
filter_cond = self.equivalent_condition_to_str(condition_dict, db_instance.get_table(table_name))
table_found = True
break
except Exception:
pass
if table_found:
break
if not table_found:
self.logger.error(f"No valid tables found for indexNames {index_names} and memoryIds {memory_ids}")
return pd.DataFrame(), 0
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
# Add rank_feature support
if rank_feature and "rank_features" not in matchExpr.extra_options:
# Convert rank_feature dict to Infinity's rank_features string format
# Format: "field^feature_name^weight,field^feature_name^weight"
rank_features_list = []
for feature_name, weight in rank_feature.items():
# Use TAG_FLD as the field containing rank features
rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
if rank_features_list:
matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
similarity = matchExpr.extra_options.get("similarity")
if similarity:
matchExpr.extra_options["threshold"] = similarity
del matchExpr.extra_options["similarity"]
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, FusionExpr):
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
total_hits_count = 0
# Scatter search tables and gather the results
for indexName in index_names:
for memory_id in memory_ids:
table_name = f"{indexName}_{memory_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
continue
table_list.append(table_name)
builder = table_instance.output(output)
if len(match_expressions) > 0:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options.copy(),
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params)
else:
if filter_cond and len(filter_cond) > 0:
builder.filter(filter_cond)
if order_by.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
mem_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(mem_res)}")
df_list.append(mem_res)
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, output)
if match_expressions:
res["_score"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
res = res.head(limit)
self.logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
condition = {"memory_id": memory_id, "exists": "forget_at_flt"}
order_by = OrderByExpr()
order_by.asc("forget_at_flt")
# query
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
output_fields = [self.convert_message_field_to_infinity(f) for f in select_fields]
builder = table_instance.output(output_fields)
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
builder.filter(filter_cond)
order_by_expr_list = list()
if order_by.fields:
for order_field in order_by.fields:
order_field_name = self.convert_condition_and_order_field(order_field[0])
if order_field[1] == 0:
order_by_expr_list.append((order_field_name, SortType.Asc))
else:
order_by_expr_list.append((order_field_name, SortType.Desc))
builder.sort(order_by_expr_list)
builder.offset(0).limit(limit)
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
res = self.concat_dataframes(mem_res, output_fields)
res.head(limit)
self.connPool.release_conn(inf_conn)
return res
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(memory_ids, list)
table_list = list()
for memoryId in memory_ids:
table_name = f"{index_name}_{memoryId}"
table_list.append(table_name)
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Table not found: {table_name}, this memory isn't created in Infinity. Maybe it is created in other document engine.")
continue
mem_res, _ = table_instance.output(["*"]).filter(f"id = '{message_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(mem_res)}")
df_list.append(mem_res)
self.connPool.release_conn(inf_conn)
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
res_fields = self.get_fields(res, list(fields))
return res_fields.get(message_id, None)
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
if not documents:
return []
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
vector_size = int(len(documents[0]["content_embed"]))
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
# src/common/status.cppm, kTableNotExist = 3022
if e.error_code != ErrorCode.TABLE_NOT_EXIST:
raise
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.create_idx(index_name, memory_id, vector_size)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
embedding_columns = []
table_columns = table_instance.show_columns().rows()
for n, ty, _, _ in table_columns:
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
if not r:
continue
embedding_columns.append((n, int(r.group(1))))
docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d
assert "id" in d
for k, v in list(d.items()):
field_name = self.convert_message_field_to_infinity(k)
if field_name in ["valid_at", "invalid_at", "forget_at"]:
d[f"{field_name}_flt"] = date_string_to_timestamp(v) if v else 0
if v is None:
d[field_name] = ""
elif self.field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
d[k] = v
elif k == "memory_id":
if isinstance(d[k], list):
d[k] = d[k][0] # since d[k] is a list, but we need a str
elif field_name == "content_embed":
d[f"q_{vector_size}_vec"] = d["content_embed"]
d.pop("content_embed")
else:
d[field_name] = v
if k != field_name:
d.pop(k)
for n, vs in embedding_columns:
if n in d:
continue
d[n] = [0] * vs
ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter)
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{memory_id}"
table_instance = db_instance.get_table(table_name)
columns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
columns[n] = (ty, de)
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
filter = self.equivalent_condition_to_str(condition_dict, table_instance)
update_dict = {self.convert_message_field_to_infinity(k): v for k, v in new_value.items()}
date_floats = {}
for k, v in update_dict.items():
if k in ["valid_at", "invalid_at", "forget_at"]:
date_floats[f"{k}_flt"] = date_string_to_timestamp(v) if v else 0
elif self.field_keyword(k):
if isinstance(v, list):
update_dict[k] = "###".join(v)
else:
update_dict[k] = v
elif k == "memory_id":
if isinstance(update_dict[k], list):
update_dict[k] = update_dict[k][0] # since d[k] is a list, but we need a str
else:
update_dict[k] = v
if date_floats:
update_dict.update(date_floats)
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
table_instance.update(filter, update_dict)
self.connPool.release_conn(inf_conn)
return True
"""
Helper functions for search result
"""
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
return {}
fields_all = fields.copy()
fields_all.append("id")
fields_all = {self.convert_message_field_to_infinity(f) for f in fields_all}
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
none_columns = [col for col in fields_all if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
res2.drop_duplicates(subset=["id"], inplace=True)
for column in list(res2.columns):
k = column.lower()
if self.field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
else:
pass
for column in ["content"]:
if column in res2:
del res2[column]
for column in none_columns:
res2[column] = None
res_dict = res2.set_index("id").to_dict(orient="index")
return {_id: {self.convert_infinity_field_to_message(k): v for k, v in doc.items()} for _id, doc in res_dict.items()}

37
memory/utils/msg_util.py Normal file
View File

@ -0,0 +1,37 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
def get_json_result_from_llm_response(response_str: str) -> dict:
"""
Parse the LLM response string to extract JSON content.
The function looks for the first and last curly braces to identify the JSON part.
If parsing fails, it returns an empty dictionary.
:param response_str: The response string from the LLM.
:return: A dictionary parsed from the JSON content in the response.
"""
try:
clean_str = response_str.strip()
if clean_str.startswith('```json'):
clean_str = clean_str[7:] # Remove the starting ```json
if clean_str.endswith('```'):
clean_str = clean_str[:-3] # Remove the ending ```
return json.loads(clean_str.strip())
except (ValueError, json.JSONDecodeError):
return {}

201
memory/utils/prompt_util.py Normal file
View File

@ -0,0 +1,201 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, List
from common.constants import MemoryType
from common.time_utils import current_timestamp
class PromptAssembler:
SYSTEM_BASE_TEMPLATE = """**Memory Extraction Specialist**
You are an expert at analyzing conversations to extract structured memory.
{type_specific_instructions}
**OUTPUT REQUIREMENTS:**
1. Output MUST be valid JSON
2. Follow the specified output format exactly
3. Each extracted item MUST have: content, valid_at, invalid_at
4. Timestamps in {timestamp_format} format
5. Only extract memory types specified above
6. Maximum {max_items} items per type
"""
TYPE_INSTRUCTIONS = {
MemoryType.SEMANTIC.name.lower(): """
**EXTRACT SEMANTIC KNOWLEDGE:**
- Universal facts, definitions, concepts, relationships
- Time-invariant, generally true information
- Examples: "The capital of France is Paris", "Water boils at 100°C"
**Timestamp Rules for Semantic Knowledge:**
- valid_at: When the fact became true (e.g., law enactment, discovery)
- invalid_at: When it becomes false (e.g., repeal, disproven) or empty if still true
- Default: valid_at = conversation time, invalid_at = "" for timeless facts
""",
MemoryType.EPISODIC.name.lower(): """
**EXTRACT EPISODIC KNOWLEDGE:**
- Specific experiences, events, personal stories
- Time-bound, person-specific, contextual
- Examples: "Yesterday I fixed the bug", "User reported issue last week"
**Timestamp Rules for Episodic Knowledge:**
- valid_at: Event start/occurrence time
- invalid_at: Event end time or empty if instantaneous
- Extract explicit times: "at 3 PM", "last Monday", "from X to Y"
""",
MemoryType.PROCEDURAL.name.lower(): """
**EXTRACT PROCEDURAL KNOWLEDGE:**
- Processes, methods, step-by-step instructions
- Goal-oriented, actionable, often includes conditions
- Examples: "To reset password, click...", "Debugging steps: 1)..."
**Timestamp Rules for Procedural Knowledge:**
- valid_at: When procedure becomes valid/effective
- invalid_at: When it expires/becomes obsolete or empty if current
- For version-specific: use release dates
- For best practices: invalid_at = ""
"""
}
OUTPUT_TEMPLATES = {
MemoryType.SEMANTIC.name.lower(): """
"semantic": [
{
"content": "Clear factual statement",
"valid_at": "timestamp or empty",
"invalid_at": "timestamp or empty"
}
]
""",
MemoryType.EPISODIC.name.lower(): """
"episodic": [
{
"content": "Narrative event description",
"valid_at": "event start timestamp",
"invalid_at": "event end timestamp or empty"
}
]
""",
MemoryType.PROCEDURAL.name.lower(): """
"procedural": [
{
"content": "Actionable instructions",
"valid_at": "procedure effective timestamp",
"invalid_at": "procedure expiration timestamp or empty"
}
]
"""
}
BASE_USER_PROMPT = """
**CONVERSATION:**
{conversation}
**CONVERSATION TIME:** {conversation_time}
**CURRENT TIME:** {current_time}
"""
@classmethod
def assemble_system_prompt(cls, config: dict) -> str:
types_to_extract = cls._get_types_to_extract(config["memory_type"])
type_instructions = cls._generate_type_instructions(types_to_extract)
output_format = cls._generate_output_format(types_to_extract)
full_prompt = cls.SYSTEM_BASE_TEMPLATE.format(
type_specific_instructions=type_instructions,
timestamp_format=config.get("timestamp_format", "ISO 8601"),
max_items=config.get("max_items_per_type", 5)
)
full_prompt += f"\n**REQUIRED OUTPUT FORMAT (JSON):**\n```json\n{{\n{output_format}\n}}\n```\n"
examples = cls._generate_examples(types_to_extract)
if examples:
full_prompt += f"\n**EXAMPLES:**\n{examples}\n"
return full_prompt
@staticmethod
def _get_types_to_extract(requested_types: List[str]) -> List[str]:
types = set()
for rt in requested_types:
if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower():
types.add(rt)
return list(types)
@classmethod
def _generate_type_instructions(cls, types_to_extract: List[str]) -> str:
target_types = set(types_to_extract)
instructions = [cls.TYPE_INSTRUCTIONS[mt] for mt in target_types]
return "\n".join(instructions)
@classmethod
def _generate_output_format(cls, types_to_extract: List[str]) -> str:
target_types = set(types_to_extract)
output_parts = [cls.OUTPUT_TEMPLATES[mt] for mt in target_types]
return ",\n".join(output_parts)
@staticmethod
def _generate_examples(types_to_extract: list[str]) -> str:
examples = []
if MemoryType.SEMANTIC.name.lower() in types_to_extract:
examples.append("""
**Semantic Example:**
Input: "Python lists are mutable and support various operations."
Output: {"semantic": [{"content": "Python lists are mutable data structures", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
""")
if MemoryType.EPISODIC.name.lower() in types_to_extract:
examples.append("""
**Episodic Example:**
Input: "I deployed the new feature yesterday afternoon."
Output: {"episodic": [{"content": "User deployed new feature", "valid_at": "2024-01-14T14:00:00", "invalid_at": "2024-01-14T18:00:00"}]}
""")
if MemoryType.PROCEDURAL.name.lower() in types_to_extract:
examples.append("""
**Procedural Example:**
Input: "To debug API errors: 1) Check logs 2) Verify endpoints 3) Test connectivity."
Output: {"procedural": [{"content": "API error debugging: 1. Check logs 2. Verify endpoints 3. Test connectivity", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
""")
return "\n".join(examples)
@classmethod
def assemble_user_prompt(
cls,
conversation: str,
conversation_time: Optional[str] = None,
current_time: Optional[str] = None
) -> str:
return cls.BASE_USER_PROMPT.format(
conversation=conversation,
conversation_time=conversation_time or "Not specified",
current_time=current_time or current_timestamp(),
)
@classmethod
def get_raw_user_prompt(cls):
return cls.BASE_USER_PROMPT