Files
ragflow/common/doc_store/es_conn_base.py
Lynn 17b8bb62b6 Feat: message manage (#12083)
### What problem does this PR solve?

Message CRUD.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-12-23 21:16:25 +08:00

327 lines
12 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import json
import time
import os
from abc import abstractmethod
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import Index
from elastic_transport import ConnectionTimeout
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
from rag.nlp import is_english, rag_tokenizer
from common import settings
ATTEMPT_TIME = 2
class ESConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'):
self.logger = logging.getLogger(logger_name)
self.info = {}
self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
self.logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
self.logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
self.logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name)
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
self.logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def db_type(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
self.logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
if self.index_exist(index_name, dataset_id):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=index_name,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
self.logger.exception("ESConnection.createIndex error %s" % index_name)
def delete_idx(self, index_name: str, dataset_id: str):
if len(dataset_id) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=index_name, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
self.logger.exception("ESConnection.deleteIdx error %s" % index_name)
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
s = Index(index_name, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(e)
break
return False
"""
CRUD operations
"""
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
doc = res["_source"]
doc["id"] = doc_id
return doc
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
@abstractmethod
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_doc_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def _get_source(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
@abstractmethod
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res, keywords: list[str], field_name: str):
ans = {}
for d in res["hits"]["hits"]:
highlights = d.get("highlight")
if not highlights:
continue
txt = "...".join([a for a in list(highlights.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][field_name]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]])
return ans
def get_aggregation(self, res, field_name: str):
agg_field = "aggs_" + field_name
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
buckets = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in buckets]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
self.logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
self.logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None