mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-24 23:46:52 +08:00
### What problem does this PR solve? Message CRUD. Issue #4213 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
327 lines
12 KiB
Python
327 lines
12 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import logging
|
|
import re
|
|
import json
|
|
import time
|
|
import os
|
|
from abc import abstractmethod
|
|
|
|
from elasticsearch import Elasticsearch, NotFoundError
|
|
from elasticsearch_dsl import Index
|
|
from elastic_transport import ConnectionTimeout
|
|
from common.file_utils import get_project_base_directory
|
|
from common.misc_utils import convert_bytes
|
|
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
|
|
from rag.nlp import is_english, rag_tokenizer
|
|
from common import settings
|
|
|
|
ATTEMPT_TIME = 2
|
|
|
|
|
|
class ESConnectionBase(DocStoreConnection):
|
|
def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'):
|
|
self.logger = logging.getLogger(logger_name)
|
|
|
|
self.info = {}
|
|
self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
|
|
for _ in range(ATTEMPT_TIME):
|
|
try:
|
|
if self._connect():
|
|
break
|
|
except Exception as e:
|
|
self.logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
|
|
time.sleep(5)
|
|
|
|
if not self.es.ping():
|
|
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
|
|
self.logger.error(msg)
|
|
raise Exception(msg)
|
|
v = self.info.get("version", {"number": "8.11.3"})
|
|
v = v["number"].split(".")[0]
|
|
if int(v) < 8:
|
|
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
|
|
self.logger.error(msg)
|
|
raise Exception(msg)
|
|
fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name)
|
|
if not os.path.exists(fp_mapping):
|
|
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
|
|
self.logger.error(msg)
|
|
raise Exception(msg)
|
|
self.mapping = json.load(open(fp_mapping, "r"))
|
|
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
|
|
|
|
def _connect(self):
|
|
self.es = Elasticsearch(
|
|
settings.ES["hosts"].split(","),
|
|
basic_auth=(settings.ES["username"], settings.ES[
|
|
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
|
verify_certs= settings.ES.get("verify_certs", False),
|
|
timeout=600 )
|
|
if self.es:
|
|
self.info = self.es.info()
|
|
return True
|
|
return False
|
|
|
|
"""
|
|
Database operations
|
|
"""
|
|
|
|
def db_type(self) -> str:
|
|
return "elasticsearch"
|
|
|
|
def health(self) -> dict:
|
|
health_dict = dict(self.es.cluster.health())
|
|
health_dict["type"] = "elasticsearch"
|
|
return health_dict
|
|
|
|
def get_cluster_stats(self):
|
|
"""
|
|
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
|
|
"""
|
|
raw_stats = self.es.cluster.stats()
|
|
self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
|
|
try:
|
|
res = {
|
|
'cluster_name': raw_stats['cluster_name'],
|
|
'status': raw_stats['status']
|
|
}
|
|
indices_status = raw_stats['indices']
|
|
res.update({
|
|
'indices': indices_status['count'],
|
|
'indices_shards': indices_status['shards']['total']
|
|
})
|
|
doc_info = indices_status['docs']
|
|
res.update({
|
|
'docs': doc_info['count'],
|
|
'docs_deleted': doc_info['deleted']
|
|
})
|
|
store_info = indices_status['store']
|
|
res.update({
|
|
'store_size': convert_bytes(store_info['size_in_bytes']),
|
|
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
|
|
})
|
|
mappings_info = indices_status['mappings']
|
|
res.update({
|
|
'mappings_fields': mappings_info['total_field_count'],
|
|
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
|
|
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
|
|
})
|
|
node_info = raw_stats['nodes']
|
|
res.update({
|
|
'nodes': node_info['count']['total'],
|
|
'nodes_version': node_info['versions'],
|
|
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
|
|
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
|
|
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
|
|
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
|
|
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
|
|
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
|
|
})
|
|
return res
|
|
|
|
except Exception as e:
|
|
self.logger.exception(f"ESConnection.get_cluster_stats: {e}")
|
|
return None
|
|
|
|
"""
|
|
Table operations
|
|
"""
|
|
|
|
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
|
if self.index_exist(index_name, dataset_id):
|
|
return True
|
|
try:
|
|
from elasticsearch.client import IndicesClient
|
|
return IndicesClient(self.es).create(index=index_name,
|
|
settings=self.mapping["settings"],
|
|
mappings=self.mapping["mappings"])
|
|
except Exception:
|
|
self.logger.exception("ESConnection.createIndex error %s" % index_name)
|
|
|
|
def delete_idx(self, index_name: str, dataset_id: str):
|
|
if len(dataset_id) > 0:
|
|
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
|
return
|
|
try:
|
|
self.es.indices.delete(index=index_name, allow_no_indices=True)
|
|
except NotFoundError:
|
|
pass
|
|
except Exception:
|
|
self.logger.exception("ESConnection.deleteIdx error %s" % index_name)
|
|
|
|
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
|
|
s = Index(index_name, self.es)
|
|
for i in range(ATTEMPT_TIME):
|
|
try:
|
|
return s.exists()
|
|
except ConnectionTimeout:
|
|
self.logger.exception("ES request timeout")
|
|
time.sleep(3)
|
|
self._connect()
|
|
continue
|
|
except Exception as e:
|
|
self.logger.exception(e)
|
|
break
|
|
return False
|
|
|
|
"""
|
|
CRUD operations
|
|
"""
|
|
|
|
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
|
|
for i in range(ATTEMPT_TIME):
|
|
try:
|
|
res = self.es.get(index=index_name,
|
|
id=doc_id, source=True, )
|
|
if str(res.get("timed_out", "")).lower() == "true":
|
|
raise Exception("Es Timeout.")
|
|
doc = res["_source"]
|
|
doc["id"] = doc_id
|
|
return doc
|
|
except NotFoundError:
|
|
return None
|
|
except Exception as e:
|
|
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
|
|
raise e
|
|
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
|
|
raise Exception("ESConnection.get timeout.")
|
|
|
|
@abstractmethod
|
|
def search(
|
|
self, select_fields: list[str],
|
|
highlight_fields: list[str],
|
|
condition: dict,
|
|
match_expressions: list[MatchExpr],
|
|
order_by: OrderByExpr,
|
|
offset: int,
|
|
limit: int,
|
|
index_names: str | list[str],
|
|
dataset_ids: list[str],
|
|
agg_fields: list[str] | None = None,
|
|
rank_feature: dict | None = None
|
|
):
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
@abstractmethod
|
|
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
"""
|
|
Helper functions for search result
|
|
"""
|
|
|
|
def get_total(self, res):
|
|
if isinstance(res["hits"]["total"], type({})):
|
|
return res["hits"]["total"]["value"]
|
|
return res["hits"]["total"]
|
|
|
|
def get_doc_ids(self, res):
|
|
return [d["_id"] for d in res["hits"]["hits"]]
|
|
|
|
def _get_source(self, res):
|
|
rr = []
|
|
for d in res["hits"]["hits"]:
|
|
d["_source"]["id"] = d["_id"]
|
|
d["_source"]["_score"] = d["_score"]
|
|
rr.append(d["_source"])
|
|
return rr
|
|
|
|
@abstractmethod
|
|
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
|
raise NotImplementedError("Not implemented")
|
|
|
|
def get_highlight(self, res, keywords: list[str], field_name: str):
|
|
ans = {}
|
|
for d in res["hits"]["hits"]:
|
|
highlights = d.get("highlight")
|
|
if not highlights:
|
|
continue
|
|
txt = "...".join([a for a in list(highlights.items())[0][1]])
|
|
if not is_english(txt.split()):
|
|
ans[d["_id"]] = txt
|
|
continue
|
|
|
|
txt = d["_source"][field_name]
|
|
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
|
txt_list = []
|
|
for t in re.split(r"[.?!;\n]", txt):
|
|
for w in keywords:
|
|
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
|
flags=re.IGNORECASE | re.MULTILINE)
|
|
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
|
continue
|
|
txt_list.append(t)
|
|
ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]])
|
|
|
|
return ans
|
|
|
|
def get_aggregation(self, res, field_name: str):
|
|
agg_field = "aggs_" + field_name
|
|
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
|
return list()
|
|
buckets = res["aggregations"][agg_field]["buckets"]
|
|
return [(b["key"], b["doc_count"]) for b in buckets]
|
|
|
|
"""
|
|
SQL
|
|
"""
|
|
|
|
def sql(self, sql: str, fetch_size: int, format: str):
|
|
self.logger.debug(f"ESConnection.sql get sql: {sql}")
|
|
sql = re.sub(r"[ `]+", " ", sql)
|
|
sql = sql.replace("%", "")
|
|
replaces = []
|
|
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
|
fld, v = r.group(1), r.group(3)
|
|
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
|
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
|
replaces.append(
|
|
("{}{}'{}'".format(
|
|
r.group(1),
|
|
r.group(2),
|
|
r.group(3)),
|
|
match))
|
|
|
|
for p, r in replaces:
|
|
sql = sql.replace(p, r, 1)
|
|
self.logger.debug(f"ESConnection.sql to es: {sql}")
|
|
|
|
for i in range(ATTEMPT_TIME):
|
|
try:
|
|
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
|
request_timeout="2s")
|
|
return res
|
|
except ConnectionTimeout:
|
|
self.logger.exception("ES request timeout")
|
|
time.sleep(3)
|
|
self._connect()
|
|
continue
|
|
except Exception as e:
|
|
self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
|
|
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
|
|
self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
|
|
return None
|