diff --git a/common/doc_store/ob_conn_base.py b/common/doc_store/ob_conn_base.py new file mode 100644 index 000000000..0b95770ca --- /dev/null +++ b/common/doc_store/ob_conn_base.py @@ -0,0 +1,739 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging +import os +import re +import threading +import time +from abc import abstractmethod +from typing import Any + +from pymysql.converters import escape_string +from pyobvector import ObVecClient, FtsIndexParam, FtsParser, VECTOR +from sqlalchemy import Column, Table + +from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr + +ATTEMPT_TIME = 2 + +# Common templates for OceanBase +index_name_template = "ix_%s_%s" +fulltext_index_name_template = "fts_idx_%s" +fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)" +vector_search_template = "cosine_distance(%s, '%s')" +vector_column_pattern = re.compile(r"q_(?P\d+)_vec") + + +def get_value_str(value: Any) -> str: + """Convert value to SQL string representation.""" + if isinstance(value, str): + # escape_string already handles all necessary escaping for MySQL/OceanBase + # including backslashes, quotes, newlines, etc. + return f"'{escape_string(value)}'" + elif isinstance(value, bool): + return "true" if value else "false" + elif value is None: + return "NULL" + elif isinstance(value, (list, dict)): + json_str = json.dumps(value, ensure_ascii=False) + return f"'{escape_string(json_str)}'" + else: + return str(value) + + +def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None): + """Execute function with distributed lock.""" + if not timeout: + timeout = int(os.environ.get("OB_DDL_TIMEOUT", "60")) + + if not check_func(): + from rag.utils.redis_conn import RedisDistributedLock + lock = RedisDistributedLock(lock_name) + if lock.acquire(): + try: + process_func() + return + except Exception as e: + if "Duplicate" in str(e): + return + raise + finally: + lock.release() + + if not check_func(): + time.sleep(1) + count = 1 + while count < timeout and not check_func(): + count += 1 + time.sleep(1) + if count >= timeout and not check_func(): + raise Exception(f"Timeout to wait for process complete for {lock_name}.") + + +class OBConnectionBase(DocStoreConnection): + """Base class for OceanBase document store connections.""" + + def __init__(self, logger_name: str = 'ragflow.ob_conn'): + from common.doc_store.ob_conn_pool import OB_CONN + + self.logger = logging.getLogger(logger_name) + self.client: ObVecClient = OB_CONN.get_client() + self.es = OB_CONN.get_hybrid_search_client() + self.db_name = OB_CONN.get_db_name() + self.uri = OB_CONN.get_uri() + + self._load_env_vars() + + self._table_exists_cache: set[str] = set() + self._table_exists_cache_lock = threading.RLock() + + # Cache for vector columns: stores (table_name, vector_size) tuples + self._vector_column_cache: set[tuple[str, int]] = set() + self._vector_column_cache_lock = threading.RLock() + + self.logger.info(f"OceanBase {self.uri} connection initialized.") + + def _load_env_vars(self): + def is_true(var: str, default: str) -> bool: + return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y'] + + self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true') + self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true') + self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true') + self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false') + self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true') + + # Adjust settings based on hybrid search availability + if self.es is not None and self.search_original_content: + self.logger.info("HybridSearch is enabled, forcing search_original_content to False") + self.search_original_content = False + + """ + Template methods - must be implemented by subclasses + """ + + @abstractmethod + def get_index_columns(self) -> list[str]: + """Return list of column names that need regular indexes.""" + raise NotImplementedError("Not implemented") + + @abstractmethod + def get_fulltext_columns(self) -> list[str]: + """Return list of column names that need fulltext indexes (without weight suffix).""" + raise NotImplementedError("Not implemented") + + @abstractmethod + def get_column_definitions(self) -> list[Column]: + """Return list of column definitions for table creation.""" + raise NotImplementedError("Not implemented") + + def get_extra_columns(self) -> list[Column]: + """Return list of extra columns to add after table creation. Override if needed.""" + return [] + + def get_table_name(self, index_name: str, dataset_id: str) -> str: + """Return the actual table name given index_name and dataset_id.""" + return index_name + + @abstractmethod + def get_lock_prefix(self) -> str: + """Return the lock name prefix for distributed locking.""" + raise NotImplementedError("Not implemented") + + """ + Database operations + """ + + def db_type(self) -> str: + return "oceanbase" + + def health(self) -> dict: + return { + "uri": self.uri, + "version_comment": self._get_variable_value("version_comment") + } + + def _get_variable_value(self, var_name: str) -> Any: + rows = self.client.perform_raw_text_sql(f"SHOW VARIABLES LIKE '{var_name}'") + for row in rows: + return row[1] + raise Exception(f"Variable '{var_name}' not found.") + + """ + Table operations - common implementation using template methods + """ + + def _check_table_exists_cached(self, table_name: str) -> bool: + """ + Check table existence with cache to reduce INFORMATION_SCHEMA queries. + Thread-safe implementation using RLock. + """ + if table_name in self._table_exists_cache: + return True + + try: + if not self.client.check_table_exists(table_name): + return False + + # Check regular indexes + for column_name in self.get_index_columns(): + if not self._index_exists(table_name, index_name_template % (table_name, column_name)): + return False + + # Check fulltext indexes + for column_name in self.get_fulltext_columns(): + if not self._index_exists(table_name, fulltext_index_name_template % column_name): + return False + + # Check extra columns + for column in self.get_extra_columns(): + if not self._column_exist(table_name, column.name): + return False + + except Exception as e: + raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}") + + with self._table_exists_cache_lock: + if table_name not in self._table_exists_cache: + self._table_exists_cache.add(table_name) + return True + + def _create_table(self, table_name: str): + """Create table using column definitions from subclass.""" + self._create_table_with_columns(table_name, self.get_column_definitions()) + + def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None): + """Create index/table with all necessary indexes.""" + table_name = self.get_table_name(index_name, dataset_id) + lock_prefix = self.get_lock_prefix() + + try: + _try_with_lock( + lock_name=f"{lock_prefix}create_table_{table_name}", + check_func=lambda: self.client.check_table_exists(table_name), + process_func=lambda: self._create_table(table_name), + ) + + for column_name in self.get_index_columns(): + _try_with_lock( + lock_name=f"{lock_prefix}add_idx_{table_name}_{column_name}", + check_func=lambda cn=column_name: self._index_exists(table_name, + index_name_template % (table_name, cn)), + process_func=lambda cn=column_name: self._add_index(table_name, cn), + ) + + for column_name in self.get_fulltext_columns(): + _try_with_lock( + lock_name=f"{lock_prefix}add_fulltext_idx_{table_name}_{column_name}", + check_func=lambda cn=column_name: self._index_exists(table_name, fulltext_index_name_template % cn), + process_func=lambda cn=column_name: self._add_fulltext_index(table_name, cn), + ) + + # Add vector column and index (skip metadata refresh, will be done in finally) + self._ensure_vector_column_exists(table_name, vector_size, refresh_metadata=False) + + # Add extra columns if any + for column in self.get_extra_columns(): + _try_with_lock( + lock_name=f"{lock_prefix}add_{column.name}_{table_name}", + check_func=lambda c=column: self._column_exist(table_name, c.name), + process_func=lambda c=column: self._add_column(table_name, c), + ) + + except Exception as e: + raise Exception(f"OBConnection.create_idx error: {str(e)}") + finally: + self.client.refresh_metadata([table_name]) + + def create_doc_meta_idx(self, index_name: str): + """ + Create a document metadata table. + + Table name pattern: ragflow_doc_meta_{tenant_id} + - Per-tenant metadata table for storing document metadata fields + """ + from sqlalchemy import JSON + from sqlalchemy.dialects.mysql import VARCHAR + + table_name = index_name + lock_prefix = self.get_lock_prefix() + + # Define columns for document metadata table + doc_meta_columns = [ + Column("id", VARCHAR(256), primary_key=True, comment="document id"), + Column("kb_id", VARCHAR(256), nullable=False, comment="knowledge base id"), + Column("meta_fields", JSON, nullable=True, comment="document metadata fields"), + ] + + try: + # Create table with distributed lock + _try_with_lock( + lock_name=f"{lock_prefix}create_doc_meta_table_{table_name}", + check_func=lambda: self.client.check_table_exists(table_name), + process_func=lambda: self._create_table_with_columns(table_name, doc_meta_columns), + ) + + # Create index on kb_id for better query performance + _try_with_lock( + lock_name=f"{lock_prefix}add_idx_{table_name}_kb_id", + check_func=lambda: self._index_exists(table_name, index_name_template % (table_name, "kb_id")), + process_func=lambda: self._add_index(table_name, "kb_id"), + ) + + self.logger.info(f"Created document metadata table '{table_name}'.") + return True + + except Exception as e: + self.logger.error(f"OBConnection.create_doc_meta_idx error: {str(e)}") + return False + finally: + self.client.refresh_metadata([table_name]) + + def delete_idx(self, index_name: str, dataset_id: str): + """Delete index/table.""" + # For doc_meta tables, use index_name directly as table name + if index_name.startswith("ragflow_doc_meta_"): + table_name = index_name + else: + table_name = self.get_table_name(index_name, dataset_id) + try: + if self.client.check_table_exists(table_name=table_name): + self.client.drop_table_if_exist(table_name) + self.logger.info(f"Dropped table '{table_name}'.") + except Exception as e: + raise Exception(f"OBConnection.delete_idx error: {str(e)}") + + def index_exist(self, index_name: str, dataset_id: str = None) -> bool: + """Check if index/table exists.""" + # For doc_meta tables, use index_name directly as table name + if index_name.startswith("ragflow_doc_meta_"): + table_name = index_name + else: + table_name = self.get_table_name(index_name, dataset_id) if dataset_id else index_name + return self._check_table_exists_cached(table_name) + + """ + Table operations - helper methods + """ + + def _get_count(self, table_name: str, filter_list: list[str] = None) -> int: + where_clause = "WHERE " + " AND ".join(filter_list) if filter_list and len(filter_list) > 0 else "" + (count,) = self.client.perform_raw_text_sql( + f"SELECT COUNT(*) FROM {table_name} {where_clause}" + ).fetchone() + return count + + def _column_exist(self, table_name: str, column_name: str) -> bool: + return self._get_count( + table_name="INFORMATION_SCHEMA.COLUMNS", + filter_list=[ + f"TABLE_SCHEMA = '{self.db_name}'", + f"TABLE_NAME = '{table_name}'", + f"COLUMN_NAME = '{column_name}'", + ]) > 0 + + def _index_exists(self, table_name: str, idx_name: str) -> bool: + return self._get_count( + table_name="INFORMATION_SCHEMA.STATISTICS", + filter_list=[ + f"TABLE_SCHEMA = '{self.db_name}'", + f"TABLE_NAME = '{table_name}'", + f"INDEX_NAME = '{idx_name}'", + ]) > 0 + + def _create_table_with_columns(self, table_name: str, columns: list[Column]): + """Create table with specified columns.""" + if table_name in self.client.metadata_obj.tables: + self.client.metadata_obj.remove(Table(table_name, self.client.metadata_obj)) + + table_options = { + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_unicode_ci", + "mysql_organization": "heap", + } + + self.client.create_table( + table_name=table_name, + columns=[c.copy() for c in columns], + **table_options, + ) + self.logger.info(f"Created table '{table_name}'.") + + def _add_index(self, table_name: str, column_name: str): + idx_name = index_name_template % (table_name, column_name) + self.client.create_index( + table_name=table_name, + is_vec_index=False, + index_name=idx_name, + column_names=[column_name], + ) + self.logger.info(f"Created index '{idx_name}' on table '{table_name}'.") + + def _add_fulltext_index(self, table_name: str, column_name: str): + fulltext_idx_name = fulltext_index_name_template % column_name + self.client.create_fts_idx_with_fts_index_param( + table_name=table_name, + fts_idx_param=FtsIndexParam( + index_name=fulltext_idx_name, + field_names=[column_name], + parser_type=FtsParser.IK, + ), + ) + self.logger.info(f"Created full text index '{fulltext_idx_name}' on table '{table_name}'.") + + def _add_vector_column(self, table_name: str, vector_size: int): + vector_field_name = f"q_{vector_size}_vec" + self.client.add_columns( + table_name=table_name, + columns=[Column(vector_field_name, VECTOR(vector_size), nullable=True)], + ) + self.logger.info(f"Added vector column '{vector_field_name}' to table '{table_name}'.") + + def _add_vector_index(self, table_name: str, vector_field_name: str): + vector_idx_name = f"{vector_field_name}_idx" + self.client.create_index( + table_name=table_name, + is_vec_index=True, + index_name=vector_idx_name, + column_names=[vector_field_name], + vidx_params="distance=cosine, type=hnsw, lib=vsag", + ) + self.logger.info( + f"Created vector index '{vector_idx_name}' on table '{table_name}' with column '{vector_field_name}'." + ) + + def _add_column(self, table_name: str, column: Column): + try: + self.client.add_columns( + table_name=table_name, + columns=[column.copy()], + ) + self.logger.info(f"Added column '{column.name}' to table '{table_name}'.") + except Exception as e: + self.logger.warning(f"Failed to add column '{column.name}' to table '{table_name}': {str(e)}") + + def _ensure_vector_column_exists(self, table_name: str, vector_size: int, refresh_metadata: bool = True): + """ + Ensure vector column and index exist for the given vector size. + This method is safe to call multiple times - it will skip if already exists. + Uses cache to avoid repeated INFORMATION_SCHEMA queries. + + Args: + table_name: Name of the table + vector_size: Size of the vector column + refresh_metadata: Whether to refresh SQLAlchemy metadata after changes (default True) + """ + if vector_size <= 0: + return + + cache_key = (table_name, vector_size) + + # Check cache first + if cache_key in self._vector_column_cache: + return + + lock_prefix = self.get_lock_prefix() + vector_field_name = f"q_{vector_size}_vec" + vector_index_name = f"{vector_field_name}_idx" + + # Check if already exists (may have been created by another process) + column_exists = self._column_exist(table_name, vector_field_name) + index_exists = self._index_exists(table_name, vector_index_name) + + if column_exists and index_exists: + # Already exists, add to cache and return + with self._vector_column_cache_lock: + self._vector_column_cache.add(cache_key) + return + + # Create column if needed + if not column_exists: + _try_with_lock( + lock_name=f"{lock_prefix}add_vector_column_{table_name}_{vector_field_name}", + check_func=lambda: self._column_exist(table_name, vector_field_name), + process_func=lambda: self._add_vector_column(table_name, vector_size), + ) + + # Create index if needed + if not index_exists: + _try_with_lock( + lock_name=f"{lock_prefix}add_vector_idx_{table_name}_{vector_field_name}", + check_func=lambda: self._index_exists(table_name, vector_index_name), + process_func=lambda: self._add_vector_index(table_name, vector_field_name), + ) + + if refresh_metadata: + self.client.refresh_metadata([table_name]) + + # Add to cache after successful creation + with self._vector_column_cache_lock: + self._vector_column_cache.add(cache_key) + + def _execute_search_sql(self, sql: str) -> tuple[list, float]: + start_time = time.time() + res = self.client.perform_raw_text_sql(sql) + rows = res.fetchall() + elapsed_time = time.time() - start_time + return rows, elapsed_time + + def _parse_fulltext_columns( + self, + fulltext_query: str, + fulltext_columns: list[str] + ) -> tuple[dict[str, str], dict[str, float]]: + """ + Parse fulltext search columns with optional weight suffix and build search expressions. + + Args: + fulltext_query: The escaped fulltext query string + fulltext_columns: List of column names, optionally with weight suffix (e.g., "col^0.5") + + Returns: + Tuple of (fulltext_search_expr dict, fulltext_search_weight dict) + where weights are normalized to 0~1 + """ + fulltext_search_expr: dict[str, str] = {} + fulltext_search_weight: dict[str, float] = {} + + # get fulltext match expression and weight values + for field in fulltext_columns: + parts = field.split("^") + column_name: str = parts[0] + column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0 + + fulltext_search_weight[column_name] = column_weight + fulltext_search_expr[column_name] = fulltext_search_template % (column_name, fulltext_query) + + # adjust the weight to 0~1 + weight_sum = sum(fulltext_search_weight.values()) + n = len(fulltext_search_weight) + if weight_sum <= 0 < n: + # All weights are 0 (e.g. "col^0"); use equal weights to avoid ZeroDivisionError + for column_name in fulltext_search_weight: + fulltext_search_weight[column_name] = 1.0 / n + else: + for column_name in fulltext_search_weight: + fulltext_search_weight[column_name] = fulltext_search_weight[column_name] / weight_sum + + return fulltext_search_expr, fulltext_search_weight + + def _build_vector_search_sql( + self, + table_name: str, + fields_expr: str, + vector_search_score_expr: str, + filters_expr: str, + vector_search_filter: str, + vector_search_expr: str, + limit: int, + vector_topn: int, + offset: int = 0 + ) -> str: + sql = ( + f"SELECT {fields_expr}, {vector_search_score_expr} AS _score" + f" FROM {table_name}" + f" WHERE {filters_expr} AND {vector_search_filter}" + f" ORDER BY {vector_search_expr}" + f" APPROXIMATE LIMIT {limit if limit != 0 else vector_topn}" + ) + if offset != 0: + sql += f" OFFSET {offset}" + return sql + + def _build_fulltext_search_sql( + self, + table_name: str, + fields_expr: str, + fulltext_search_score_expr: str, + filters_expr: str, + fulltext_search_filter: str, + offset: int, + limit: int, + fulltext_topn: int, + hint: str = "" + ) -> str: + hint_expr = f"{hint} " if hint else "" + return ( + f"SELECT {hint_expr}{fields_expr}, {fulltext_search_score_expr} AS _score" + f" FROM {table_name}" + f" WHERE {filters_expr} AND {fulltext_search_filter}" + f" ORDER BY _score DESC" + f" LIMIT {offset}, {limit if limit != 0 else fulltext_topn}" + ) + + def _build_filter_search_sql( + self, + table_name: str, + fields_expr: str, + filters_expr: str, + order_by_expr: str = "", + limit_expr: str = "" + ) -> str: + return ( + f"SELECT {fields_expr}" + f" FROM {table_name}" + f" WHERE {filters_expr}" + f" {order_by_expr} {limit_expr}" + ) + + def _build_count_sql( + self, + table_name: str, + filters_expr: str, + extra_filter: str = "", + hint: str = "" + ) -> str: + hint_expr = f"{hint} " if hint else "" + where_clause = f"{filters_expr} AND {extra_filter}" if extra_filter else filters_expr + return f"SELECT {hint_expr}COUNT(id) FROM {table_name} WHERE {where_clause}" + + def _row_to_entity(self, data, fields: list[str]) -> dict: + entity = {} + for i, field in enumerate(fields): + value = data[i] + if value is None: + continue + entity[field] = value + return entity + + def _get_dataset_id_field(self) -> str: + return "kb_id" + + def _get_filters(self, condition: dict) -> list[str]: + filters: list[str] = [] + for k, v in condition.items(): + if not v: + continue + if k == "exists": + filters.append(f"{v} IS NOT NULL") + elif k == "must_not" and isinstance(v, dict) and "exists" in v: + filters.append(f"{v.get('exists')} IS NULL") + elif isinstance(v, list): + values: list[str] = [] + for item in v: + values.append(get_value_str(item)) + value = ", ".join(values) + filters.append(f"{k} IN ({value})") + else: + filters.append(f"{k} = {get_value_str(v)}") + return filters + + def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None: + if not self._check_table_exists_cached(index_name): + return None + try: + res = self.client.get( + table_name=index_name, + ids=[doc_id], + ) + row = res.fetchone() + if row is None: + return None + return self._row_to_entity(row, fields=list(res.keys())) + except Exception as e: + self.logger.exception(f"OBConnectionBase.get({doc_id}) got exception") + raise e + + def delete(self, condition: dict, index_name: str, dataset_id: str) -> int: + if not self._check_table_exists_cached(index_name): + return 0 + # For doc_meta tables, don't add dataset_id to condition + if not index_name.startswith("ragflow_doc_meta_"): + condition[self._get_dataset_id_field()] = dataset_id + try: + from sqlalchemy import text + res = self.client.get( + table_name=index_name, + ids=None, + where_clause=[text(f) for f in self._get_filters(condition)], + output_column_name=["id"], + ) + rows = res.fetchall() + if len(rows) == 0: + return 0 + ids = [row[0] for row in rows] + self.logger.debug(f"OBConnection.delete, filters: {condition}, ids: {ids}") + self.client.delete( + table_name=index_name, + ids=ids, + ) + return len(ids) + except Exception as e: + self.logger.error(f"OBConnection.delete error: {str(e)}") + return 0 + + """ + Abstract CRUD methods that must be implemented by subclasses + """ + + @abstractmethod + def search( + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + knowledgebase_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, + **kwargs, + ): + raise NotImplementedError("Not implemented") + + @abstractmethod + def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]: + raise NotImplementedError("Not implemented") + + @abstractmethod + def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool: + raise NotImplementedError("Not implemented") + + """ + Helper functions for search result - abstract methods + """ + + @abstractmethod + def get_total(self, res) -> int: + raise NotImplementedError("Not implemented") + + @abstractmethod + def get_doc_ids(self, res) -> list[str]: + raise NotImplementedError("Not implemented") + + @abstractmethod + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: + raise NotImplementedError("Not implemented") + + @abstractmethod + def get_highlight(self, res, keywords: list[str], field_name: str): + raise NotImplementedError("Not implemented") + + @abstractmethod + def get_aggregation(self, res, field_name: str): + raise NotImplementedError("Not implemented") + + """ + SQL - can be overridden by subclasses + """ + + def sql(self, sql: str, fetch_size: int, format: str): + """Execute SQL query - default implementation.""" + return None diff --git a/common/doc_store/ob_conn_pool.py b/common/doc_store/ob_conn_pool.py new file mode 100644 index 000000000..5cb995edb --- /dev/null +++ b/common/doc_store/ob_conn_pool.py @@ -0,0 +1,191 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import time + +from pyobvector import ObVecClient +from pyobvector.client import ClusterVersionException +from pyobvector.client.hybrid_search import HybridSearch +from pyobvector.util import ObVersion + +from common import settings +from common.decorator import singleton + +ATTEMPT_TIME = 2 +OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000")) + +logger = logging.getLogger('ragflow.ob_conn_pool') + + +@singleton +class OceanBaseConnectionPool: + + def __init__(self): + self.client = None + self.es = None # HybridSearch client + + if hasattr(settings, "OB"): + self.OB_CONFIG = settings.OB + else: + self.OB_CONFIG = settings.get_base_config("oceanbase", {}) + + scheme = self.OB_CONFIG.get("scheme") + ob_config = self.OB_CONFIG.get("config", {}) + + if scheme and scheme.lower() == "mysql": + mysql_config = settings.get_base_config("mysql", {}) + logger.info("Use MySQL scheme to create OceanBase connection.") + host = mysql_config.get("host", "localhost") + port = mysql_config.get("port", 2881) + self.username = mysql_config.get("user", "root@test") + self.password = mysql_config.get("password", "infini_rag_flow") + max_connections = mysql_config.get("max_connections", 300) + else: + logger.info("Use customized config to create OceanBase connection.") + host = ob_config.get("host", "localhost") + port = ob_config.get("port", 2881) + self.username = ob_config.get("user", "root@test") + self.password = ob_config.get("password", "infini_rag_flow") + max_connections = ob_config.get("max_connections", 300) + + self.db_name = ob_config.get("db_name", "test") + self.uri = f"{host}:{port}" + + logger.info(f"Use OceanBase '{self.uri}' as the doc engine.") + + max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10))) + pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30")) + + for _ in range(ATTEMPT_TIME): + try: + self.client = ObVecClient( + uri=self.uri, + user=self.username, + password=self.password, + db_name=self.db_name, + pool_pre_ping=True, + pool_recycle=3600, + pool_size=max_connections, + max_overflow=max_overflow, + pool_timeout=pool_timeout, + ) + break + except Exception as e: + logger.warning(f"{str(e)}. Waiting OceanBase {self.uri} to be healthy.") + time.sleep(5) + + if self.client is None: + msg = f"OceanBase {self.uri} connection failed after {ATTEMPT_TIME} attempts." + logger.error(msg) + raise Exception(msg) + + self._check_ob_version() + self._try_to_update_ob_query_timeout() + self._init_hybrid_search(max_connections, max_overflow, pool_timeout) + + logger.info(f"OceanBase {self.uri} is healthy.") + + def _check_ob_version(self): + try: + res = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL").fetchone() + version_str = res[0] if res else None + logger.info(f"OceanBase {self.uri} version is {version_str}") + except Exception as e: + raise Exception(f"Failed to get OceanBase version from {self.uri}, error: {str(e)}") + + if not version_str: + raise Exception(f"Failed to get OceanBase version from {self.uri}.") + + ob_version = ObVersion.from_db_version_string(version_str) + if ob_version < ObVersion.from_db_version_nums(4, 3, 5, 1): + raise Exception( + f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}" + ) + + def _try_to_update_ob_query_timeout(self): + try: + rows = self.client.perform_raw_text_sql("SHOW VARIABLES LIKE 'ob_query_timeout'") + for row in rows: + val = row[1] + if val and int(val) >= OB_QUERY_TIMEOUT: + return + except Exception as e: + logger.warning("Failed to get 'ob_query_timeout' variable: %s", str(e)) + + try: + self.client.perform_raw_text_sql(f"SET GLOBAL ob_query_timeout={OB_QUERY_TIMEOUT}") + logger.info("Set GLOBAL variable 'ob_query_timeout' to %d.", OB_QUERY_TIMEOUT) + self.client.engine.dispose() + logger.info("Disposed all connections in engine pool to refresh connection pool") + except Exception as e: + logger.warning(f"Failed to set 'ob_query_timeout' variable: {str(e)}") + + def _init_hybrid_search(self, max_connections, max_overflow, pool_timeout): + enable_hybrid_search = os.getenv('ENABLE_HYBRID_SEARCH', 'false').lower() in ['true', '1', 'yes', 'y'] + if enable_hybrid_search: + try: + self.es = HybridSearch( + uri=self.uri, + user=self.username, + password=self.password, + db_name=self.db_name, + pool_pre_ping=True, + pool_recycle=3600, + pool_size=max_connections, + max_overflow=max_overflow, + pool_timeout=pool_timeout, + ) + logger.info("OceanBase Hybrid Search feature is enabled") + except ClusterVersionException as e: + logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e) + self.es = None + + def get_client(self) -> ObVecClient: + return self.client + + def get_hybrid_search_client(self) -> HybridSearch | None: + return self.es + + def get_db_name(self) -> str: + return self.db_name + + def get_uri(self) -> str: + return self.uri + + def refresh_client(self) -> ObVecClient: + try: + self.client.perform_raw_text_sql("SELECT 1 FROM DUAL") + return self.client + except Exception as e: + logger.warning(f"OceanBase connection unhealthy: {str(e)}, refreshing...") + self.client.engine.dispose() + return self.client + + def __del__(self): + if hasattr(self, "client") and self.client: + try: + self.client.engine.dispose() + except Exception: + pass + if hasattr(self, "es") and self.es: + try: + self.es.engine.dispose() + except Exception: + pass + + +OB_CONN = OceanBaseConnectionPool() diff --git a/common/settings.py b/common/settings.py index 221e4a909..97be3c521 100644 --- a/common/settings.py +++ b/common/settings.py @@ -41,6 +41,7 @@ from rag.nlp import search import memory.utils.es_conn as memory_es_conn import memory.utils.infinity_conn as memory_infinity_conn +import memory.utils.ob_conn as memory_ob_conn LLM = None LLM_FACTORY = None @@ -281,6 +282,8 @@ def init_settings(): "db_name": "default_db" }) msgStoreConn = memory_infinity_conn.InfinityConnection() + elif lower_case_doc_engine in ["oceanbase", "seekdb"]: + msgStoreConn = memory_ob_conn.OBConnection() global AZURE, S3, MINIO, OSS, GCS if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: diff --git a/memory/utils/ob_conn.py b/memory/utils/ob_conn.py new file mode 100644 index 000000000..bf8ac4005 --- /dev/null +++ b/memory/utils/ob_conn.py @@ -0,0 +1,613 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Optional + +import numpy as np +from pydantic import BaseModel +from pymysql.converters import escape_string +from sqlalchemy import Column, String, Integer +from sqlalchemy.dialects.mysql import LONGTEXT + +from common.decorator import singleton +from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr +from common.doc_store.ob_conn_base import OBConnectionBase, get_value_str, vector_search_template +from common.float_utils import get_float +from rag.nlp.rag_tokenizer import tokenize, fine_grained_tokenize + +# Column definitions for memory message table +COLUMN_DEFINITIONS: list[Column] = [ + Column("id", String(256), primary_key=True, comment="unique record id"), + Column("message_id", String(256), nullable=False, index=True, comment="message id"), + Column("message_type_kwd", String(64), nullable=True, comment="message type"), + Column("source_id", String(256), nullable=True, comment="source message id"), + Column("memory_id", String(256), nullable=False, index=True, comment="memory id"), + Column("user_id", String(256), nullable=True, comment="user id"), + Column("agent_id", String(256), nullable=True, comment="agent id"), + Column("session_id", String(256), nullable=True, comment="session id"), + Column("zone_id", Integer, nullable=True, server_default="0", comment="zone id"), + Column("valid_at", String(64), nullable=True, comment="valid at timestamp string"), + Column("invalid_at", String(64), nullable=True, comment="invalid at timestamp string"), + Column("forget_at", String(64), nullable=True, comment="forget at timestamp string"), + Column("status_int", Integer, nullable=False, server_default="1", comment="status: 1 for active, 0 for inactive"), + Column("content_ltks", LONGTEXT, nullable=True, comment="content with tokenization"), + Column("tokenized_content_ltks", LONGTEXT, nullable=True, comment="fine-grained tokenized content"), +] + +COLUMN_NAMES: list[str] = [col.name for col in COLUMN_DEFINITIONS] + +# Index columns for creating indexes +INDEX_COLUMNS: list[str] = [ + "message_id", + "memory_id", + "status_int", +] + +# Full-text search columns +FTS_COLUMNS: list[str] = [ + "content_ltks", + "tokenized_content_ltks", +] + + +class SearchResult(BaseModel): + total: int + messages: list[dict] + + +@singleton +class OBConnection(OBConnectionBase): + def __init__(self): + super().__init__(logger_name='ragflow.memory_ob_conn') + self._fulltext_search_columns = FTS_COLUMNS + + """ + Template method implementations + """ + + def get_index_columns(self) -> list[str]: + return INDEX_COLUMNS + + def get_fulltext_columns(self) -> list[str]: + """Return list of column names that need fulltext indexes (without weight suffix).""" + return [col.split("^")[0] for col in self._fulltext_search_columns] + + def get_column_definitions(self) -> list[Column]: + return COLUMN_DEFINITIONS + + def get_lock_prefix(self) -> str: + return "ob_memory_" + + def _get_dataset_id_field(self) -> str: + return "memory_id" + + def _get_vector_column_name_from_table(self, table_name: str) -> Optional[str]: + """Get the vector column name from the table (q_{size}_vec pattern).""" + sql = f""" + SELECT COLUMN_NAME + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = '{self.db_name}' + AND TABLE_NAME = '{table_name}' + AND COLUMN_NAME REGEXP '^q_[0-9]+_vec$' + LIMIT 1 + """ + try: + res = self.client.perform_raw_text_sql(sql) + row = res.fetchone() + return row[0] if row else None + except Exception: + return None + + """ + Field conversion methods + """ + + @staticmethod + def convert_field_name(field_name: str, use_tokenized_content=False) -> str: + """Convert message field name to database column name.""" + match field_name: + case "message_type": + return "message_type_kwd" + case "status": + return "status_int" + case "content": + if use_tokenized_content: + return "tokenized_content_ltks" + return "content_ltks" + case _: + return field_name + + @staticmethod + def map_message_to_ob_fields(message: dict) -> dict: + """Map message dictionary fields to OceanBase document fields.""" + storage_doc = { + "id": message.get("id"), + "message_id": message["message_id"], + "message_type_kwd": message["message_type"], + "source_id": message.get("source_id"), + "memory_id": message["memory_id"], + "user_id": message.get("user_id", ""), + "agent_id": message["agent_id"], + "session_id": message["session_id"], + "valid_at": message["valid_at"], + "invalid_at": message.get("invalid_at"), + "forget_at": message.get("forget_at"), + "status_int": 1 if message["status"] else 0, + "zone_id": message.get("zone_id", 0), + "content_ltks": message["content"], + "tokenized_content_ltks": fine_grained_tokenize(tokenize(message["content"])), + } + # Handle vector embedding + content_embed = message.get("content_embed", []) + if len(content_embed) > 0: + storage_doc[f"q_{len(content_embed)}_vec"] = content_embed + return storage_doc + + @staticmethod + def get_message_from_ob_doc(doc: dict) -> dict: + """Convert an OceanBase document back to a message dictionary.""" + embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None) + content_embed = doc.get(embd_field_name, []) if embd_field_name else [] + if isinstance(content_embed, np.ndarray): + content_embed = content_embed.tolist() + message = { + "message_id": doc.get("message_id"), + "message_type": doc.get("message_type_kwd"), + "source_id": doc.get("source_id") if doc.get("source_id") else None, + "memory_id": doc.get("memory_id"), + "user_id": doc.get("user_id", ""), + "agent_id": doc.get("agent_id"), + "session_id": doc.get("session_id"), + "zone_id": doc.get("zone_id", 0), + "valid_at": doc.get("valid_at"), + "invalid_at": doc.get("invalid_at", "-"), + "forget_at": doc.get("forget_at", "-"), + "status": bool(int(doc.get("status_int", 0))), + "content": doc.get("content_ltks", ""), + "content_embed": content_embed, + } + if doc.get("id"): + message["id"] = doc["id"] + return message + + """ + CRUD operations + """ + + def search( + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + memory_ids: list[str], + agg_fields: list[str] | None = None, + rank_feature: dict | None = None, + hide_forgotten: bool = True + ): + """Search messages in memory storage.""" + if isinstance(index_names, str): + index_names = index_names.split(",") + assert isinstance(index_names, list) and len(index_names) > 0 + + result: SearchResult = SearchResult(total=0, messages=[]) + + output_fields = select_fields.copy() + if "id" not in output_fields: + output_fields = ["id"] + output_fields + if "_score" in output_fields: + output_fields.remove("_score") + + # Handle content_embed field - resolve to actual vector column name + has_content_embed = "content_embed" in output_fields + actual_vector_column: Optional[str] = None + if has_content_embed: + output_fields = [f for f in output_fields if f != "content_embed"] + # Try to get vector column name from first available table + for idx_name in index_names: + if self._check_table_exists_cached(idx_name): + actual_vector_column = self._get_vector_column_name_from_table(idx_name) + if actual_vector_column: + output_fields.append(actual_vector_column) + break + + if highlight_fields: + for field in highlight_fields: + field_name = self.convert_field_name(field) + if field_name not in output_fields: + output_fields.append(field_name) + + db_output_fields = [self.convert_field_name(f) for f in output_fields] + fields_expr = ", ".join(db_output_fields) + + condition["memory_id"] = memory_ids + if hide_forgotten: + condition["must_not"] = {"exists": "forget_at"} + + condition_dict = {self.convert_field_name(k): v for k, v in condition.items()} + filters: list[str] = self._get_filters(condition_dict) + filters_expr = " AND ".join(filters) if filters else "1=1" + + # Parse match expressions + fulltext_query: Optional[str] = None + fulltext_topn: Optional[int] = None + fulltext_search_expr: dict[str, str] = {} + fulltext_search_weight: dict[str, float] = {} + fulltext_search_filter: Optional[str] = None + fulltext_search_score_expr: Optional[str] = None + + vector_column_name: Optional[str] = None + vector_data: Optional[list[float]] = None + vector_topn: Optional[int] = None + vector_similarity_threshold: Optional[float] = None + vector_similarity_weight: Optional[float] = None + vector_search_expr: Optional[str] = None + vector_search_score_expr: Optional[str] = None + vector_search_filter: Optional[str] = None + + for m in match_expressions: + if isinstance(m, MatchTextExpr): + assert "original_query" in m.extra_options, "'original_query' is missing in extra_options." + fulltext_query = m.extra_options["original_query"] + fulltext_query = escape_string(fulltext_query.strip()) + fulltext_topn = m.topn + + fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns( + fulltext_query, self._fulltext_search_columns + ) + elif isinstance(m, MatchDenseExpr): + vector_column_name = m.vector_column_name + vector_data = m.embedding_data + vector_topn = m.topn + vector_similarity_threshold = m.extra_options.get("similarity", 0.0) if m.extra_options else 0.0 + elif isinstance(m, FusionExpr): + weights = m.fusion_params.get("weights", "0.5,0.5") if m.fusion_params else "0.5,0.5" + vector_similarity_weight = get_float(weights.split(",")[1]) + + if fulltext_query: + fulltext_search_filter = f"({' OR '.join([expr for expr in fulltext_search_expr.values()])})" + fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})" + + if vector_data: + vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]" + vector_search_expr = vector_search_template % (vector_column_name, vector_data_str) + vector_search_score_expr = f"(1 - {vector_search_expr})" + vector_search_filter = f"{vector_search_score_expr} >= {vector_similarity_threshold}" + + # Determine search type + if fulltext_query and vector_data: + search_type = "fusion" + elif fulltext_query: + search_type = "fulltext" + elif vector_data: + search_type = "vector" + else: + search_type = "filter" + + if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields: + output_fields.append("_score") + + if limit: + if vector_topn is not None: + limit = min(vector_topn, limit) + if fulltext_topn is not None: + limit = min(fulltext_topn, limit) + + for index_name in index_names: + table_name = index_name + + if not self._check_table_exists_cached(table_name): + continue + + if search_type == "fusion": + num_candidates = (vector_topn or limit) + (fulltext_topn or limit) + score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight})" + fusion_sql = ( + f"WITH fulltext_results AS (" + f" SELECT *, {fulltext_search_score_expr} AS relevance" + f" FROM {table_name}" + f" WHERE {filters_expr} AND {fulltext_search_filter}" + f" ORDER BY relevance DESC" + f" LIMIT {num_candidates}" + f")" + f" SELECT {fields_expr}, {score_expr} AS _score" + f" FROM fulltext_results" + f" WHERE {vector_search_filter}" + f" ORDER BY _score DESC" + f" LIMIT {offset}, {limit}" + ) + self.logger.debug("OBConnection.search with fusion sql: %s", fusion_sql) + rows, elapsed_time = self._execute_search_sql(fusion_sql) + self.logger.info( + f"OBConnection.search table {table_name}, search type: fusion, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}" + ) + + for row in rows: + result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"])) + result.total += 1 + + elif search_type == "vector": + vector_sql = self._build_vector_search_sql( + table_name, fields_expr, vector_search_score_expr, filters_expr, + vector_search_filter, vector_search_expr, limit, vector_topn, offset + ) + self.logger.debug("OBConnection.search with vector sql: %s", vector_sql) + rows, elapsed_time = self._execute_search_sql(vector_sql) + self.logger.info( + f"OBConnection.search table {table_name}, search type: vector, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}" + ) + + for row in rows: + result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"])) + result.total += 1 + + elif search_type == "fulltext": + fulltext_sql = self._build_fulltext_search_sql( + table_name, fields_expr, fulltext_search_score_expr, filters_expr, + fulltext_search_filter, offset, limit, fulltext_topn + ) + self.logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql) + rows, elapsed_time = self._execute_search_sql(fulltext_sql) + self.logger.info( + f"OBConnection.search table {table_name}, search type: fulltext, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}" + ) + + for row in rows: + result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"])) + result.total += 1 + + else: + orders: list[str] = [] + if order_by and order_by.fields: + for field, order_dir in order_by.fields: + field_name = self.convert_field_name(field) + order_str = "ASC" if order_dir == 0 else "DESC" + orders.append(f"{field_name} {order_str}") + + order_by_expr = ("ORDER BY " + ", ".join(orders)) if orders else "" + limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else "" + filter_sql = self._build_filter_search_sql( + table_name, fields_expr, filters_expr, order_by_expr, limit_expr + ) + self.logger.debug("OBConnection.search with filter sql: %s", filter_sql) + rows, elapsed_time = self._execute_search_sql(filter_sql) + self.logger.info( + f"OBConnection.search table {table_name}, search type: filter, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}" + ) + + for row in rows: + result.messages.append(self._row_to_entity(row, db_output_fields)) + result.total += 1 + + if result.total == 0: + result.total = len(result.messages) + + return result, result.total + + def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int = 512): + """Get forgotten messages (messages with forget_at set).""" + if not self._check_table_exists_cached(index_name): + return None + + db_output_fields = [self.convert_field_name(f) for f in select_fields] + fields_expr = ", ".join(db_output_fields) + + sql = ( + f"SELECT {fields_expr}" + f" FROM {index_name}" + f" WHERE memory_id = {get_value_str(memory_id)} AND forget_at IS NOT NULL" + f" ORDER BY forget_at ASC" + f" LIMIT {limit}" + ) + self.logger.debug("OBConnection.get_forgotten_messages sql: %s", sql) + + res = self.client.perform_raw_text_sql(sql) + rows = res.fetchall() + + result = SearchResult(total=len(rows), messages=[]) + for row in rows: + result.messages.append(self._row_to_entity(row, db_output_fields)) + + return result + + def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str, + limit: int = 512): + """Get messages missing a specific field.""" + if not self._check_table_exists_cached(index_name): + return None + + db_field_name = self.convert_field_name(field_name) + db_output_fields = [self.convert_field_name(f) for f in select_fields] + fields_expr = ", ".join(db_output_fields) + + sql = ( + f"SELECT {fields_expr}" + f" FROM {index_name}" + f" WHERE memory_id = {get_value_str(memory_id)} AND {db_field_name} IS NULL" + f" ORDER BY valid_at ASC" + f" LIMIT {limit}" + ) + self.logger.debug("OBConnection.get_missing_field_message sql: %s", sql) + + res = self.client.perform_raw_text_sql(sql) + rows = res.fetchall() + + result = SearchResult(total=len(rows), messages=[]) + for row in rows: + result.messages.append(self._row_to_entity(row, db_output_fields)) + + return result + + def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None: + """Get single message by id.""" + doc = super().get(doc_id, index_name, memory_ids) + if doc is None: + return None + return self.get_message_from_ob_doc(doc) + + def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]: + """Insert messages into memory storage.""" + if not documents: + return [] + + vector_size = len(documents[0].get("content_embed", [])) if "content_embed" in documents[0] else 0 + + if not self._check_table_exists_cached(index_name): + if vector_size == 0: + raise ValueError("Cannot infer vector size from documents") + self.create_idx(index_name, memory_id, vector_size) + elif vector_size > 0: + # Table exists but may not have the required vector column + self._ensure_vector_column_exists(index_name, vector_size) + + docs: list[dict] = [] + ids: list[str] = [] + + for document in documents: + d = self.map_message_to_ob_fields(document) + ids.append(d["id"]) + + for column_name in COLUMN_NAMES: + if column_name not in d: + d[column_name] = None + + docs.append(d) + + self.logger.debug("OBConnection.insert messages: %s", ids) + + res = [] + try: + self.client.upsert(index_name, docs) + except Exception as e: + self.logger.error(f"OBConnection.insert error: {str(e)}") + res.append(str(e)) + return res + + def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool: + """Update messages with given condition.""" + if not self._check_table_exists_cached(index_name): + return True + + condition["memory_id"] = memory_id + condition_dict = {self.convert_field_name(k): v for k, v in condition.items()} + filters = self._get_filters(condition_dict) + + update_dict = {self.convert_field_name(k): v for k, v in new_value.items()} + if "content_ltks" in update_dict: + update_dict["tokenized_content_ltks"] = fine_grained_tokenize(tokenize(update_dict["content_ltks"])) + update_dict.pop("id", None) + + set_values: list[str] = [] + for k, v in update_dict.items(): + if k == "remove": + if isinstance(v, str): + set_values.append(f"{v} = NULL") + elif k == "status": + set_values.append(f"status_int = {1 if v else 0}") + else: + set_values.append(f"{k} = {get_value_str(v)}") + + if not set_values: + return True + + update_sql = ( + f"UPDATE {index_name}" + f" SET {', '.join(set_values)}" + f" WHERE {' AND '.join(filters)}" + ) + self.logger.debug("OBConnection.update sql: %s", update_sql) + + try: + self.client.perform_raw_text_sql(update_sql) + return True + except Exception as e: + self.logger.error(f"OBConnection.update error: {str(e)}") + return False + + def delete(self, condition: dict, index_name: str, memory_id: str) -> int: + """Delete messages with given condition.""" + condition_dict = {self.convert_field_name(k): v for k, v in condition.items()} + return super().delete(condition_dict, index_name, memory_id) + + """ + Helper functions for search result + """ + + def get_total(self, res) -> int: + if isinstance(res, tuple): + return res[1] + if hasattr(res, 'total'): + return res.total + return 0 + + def get_doc_ids(self, res) -> list[str]: + if isinstance(res, tuple): + res = res[0] + if hasattr(res, 'messages'): + return [row.get("id") for row in res.messages if row.get("id")] + return [] + + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: + """Get fields from search result.""" + if isinstance(res, tuple): + res = res[0] + + res_fields = {} + if not fields: + return {} + + messages = res.messages if hasattr(res, 'messages') else [] + + for doc in messages: + message = self.get_message_from_ob_doc(doc) + m = {} + for n, v in message.items(): + if n not in fields: + continue + if isinstance(v, list): + m[n] = v + continue + if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, + (int, + float, + bool)): + m[n] = v + continue + if not isinstance(v, str): + m[n] = str(v) if v is not None else "" + else: + m[n] = v + + doc_id = doc.get("id") or message.get("id") + if m and doc_id: + res_fields[doc_id] = m + + return res_fields + + def get_highlight(self, res, keywords: list[str], field_name: str): + """Get highlighted text for search results.""" + # TODO: Implement highlight functionality for OceanBase memory + return {} + + def get_aggregation(self, res, field_name: str): + """Get aggregation for search results.""" + # TODO: Implement aggregation functionality for OceanBase memory + return [] diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index 08187c78c..e20f8993e 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -15,9 +15,7 @@ # import json import logging -import os import re -import threading import time from typing import Any, Optional @@ -25,25 +23,22 @@ import numpy as np from elasticsearch_dsl import Q, Search from pydantic import BaseModel from pymysql.converters import escape_string -from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR -from pyobvector.client import ClusterVersionException -from pyobvector.client.hybrid_search import HybridSearch -from pyobvector.util import ObVersion -from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table +from pyobvector import ARRAY +from sqlalchemy import Column, String, Integer, JSON, Double, Row from sqlalchemy.dialects.mysql import LONGTEXT, TEXT from sqlalchemy.sql.type_api import TypeEngine -from common import settings from common.constants import PAGERANK_FLD, TAG_FLD from common.decorator import singleton +from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr +from common.doc_store.ob_conn_base import ( + OBConnectionBase, get_value_str, + vector_search_template, vector_column_pattern, + fulltext_index_name_template, +) from common.float_utils import get_float -from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \ - MatchDenseExpr from rag.nlp import rag_tokenizer -ATTEMPT_TIME = 2 -OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000")) - logger = logging.getLogger('ragflow.ob_conn') column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk order id for maintaining sequence") @@ -102,9 +97,8 @@ column_names: list[str] = [col.name for col in column_definitions] column_types: dict[str, TypeEngine] = {col.name: col.type for col in column_definitions} array_columns: list[str] = [col.name for col in column_definitions if isinstance(col.type, ARRAY)] -vector_column_pattern = re.compile(r"q_(?P\d+)_vec") - -index_columns: list[str] = [ +# Index columns for RAG chunk table +INDEX_COLUMNS: list[str] = [ "kb_id", "doc_id", "available_int", @@ -113,14 +107,16 @@ index_columns: list[str] = [ "removed_kwd", ] -fts_columns_origin: list[str] = [ +# Full-text search columns (with weight) - original content +FTS_COLUMNS_ORIGIN: list[str] = [ "docnm_kwd^10", "content_with_weight", "important_tks^20", "question_tks^20", ] -fts_columns_tks: list[str] = [ +# Full-text search columns (with weight) - tokenized content +FTS_COLUMNS_TKS: list[str] = [ "title_tks^10", "title_sm_tks^5", "important_tks^20", @@ -129,12 +125,8 @@ fts_columns_tks: list[str] = [ "content_sm_ltks", ] -index_name_template = "ix_%s_%s" -fulltext_index_name_template = "fts_idx_%s" -# MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607 -fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)" -# cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938 -vector_search_template = "cosine_distance(%s, '%s')" +# Extra columns to add after table creation (for migration) +EXTRA_COLUMNS: list[Column] = [column_order_id, column_group_id, column_mom_id] class SearchResult(BaseModel): @@ -186,24 +178,6 @@ def get_default_value(column_name: str) -> Any: return None -def get_value_str(value: Any) -> str: - if isinstance(value, str): - cleaned_str = value.replace('\\', '\\\\') - cleaned_str = cleaned_str.replace('\n', '\\n') - cleaned_str = cleaned_str.replace('\r', '\\r') - cleaned_str = cleaned_str.replace('\t', '\\t') - return f"'{escape_string(cleaned_str)}'" - elif isinstance(value, bool): - return "true" if value else "false" - elif value is None: - return "NULL" - elif isinstance(value, (list, dict)): - json_str = json.dumps(value, ensure_ascii=False) - return f"'{escape_string(json_str)}'" - else: - return str(value) - - def get_metadata_filter_expression(metadata_filtering_conditions: dict) -> str: """ Convert metadata filtering conditions to MySQL JSON path expression. @@ -319,225 +293,50 @@ def get_filters(condition: dict) -> list[str]: return filters -def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None): - if not timeout: - timeout = int(os.environ.get("OB_DDL_TIMEOUT", "60")) - - if not check_func(): - from rag.utils.redis_conn import RedisDistributedLock - lock = RedisDistributedLock(lock_name) - if lock.acquire(): - logger.info(f"acquired lock success: {lock_name}, start processing.") - try: - process_func() - return - except Exception as e: - if "Duplicate" in str(e): - # In some cases, the schema may change after the lock is acquired, so if the error message - # indicates that the column or index is duplicated, it should be assumed that 'process_func' - # has been executed correctly. - logger.warning(f"Skip processing {lock_name} due to duplication: {str(e)}") - return - raise - finally: - lock.release() - - if not check_func(): - logger.info(f"Waiting for process complete for {lock_name} on other task executors.") - time.sleep(1) - count = 1 - while count < timeout and not check_func(): - count += 1 - time.sleep(1) - if count >= timeout and not check_func(): - raise Exception(f"Timeout to wait for process complete for {lock_name}.") - - @singleton -class OBConnection(DocStoreConnection): +class OBConnection(OBConnectionBase): def __init__(self): - scheme: str = settings.OB.get("scheme") - ob_config = settings.OB.get("config", {}) - - if scheme and scheme.lower() == "mysql": - mysql_config = settings.get_base_config("mysql", {}) - logger.info("Use MySQL scheme to create OceanBase connection.") - host = mysql_config.get("host", "localhost") - port = mysql_config.get("port", 2881) - self.username = mysql_config.get("user", "root@test") - self.password = mysql_config.get("password", "infini_rag_flow") - max_connections = mysql_config.get("max_connections", 300) - else: - logger.info("Use customized config to create OceanBase connection.") - host = ob_config.get("host", "localhost") - port = ob_config.get("port", 2881) - self.username = ob_config.get("user", "root@test") - self.password = ob_config.get("password", "infini_rag_flow") - max_connections = ob_config.get("max_connections", 300) - - self.db_name = ob_config.get("db_name", "test") - self.uri = f"{host}:{port}" - - logger.info(f"Use OceanBase '{self.uri}' as the doc engine.") - - # Set the maximum number of connections that can be created above the pool_size. - # By default, this is half of max_connections, but at least 10. - # This allows the pool to handle temporary spikes in demand without exhausting resources. - max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10))) - # Set the number of seconds to wait before giving up when trying to get a connection from the pool. - # Default is 30 seconds, but can be overridden with the OB_POOL_TIMEOUT environment variable. - pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30")) - - for _ in range(ATTEMPT_TIME): - try: - self.client = ObVecClient( - uri=self.uri, - user=self.username, - password=self.password, - db_name=self.db_name, - pool_pre_ping=True, - pool_recycle=3600, - pool_size=max_connections, - max_overflow=max_overflow, - pool_timeout=pool_timeout, - ) - break - except Exception as e: - logger.warning(f"{str(e)}. Waiting OceanBase {self.uri} to be healthy.") - time.sleep(5) - - if self.client is None: - msg = f"OceanBase {self.uri} connection failed after {ATTEMPT_TIME} attempts." - logger.error(msg) - raise Exception(msg) - - self._load_env_vars() - self._check_ob_version() - self._try_to_update_ob_query_timeout() - - self.es = None - if self.enable_hybrid_search: - try: - self.es = HybridSearch( - uri=self.uri, - user=self.username, - password=self.password, - db_name=self.db_name, - pool_pre_ping=True, - pool_recycle=3600, - pool_size=max_connections, - max_overflow=max_overflow, - pool_timeout=pool_timeout, - ) - logger.info("OceanBase Hybrid Search feature is enabled") - except ClusterVersionException as e: - logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e) - self.es = None - - if self.es is not None and self.search_original_content: - logger.info("HybridSearch is enabled, forcing search_original_content to False") - self.search_original_content = False - # Determine which columns to use for full-text search dynamically: - # If HybridSearch is enabled (self.es is not None), we must use tokenized columns (fts_columns_tks) - # for compatibility and performance with HybridSearch. Otherwise, we use the original content columns - # (fts_columns_origin), which may be controlled by an environment variable. - self.fulltext_search_columns = fts_columns_origin if self.search_original_content else fts_columns_tks - - self._table_exists_cache: set[str] = set() - self._table_exists_cache_lock = threading.RLock() - - logger.info(f"OceanBase {self.uri} is healthy.") - - def _check_ob_version(self): - try: - res = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL").fetchone() - version_str = res[0] if res else None - logger.info(f"OceanBase {self.uri} version is {version_str}") - except Exception as e: - raise Exception(f"Failed to get OceanBase version from {self.uri}, error: {str(e)}") - - if not version_str: - raise Exception(f"Failed to get OceanBase version from {self.uri}.") - - ob_version = ObVersion.from_db_version_string(version_str) - if ob_version < ObVersion.from_db_version_nums(4, 3, 5, 1): - raise Exception( - f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}" - ) - - def _try_to_update_ob_query_timeout(self): - try: - val = self._get_variable_value("ob_query_timeout") - if val and int(val) >= OB_QUERY_TIMEOUT: - return - except Exception as e: - logger.warning("Failed to get 'ob_query_timeout' variable: %s", str(e)) - - try: - self.client.perform_raw_text_sql(f"SET GLOBAL ob_query_timeout={OB_QUERY_TIMEOUT}") - logger.info("Set GLOBAL variable 'ob_query_timeout' to %d.", OB_QUERY_TIMEOUT) - - # refresh connection pool to ensure 'ob_query_timeout' has taken effect - self.client.engine.dispose() - if self.es is not None: - self.es.engine.dispose() - logger.info("Disposed all connections in engine pool to refresh connection pool") - except Exception as e: - logger.warning(f"Failed to set 'ob_query_timeout' variable: {str(e)}") - - def _load_env_vars(self): - - def is_true(var: str, default: str) -> bool: - return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y'] - - self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true') - logger.info(f"ENABLE_FULLTEXT_SEARCH={self.enable_fulltext_search}") - - self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true') - logger.info(f"USE_FULLTEXT_HINT={self.use_fulltext_hint}") - - self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true') - logger.info(f"SEARCH_ORIGINAL_CONTENT={self.search_original_content}") - - self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false') - logger.info(f"ENABLE_HYBRID_SEARCH={self.enable_hybrid_search}") - - self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true') - logger.info(f"USE_FULLTEXT_FIRST_FUSION_SEARCH={self.use_fulltext_first_fusion_search}") + super().__init__(logger_name='ragflow.ob_conn') + # Determine which columns to use for full-text search dynamically + self._fulltext_search_columns = FTS_COLUMNS_ORIGIN if self.search_original_content else FTS_COLUMNS_TKS """ - Database operations + Template method implementations """ - def db_type(self) -> str: - return "oceanbase" + def get_index_columns(self) -> list[str]: + return INDEX_COLUMNS + + def get_column_definitions(self) -> list[Column]: + return column_definitions + + def get_extra_columns(self) -> list[Column]: + return EXTRA_COLUMNS + + def get_lock_prefix(self) -> str: + return "ob_" + + def _get_filters(self, condition: dict) -> list[str]: + return get_filters(condition) + + def get_fulltext_columns(self) -> list[str]: + """Return list of column names that need fulltext indexes (without weight suffix).""" + return [col.split("^")[0] for col in self._fulltext_search_columns] + + def delete_idx(self, index_name: str, dataset_id: str): + if dataset_id: + # The index need to be alive after any kb deletion since all kb under this tenant are in one index. + return + super().delete_idx(index_name, dataset_id) + + """ + Performance monitoring + """ - def health(self) -> dict: - """ - Check OceanBase health status with basic connection information. - - Returns: - dict: Health status with URI and version information - """ - try: - return { - "uri": self.uri, - "version_comment": self._get_variable_value("version_comment"), - "status": "healthy", - "connection": "connected" - } - except Exception as e: - return { - "uri": self.uri, - "status": "unhealthy", - "connection": "disconnected", - "error": str(e) - } - def get_performance_metrics(self) -> dict: """ Get comprehensive performance metrics for OceanBase. - + Returns: dict: Performance metrics including latency, storage, QPS, and slow queries """ @@ -551,53 +350,52 @@ class OBConnection(DocStoreConnection): "active_connections": 0, "max_connections": 0 } - + try: # Measure connection latency - import time start_time = time.time() self.client.perform_raw_text_sql("SELECT 1").fetchone() metrics["latency_ms"] = round((time.time() - start_time) * 1000, 2) - + # Get storage information try: storage_info = self._get_storage_info() metrics.update(storage_info) except Exception as e: logger.warning(f"Failed to get storage info: {str(e)}") - + # Get connection pool statistics try: pool_stats = self._get_connection_pool_stats() metrics.update(pool_stats) except Exception as e: logger.warning(f"Failed to get connection pool stats: {str(e)}") - + # Get slow query statistics try: slow_queries = self._get_slow_query_count() metrics["slow_queries"] = slow_queries except Exception as e: logger.warning(f"Failed to get slow query count: {str(e)}") - + # Get QPS (Queries Per Second) - approximate from processlist try: qps = self._estimate_qps() metrics["query_per_second"] = qps except Exception as e: logger.warning(f"Failed to estimate QPS: {str(e)}") - + except Exception as e: metrics["connection"] = "disconnected" metrics["error"] = str(e) logger.error(f"Failed to get OceanBase performance metrics: {str(e)}") - + return metrics - + def _get_storage_info(self) -> dict: """ Get storage space usage information. - + Returns: dict: Storage information with used and total space """ @@ -607,9 +405,9 @@ class OBConnection(DocStoreConnection): f"SELECT ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS 'size_mb' " f"FROM information_schema.tables WHERE table_schema = '{self.db_name}'" ).fetchone() - + size_mb = float(result[0]) if result and result[0] else 0.0 - + # Try to get total available space (may not be available in all OceanBase versions) try: result = self.client.perform_raw_text_sql( @@ -620,7 +418,7 @@ class OBConnection(DocStoreConnection): except Exception: # Fallback: estimate total space (100GB default if not available) total_gb = 100.0 - + return { "storage_used": f"{size_mb:.2f}MB", "storage_total": f"{total_gb:.2f}GB" if total_gb else "N/A" @@ -631,11 +429,11 @@ class OBConnection(DocStoreConnection): "storage_used": "N/A", "storage_total": "N/A" } - + def _get_connection_pool_stats(self) -> dict: """ Get connection pool statistics. - + Returns: dict: Connection pool statistics """ @@ -643,16 +441,16 @@ class OBConnection(DocStoreConnection): # Get active connections from processlist result = self.client.perform_raw_text_sql("SHOW PROCESSLIST") active_connections = len(list(result.fetchall())) - + # Get max_connections setting max_conn_result = self.client.perform_raw_text_sql( "SHOW VARIABLES LIKE 'max_connections'" ).fetchone() max_connections = int(max_conn_result[1]) if max_conn_result and max_conn_result[1] else 0 - + # Get pool size from client if available pool_size = getattr(self.client, 'pool_size', None) or 0 - + return { "active_connections": active_connections, "max_connections": max_connections if max_connections > 0 else pool_size, @@ -665,14 +463,14 @@ class OBConnection(DocStoreConnection): "max_connections": 0, "pool_size": 0 } - + def _get_slow_query_count(self, threshold_seconds: int = 1) -> int: """ Get count of slow queries (queries taking longer than threshold). - + Args: threshold_seconds: Threshold in seconds for slow queries (default: 1) - + Returns: int: Number of slow queries """ @@ -685,11 +483,11 @@ class OBConnection(DocStoreConnection): except Exception as e: logger.warning(f"Failed to get slow query count: {str(e)}") return 0 - + def _estimate_qps(self) -> int: """ Estimate queries per second from processlist. - + Returns: int: Estimated queries per second """ @@ -699,263 +497,54 @@ class OBConnection(DocStoreConnection): "SELECT COUNT(*) FROM information_schema.processlist WHERE command != 'Sleep'" ).fetchone() active_queries = int(result[0]) if result and result[0] else 0 - + # Rough estimate: assume average query takes 0.1 seconds # This is a simplified estimation estimated_qps = max(0, active_queries * 10) - + return estimated_qps except Exception as e: logger.warning(f"Failed to estimate QPS: {str(e)}") return 0 - def _get_variable_value(self, var_name: str) -> Any: - rows = self.client.perform_raw_text_sql(f"SHOW VARIABLES LIKE '{var_name}'") - for row in rows: - return row[1] - raise Exception(f"Variable '{var_name}' not found.") - - def _check_table_exists_cached(self, table_name: str) -> bool: - """ - Check table existence with cache to reduce INFORMATION_SCHEMA queries under high concurrency. - Only caches when table exists. Does not cache when table does not exist. - Thread-safe implementation: read operations are lock-free (GIL-protected), - write operations are protected by RLock to ensure cache consistency. - - Args: - table_name: Table name - - Returns: - Whether the table exists with all required indexes and columns - """ - if table_name in self._table_exists_cache: - return True - - try: - if not self.client.check_table_exists(table_name): - return False - for column_name in index_columns: - if not self._index_exists(table_name, index_name_template % (table_name, column_name)): - return False - for fts_column in self.fulltext_search_columns: - column_name = fts_column.split("^")[0] - if not self._index_exists(table_name, fulltext_index_name_template % column_name): - return False - for column in [column_order_id, column_group_id, column_mom_id]: - if not self._column_exist(table_name, column.name): - return False - except Exception as e: - raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}") - - with self._table_exists_cache_lock: - if table_name not in self._table_exists_cache: - self._table_exists_cache.add(table_name) - return True - - """ - Table operations - """ - - def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int, parser_id: str = None): - vector_field_name = f"q_{vectorSize}_vec" - vector_index_name = f"{vector_field_name}_idx" - - try: - _try_with_lock( - lock_name=f"ob_create_table_{indexName}", - check_func=lambda: self.client.check_table_exists(indexName), - process_func=lambda: self._create_table(indexName), - ) - - for column_name in index_columns: - _try_with_lock( - lock_name=f"ob_add_idx_{indexName}_{column_name}", - check_func=lambda: self._index_exists(indexName, index_name_template % (indexName, column_name)), - process_func=lambda: self._add_index(indexName, column_name), - ) - - for fts_column in self.fulltext_search_columns: - column_name = fts_column.split("^")[0] - _try_with_lock( - lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}", - check_func=lambda: self._index_exists(indexName, fulltext_index_name_template % column_name), - process_func=lambda: self._add_fulltext_index(indexName, column_name), - ) - - _try_with_lock( - lock_name=f"ob_add_vector_column_{indexName}_{vector_field_name}", - check_func=lambda: self._column_exist(indexName, vector_field_name), - process_func=lambda: self._add_vector_column(indexName, vectorSize), - ) - - _try_with_lock( - lock_name=f"ob_add_vector_idx_{indexName}_{vector_field_name}", - check_func=lambda: self._index_exists(indexName, vector_index_name), - process_func=lambda: self._add_vector_index(indexName, vector_field_name), - ) - - # new columns migration - for column in [column_chunk_data, column_order_id, column_group_id, column_mom_id]: - _try_with_lock( - lock_name=f"ob_add_{column.name}_{indexName}", - check_func=lambda: self._column_exist(indexName, column.name), - process_func=lambda: self._add_column(indexName, column), - ) - except Exception as e: - raise Exception(f"OBConnection.createIndex error: {str(e)}") - finally: - # always refresh metadata to make sure it contains the latest table structure - self.client.refresh_metadata([indexName]) - - def delete_idx(self, indexName: str, knowledgebaseId: str): - if len(knowledgebaseId) > 0: - # The index need to be alive after any kb deletion since all kb under this tenant are in one index. - return - try: - if self.client.check_table_exists(table_name=indexName): - self.client.drop_table_if_exist(indexName) - logger.info(f"Dropped table '{indexName}'.") - except Exception as e: - raise Exception(f"OBConnection.deleteIndex error: {str(e)}") - - def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool: - return self._check_table_exists_cached(indexName) - - def _get_count(self, table_name: str, filter_list: list[str] = None) -> int: - where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else "" - (count,) = self.client.perform_raw_text_sql( - f"SELECT COUNT(*) FROM {table_name} {where_clause}" - ).fetchone() - return count - - def _column_exist(self, table_name: str, column_name: str) -> bool: - return self._get_count( - table_name="INFORMATION_SCHEMA.COLUMNS", - filter_list=[ - f"TABLE_SCHEMA = '{self.db_name}'", - f"TABLE_NAME = '{table_name}'", - f"COLUMN_NAME = '{column_name}'", - ]) > 0 - - def _index_exists(self, table_name: str, index_name: str) -> bool: - return self._get_count( - table_name="INFORMATION_SCHEMA.STATISTICS", - filter_list=[ - f"TABLE_SCHEMA = '{self.db_name}'", - f"TABLE_NAME = '{table_name}'", - f"INDEX_NAME = '{index_name}'", - ]) > 0 - - def _create_table(self, table_name: str): - # remove outdated metadata for external changes - if table_name in self.client.metadata_obj.tables: - self.client.metadata_obj.remove(Table(table_name, self.client.metadata_obj)) - - table_options = { - "mysql_charset": "utf8mb4", - "mysql_collate": "utf8mb4_unicode_ci", - "mysql_organization": "heap", - } - - self.client.create_table( - table_name=table_name, - columns=[c.copy() for c in column_definitions], - **table_options, - ) - logger.info(f"Created table '{table_name}'.") - - def _add_index(self, table_name: str, column_name: str): - index_name = index_name_template % (table_name, column_name) - self.client.create_index( - table_name=table_name, - is_vec_index=False, - index_name=index_name, - column_names=[column_name], - ) - logger.info(f"Created index '{index_name}' on table '{table_name}'.") - - def _add_fulltext_index(self, table_name: str, column_name: str): - fulltext_index_name = fulltext_index_name_template % column_name - self.client.create_fts_idx_with_fts_index_param( - table_name=table_name, - fts_idx_param=FtsIndexParam( - index_name=fulltext_index_name, - field_names=[column_name], - parser_type=FtsParser.IK, - ), - ) - logger.info(f"Created full text index '{fulltext_index_name}' on table '{table_name}'.") - - def _add_vector_column(self, table_name: str, vector_size: int): - vector_field_name = f"q_{vector_size}_vec" - - self.client.add_columns( - table_name=table_name, - columns=[Column(vector_field_name, VECTOR(vector_size), nullable=True)], - ) - logger.info(f"Added vector column '{vector_field_name}' to table '{table_name}'.") - - def _add_vector_index(self, table_name: str, vector_field_name: str): - vector_index_name = f"{vector_field_name}_idx" - self.client.create_index( - table_name=table_name, - is_vec_index=True, - index_name=vector_index_name, - column_names=[vector_field_name], - vidx_params="distance=cosine, type=hnsw, lib=vsag", - ) - logger.info( - f"Created vector index '{vector_index_name}' on table '{table_name}' with column '{vector_field_name}'." - ) - - def _add_column(self, table_name: str, column: Column): - try: - self.client.add_columns( - table_name=table_name, - columns=[column.copy()], - ) - logger.info(f"Added column '{column.name}' to table '{table_name}'.") - except Exception as e: - logger.warning(f"Failed to add column '{column.name}' to table '{table_name}': {str(e)}") - """ CRUD operations """ def search( - self, - selectFields: list[str], - highlightFields: list[str], - condition: dict, - matchExprs: list[MatchExpr], - orderBy: OrderByExpr, - offset: int, - limit: int, - indexNames: str | list[str], - knowledgebaseIds: list[str], - aggFields: list[str] = [], - rank_feature: dict | None = None, - **kwargs, + self, + select_fields: list[str], + highlight_fields: list[str], + condition: dict, + match_expressions: list[MatchExpr], + order_by: OrderByExpr, + offset: int, + limit: int, + index_names: str | list[str], + knowledgebase_ids: list[str], + agg_fields: list[str] = [], + rank_feature: dict | None = None, + **kwargs, ): - if isinstance(indexNames, str): - indexNames = indexNames.split(",") - assert isinstance(indexNames, list) and len(indexNames) > 0 - indexNames = list(set(indexNames)) + if isinstance(index_names, str): + index_names = index_names.split(",") + assert isinstance(index_names, list) and len(index_names) > 0 + index_names = list(set(index_names)) - if len(matchExprs) == 3: + if len(match_expressions) == 3: if not self.enable_fulltext_search: # disable fulltext search in fusion search, which means fallback to vector search - matchExprs = [m for m in matchExprs if isinstance(m, MatchDenseExpr)] + match_expressions = [m for m in match_expressions if isinstance(m, MatchDenseExpr)] else: - for m in matchExprs: + for m in match_expressions: if isinstance(m, FusionExpr): weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) # skip the search if its weight is zero if vector_similarity_weight <= 0.0: - matchExprs = [m for m in matchExprs if isinstance(m, MatchTextExpr)] + match_expressions = [m for m in match_expressions if isinstance(m, MatchTextExpr)] elif vector_similarity_weight >= 1.0: - matchExprs = [m for m in matchExprs if isinstance(m, MatchDenseExpr)] + match_expressions = [m for m in match_expressions if isinstance(m, MatchDenseExpr)] result: SearchResult = SearchResult( total=0, @@ -963,9 +552,9 @@ class OBConnection(DocStoreConnection): ) # copied from es_conn.py - if len(matchExprs) == 3 and self.es: + if len(match_expressions) == 3 and self.es: bqry = Q("bool", must=[]) - condition["kb_id"] = knowledgebaseIds + condition["kb_id"] = knowledgebase_ids for k, v in condition.items(): if k == "available_int": if v == 0: @@ -986,20 +575,20 @@ class OBConnection(DocStoreConnection): s = Search() vector_similarity_weight = 0.5 - for m in matchExprs: + for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: - assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance( - matchExprs[1], + assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance( + match_expressions[1], MatchDenseExpr) and isinstance( - matchExprs[2], FusionExpr) + match_expressions[2], FusionExpr) weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) - for m in matchExprs: + for m in match_expressions: if isinstance(m, MatchTextExpr): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" - bqry.must.append(Q("query_string", fields=fts_columns_tks, + bqry.must.append(Q("query_string", fields=FTS_COLUMNS_TKS, type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) @@ -1029,9 +618,9 @@ class OBConnection(DocStoreConnection): # for field in highlightFields: # s = s.highlight(field) - if orderBy: + if order_by: orders = list() - for field, order in orderBy.fields: + for field, order in order_by.fields: order = "asc" if order == 0 else "desc" if field in ["page_num_int", "top_int"]: order_info = {"order": order, "unmapped_type": "float", @@ -1043,15 +632,15 @@ class OBConnection(DocStoreConnection): orders.append({field: order_info}) s = s.sort(*orders) - for fld in aggFields: + for fld in agg_fields: s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) if limit > 0: s = s[offset:offset + limit] q = s.to_dict() - logger.debug(f"OBConnection.hybrid_search {str(indexNames)} query: " + json.dumps(q)) + logger.debug(f"OBConnection.hybrid_search {str(index_names)} query: " + json.dumps(q)) - for index_name in indexNames: + for index_name in index_names: start_time = time.time() res = self.es.search(index=index_name, body=q, @@ -1068,20 +657,20 @@ class OBConnection(DocStoreConnection): result.total = result.total + 1 return result - output_fields = selectFields.copy() + output_fields = select_fields.copy() if "id" not in output_fields: output_fields = ["id"] + output_fields if "_score" in output_fields: output_fields.remove("_score") - if highlightFields: - for field in highlightFields: + if highlight_fields: + for field in highlight_fields: if field not in output_fields: output_fields.append(field) fields_expr = ", ".join(output_fields) - condition["kb_id"] = knowledgebaseIds + condition["kb_id"] = knowledgebase_ids filters: list[str] = get_filters(condition) filters_expr = " AND ".join(filters) @@ -1102,34 +691,19 @@ class OBConnection(DocStoreConnection): vector_search_score_expr: Optional[str] = None vector_search_filter: Optional[str] = None - for m in matchExprs: + for m in match_expressions: if isinstance(m, MatchTextExpr): assert "original_query" in m.extra_options, "'original_query' is missing in extra_options." fulltext_query = m.extra_options["original_query"] fulltext_query = escape_string(fulltext_query.strip()) fulltext_topn = m.topn - # get fulltext match expression and weight values - for field in self.fulltext_search_columns: - parts = field.split("^") - column_name: str = parts[0] - column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0 - - fulltext_search_weight[column_name] = column_weight - fulltext_search_expr[column_name] = fulltext_search_template % (column_name, fulltext_query) + fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns( + fulltext_query, self._fulltext_search_columns + ) + for column_name in fulltext_search_expr.keys(): fulltext_search_idx_list.append(fulltext_index_name_template % column_name) - # adjust the weight to 0~1 - weight_sum = sum(fulltext_search_weight.values()) - n = len(fulltext_search_weight) - if weight_sum <= 0 and n > 0: - # All weights are 0 (e.g. "col^0"); use equal weights to avoid ZeroDivisionError - for column_name in fulltext_search_weight: - fulltext_search_weight[column_name] = 1.0 / n - else: - for column_name in fulltext_search_weight: - fulltext_search_weight[column_name] = fulltext_search_weight[column_name] / weight_sum - elif isinstance(m, MatchDenseExpr): assert m.embedding_data_type == "float", f"embedding data type '{m.embedding_data_type}' is not float." vector_column_name = m.vector_column_name @@ -1163,7 +737,7 @@ class OBConnection(DocStoreConnection): search_type = "fulltext" elif vector_data: search_type = "vector" - elif len(aggFields) > 0: + elif len(agg_fields) > 0: search_type = "aggregation" else: search_type = "filter" @@ -1177,7 +751,7 @@ class OBConnection(DocStoreConnection): if fulltext_topn is not None: limit = min(fulltext_topn, limit) - for index_name in indexNames: + for index_name in index_names: if not self._check_table_exists_cached(index_name): continue @@ -1215,14 +789,9 @@ class OBConnection(DocStoreConnection): f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id" ) logger.debug("OBConnection.search with count sql: %s", count_sql) - - start_time = time.time() - - res = self.client.perform_raw_text_sql(count_sql) - total_count = res.fetchone()[0] if res else 0 + rows, elapsed_time = self._execute_search_sql(count_sql) + total_count = rows[0][0] if rows else 0 result.total += total_count - - elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: fusion, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" vector column: '{vector_column_name}'," @@ -1284,13 +853,7 @@ class OBConnection(DocStoreConnection): f" LIMIT {offset}, {limit}" ) logger.debug("OBConnection.search with fusion sql: %s", fusion_sql) - - start_time = time.time() - - res = self.client.perform_raw_text_sql(fusion_sql) - rows = res.fetchall() - - elapsed_time = time.time() - start_time + rows, elapsed_time = self._execute_search_sql(fusion_sql) logger.info( f"OBConnection.search table {index_name}, search type: fusion, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," @@ -1306,16 +869,11 @@ class OBConnection(DocStoreConnection): result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "vector": # vector search, usually used for graph search - count_sql = f"SELECT COUNT(id) FROM {index_name} WHERE {filters_expr} AND {vector_search_filter}" + count_sql = self._build_count_sql(index_name, filters_expr, vector_search_filter) logger.debug("OBConnection.search with vector count sql: %s", count_sql) - - start_time = time.time() - - res = self.client.perform_raw_text_sql(count_sql) - total_count = res.fetchone()[0] if res else 0 + rows, elapsed_time = self._execute_search_sql(count_sql) + total_count = rows[0][0] if rows else 0 result.total += total_count - - elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: vector, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" vector column: '{vector_column_name}'," @@ -1327,23 +885,12 @@ class OBConnection(DocStoreConnection): if total_count == 0: continue - vector_sql = ( - f"SELECT {fields_expr}, {vector_search_score_expr} AS _score" - f" FROM {index_name}" - f" WHERE {filters_expr} AND {vector_search_filter}" - f" ORDER BY {vector_search_expr}" - f" APPROXIMATE LIMIT {limit if limit != 0 else vector_topn}" + vector_sql = self._build_vector_search_sql( + index_name, fields_expr, vector_search_score_expr, filters_expr, + vector_search_filter, vector_search_expr, limit, vector_topn, offset ) - if offset != 0: - vector_sql += f" OFFSET {offset}" logger.debug("OBConnection.search with vector sql: %s", vector_sql) - - start_time = time.time() - - res = self.client.perform_raw_text_sql(vector_sql) - rows = res.fetchall() - - elapsed_time = time.time() - start_time + rows, elapsed_time = self._execute_search_sql(vector_sql) logger.info( f"OBConnection.search table {index_name}, search type: vector, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," @@ -1357,16 +904,11 @@ class OBConnection(DocStoreConnection): result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "fulltext": # fulltext search, usually used to search chunks in one dataset - count_sql = f"SELECT {fulltext_search_hint} COUNT(id) FROM {index_name} WHERE {filters_expr} AND {fulltext_search_filter}" + count_sql = self._build_count_sql(index_name, filters_expr, fulltext_search_filter, fulltext_search_hint) logger.debug("OBConnection.search with fulltext count sql: %s", count_sql) - - start_time = time.time() - - res = self.client.perform_raw_text_sql(count_sql) - total_count = res.fetchone()[0] if res else 0 + rows, elapsed_time = self._execute_search_sql(count_sql) + total_count = rows[0][0] if rows else 0 result.total += total_count - - elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: fulltext, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" query text: '{fulltext_query}'," @@ -1377,21 +919,12 @@ class OBConnection(DocStoreConnection): if total_count == 0: continue - fulltext_sql = ( - f"SELECT {fulltext_search_hint} {fields_expr}, {fulltext_search_score_expr} AS _score" - f" FROM {index_name}" - f" WHERE {filters_expr} AND {fulltext_search_filter}" - f" ORDER BY _score DESC" - f" LIMIT {offset}, {limit if limit != 0 else fulltext_topn}" + fulltext_sql = self._build_fulltext_search_sql( + index_name, fields_expr, fulltext_search_score_expr, filters_expr, + fulltext_search_filter, offset, limit, fulltext_topn, fulltext_search_hint ) logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql) - - start_time = time.time() - - res = self.client.perform_raw_text_sql(fulltext_sql) - rows = res.fetchall() - - elapsed_time = time.time() - start_time + rows, elapsed_time = self._execute_search_sql(fulltext_sql) logger.info( f"OBConnection.search table {index_name}, search type: fulltext, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," @@ -1404,8 +937,8 @@ class OBConnection(DocStoreConnection): result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "aggregation": # aggregation search - assert len(aggFields) == 1, "Only one aggregation field is supported in OceanBase." - agg_field = aggFields[0] + assert len(agg_fields) == 1, "Only one aggregation field is supported in OceanBase." + agg_field = agg_fields[0] if agg_field in array_columns: res = self.client.perform_raw_text_sql( f"SELECT {agg_field} FROM {index_name}" @@ -1449,24 +982,19 @@ class OBConnection(DocStoreConnection): else: # only filter orders: list[str] = [] - if orderBy: - for field, order in orderBy.fields: + if order_by: + for field, order in order_by.fields: if isinstance(column_types[field], ARRAY): f = field + "_sort" fields_expr += f", array_to_string({field}, ',') AS {f}" field = f order = "ASC" if order == 0 else "DESC" orders.append(f"{field} {order}") - count_sql = f"SELECT COUNT(id) FROM {index_name} WHERE {filters_expr}" + count_sql = self._build_count_sql(index_name, filters_expr) logger.debug("OBConnection.search with normal count sql: %s", count_sql) - - start_time = time.time() - - res = self.client.perform_raw_text_sql(count_sql) - total_count = res.fetchone()[0] if res else 0 + rows, elapsed_time = self._execute_search_sql(count_sql) + total_count = rows[0][0] if rows else 0 result.total += total_count - - elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: normal, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" condition: '{condition}'," @@ -1478,20 +1006,11 @@ class OBConnection(DocStoreConnection): order_by_expr = ("ORDER BY " + ", ".join(orders)) if len(orders) > 0 else "" limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else "" - filter_sql = ( - f"SELECT {fields_expr}" - f" FROM {index_name}" - f" WHERE {filters_expr}" - f" {order_by_expr} {limit_expr}" + filter_sql = self._build_filter_search_sql( + index_name, fields_expr, filters_expr, order_by_expr, limit_expr ) logger.debug("OBConnection.search with normal sql: %s", filter_sql) - - start_time = time.time() - - res = self.client.perform_raw_text_sql(filter_sql) - rows = res.fetchall() - - elapsed_time = time.time() - start_time + rows, elapsed_time = self._execute_search_sql(filter_sql) logger.info( f"OBConnection.search table {index_name}, search type: normal, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," @@ -1507,34 +1026,30 @@ class OBConnection(DocStoreConnection): return result - def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: - if not self._check_table_exists_cached(indexName): - return None - + def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None: try: - res = self.client.get( - table_name=indexName, - ids=[chunkId], - ) - row = res.fetchone() - if row is None: - raise Exception(f"ChunkId {chunkId} not found in index {indexName}.") - - return self._row_to_entity(row, fields=list(res.keys())) + doc = super().get(chunk_id, index_name, knowledgebase_ids) + if doc is None: + return None + return doc except json.JSONDecodeError as e: - logger.error(f"JSON decode error when getting chunk {chunkId}: {str(e)}") + logger.error(f"JSON decode error when getting chunk {chunk_id}: {str(e)}") return { - "id": chunkId, + "id": chunk_id, "error": f"Failed to parse chunk data due to invalid JSON: {str(e)}" } except Exception as e: - logger.error(f"Error getting chunk {chunkId}: {str(e)}") - raise + logger.exception(f"OBConnection.get({chunk_id}) got exception") + raise e - def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: + def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: if not documents: return [] + # For doc_meta tables, use simple insert without field transformation + if index_name.startswith("ragflow_doc_meta_"): + return self._insert_doc_meta(documents, index_name) + docs: list[dict] = [] ids: list[str] = [] for document in documents: @@ -1600,35 +1115,68 @@ class OBConnection(DocStoreConnection): res = [] try: - self.client.upsert(indexName, docs) + self.client.upsert(index_name, docs) except Exception as e: logger.error(f"OBConnection.insert error: {str(e)}") res.append(str(e)) return res - def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: - if not self._check_table_exists_cached(indexName): + def _insert_doc_meta(self, documents: list[dict], index_name: str) -> list[str]: + """Insert documents into doc_meta table with simple field handling.""" + docs: list[dict] = [] + for document in documents: + d = { + "id": document.get("id"), + "kb_id": document.get("kb_id"), + } + # Handle meta_fields - store as JSON + meta_fields = document.get("meta_fields") + if meta_fields is not None: + if isinstance(meta_fields, dict): + d["meta_fields"] = json.dumps(meta_fields, ensure_ascii=False) + elif isinstance(meta_fields, str): + d["meta_fields"] = meta_fields + else: + d["meta_fields"] = "{}" + else: + d["meta_fields"] = "{}" + docs.append(d) + + logger.debug("OBConnection._insert_doc_meta: %s", docs) + + res = [] + try: + self.client.upsert(index_name, docs) + except Exception as e: + logger.error(f"OBConnection._insert_doc_meta error: {str(e)}") + res.append(str(e)) + return res + + def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool: + if not self._check_table_exists_cached(index_name): return True - condition["kb_id"] = knowledgebaseId + # For doc_meta tables, don't force kb_id in condition + if not index_name.startswith("ragflow_doc_meta_"): + condition["kb_id"] = knowledgebase_id filters = get_filters(condition) set_values: list[str] = [] - for k, v in newValue.items(): + for k, v in new_value.items(): if k == "remove": if isinstance(v, str): set_values.append(f"{v} = NULL") else: - assert isinstance(v, dict), f"Expected str or dict for 'remove', got {type(newValue[k])}." + assert isinstance(v, dict), f"Expected str or dict for 'remove', got {type(new_value[k])}." for kk, vv in v.items(): assert kk in array_columns, f"Column '{kk}' is not an array column." set_values.append(f"{kk} = array_remove({kk}, {get_value_str(vv)})") elif k == "add": - assert isinstance(v, dict), f"Expected str or dict for 'add', got {type(newValue[k])}." + assert isinstance(v, dict), f"Expected str or dict for 'add', got {type(new_value[k])}." for kk, vv in v.items(): assert kk in array_columns, f"Column '{kk}' is not an array column." set_values.append(f"{kk} = array_append({kk}, {get_value_str(vv)})") elif k == "metadata": - assert isinstance(v, dict), f"Expected dict for 'metadata', got {type(newValue[k])}" + assert isinstance(v, dict), f"Expected dict for 'metadata', got {type(new_value[k])}" set_values.append(f"{k} = {get_value_str(v)}") if v and "doc_id" in condition: group_id = v.get("_group_id") @@ -1644,7 +1192,7 @@ class OBConnection(DocStoreConnection): return True update_sql = ( - f"UPDATE {indexName}" + f"UPDATE {index_name}" f" SET {', '.join(set_values)}" f" WHERE {' AND '.join(filters)}" ) @@ -1657,34 +1205,7 @@ class OBConnection(DocStoreConnection): logger.error(f"OBConnection.update error: {str(e)}") return False - def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: - if not self._check_table_exists_cached(indexName): - return 0 - - condition["kb_id"] = knowledgebaseId - try: - res = self.client.get( - table_name=indexName, - ids=None, - where_clause=[text(f) for f in get_filters(condition)], - output_column_name=["id"], - ) - rows = res.fetchall() - if len(rows) == 0: - return 0 - ids = [row[0] for row in rows] - logger.debug(f"OBConnection.delete chunks, filters: {condition}, ids: {ids}") - self.client.delete( - table_name=indexName, - ids=ids, - ) - return len(ids) - except Exception as e: - logger.error(f"OBConnection.delete error: {str(e)}") - return 0 - - @staticmethod - def _row_to_entity(data: Row, fields: list[str]) -> dict: + def _row_to_entity(self, data: Row, fields: list[str]) -> dict: entity = {} for i, field in enumerate(fields): value = data[i] @@ -1756,7 +1277,7 @@ class OBConnection(DocStoreConnection): flags=re.IGNORECASE | re.MULTILINE, ) if len(re.findall(r'', highlighted_txt)) > 0 or len( - re.findall(r'\s*', highlighted_txt)) > 0: + re.findall(r'\s*', highlighted_txt)) > 0: return highlighted_txt else: return None @@ -1775,9 +1296,9 @@ class OBConnection(DocStoreConnection): if token_pos != -1: if token in keywords: highlighted_txt = ( - highlighted_txt[:token_pos] + - f'{token}' + - highlighted_txt[token_pos + len(token):] + highlighted_txt[:token_pos] + + f'{token}' + + highlighted_txt[token_pos + len(token):] ) last_pos = token_pos return re.sub(r'', '', highlighted_txt)