Feat: message manage (#12083)

### What problem does this PR solve?

Message CRUD.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Lynn
2025-12-23 21:16:25 +08:00
committed by GitHub
parent bab6a4a219
commit 17b8bb62b6
49 changed files with 3480 additions and 1031 deletions

View File

@ -1,271 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
values: list[float] | list[int] | None = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
def to_dict_old(self):
d = {"indices": self.indices}
if self.values is not None:
d["values"] = self.values
return d
def to_dict(self):
if self.values is None:
raise ValueError("SparseVector.values is None")
result = {}
for i, v in zip(self.indices, self.values):
result[str(i)] = v
return result
@staticmethod
def from_dict(d):
return SparseVector(d["indices"], d.get("values"))
def __str__(self):
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
def __init__(
self,
fields: list[str],
matching_text: str,
topn: int,
extra_options: dict = dict(),
):
self.fields = fields
self.matching_text = matching_text
self.topn = topn
self.extra_options = extra_options
class MatchDenseExpr(ABC):
def __init__(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
self.embedding_data_type = embedding_data_type
self.distance_type = distance_type
self.topn = topn
self.extra_options = extra_options
class MatchSparseExpr(ABC):
def __init__(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: dict | None = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
self.distance_type = distance_type
self.topn = topn
self.opt_params = opt_params
class MatchTensorExpr(ABC):
def __init__(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: dict | None = None,
):
self.column_name = column_name
self.query_data = query_data
self.query_data_type = query_data_type
self.topn = topn
self.extra_option = extra_option
class FusionExpr(ABC):
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr(ABC):
def __init__(self):
self.fields = list()
def asc(self, field: str):
self.fields.append((field, 0))
return self
def desc(self, field: str):
self.fields.append((field, 1))
return self
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
"""
Return the type of the database.
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def health(self) -> dict:
"""
Return the health status of the database.
"""
raise NotImplementedError("Not implemented")
"""
Table operations
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
"""
Check if an index with given name exists
"""
raise NotImplementedError("Not implemented")
"""
CRUD operations
"""
@abstractmethod
def search(
self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
):
"""
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
@abstractmethod
def get_total(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_chunk_ids(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def get_highlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_aggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""
raise NotImplementedError("Not implemented")

View File

@ -14,194 +14,92 @@
# limitations under the License.
#
import logging
import re
import json
import time
import os
import copy
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elasticsearch_dsl import UpdateByQuery, Q, Search
from elastic_transport import ConnectionTimeout
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr
from common.doc_store.es_conn_base import ESConnectionBase
from common.float_utils import get_float
from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
ATTEMPT_TIME = 2
logger = logging.getLogger('ragflow.es_conn')
@singleton
class ESConnection(DocStoreConnection):
def __init__(self):
self.info = {}
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
time.sleep(5)
if not self.es.ping():
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
logger.error(msg)
raise Exception(msg)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
self.es = Elasticsearch(
settings.ES["hosts"].split(","),
basic_auth=(settings.ES["username"], settings.ES[
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
verify_certs= settings.ES.get("verify_certs", False),
timeout=600 )
if self.es:
self.info = self.es.info()
return True
return False
"""
Database operations
"""
def dbType(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=indexName,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
logger.exception("ESConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=indexName, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
logger.exception("ESConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
s = Index(indexName, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.exception(e)
break
return False
class ESConnection(ESConnectionBase):
"""
CRUD operations
"""
def search(
self, selectFields: list[str],
highlightFields: list[str],
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
assert "_id" not in condition
bqry = Q("bool", must=[])
condition["kb_id"] = knowledgebaseIds
bool_query = Q("bool", must=[])
condition["kb_id"] = knowledgebase_ids
for k, v in condition.items():
if k == "available_int":
if v == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
bool_query.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(
bool_query.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
s = Search()
vector_similarity_weight = 0.5
for m in matchExprs:
for m in match_expressions:
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
MatchDenseExpr) and isinstance(
matchExprs[2], FusionExpr)
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
MatchDenseExpr) and isinstance(
match_expressions[2], FusionExpr)
weights = m.fusion_params["weights"]
vector_similarity_weight = get_float(weights.split(",")[1])
for m in matchExprs:
for m in match_expressions:
if isinstance(m, MatchTextExpr):
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
if isinstance(minimum_should_match, float):
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
bqry.must.append(Q("query_string", fields=m.fields,
bool_query.must.append(Q("query_string", fields=m.fields,
type="best_fields", query=m.matching_text,
minimum_should_match=minimum_should_match,
boost=1))
bqry.boost = 1.0 - vector_similarity_weight
bool_query.boost = 1.0 - vector_similarity_weight
elif isinstance(m, MatchDenseExpr):
assert (bqry is not None)
assert (bool_query is not None)
similarity = 0.0
if "similarity" in m.extra_options:
similarity = m.extra_options["similarity"]
@ -209,24 +107,24 @@ class ESConnection(DocStoreConnection):
m.topn,
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
filter=bool_query.to_dict(),
similarity=similarity,
)
if bqry and rank_feature:
if bool_query and rank_feature:
for fld, sc in rank_feature.items():
if fld != PAGERANK_FLD:
fld = f"{TAG_FLD}.{fld}"
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
if bqry:
s = s.query(bqry)
for field in highlightFields:
if bool_query:
s = s.query(bool_query)
for field in highlight_fields:
s = s.highlight(field)
if orderBy:
if order_by:
orders = list()
for field, order in orderBy.fields:
for field, order in order_by.fields:
order = "asc" if order == 0 else "desc"
if field in ["page_num_int", "top_int"]:
order_info = {"order": order, "unmapped_type": "float",
@ -237,19 +135,19 @@ class ESConnection(DocStoreConnection):
order_info = {"order": order, "unmapped_type": "text"}
orders.append({field: order_info})
s = s.sort(*orders)
for fld in aggFields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if agg_fields:
for fld in agg_fields:
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
if limit > 0:
s = s[offset:offset + limit]
q = s.to_dict()
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
for i in range(ATTEMPT_TIME):
try:
#print(json.dumps(q, ensure_ascii=False))
res = self.es.search(index=indexNames,
res = self.es.search(index=index_names,
body=q,
timeout="600s",
# search_type="dfs_query_then_fetch",
@ -257,55 +155,37 @@ class ESConnection(DocStoreConnection):
_source=True)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
self._connect()
continue
except Exception as e:
logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e))
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
raise e
logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.search timeout.")
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=(indexName),
id=chunkId, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
chunk = res["_source"]
chunk["id"] = chunkId
return chunk
except NotFoundError:
return None
except Exception as e:
logger.exception(f"ESConnection.get({chunkId}) got exception")
raise e
logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
operations = []
for d in documents:
assert "_id" not in d
assert "id" in d
d_copy = copy.deepcopy(d)
d_copy["kb_id"] = knowledgebaseId
d_copy["kb_id"] = knowledgebase_id
meta_id = d_copy.pop("id", "")
operations.append(
{"index": {"_index": indexName, "_id": meta_id}})
{"index": {"_index": index_name, "_id": meta_id}})
operations.append(d_copy)
res = []
for _ in range(ATTEMPT_TIME):
try:
res = []
r = self.es.bulk(index=(indexName), operations=operations,
r = self.es.bulk(index=index_name, operations=operations,
refresh=False, timeout="60s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
@ -316,58 +196,58 @@ class ESConnection(DocStoreConnection):
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
res.append(str(e))
logger.warning("ESConnection.insert got exception: " + str(e))
self.logger.warning("ESConnection.insert got exception: " + str(e))
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
doc = copy.deepcopy(newValue)
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
doc = copy.deepcopy(new_value)
doc.pop("id", None)
condition["kb_id"] = knowledgebaseId
condition["kb_id"] = knowledgebase_id
if "id" in condition and isinstance(condition["id"], str):
# update specific single document
chunkId = condition["id"]
chunk_id = condition["id"]
for i in range(ATTEMPT_TIME):
for k in doc.keys():
if "feas" != k.split("_")[-1]:
continue
try:
self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");")
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
except Exception:
logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
try:
self.es.update(index=indexName, id=chunkId, doc=doc)
self.es.update(index=index_name, id=chunk_id, doc=doc)
return True
except Exception as e:
logger.exception(
f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e))
self.logger.exception(
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
break
return False
# update unspecific maybe-multiple documents
bqry = Q("bool")
bool_query = Q("bool")
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if k == "exists":
bqry.filter.append(Q("exists", field=v))
bool_query.filter.append(Q("exists", field=v))
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
bool_query.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):
bqry.filter.append(Q("term", **{k: v}))
bool_query.filter.append(Q("term", **{k: v}))
else:
raise Exception(
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
scripts = []
params = {}
for k, v in newValue.items():
for k, v in new_value.items():
if k == "remove":
if isinstance(v, str):
scripts.append(f"ctx._source.remove('{v}');")
@ -397,8 +277,8 @@ class ESConnection(DocStoreConnection):
raise Exception(
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
ubq = UpdateByQuery(
index=indexName).using(
self.es).query(bqry)
index=index_name).using(
self.es).query(bool_query)
ubq = ubq.script(source="".join(scripts), params=params)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
@ -409,19 +289,18 @@ class ESConnection(DocStoreConnection):
_ = ubq.execute()
return True
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
break
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
qry = None
def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int:
assert "_id" not in condition
condition["kb_id"] = knowledgebaseId
condition["kb_id"] = knowledgebase_id
if "id" in condition:
chunk_ids = condition["id"]
if not isinstance(chunk_ids, list):
@ -448,21 +327,21 @@ class ESConnection(DocStoreConnection):
qry.must.append(Q("term", **{k: v}))
else:
raise Exception("Condition value must be int, str or list.")
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
for _ in range(ATTEMPT_TIME):
try:
res = self.es.delete_by_query(
index=indexName,
index=index_name,
body=Search().query(qry).to_dict(),
refresh=True)
return res["deleted"]
except ConnectionTimeout:
logger.exception("ES request timeout")
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.warning("ESConnection.delete got exception: " + str(e))
self.logger.warning("ESConnection.delete got exception: " + str(e))
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return 0
return 0
@ -471,27 +350,11 @@ class ESConnection(DocStoreConnection):
Helper functions for search result
"""
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {}
if not fields:
return {}
for d in self.__getSource(res):
for d in self._get_source(res):
m = {n: d.get(n) for n in fields if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, list):
@ -508,124 +371,3 @@ class ESConnection(DocStoreConnection):
if m:
res_fields[d["id"]] = m
return res_fields
def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
txt = "...".join([a for a in list(hlts.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][fieldnm]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
return ans
def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
bkts = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None

View File

@ -14,365 +14,125 @@
# limitations under the License.
#
import logging
import os
import re
import json
import time
import copy
import infinity
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.common import InfinityException, SortType
from infinity.errors import ErrorCode
from common.decorator import singleton
import pandas as pd
from common.file_utils import get_project_base_directory
from rag.nlp import is_english
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
from rag.utils.doc_store_conn import (
DocStoreConnection,
MatchExpr,
MatchTextExpr,
MatchDenseExpr,
FusionExpr,
OrderByExpr,
)
logger = logging.getLogger("ragflow.infinity_conn")
def field_keyword(field_name: str):
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", "question_kwd"]):
return True
return False
def convert_select_fields(output_fields: list[str]) -> list[str]:
for i, field in enumerate(output_fields):
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
output_fields[i] = "docnm"
elif field in ["important_kwd", "important_tks"]:
output_fields[i] = "important_keywords"
elif field in ["question_kwd", "question_tks"]:
output_fields[i] = "questions"
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
output_fields[i] = "content"
elif field in ["authors_tks", "authors_sm_tks"]:
output_fields[i] = "authors"
return list(set(output_fields))
def convert_matching_field(field_weightstr: str) -> str:
tokens = field_weightstr.split("^")
field = tokens[0]
if field == "docnm_kwd" or field == "title_tks":
field = "docnm@ft_docnm_rag_coarse"
elif field == "title_sm_tks":
field = "docnm@ft_docnm_rag_fine"
elif field == "important_kwd":
field = "important_keywords@ft_important_keywords_rag_coarse"
elif field == "important_tks":
field = "important_keywords@ft_important_keywords_rag_fine"
elif field == "question_kwd":
field = "questions@ft_questions_rag_coarse"
elif field == "question_tks":
field = "questions@ft_questions_rag_fine"
elif field == "content_with_weight" or field == "content_ltks":
field = "content@ft_content_rag_coarse"
elif field == "content_sm_ltks":
field = "content@ft_content_rag_fine"
elif field == "authors_tks":
field = "authors@ft_authors_rag_coarse"
elif field == "authors_sm_tks":
field = "authors@ft_authors_rag_fine"
tokens[0] = field
return "^".join(tokens)
def list2str(lst: str|list, sep: str = " ") -> str:
if isinstance(lst, str):
return lst
return sep.join(lst)
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
clmns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)
def exists(cln):
nonlocal clmns
assert cln in clmns, f"'{cln}' should be in '{clmns}'."
ty, de = clmns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if field_keyword(k):
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{convert_matching_field(k)}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{convert_matching_field(k)}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in selectFields:
if field_name == "score()": # Workaround: fix schema is changed to score()
schema.append("SCORE")
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
schema.append("SIMILARITY")
else:
schema.append(field_name)
return pd.DataFrame(columns=schema)
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from common.doc_store.infinity_conn_base import InfinityConnectionBase
@singleton
class InfinityConnection(DocStoreConnection):
def __init__(self):
self.dbName = settings.INFINITY.get("db_name", "default_db")
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = None
logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
for _ in range(24):
try:
connPool = ConnectionPool(infinity_uri, max_size=4)
inf_conn = connPool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = connPool
connPool.release_conn(inf_conn)
break
connPool.release_conn(inf_conn)
logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
except Exception as e:
logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
logger.error(msg)
raise Exception(msg)
logger.info(f"Infinity {infinity_uri} is healthy.")
def _migrate_db(self, inf_conn):
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
table_names = inf_db.list_tables().table_names
for table_name in table_names:
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
# Skip tables not created by me
continue
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
if field_name in column_names:
continue
res = inf_table.add_columns({field_name: field_info})
assert res.error_code == infinity.ErrorCode.OK
logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
class InfinityConnection(InfinityConnectionBase):
"""
Database operations
Dataframe and fields convert
"""
def dbType(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
"""
inf_conn = self.connPool.get_conn()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vectorSize}_vec"
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}")
def deleteIdx(self, indexName: str, knowledgebaseId: str):
table_name = f"{indexName}_{knowledgebaseId}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY dropped table {table_name}")
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
table_name = f"{indexName}_{knowledgebaseId}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
@staticmethod
def field_keyword(field_name: str):
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
if field_name == "source_id" or (
field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd",
"question_kwd"]):
return True
except Exception as e:
logger.warning(f"INFINITY indexExist {str(e)}")
return False
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
for i, field in enumerate(output_fields):
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
output_fields[i] = "docnm"
elif field in ["important_kwd", "important_tks"]:
output_fields[i] = "important_keywords"
elif field in ["question_kwd", "question_tks"]:
output_fields[i] = "questions"
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
output_fields[i] = "content"
elif field in ["authors_tks", "authors_sm_tks"]:
output_fields[i] = "authors"
return list(set(output_fields))
@staticmethod
def convert_matching_field(field_weight_str: str) -> str:
tokens = field_weight_str.split("^")
field = tokens[0]
if field == "docnm_kwd" or field == "title_tks":
field = "docnm@ft_docnm_rag_coarse"
elif field == "title_sm_tks":
field = "docnm@ft_docnm_rag_fine"
elif field == "important_kwd":
field = "important_keywords@ft_important_keywords_rag_coarse"
elif field == "important_tks":
field = "important_keywords@ft_important_keywords_rag_fine"
elif field == "question_kwd":
field = "questions@ft_questions_rag_coarse"
elif field == "question_tks":
field = "questions@ft_questions_rag_fine"
elif field == "content_with_weight" or field == "content_ltks":
field = "content@ft_content_rag_coarse"
elif field == "content_sm_ltks":
field = "content@ft_content_rag_fine"
elif field == "authors_tks":
field = "authors@ft_authors_rag_coarse"
elif field == "authors_sm_tks":
field = "authors@ft_authors_rag_fine"
tokens[0] = field
return "^".join(tokens)
"""
CRUD operations
"""
def search(
self,
selectFields: list[str],
highlightFields: list[str],
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
index_names: str | list[str],
knowledgebase_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
"""
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
"""
if isinstance(indexNames, str):
indexNames = indexNames.split(",")
assert isinstance(indexNames, list) and len(indexNames) > 0
if isinstance(index_names, str):
index_names = index_names.split(",")
assert isinstance(index_names, list) and len(index_names) > 0
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
table_list = list()
output = selectFields.copy()
output = convert_select_fields(output)
for essential_field in ["id"] + aggFields:
output = select_fields.copy()
output = self.convert_select_fields(output)
if agg_fields is None:
agg_fields = []
for essential_field in ["id"] + agg_fields:
if essential_field not in output:
output.append(essential_field)
score_func = ""
score_column = ""
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
score_func = "score()"
score_column = "SCORE"
break
if not score_func:
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchDenseExpr):
score_func = "similarity()"
score_column = "SIMILARITY"
break
if matchExprs:
if match_expressions:
if score_func not in output:
output.append(score_func)
if PAGERANK_FLD not in output:
@ -387,11 +147,11 @@ class InfinityConnection(DocStoreConnection):
filter_fulltext = ""
if condition:
table_found = False
for indexName in indexNames:
for kb_id in knowledgebaseIds:
for indexName in index_names:
for kb_id in knowledgebase_ids:
table_name = f"{indexName}_{kb_id}"
try:
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
table_found = True
break
except Exception:
@ -399,14 +159,14 @@ class InfinityConnection(DocStoreConnection):
if table_found:
break
if not table_found:
logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}")
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
return pd.DataFrame(), 0
for matchExpr in matchExprs:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
matchExpr.fields = [convert_matching_field(field) for field in matchExpr.fields]
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
@ -430,7 +190,7 @@ class InfinityConnection(DocStoreConnection):
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, MatchDenseExpr):
if filter_fulltext and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_fulltext})
@ -441,16 +201,16 @@ class InfinityConnection(DocStoreConnection):
if similarity:
matchExpr.extra_options["threshold"] = similarity
del matchExpr.extra_options["similarity"]
logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
elif isinstance(matchExpr, FusionExpr):
if matchExpr.method == "weighted_sum":
# The default is "minmax" which gives a zero score for the last doc.
matchExpr.fusion_params["normalize"] = "atan"
logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
order_by_expr_list = list()
if orderBy.fields:
for order_field in orderBy.fields:
if order_by.fields:
for order_field in order_by.fields:
if order_field[1] == 0:
order_by_expr_list.append((order_field[0], SortType.Asc))
else:
@ -458,8 +218,8 @@ class InfinityConnection(DocStoreConnection):
total_hits_count = 0
# Scatter search tables and gather the results
for indexName in indexNames:
for knowledgebaseId in knowledgebaseIds:
for indexName in index_names:
for knowledgebaseId in knowledgebase_ids:
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
@ -467,8 +227,8 @@ class InfinityConnection(DocStoreConnection):
continue
table_list.append(table_name)
builder = table_instance.output(output)
if len(matchExprs) > 0:
for matchExpr in matchExprs:
if len(match_expressions) > 0:
for matchExpr in match_expressions:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
@ -491,53 +251,52 @@ class InfinityConnection(DocStoreConnection):
else:
if filter_cond and len(filter_cond) > 0:
builder.filter(filter_cond)
if orderBy.fields:
if order_by.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
if extra_result:
total_hits_count += int(extra_result["total_hits_count"])
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, output)
if matchExprs:
res = self.concat_dataframes(df_list, output)
if match_expressions:
res["_score"] = res[score_column] + res[PAGERANK_FLD]
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
res = res.head(limit)
logger.debug(f"INFINITY search final result: {str(res)}")
self.logger.debug(f"INFINITY search final result: {str(res)}")
return res, total_hits_count
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
df_list = list()
assert isinstance(knowledgebaseIds, list)
assert isinstance(knowledgebase_ids, list)
table_list = list()
for knowledgebaseId in knowledgebaseIds:
table_name = f"{indexName}_{knowledgebaseId}"
for knowledgebaseId in knowledgebase_ids:
table_name = f"{index_name}_{knowledgebaseId}"
table_list.append(table_name)
table_instance = None
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
continue
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res = self.concat_dataframes(df_list, ["id"])
fields = set(res.columns.tolist())
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
fields.add(field)
res_fields = self.get_fields(res, list(fields))
return res_fields.get(chunkId, None)
return res_fields.get(chunk_id, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_name = f"{index_name}_{knowledgebase_id}"
try:
table_instance = db_instance.get_table(table_name)
except InfinityException as e:
@ -553,7 +312,7 @@ class InfinityConnection(DocStoreConnection):
break
if vector_size == 0:
raise ValueError("Cannot infer vector size from documents")
self.createIdx(indexName, knowledgebaseId, vector_size)
self.create_idx(index_name, knowledgebase_id, vector_size)
table_instance = db_instance.get_table(table_name)
# embedding fields can't have a default value....
@ -574,12 +333,12 @@ class InfinityConnection(DocStoreConnection):
d["docnm"] = v
elif k == "title_kwd":
if not d.get("docnm_kwd"):
d["docnm"] = list2str(v)
d["docnm"] = self.list2str(v)
elif k == "title_sm_tks":
if not d.get("docnm_kwd"):
d["docnm"] = list2str(v)
d["docnm"] = self.list2str(v)
elif k == "important_kwd":
d["important_keywords"] = list2str(v)
d["important_keywords"] = self.list2str(v)
elif k == "important_tks":
if not d.get("important_kwd"):
d["important_keywords"] = v
@ -597,11 +356,11 @@ class InfinityConnection(DocStoreConnection):
if not d.get("authors_tks"):
d["authors"] = v
elif k == "question_kwd":
d["questions"] = list2str(v, "\n")
d["questions"] = self.list2str(v, "\n")
elif k == "question_tks":
if not d.get("question_kwd"):
d["questions"] = list2str(v)
elif field_keyword(k):
d["questions"] = self.list2str(v)
elif self.field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
@ -637,15 +396,15 @@ class InfinityConnection(DocStoreConnection):
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
# if 'position_int' in newValue:
# logger.info(f"update position_int: {newValue['position_int']}")
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
table_name = f"{index_name}_{knowledgebase_id}"
table_instance = db_instance.get_table(table_name)
# if "exists" in condition:
# del condition["exists"]
@ -654,57 +413,57 @@ class InfinityConnection(DocStoreConnection):
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
clmns[n] = (ty, de)
filter = equivalent_condition_to_str(condition, table_instance)
filter = self.equivalent_condition_to_str(condition, table_instance)
removeValue = {}
for k, v in list(newValue.items()):
for k, v in list(new_value.items()):
if k == "docnm_kwd":
newValue["docnm"] = list2str(v)
new_value["docnm"] = self.list2str(v)
elif k == "title_kwd":
if not newValue.get("docnm_kwd"):
newValue["docnm"] = list2str(v)
if not new_value.get("docnm_kwd"):
new_value["docnm"] = self.list2str(v)
elif k == "title_sm_tks":
if not newValue.get("docnm_kwd"):
newValue["docnm"] = v
if not new_value.get("docnm_kwd"):
new_value["docnm"] = v
elif k == "important_kwd":
newValue["important_keywords"] = list2str(v)
new_value["important_keywords"] = self.list2str(v)
elif k == "important_tks":
if not newValue.get("important_kwd"):
newValue["important_keywords"] = v
if not new_value.get("important_kwd"):
new_value["important_keywords"] = v
elif k == "content_with_weight":
newValue["content"] = v
new_value["content"] = v
elif k == "content_ltks":
if not newValue.get("content_with_weight"):
newValue["content"] = v
if not new_value.get("content_with_weight"):
new_value["content"] = v
elif k == "content_sm_ltks":
if not newValue.get("content_with_weight"):
newValue["content"] = v
if not new_value.get("content_with_weight"):
new_value["content"] = v
elif k == "authors_tks":
newValue["authors"] = v
new_value["authors"] = v
elif k == "authors_sm_tks":
if not newValue.get("authors_tks"):
newValue["authors"] = v
if not new_value.get("authors_tks"):
new_value["authors"] = v
elif k == "question_kwd":
newValue["questions"] = "\n".join(v)
new_value["questions"] = "\n".join(v)
elif k == "question_tks":
if not newValue.get("question_kwd"):
newValue["questions"] = list2str(v)
elif field_keyword(k):
if not new_value.get("question_kwd"):
new_value["questions"] = self.list2str(v)
elif self.field_keyword(k):
if isinstance(v, list):
newValue[k] = "###".join(v)
new_value[k] = "###".join(v)
else:
newValue[k] = v
new_value[k] = v
elif re.search(r"_feas$", k):
newValue[k] = json.dumps(v)
new_value[k] = json.dumps(v)
elif k == "kb_id":
if isinstance(newValue[k], list):
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
if isinstance(new_value[k], list):
new_value[k] = new_value[k][0] # since d[k] is a list, but we need a str
elif k == "position_int":
assert isinstance(v, list)
arr = [num for row in v for num in row]
newValue[k] = "_".join(f"{num:08x}" for num in arr)
new_value[k] = "_".join(f"{num:08x}" for num in arr)
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
newValue[k] = "_".join(f"{num:08x}" for num in v)
new_value[k] = "_".join(f"{num:08x}" for num in v)
elif k == "remove":
if isinstance(v, str):
assert v in clmns, f"'{v}' should be in '{clmns}'."
@ -712,22 +471,22 @@ class InfinityConnection(DocStoreConnection):
if ty.lower().find("cha"):
if not de:
de = ""
newValue[v] = de
new_value[v] = de
else:
for kk, vv in v.items():
removeValue[kk] = vv
del newValue[k]
del new_value[k]
else:
newValue[k] = v
new_value[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in newValue:
del newValue[k]
if k in new_value:
del new_value[k]
remove_opt = {} # "[k,new_value]": [id_to_update, ...]
if removeValue:
col_to_remove = list(removeValue.keys())
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
self.logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
for id, old_v in row_to_opt.items():
for k, remove_v in removeValue.items():
@ -740,78 +499,53 @@ class InfinityConnection(DocStoreConnection):
else:
remove_opt[kv_key].append(id)
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
for update_kv, ids in remove_opt.items():
k, v = json.loads(update_kv)
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
table_instance.update(filter, newValue)
table_instance.update(filter, new_value)
self.connPool.release_conn(inf_conn)
return True
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
return 0
filter = equivalent_condition_to_str(condition, table_instance)
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple):
res = res[0]
if not fields:
return {}
fieldsAll = fields.copy()
fieldsAll.append("id")
fieldsAll = set(fieldsAll)
fields_all = fields.copy()
fields_all.append("id")
fields_all = set(fields_all)
if "docnm" in res.columns:
for field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["docnm"]
if "important_keywords" in res.columns:
if "important_kwd" in fieldsAll:
if "important_kwd" in fields_all:
res["important_kwd"] = res["important_keywords"].apply(lambda v: v.split())
if "important_tks" in fieldsAll:
if "important_tks" in fields_all:
res["important_tks"] = res["important_keywords"]
if "questions" in res.columns:
if "question_kwd" in fieldsAll:
if "question_kwd" in fields_all:
res["question_kwd"] = res["questions"].apply(lambda v: v.splitlines())
if "question_tks" in fieldsAll:
if "question_tks" in fields_all:
res["question_tks"] = res["questions"]
if "content" in res.columns:
for field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["content"]
if "authors" in res.columns:
for field in ["authors_tks", "authors_sm_tks"]:
if field in fieldsAll:
if field in fields_all:
res[field] = res["authors"]
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in fieldsAll if col.lower() in column_map}
none_columns = [col for col in fieldsAll if col.lower() not in column_map]
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
none_columns = [col for col in fields_all if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
@ -819,7 +553,7 @@ class InfinityConnection(DocStoreConnection):
for column in list(res2.columns):
k = column.lower()
if field_keyword(k):
if self.field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
elif re.search(r"_feas$", k):
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
@ -844,95 +578,3 @@ class InfinityConnection(DocStoreConnection):
res2[column] = None
return res2.set_index("id").to_dict(orient="index")
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
num_rows = len(res)
column_id = res["id"]
if fieldnm not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[fieldnm][i]
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
ans[id] = txt
continue
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txts = []
for t in re.split(r"[.?!;\n]", txt):
if is_english([t]):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
else:
for w in sorted(keywords, key=len, reverse=True):
t = re.sub(
re.escape(w),
f"<em>{w}</em>",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txts.append(t)
if txts:
ans[id] = "...".join(txts)
else:
ans[id] = txt
return ans
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""
from collections import Counter
# Extract DataFrame from result
if isinstance(res, tuple):
df, _ = res
else:
df = res
if df.empty or fieldnm not in df.columns:
return []
# Aggregate tag counts
tag_counter = Counter()
for value in df[fieldnm]:
if pd.isna(value) or not value:
continue
# Handle different tag formats
if isinstance(value, str):
# Split by ### for tag_kwd field or comma for other formats
if fieldnm == "tag_kwd" and "###" in value:
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
else:
# Try comma separation as fallback
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
for tag in tags:
if tag: # Only count non-empty tags
tag_counter[tag] += 1
elif isinstance(value, list):
# Handle list format
for tag in value:
if tag and isinstance(tag, str):
tag_counter[tag.strip()] += 1
# Return as list of [tag, count] pairs, sorted by count descending
return [[tag, count] for tag, count in tag_counter.most_common()]
"""
SQL
"""
def sql(sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -37,9 +37,8 @@ from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
from common.float_utils import get_float
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
from rag.nlp import rag_tokenizer
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
MatchDenseExpr
ATTEMPT_TIME = 2
OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))
@ -497,7 +496,7 @@ class OBConnection(DocStoreConnection):
Database operations
"""
def dbType(self) -> str:
def db_type(self) -> str:
return "oceanbase"
def health(self) -> dict:
@ -553,7 +552,7 @@ class OBConnection(DocStoreConnection):
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
vector_field_name = f"q_{vectorSize}_vec"
vector_index_name = f"{vector_field_name}_idx"
@ -604,7 +603,7 @@ class OBConnection(DocStoreConnection):
# always refresh metadata to make sure it contains the latest table structure
self.client.refresh_metadata([indexName])
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
@ -615,7 +614,7 @@ class OBConnection(DocStoreConnection):
except Exception as e:
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
return self._check_table_exists_cached(indexName)
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
@ -1500,7 +1499,7 @@ class OBConnection(DocStoreConnection):
def get_total(self, res) -> int:
return res.total
def get_chunk_ids(self, res) -> list[str]:
def get_doc_ids(self, res) -> list[str]:
return [row["id"] for row in res.chunks]
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:

View File

@ -26,8 +26,7 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
from opensearchpy import ConnectionTimeout
from common.decorator import singleton
from common.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
FusionExpr
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from rag.nlp import is_english, rag_tokenizer
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
@ -79,7 +78,7 @@ class OSConnection(DocStoreConnection):
Database operations
"""
def dbType(self) -> str:
def db_type(self) -> str:
return "opensearch"
def health(self) -> dict:
@ -91,8 +90,8 @@ class OSConnection(DocStoreConnection):
Table operations
"""
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.indexExist(indexName, knowledgebaseId):
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
if self.index_exist(indexName, knowledgebaseId):
return True
try:
from opensearchpy.client import IndicesClient
@ -101,7 +100,7 @@ class OSConnection(DocStoreConnection):
except Exception:
logger.exception("OSConnection.createIndex error %s" % (indexName))
def deleteIdx(self, indexName: str, knowledgebaseId: str):
def delete_idx(self, indexName: str, knowledgebaseId: str):
if len(knowledgebaseId) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
@ -112,7 +111,7 @@ class OSConnection(DocStoreConnection):
except Exception:
logger.exception("OSConnection.deleteIdx error %s" % (indexName))
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
s = Index(indexName, self.os)
for i in range(ATTEMPT_TIME):
try:
@ -460,7 +459,7 @@ class OSConnection(DocStoreConnection):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_chunk_ids(self, res):
def get_doc_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res):

View File

@ -272,6 +272,49 @@ class RedisDB:
self.__open__()
return None
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int:
redis_key = f"{key_prefix}:{namespace}"
try:
# Use pipeline for atomicity
pipe = self.REDIS.pipeline()
# Check if key exists
pipe.exists(redis_key)
# Get/Increment
if ensure_minimum is not None:
# Ensure minimum value
pipe.get(redis_key)
results = pipe.execute()
if results[0] == 0: # Key doesn't exist
start_id = max(1, ensure_minimum)
pipe.set(redis_key, start_id)
pipe.execute()
return start_id
else:
current = int(results[1])
if current < ensure_minimum:
pipe.set(redis_key, ensure_minimum)
pipe.execute()
return ensure_minimum
# Increment operation
next_id = self.REDIS.incrby(redis_key, increment)
# If it's the first time, set a reasonable initial value
if next_id == increment:
self.REDIS.set(redis_key, 1 + increment)
return 1 + increment
return next_id
except Exception as e:
logging.warning("RedisDB.generate_auto_increment_id got exception: " + str(e))
self.__open__()
return -1
def transaction(self, key, value, exp=3600):
try:
pipeline = self.REDIS.pipeline(transaction=True)