# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import re import json import time import copy from elasticsearch_dsl import UpdateByQuery, Q, Search from elastic_transport import ConnectionTimeout from common.decorator import singleton from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr from common.doc_store.es_conn_base import ESConnectionBase from common.float_utils import get_float from common.constants import PAGERANK_FLD, TAG_FLD ATTEMPT_TIME = 2 @singleton class ESConnection(ESConnectionBase): """ CRUD operations """ def search( self, select_fields: list[str], highlight_fields: list[str], condition: dict, match_expressions: list[MatchExpr], order_by: OrderByExpr, offset: int, limit: int, index_names: str | list[str], knowledgebase_ids: list[str], agg_fields: list[str] | None = None, rank_feature: dict | None = None ): """ Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html """ if isinstance(index_names, str): index_names = index_names.split(",") assert isinstance(index_names, list) and len(index_names) > 0 assert "_id" not in condition bool_query = Q("bool", must=[]) condition["kb_id"] = knowledgebase_ids for k, v in condition.items(): if k == "available_int": if v == 0: bool_query.filter.append(Q("range", available_int={"lt": 1})) else: bool_query.filter.append( Q("bool", must_not=Q("range", available_int={"lt": 1}))) continue if not v: continue if isinstance(v, list): bool_query.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") s = Search() vector_similarity_weight = 0.5 for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1], MatchDenseExpr) and isinstance( match_expressions[2], FusionExpr) weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) for m in match_expressions: if isinstance(m, MatchTextExpr): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" bool_query.must.append(Q("query_string", fields=m.fields, type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) bool_query.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): assert (bool_query is not None) similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] s = s.knn(m.vector_column_name, m.topn, m.topn * 2, query_vector=list(m.embedding_data), filter=bool_query.to_dict(), similarity=similarity, ) if bool_query and rank_feature: for fld, sc in rank_feature.items(): if fld != PAGERANK_FLD: fld = f"{TAG_FLD}.{fld}" bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) if bool_query: s = s.query(bool_query) for field in highlight_fields: s = s.highlight(field) if order_by: orders = list() for field, order in order_by.fields: order = "asc" if order == 0 else "desc" if field in ["page_num_int", "top_int"]: order_info = {"order": order, "unmapped_type": "float", "mode": "avg", "numeric_type": "double"} elif field.endswith("_int") or field.endswith("_flt"): order_info = {"order": order, "unmapped_type": "float"} else: order_info = {"order": order, "unmapped_type": "text"} orders.append({field: order_info}) s = s.sort(*orders) if agg_fields: for fld in agg_fields: s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) if limit > 0: s = s[offset:offset + limit] q = s.to_dict() self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q)) for i in range(ATTEMPT_TIME): try: #print(json.dumps(q, ensure_ascii=False)) res = self.es.search(index=index_names, body=q, timeout="600s", # search_type="dfs_query_then_fetch", track_total_hits=True, _source=True) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res)) return res except ConnectionTimeout: self.logger.exception("ES request timeout") self._connect() continue except Exception as e: self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e)) raise e self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!") raise Exception("ESConnection.search timeout.") def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html operations = [] for d in documents: assert "_id" not in d assert "id" in d d_copy = copy.deepcopy(d) d_copy["kb_id"] = knowledgebase_id meta_id = d_copy.pop("id", "") operations.append( {"index": {"_index": index_name, "_id": meta_id}}) operations.append(d_copy) res = [] for _ in range(ATTEMPT_TIME): try: res = [] r = self.es.bulk(index=index_name, operations=operations, refresh=False, timeout="60s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res for item in r["items"]: for action in ["create", "delete", "index", "update"]: if action in item and "error" in item[action]: res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"])) return res except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: res.append(str(e)) self.logger.warning("ESConnection.insert got exception: " + str(e)) return res def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool: doc = copy.deepcopy(new_value) doc.pop("id", None) condition["kb_id"] = knowledgebase_id if "id" in condition and isinstance(condition["id"], str): # update specific single document chunk_id = condition["id"] for i in range(ATTEMPT_TIME): for k in doc.keys(): if "feas" != k.split("_")[-1]: continue try: self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");") except Exception: self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception") try: self.es.update(index=index_name, id=chunk_id, doc=doc) return True except Exception as e: self.logger.exception( f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e)) break return False # update unspecific maybe-multiple documents bool_query = Q("bool") for k, v in condition.items(): if not isinstance(k, str) or not v: continue if k == "exists": bool_query.filter.append(Q("exists", field=v)) continue if isinstance(v, list): bool_query.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bool_query.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") scripts = [] params = {} for k, v in new_value.items(): if k == "remove": if isinstance(v, str): scripts.append(f"ctx._source.remove('{v}');") if isinstance(v, dict): for kk, vv in v.items(): scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);") params[f"p_{kk}"] = vv continue if k == "add": if isinstance(v, dict): for kk, vv in v.items(): scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});") params[f"pp_{kk}"] = vv.strip() continue if (not isinstance(k, str) or not v) and k != "available_int": continue if isinstance(v, str): v = re.sub(r"(['\n\r]|\\.)", " ", v) params[f"pp_{k}"] = v scripts.append(f"ctx._source.{k}=params.pp_{k};") elif isinstance(v, int) or isinstance(v, float): scripts.append(f"ctx._source.{k}={v};") elif isinstance(v, list): scripts.append(f"ctx._source.{k}=params.pp_{k};") params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False) else: raise Exception( f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") ubq = UpdateByQuery( index=index_name).using( self.es).query(bool_query) ubq = ubq.script(source="".join(scripts), params=params) ubq = ubq.params(refresh=True) ubq = ubq.params(slices=5) ubq = ubq.params(conflicts="proceed") for _ in range(ATTEMPT_TIME): try: _ = ubq.execute() return True except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts)) break return False def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int: assert "_id" not in condition condition["kb_id"] = knowledgebase_id if "id" in condition: chunk_ids = condition["id"] if not isinstance(chunk_ids, list): chunk_ids = [chunk_ids] if not chunk_ids: # when chunk_ids is empty, delete all qry = Q("match_all") else: qry = Q("ids", values=chunk_ids) else: qry = Q("bool") for k, v in condition.items(): if k == "exists": qry.filter.append(Q("exists", field=v)) elif k == "must_not": if isinstance(v, dict): for kk, vv in v.items(): if kk == "exists": qry.must_not.append(Q("exists", field=vv)) elif isinstance(v, list): qry.must.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): qry.must.append(Q("term", **{k: v})) else: raise Exception("Condition value must be int, str or list.") self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict())) for _ in range(ATTEMPT_TIME): try: res = self.es.delete_by_query( index=index_name, body=Search().query(qry).to_dict(), refresh=True) return res["deleted"] except ConnectionTimeout: self.logger.exception("ES request timeout") time.sleep(3) self._connect() continue except Exception as e: self.logger.warning("ESConnection.delete got exception: " + str(e)) if re.search(r"(not_found)", str(e), re.IGNORECASE): return 0 return 0 """ Helper functions for search result """ def get_fields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} for d in self._get_source(res): m = {n: d.get(n) for n in fields if d.get(n) is not None} for n, v in m.items(): if isinstance(v, list): m[n] = v continue if n == "available_int" and isinstance(v, (int, float)): m[n] = v continue if not isinstance(v, str): m[n] = str(m[n]) # if n.find("tks") > 0: # m[n] = remove_redundant_spaces(m[n]) if m: res_fields[d["id"]] = m return res_fields