# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import re import json import time import copy from elasticsearch import NotFoundError from elasticsearch_dsl import UpdateByQuery, Q, Search from elastic_transport import ConnectionTimeout from common.decorator import singleton from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr from common.doc_store.es_conn_base import ESConnectionBase from common.float_utils import get_float from common.constants import PAGERANK_FLD, TAG_FLD ATTEMPT_TIME = 2 @singleton class ESConnection(ESConnectionBase): @staticmethod def convert_field_name(field_name: str) -> str: match field_name: case "message_type": return "message_type_kwd" case "status": return "status_int" case "content": return "content_ltks" case _: return field_name @staticmethod def map_message_to_es_fields(message: dict) -> dict: """ Map message dictionary fields to Elasticsearch document/Infinity fields. :param message: A dictionary containing message details. :return: A dictionary formatted for Elasticsearch/Infinity indexing. """ storage_doc = { "id": message.get("id"), "message_id": message["message_id"], "message_type_kwd": message["message_type"], "source_id": message["source_id"], "memory_id": message["memory_id"], "user_id": message["user_id"], "agent_id": message["agent_id"], "session_id": message["session_id"], "valid_at": message["valid_at"], "invalid_at": message["invalid_at"], "forget_at": message["forget_at"], "status_int": 1 if message["status"] else 0, "zone_id": message.get("zone_id", 0), "content_ltks": message["content"], f"q_{len(message['content_embed'])}_vec": message["content_embed"], } return storage_doc @staticmethod def get_message_from_es_doc(doc: dict) -> dict: """ Convert an Elasticsearch/Infinity document back to a message dictionary. :param doc: A dictionary representing the Elasticsearch/Infinity document. :return: A dictionary formatted as a message. """ embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None) message = { "message_id": doc["message_id"], "message_type": doc["message_type_kwd"], "source_id": doc["source_id"] if doc["source_id"] else None, "memory_id": doc["memory_id"], "user_id": doc.get("user_id", ""), "agent_id": doc["agent_id"], "session_id": doc["session_id"], "zone_id": doc.get("zone_id", 0), "valid_at": doc["valid_at"], "invalid_at": doc.get("invalid_at", "-"), "forget_at": doc.get("forget_at", "-"), "status": bool(int(doc["status_int"])), "content": doc.get("content_ltks", ""), "content_embed": doc.get(embd_field_name, []) if embd_field_name else [], } if doc.get("id"): message["id"] = doc["id"] return message """ CRUD operations """ def search( self, select_fields: list[str], highlight_fields: list[str], condition: dict, match_expressions: list[MatchExpr], order_by: OrderByExpr, offset: int, limit: int, index_names: str | list[str], memory_ids: list[str], agg_fields: list[str] | None = None, rank_feature: dict | None = None, hide_forgotten: bool = True ): """ Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html """ if isinstance(index_names, str): index_names = index_names.split(",") assert isinstance(index_names, list) and len(index_names) > 0 assert "_id" not in condition bool_query = Q("bool", must=[], must_not=[]) if hide_forgotten: # filter not forget bool_query.must_not.append(Q("exists", field="forget_at")) condition["memory_id"] = memory_ids for k, v in condition.items(): if k == "session_id" and v: bool_query.filter.append(Q("query_string", **{"query": f"*{v}*", "fields": ["session_id"], "analyze_wildcard": True})) continue if not v: continue if isinstance(v, list): bool_query.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") s = Search() vector_similarity_weight = 0.5 for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1], MatchDenseExpr) and isinstance( match_expressions[2], FusionExpr) weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) for m in match_expressions: if isinstance(m, MatchTextExpr): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields], type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) bool_query.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): assert (bool_query is not None) similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] s = s.knn(self.convert_field_name(m.vector_column_name), m.topn, m.topn * 2, query_vector=list(m.embedding_data), filter=bool_query.to_dict(), similarity=similarity, ) if bool_query and rank_feature: for fld, sc in rank_feature.items(): if fld != PAGERANK_FLD: fld = f"{TAG_FLD}.{fld}" bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) if bool_query: s = s.query(bool_query) for field in highlight_fields: s = s.highlight(field) if order_by: orders = list() for field, order in order_by.fields: order = "asc" if order == 0 else "desc" if field.endswith("_int") or field.endswith("_flt"): order_info = {"order": order, "unmapped_type": "float"} else: order_info = {"order": order, "unmapped_type": "text"} orders.append({field: order_info}) s = s.sort(*orders) if agg_fields: for fld in agg_fields: s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) if limit > 0: s = s[offset:offset + limit] q = s.to_dict() self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q)) for i in range(ATTEMPT_TIME): try: #print(json.dumps(q, ensure_ascii=False)) res = self.es.search(index=index_names, body=q, timeout="600s", # search_type="dfs_query_then_fetch", track_total_hits=True, _source=True) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res)) return res except ConnectionTimeout: self.logger.exception("ES request timeout") self._connect() continue except Exception as e: self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e)) raise e self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.search timeout.") def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000): bool_query = Q("bool", must_not=[]) bool_query.must_not.append(Q("term", forget_at=None)) bool_query.filter.append(Q("term", memory_id=memory_id)) # from old to new order_by = OrderByExpr() order_by.asc("forget_at") # build search s = Search() s = s.query(bool_query) s = s.sort(order_by) s = s[:limit] q = s.to_dict() # search for i in range(ATTEMPT_TIME): try: res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res)) return res except ConnectionTimeout: self.logger.exception("ES request timeout") self._connect() continue except Exception as e: self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e)) raise e self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.search timeout.") def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None: for i in range(ATTEMPT_TIME): try: res = self.es.get(index=index_name, id=doc_id, source=True, ) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") message = res["_source"] message["id"] = doc_id return self.get_message_from_es_doc(message) except NotFoundError: return None except Exception as e: self.logger.exception(f"ESConnection.get({doc_id}) got exception") raise e self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.get timeout.") def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]: # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html operations = [] for d in documents: assert "_id" not in d assert "id" in d d_copy_raw = copy.deepcopy(d) d_copy = self.map_message_to_es_fields(d_copy_raw) d_copy["memory_id"] = memory_id meta_id = d_copy.pop("id", "") operations.append( {"index": {"_index": index_name, "_id": meta_id}}) operations.append(d_copy) res = [] for _ in range(ATTEMPT_TIME): try: res = [] r = self.es.bulk(index=index_name, operations=operations, refresh=False, timeout="60s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res for item in r["items"]: for action in ["create", "delete", "index", "update"]: if action in item and "error" in item[action]: res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"])) return res except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: res.append(str(e)) self.logger.warning("ESConnection.insert got exception: " + str(e)) return res def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool: doc = copy.deepcopy(new_value) update_dict = {self.convert_field_name(k): v for k, v in doc.items()} update_dict.pop("id", None) condition_dict = {self.convert_field_name(k): v for k, v in condition.items()} condition_dict["memory_id"] = memory_id if "id" in condition_dict and isinstance(condition_dict["id"], str): # update specific single document message_id = condition_dict["id"] for i in range(ATTEMPT_TIME): for k in update_dict.keys(): if "feas" != k.split("_")[-1]: continue try: self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");") except Exception: self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") try: self.es.update(index=index_name, id=message_id, doc=update_dict) return True except Exception as e: self.logger.exception( f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e)) break return False # update unspecific maybe-multiple documents bool_query = Q("bool") for k, v in condition_dict.items(): if not isinstance(k, str) or not v: continue if k == "exists": bool_query.filter.append(Q("exists", field=v)) continue if isinstance(v, list): bool_query.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") scripts = [] params = {} for k, v in update_dict.items(): if k == "remove": if isinstance(v, str): scripts.append(f"ctx._source.remove('{v}');") if isinstance(v, dict): for kk, vv in v.items(): scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);") params[f"p_{kk}"] = vv continue if k == "add": if isinstance(v, dict): for kk, vv in v.items(): scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});") params[f"pp_{kk}"] = vv.strip() continue if (not isinstance(k, str) or not v) and k != "status_int": continue if isinstance(v, str): v = re.sub(r"(['\n\r]|\\.)", " ", v) params[f"pp_{k}"] = v scripts.append(f"ctx._source.{k}=params.pp_{k};") elif isinstance(v, int) or isinstance(v, float): scripts.append(f"ctx._source.{k}={v};") elif isinstance(v, list): scripts.append(f"ctx._source.{k}=params.pp_{k};") params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False) else: raise Exception( f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") ubq = UpdateByQuery( index=index_name).using( self.es).query(bool_query) ubq = ubq.script(source="".join(scripts), params=params) ubq = ubq.params(refresh=True) ubq = ubq.params(slices=5) ubq = ubq.params(conflicts="proceed") for _ in range(ATTEMPT_TIME): try: _ = ubq.execute() return True except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts)) break return False def delete(self, condition: dict, index_name: str, memory_id: str) -> int: assert "_id" not in condition condition_dict = {self.convert_field_name(k): v for k, v in condition.items()} condition_dict["memory_id"] = memory_id if "id" in condition_dict: message_ids = condition_dict["id"] if not isinstance(message_ids, list): message_ids = [message_ids] if not message_ids: # when message_ids is empty, delete all qry = Q("match_all") else: qry = Q("ids", values=message_ids) else: qry = Q("bool") for k, v in condition_dict.items(): if k == "exists": qry.filter.append(Q("exists", field=v)) elif k == "must_not": if isinstance(v, dict): for kk, vv in v.items(): if kk == "exists": qry.must_not.append(Q("exists", field=vv)) elif isinstance(v, list): qry.must.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): qry.must.append(Q("term", **{k: v})) else: raise Exception("Condition value must be int, str or list.") self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict())) for _ in range(ATTEMPT_TIME): try: res = self.es.delete_by_query( index=index_name, body=Search().query(qry).to_dict(), refresh=True) return res["deleted"] except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: self.logger.warning("ESConnection.delete got exception: " + str(e)) if re.search(r"(not_found)", str(e), re.IGNORECASE): return 0 return 0 """ Helper functions for search result """ def get_fields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} for doc in self._get_source(res): message = self.get_message_from_es_doc(doc) m = {} for n, v in message.items(): if n not in fields: continue if isinstance(v, list): m[n] = v continue if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)): m[n] = v continue if not isinstance(v, str): m[n] = str(v) else: m[n] = v if m: res_fields[doc["id"]] = m return res_fields