Feat: message manage (#12196)

### What problem does this PR solve?

Manage message and use in agent.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Lynn
2025-12-25 21:18:13 +08:00
committed by GitHub
parent fd53b83190
commit 6e9691a419
54 changed files with 3715 additions and 1039 deletions

View File

View File

@ -0,0 +1,270 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
values: list[float] | list[int] | None = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
def to_dict_old(self):
d = {"indices": self.indices}
if self.values is not None:
d["values"] = self.values
return d
def to_dict(self):
if self.values is None:
raise ValueError("SparseVector.values is None")
result = {}
for i, v in zip(self.indices, self.values):
result[str(i)] = v
return result
@staticmethod
def from_dict(d):
return SparseVector(d["indices"], d.get("values"))
def __str__(self):
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
def __repr__(self):
return str(self)
class MatchTextExpr:
def __init__(
self,
fields: list[str],
matching_text: str,
topn: int,
extra_options: dict | None = None,
):
self.fields = fields
self.matching_text = matching_text
self.topn = topn
self.extra_options = extra_options
class MatchDenseExpr:
def __init__(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict | None = None,
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
self.embedding_data_type = embedding_data_type
self.distance_type = distance_type
self.topn = topn
self.extra_options = extra_options
class MatchSparseExpr:
def __init__(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: dict | None = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
self.distance_type = distance_type
self.topn = topn
self.opt_params = opt_params
class MatchTensorExpr:
def __init__(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: dict | None = None,
):
self.column_name = column_name
self.query_data = query_data
self.query_data_type = query_data_type
self.topn = topn
self.extra_option = extra_option
class FusionExpr:
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr:
def __init__(self):
self.fields = list()
def asc(self, field: str):
self.fields.append((field, 0))
return self
def desc(self, field: str):
self.fields.append((field, 1))
return self
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def db_type(self) -> str:
"""
Return the type of the database.
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def health(self) -> dict:
"""
Return the health status of the database.
"""
raise NotImplementedError("Not implemented")
"""
Table operations
"""
@abstractmethod
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete_idx(self, index_name: str, dataset_id: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def index_exist(self, index_name: str, dataset_id: str) -> bool:
"""
Check if an index with given name exists
"""
raise NotImplementedError("Not implemented")
"""
CRUD operations
"""
@abstractmethod
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str|list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
"""
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, data_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
@abstractmethod
def get_total(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_doc_ids(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def get_highlight(self, res, keywords: list[str], field_name: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def get_aggregation(self, res, field_name: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(self, sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""
raise NotImplementedError("Not implemented")

View File

@ -0,0 +1,307 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import json
import time
import os
from abc import abstractmethod
from elasticsearch import NotFoundError
from elasticsearch_dsl import Index
from elastic_transport import ConnectionTimeout
from common.file_utils import get_project_base_directory
from common.misc_utils import convert_bytes
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
from rag.nlp import is_english, rag_tokenizer
from common import settings
ATTEMPT_TIME = 2
class ESConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'):
from common.doc_store.es_conn_pool import ES_CONN
self.logger = logging.getLogger(logger_name)
self.info = {}
self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
self.es = ES_CONN.get_conn()
fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name)
if not os.path.exists(fp_mapping):
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
self.logger.error(msg)
raise Exception(msg)
self.mapping = json.load(open(fp_mapping, "r"))
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
def _connect(self):
from common.doc_store.es_conn_pool import ES_CONN
if self.es.ping():
return True
self.es = ES_CONN.refresh_conn()
return True
"""
Database operations
"""
def db_type(self) -> str:
return "elasticsearch"
def health(self) -> dict:
health_dict = dict(self.es.cluster.health())
health_dict["type"] = "elasticsearch"
return health_dict
def get_cluster_stats(self):
"""
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
"""
raw_stats = self.es.cluster.stats()
self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
try:
res = {
'cluster_name': raw_stats['cluster_name'],
'status': raw_stats['status']
}
indices_status = raw_stats['indices']
res.update({
'indices': indices_status['count'],
'indices_shards': indices_status['shards']['total']
})
doc_info = indices_status['docs']
res.update({
'docs': doc_info['count'],
'docs_deleted': doc_info['deleted']
})
store_info = indices_status['store']
res.update({
'store_size': convert_bytes(store_info['size_in_bytes']),
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
})
mappings_info = indices_status['mappings']
res.update({
'mappings_fields': mappings_info['total_field_count'],
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
})
node_info = raw_stats['nodes']
res.update({
'nodes': node_info['count']['total'],
'nodes_version': node_info['versions'],
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
})
return res
except Exception as e:
self.logger.exception(f"ESConnection.get_cluster_stats: {e}")
return None
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
if self.index_exist(index_name, dataset_id):
return True
try:
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=index_name,
settings=self.mapping["settings"],
mappings=self.mapping["mappings"])
except Exception:
self.logger.exception("ESConnection.createIndex error %s" % index_name)
def delete_idx(self, index_name: str, dataset_id: str):
if len(dataset_id) > 0:
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
return
try:
self.es.indices.delete(index=index_name, allow_no_indices=True)
except NotFoundError:
pass
except Exception:
self.logger.exception("ESConnection.deleteIdx error %s" % index_name)
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
s = Index(index_name, self.es)
for i in range(ATTEMPT_TIME):
try:
return s.exists()
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(e)
break
return False
"""
CRUD operations
"""
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
for i in range(ATTEMPT_TIME):
try:
res = self.es.get(index=index_name,
id=doc_id, source=True, )
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
doc = res["_source"]
doc["id"] = doc_id
return doc
except NotFoundError:
return None
except Exception as e:
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
raise e
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
raise Exception("ESConnection.get timeout.")
@abstractmethod
def search(
self, select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None
):
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
def get_total(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def get_doc_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def _get_source(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
@abstractmethod
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res, keywords: list[str], field_name: str):
ans = {}
for d in res["hits"]["hits"]:
highlights = d.get("highlight")
if not highlights:
continue
txt = "...".join([a for a in list(highlights.items())[0][1]])
if not is_english(txt.split()):
ans[d["_id"]] = txt
continue
txt = d["_source"][field_name]
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
for w in keywords:
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
flags=re.IGNORECASE | re.MULTILINE)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]])
return ans
def get_aggregation(self, res, field_name: str):
agg_field = "aggs_" + field_name
if "aggregations" not in res or agg_field not in res["aggregations"]:
return list()
buckets = res["aggregations"][agg_field]["buckets"]
return [(b["key"], b["doc_count"]) for b in buckets]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
self.logger.debug(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "")
replaces = []
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
fld, v = r.group(1), r.group(3)
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
replaces.append(
("{}{}'{}'".format(
r.group(1),
r.group(2),
r.group(3)),
match))
for p, r in replaces:
sql = sql.replace(p, r, 1)
self.logger.debug(f"ESConnection.sql to es: {sql}")
for i in range(ATTEMPT_TIME):
try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
request_timeout="2s")
return res
except ConnectionTimeout:
self.logger.exception("ES request timeout")
time.sleep(3)
self._connect()
continue
except Exception as e:
self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
return None

View File

@ -0,0 +1,84 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import time
from elasticsearch import Elasticsearch
from common import settings
from common.decorator import singleton
ATTEMPT_TIME = 2
@singleton
class ElasticSearchConnectionPool:
def __init__(self):
if hasattr(settings, "ES"):
self.ES_CONFIG = settings.ES
else:
self.ES_CONFIG = settings.get_base_config("es", {})
for _ in range(ATTEMPT_TIME):
try:
if self._connect():
break
except Exception as e:
logging.warning(f"{str(e)}. Waiting Elasticsearch {self.ES_CONFIG['hosts']} to be healthy.")
time.sleep(5)
if not hasattr(self, "es_conn") or not self.es_conn or not self.es_conn.ping():
msg = f"Elasticsearch {self.ES_CONFIG['hosts']} is unhealthy in 10s."
logging.error(msg)
raise Exception(msg)
v = self.info.get("version", {"number": "8.11.3"})
v = v["number"].split(".")[0]
if int(v) < 8:
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
logging.error(msg)
raise Exception(msg)
def _connect(self):
self.es_conn = Elasticsearch(
self.ES_CONFIG["hosts"].split(","),
basic_auth=(self.ES_CONFIG["username"], self.ES_CONFIG[
"password"]) if "username" in self.ES_CONFIG and "password" in self.ES_CONFIG else None,
verify_certs= self.ES_CONFIG.get("verify_certs", False),
timeout=600 )
if self.es_conn:
self.info = self.es_conn.info()
return True
return False
def get_conn(self):
return self.es_conn
def refresh_conn(self):
if self.es_conn.ping():
return self.es_conn
else:
# close current if exist
if self.es_conn:
self.es_conn.close()
self._connect()
return self.es_conn
def __del__(self):
if hasattr(self, "es_conn") and self.es_conn:
self.es_conn.close()
ES_CONN = ElasticSearchConnectionPool()

View File

@ -0,0 +1,453 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import re
import json
import time
from abc import abstractmethod
import infinity
from infinity.common import ConflictType
from infinity.index import IndexInfo, IndexType
from infinity.errors import ErrorCode
import pandas as pd
from common.file_utils import get_project_base_directory
from rag.nlp import is_english
from common import settings
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
class InfinityConnectionBase(DocStoreConnection):
def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"):
from common.doc_store.infinity_conn_pool import INFINITY_CONN
self.dbName = settings.INFINITY.get("db_name", "default_db")
self.mapping_file_name = mapping_file_name
self.logger = logging.getLogger(logger_name)
infinity_uri = settings.INFINITY["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = None
self.logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
conn_pool = INFINITY_CONN.get_conn_pool()
for _ in range(24):
try:
inf_conn = conn_pool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self._migrate_db(inf_conn)
self.connPool = conn_pool
conn_pool.release_conn(inf_conn)
break
conn_pool.release_conn(inf_conn)
self.logger.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
except Exception as e:
conn_pool = INFINITY_CONN.refresh_conn_pool()
self.logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
self.logger.error(msg)
raise Exception(msg)
self.logger.info(f"Infinity {infinity_uri} is healthy.")
def _migrate_db(self, inf_conn):
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
table_names = inf_db.list_tables().table_names
for table_name in table_names:
inf_table = inf_db.get_table(table_name)
index_names = inf_table.list_indexes().index_names
if "q_vec_idx" not in index_names:
# Skip tables not created by me
continue
column_names = inf_table.show_columns()["name"]
column_names = set(column_names)
for field_name, field_info in schema.items():
if field_name in column_names:
continue
res = inf_table.add_columns({field_name: field_info})
assert res.error_code == infinity.ErrorCode.OK
self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
"""
Dataframe and fields convert
"""
@staticmethod
@abstractmethod
def field_keyword(field_name: str):
# judge keyword or not, such as "*_kwd" tag-like columns.
raise NotImplementedError("Not implemented")
@abstractmethod
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
# rm _kwd, _tks, _sm_tks, _with_weight suffix in field name.
raise NotImplementedError("Not implemented")
@staticmethod
@abstractmethod
def convert_matching_field(field_weight_str: str) -> str:
# convert matching field to
raise NotImplementedError("Not implemented")
@staticmethod
def list2str(lst: str | list, sep: str = " ") -> str:
if isinstance(lst, str):
return lst
return sep.join(lst)
def equivalent_condition_to_str(self, condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
columns = {}
if table_instance:
for n, ty, de, _ in table_instance.show_columns().rows():
columns[n] = (ty, de)
def exists(cln):
nonlocal columns
assert cln in columns, f"'{cln}' should be in '{columns}'."
ty, de = columns[cln]
if ty.lower().find("cha"):
if not de:
de = ""
return f" {cln}!='{de}' "
return f"{cln}!={de}"
cond = list()
for k, v in condition.items():
if not isinstance(k, str) or not v:
continue
if self.field_keyword(k):
if isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"'{item}'")
else:
inCond.append(str(item))
if inCond:
strInCond = ", ".join(inCond)
strInCond = f"{k} IN ({strInCond})"
cond.append(strInCond)
elif k == "must_not":
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "exists":
cond.append("NOT (%s)" % exists(vv))
elif isinstance(v, str):
cond.append(f"{k}='{v}'")
elif k == "exists":
cond.append(exists(v))
else:
cond.append(f"{k}={str(v)}")
return " AND ".join(cond) if cond else "1=1"
@staticmethod
def concat_dataframes(df_list: list[pd.DataFrame], select_fields: list[str]) -> pd.DataFrame:
df_list2 = [df for df in df_list if not df.empty]
if df_list2:
return pd.concat(df_list2, axis=0).reset_index(drop=True)
schema = []
for field_name in select_fields:
if field_name == "score()": # Workaround: fix schema is changed to score()
schema.append("SCORE")
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
schema.append("SIMILARITY")
else:
schema.append(field_name)
return pd.DataFrame(columns=schema)
"""
Database operations
"""
def db_type(self) -> str:
return "infinity"
def health(self) -> dict:
"""
Return the health status of the database.
"""
inf_conn = self.connPool.get_conn()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn)
res2 = {
"type": "infinity",
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
"error": res.error_msg,
}
return res2
"""
Table operations
"""
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
if not os.path.exists(fp_mapping):
raise Exception(f"Mapping file not found at {fp_mapping}")
schema = json.load(open(fp_mapping))
vector_name = f"q_{vector_size}_vec"
schema[vector_name] = {"type": f"vector,{vector_size},float"}
inf_table = inf_db.create_table(
table_name,
schema,
ConflictType.Ignore,
)
inf_table.create_index(
"q_vec_idx",
IndexInfo(
vector_name,
IndexType.Hnsw,
{
"M": "16",
"ef_construction": "50",
"metric": "cosine",
"encode": "lvq",
},
),
ConflictType.Ignore,
)
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}")
return True
def delete_idx(self, index_name: str, dataset_id: str):
table_name = f"{index_name}_{dataset_id}"
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn)
self.logger.info(f"INFINITY dropped table {table_name}")
def index_exist(self, index_name: str, dataset_id: str) -> bool:
table_name = f"{index_name}_{dataset_id}"
try:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
_ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
self.logger.warning(f"INFINITY indexExist {str(e)}")
return False
"""
CRUD operations
"""
@abstractmethod
def search(
self,
select_fields: list[str],
highlight_fields: list[str],
condition: dict,
match_expressions: list[MatchExpr],
order_by: OrderByExpr,
offset: int,
limit: int,
index_names: str | list[str],
dataset_ids: list[str],
agg_fields: list[str] | None = None,
rank_feature: dict | None = None,
) -> tuple[pd.DataFrame, int]:
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, doc_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, documents: list[dict], index_name: str, dataset_ids: str = None) -> list[str]:
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
raise NotImplementedError("Not implemented")
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
table_name = f"{index_name}_{dataset_id}"
try:
table_instance = db_instance.get_table(table_name)
except Exception:
self.logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
return 0
filter = self.equivalent_condition_to_str(condition, table_instance)
self.logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
res = table_instance.delete(filter)
self.connPool.release_conn(inf_conn)
return res.deleted_rows
"""
Helper functions for search result
"""
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple):
return res[1]
return len(res)
def get_doc_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple):
res = res[0]
return list(res["id"])
@abstractmethod
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], field_name: str):
if isinstance(res, tuple):
res = res[0]
ans = {}
num_rows = len(res)
column_id = res["id"]
if field_name not in res:
return {}
for i in range(num_rows):
id = column_id[i]
txt = res[field_name][i]
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
ans[id] = txt
continue
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
txt_list = []
for t in re.split(r"[.?!;\n]", txt):
if is_english([t]):
for w in keywords:
t = re.sub(
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
r"\1<em>\2</em>\3",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
else:
for w in sorted(keywords, key=len, reverse=True):
t = re.sub(
re.escape(w),
f"<em>{w}</em>",
t,
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
continue
txt_list.append(t)
if txt_list:
ans[id] = "...".join(txt_list)
else:
ans[id] = txt
return ans
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_name: str):
"""
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
"""
from collections import Counter
# Extract DataFrame from result
if isinstance(res, tuple):
df, _ = res
else:
df = res
if df.empty or field_name not in df.columns:
return []
# Aggregate tag counts
tag_counter = Counter()
for value in df[field_name]:
if pd.isna(value) or not value:
continue
# Handle different tag formats
if isinstance(value, str):
# Split by ### for tag_kwd field or comma for other formats
if field_name == "tag_kwd" and "###" in value:
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
else:
# Try comma separation as fallback
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
for tag in tags:
if tag: # Only count non-empty tags
tag_counter[tag] += 1
elif isinstance(value, list):
# Handle list format
for tag in value:
if tag and isinstance(tag, str):
tag_counter[tag.strip()] += 1
# Return as list of [tag, count] pairs, sorted by count descending
return [[tag, count] for tag, count in tag_counter.most_common()]
"""
SQL
"""
def sql(self, sql: str, fetch_size: int, format: str):
raise NotImplementedError("Not implemented")

View File

@ -0,0 +1,85 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import time
import infinity
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode
from common import settings
from common.decorator import singleton
@singleton
class InfinityConnectionPool:
def __init__(self):
if hasattr(settings, "INFINITY"):
self.INFINITY_CONFIG = settings.INFINITY
else:
self.INFINITY_CONFIG = settings.get_base_config("infinity", {"uri": "infinity:23817"})
infinity_uri = self.INFINITY_CONFIG["uri"]
if ":" in infinity_uri:
host, port = infinity_uri.split(":")
self.infinity_uri = infinity.common.NetworkAddress(host, int(port))
for _ in range(24):
try:
conn_pool = ConnectionPool(self.infinity_uri, max_size=4)
inf_conn = conn_pool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
self.conn_pool = conn_pool
conn_pool.release_conn(inf_conn)
break
except Exception as e:
logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.conn_pool is None:
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
logging.error(msg)
raise Exception(msg)
logging.info(f"Infinity {infinity_uri} is healthy.")
def get_conn_pool(self):
return self.conn_pool
def refresh_conn_pool(self):
try:
inf_conn = self.conn_pool.get_conn()
res = inf_conn.show_current_node()
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
return self.conn_pool
else:
raise Exception(f"{res.error_code}: {res.server_status}")
except Exception as e:
logging.error(str(e))
if hasattr(self, "conn_pool") and self.conn_pool:
self.conn_pool.destroy()
self.conn_pool = ConnectionPool(self.infinity_uri, max_size=32)
return self.conn_pool
def __del__(self):
if hasattr(self, "conn_pool") and self.conn_pool:
self.conn_pool.destroy()
INFINITY_CONN = InfinityConnectionPool()

72
common/query_base.py Normal file
View File

@ -0,0 +1,72 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from abc import ABC, abstractmethod
class QueryBase(ABC):
@staticmethod
def is_chinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
@abstractmethod
def question(self, text, tbl, min_match):
"""
Returns a query object based on the input text, table, and minimum match criteria.
"""
raise NotImplementedError("Not implemented")

View File

@ -39,6 +39,9 @@ from rag.utils.oss_conn import RAGFlowOSS
from rag.nlp import search
import memory.utils.es_conn as memory_es_conn
import memory.utils.infinity_conn as memory_infinity_conn
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
@ -79,6 +82,7 @@ DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
docStoreConn = None
msgStoreConn = None
retriever = None
kg_retriever = None
@ -256,6 +260,15 @@ def init_settings():
else:
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
global msgStoreConn
# use the same engine for message store
if DOC_ENGINE == "elasticsearch":
ES = get_base_config("es", {})
msgStoreConn = memory_es_conn.ESConnection()
elif DOC_ENGINE == "infinity":
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
msgStoreConn = memory_infinity_conn.InfinityConnection()
global AZURE, S3, MINIO, OSS, GCS
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
AZURE = get_base_config("azure", {})

View File

@ -14,6 +14,7 @@
# limitations under the License.
import datetime
import logging
import time
def current_timestamp():
@ -123,4 +124,31 @@ def delta_seconds(date_string: str):
3600.0 # If current time is 2024-01-01 13:00:00
"""
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
return (datetime.datetime.now() - dt).total_seconds()
return (datetime.datetime.now() - dt).total_seconds()
def format_iso_8601_to_ymd_hms(time_str: str) -> str:
"""
Convert ISO 8601 formatted string to "YYYY-MM-DD HH:MM:SS" format.
Args:
time_str: ISO 8601 date string (e.g. "2024-01-01T12:00:00Z")
Returns:
str: Date string in "YYYY-MM-DD HH:MM:SS" format
Example:
>>> format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z")
'2024-01-01 12:00:00'
"""
from dateutil import parser
try:
if parser.isoparse(time_str):
dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00"))
return dt.strftime("%Y-%m-%d %H:%M:%S")
else:
return time_str
except Exception as e:
logging.error(str(e))
return time_str