# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import os import re import json import time from abc import abstractmethod import infinity from infinity.common import ConflictType from infinity.index import IndexInfo, IndexType from infinity.connection_pool import ConnectionPool from infinity.errors import ErrorCode import pandas as pd from common.file_utils import get_project_base_directory from rag.nlp import is_english from common import settings from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr class InfinityConnectionBase(DocStoreConnection): def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"): self.dbName = settings.INFINITY.get("db_name", "default_db") self.mapping_file_name = mapping_file_name self.logger = logging.getLogger(logger_name) infinity_uri = settings.INFINITY["uri"] if ":" in infinity_uri: host, port = infinity_uri.split(":") infinity_uri = infinity.common.NetworkAddress(host, int(port)) self.connPool = None self.logger.info(f"Use Infinity {infinity_uri} as the doc engine.") for _ in range(24): try: conn_pool = ConnectionPool(infinity_uri, max_size=4) inf_conn = conn_pool.get_conn() res = inf_conn.show_current_node() if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]: self._migrate_db(inf_conn) self.connPool = conn_pool conn_pool.release_conn(inf_conn) break conn_pool.release_conn(inf_conn) self.logger.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.") time.sleep(5) except Exception as e: self.logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.") time.sleep(5) if self.connPool is None: msg = f"Infinity {infinity_uri} is unhealthy in 120s." self.logger.error(msg) raise Exception(msg) self.logger.info(f"Infinity {infinity_uri} is healthy.") def _migrate_db(self, inf_conn): inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name) if not os.path.exists(fp_mapping): raise Exception(f"Mapping file not found at {fp_mapping}") schema = json.load(open(fp_mapping)) table_names = inf_db.list_tables().table_names for table_name in table_names: inf_table = inf_db.get_table(table_name) index_names = inf_table.list_indexes().index_names if "q_vec_idx" not in index_names: # Skip tables not created by me continue column_names = inf_table.show_columns()["name"] column_names = set(column_names) for field_name, field_info in schema.items(): if field_name in column_names: continue res = inf_table.add_columns({field_name: field_info}) assert res.error_code == infinity.ErrorCode.OK self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}") if field_info["type"] != "varchar" or "analyzer" not in field_info: continue analyzers = field_info["analyzer"] if isinstance(analyzers, str): analyzers = [analyzers] for analyzer in analyzers: inf_table.create_index( f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}", IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}), ConflictType.Ignore, ) """ Dataframe and fields convert """ @staticmethod @abstractmethod def field_keyword(field_name: str): # judge keyword or not, such as "*_kwd" tag-like columns. raise NotImplementedError("Not implemented") @abstractmethod def convert_select_fields(self, output_fields: list[str]) -> list[str]: # rm _kwd, _tks, _sm_tks, _with_weight suffix in field name. raise NotImplementedError("Not implemented") @staticmethod @abstractmethod def convert_matching_field(field_weight_str: str) -> str: # convert matching field to raise NotImplementedError("Not implemented") @staticmethod def list2str(lst: str | list, sep: str = " ") -> str: if isinstance(lst, str): return lst return sep.join(lst) def equivalent_condition_to_str(self, condition: dict, table_instance=None) -> str | None: assert "_id" not in condition columns = {} if table_instance: for n, ty, de, _ in table_instance.show_columns().rows(): columns[n] = (ty, de) def exists(cln): nonlocal columns assert cln in columns, f"'{cln}' should be in '{columns}'." ty, de = columns[cln] if ty.lower().find("cha"): if not de: de = "" return f" {cln}!='{de}' " return f"{cln}!={de}" cond = list() for k, v in condition.items(): if not isinstance(k, str) or not v: continue if self.field_keyword(k): if isinstance(v, list): inCond = list() for item in v: if isinstance(item, str): item = item.replace("'", "''") inCond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{item}')") if inCond: strInCond = " or ".join(inCond) strInCond = f"({strInCond})" cond.append(strInCond) else: cond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{v}')") elif isinstance(v, list): inCond = list() for item in v: if isinstance(item, str): item = item.replace("'", "''") inCond.append(f"'{item}'") else: inCond.append(str(item)) if inCond: strInCond = ", ".join(inCond) strInCond = f"{k} IN ({strInCond})" cond.append(strInCond) elif k == "must_not": if isinstance(v, dict): for kk, vv in v.items(): if kk == "exists": cond.append("NOT (%s)" % exists(vv)) elif isinstance(v, str): cond.append(f"{k}='{v}'") elif k == "exists": cond.append(exists(v)) else: cond.append(f"{k}={str(v)}") return " AND ".join(cond) if cond else "1=1" @staticmethod def concat_dataframes(df_list: list[pd.DataFrame], select_fields: list[str]) -> pd.DataFrame: df_list2 = [df for df in df_list if not df.empty] if df_list2: return pd.concat(df_list2, axis=0).reset_index(drop=True) schema = [] for field_name in select_fields: if field_name == "score()": # Workaround: fix schema is changed to score() schema.append("SCORE") elif field_name == "similarity()": # Workaround: fix schema is changed to similarity() schema.append("SIMILARITY") else: schema.append(field_name) return pd.DataFrame(columns=schema) """ Database operations """ def db_type(self) -> str: return "infinity" def health(self) -> dict: """ Return the health status of the database. """ inf_conn = self.connPool.get_conn() res = inf_conn.show_current_node() self.connPool.release_conn(inf_conn) res2 = { "type": "infinity", "status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red", "error": res.error_msg, } return res2 """ Table operations """ def create_idx(self, index_name: str, dataset_id: str, vector_size: int): table_name = f"{index_name}_{dataset_id}" inf_conn = self.connPool.get_conn() inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name) if not os.path.exists(fp_mapping): raise Exception(f"Mapping file not found at {fp_mapping}") schema = json.load(open(fp_mapping)) vector_name = f"q_{vector_size}_vec" schema[vector_name] = {"type": f"vector,{vector_size},float"} inf_table = inf_db.create_table( table_name, schema, ConflictType.Ignore, ) inf_table.create_index( "q_vec_idx", IndexInfo( vector_name, IndexType.Hnsw, { "M": "16", "ef_construction": "50", "metric": "cosine", "encode": "lvq", }, ), ConflictType.Ignore, ) for field_name, field_info in schema.items(): if field_info["type"] != "varchar" or "analyzer" not in field_info: continue analyzers = field_info["analyzer"] if isinstance(analyzers, str): analyzers = [analyzers] for analyzer in analyzers: inf_table.create_index( f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}", IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}), ConflictType.Ignore, ) self.connPool.release_conn(inf_conn) self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}") return True def delete_idx(self, index_name: str, dataset_id: str): table_name = f"{index_name}_{dataset_id}" inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) db_instance.drop_table(table_name, ConflictType.Ignore) self.connPool.release_conn(inf_conn) self.logger.info(f"INFINITY dropped table {table_name}") def index_exist(self, index_name: str, dataset_id: str) -> bool: table_name = f"{index_name}_{dataset_id}" try: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) _ = db_instance.get_table(table_name) self.connPool.release_conn(inf_conn) return True except Exception as e: self.logger.warning(f"INFINITY indexExist {str(e)}") return False """ CRUD operations """ @abstractmethod def search( self, select_fields: list[str], highlight_fields: list[str], condition: dict, match_expressions: list[MatchExpr], order_by: OrderByExpr, offset: int, limit: int, index_names: str | list[str], dataset_ids: list[str], agg_fields: list[str] | None = None, rank_feature: dict | None = None, ) -> tuple[pd.DataFrame, int]: raise NotImplementedError("Not implemented") @abstractmethod def get(self, doc_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None: raise NotImplementedError("Not implemented") @abstractmethod def insert(self, documents: list[dict], index_name: str, dataset_ids: str = None) -> list[str]: raise NotImplementedError("Not implemented") @abstractmethod def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool: raise NotImplementedError("Not implemented") def delete(self, condition: dict, index_name: str, dataset_id: str) -> int: inf_conn = self.connPool.get_conn() db_instance = inf_conn.get_database(self.dbName) table_name = f"{index_name}_{dataset_id}" try: table_instance = db_instance.get_table(table_name) except Exception: self.logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.") return 0 filter = self.equivalent_condition_to_str(condition, table_instance) self.logger.debug(f"INFINITY delete table {table_name}, filter {filter}.") res = table_instance.delete(filter) self.connPool.release_conn(inf_conn) return res.deleted_rows """ Helper functions for search result """ def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: if isinstance(res, tuple): return res[1] return len(res) def get_doc_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: if isinstance(res, tuple): res = res[0] return list(res["id"]) @abstractmethod def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: raise NotImplementedError("Not implemented") def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], field_name: str): if isinstance(res, tuple): res = res[0] ans = {} num_rows = len(res) column_id = res["id"] if field_name not in res: return {} for i in range(num_rows): id = column_id[i] txt = res[field_name][i] if re.search(r"[^<>]+", txt, flags=re.IGNORECASE | re.MULTILINE): ans[id] = txt continue txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE) txt_list = [] for t in re.split(r"[.?!;\n]", txt): if is_english([t]): for w in keywords: t = re.sub( r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1\2\3", t, flags=re.IGNORECASE | re.MULTILINE, ) else: for w in sorted(keywords, key=len, reverse=True): t = re.sub( re.escape(w), f"{w}", t, flags=re.IGNORECASE | re.MULTILINE, ) if not re.search(r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE): continue txt_list.append(t) if txt_list: ans[id] = "...".join(txt_list) else: ans[id] = txt return ans def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_name: str): """ Manual aggregation for tag fields since Infinity doesn't provide native aggregation """ from collections import Counter # Extract DataFrame from result if isinstance(res, tuple): df, _ = res else: df = res if df.empty or field_name not in df.columns: return [] # Aggregate tag counts tag_counter = Counter() for value in df[field_name]: if pd.isna(value) or not value: continue # Handle different tag formats if isinstance(value, str): # Split by ### for tag_kwd field or comma for other formats if field_name == "tag_kwd" and "###" in value: tags = [tag.strip() for tag in value.split("###") if tag.strip()] else: # Try comma separation as fallback tags = [tag.strip() for tag in value.split(",") if tag.strip()] for tag in tags: if tag: # Only count non-empty tags tag_counter[tag] += 1 elif isinstance(value, list): # Handle list format for tag in value: if tag and isinstance(tag, str): tag_counter[tag.strip()] += 1 # Return as list of [tag, count] pairs, sorted by count descending return [[tag, count] for tag, count in tag_counter.most_common()] """ SQL """ def sql(self, sql: str, fetch_size: int, format: str): raise NotImplementedError("Not implemented")