From badf33e3b9939bce42011cece8338e496ab81f66 Mon Sep 17 00:00:00 2001 From: He Wang Date: Wed, 10 Dec 2025 19:13:37 +0800 Subject: [PATCH] feat: enhance OBConnection.search (#11876) ### What problem does this PR solve? Enhance OBConnection.search for better performance. Main changes: 1. Use string type of vector array in distance func for better parsing performance. 2. Manually set max_connections as pool size instead of using default value. 3. Set 'fulltext_search_columns' when starting. 4. Cache the results of the table existence check (we will never drop the table). 5. Remove unused 'group_results' logic. 6. Add the `USE_FULLTEXT_FIRST_FUSION_SEARCH` flag, and the corresponding fusion search SQL when it's false. ### Type of change - [x] Performance Improvement --- rag/utils/ob_conn.py | 266 ++++++++++++++++++++++++++----------------- 1 file changed, 164 insertions(+), 102 deletions(-) diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index 6218a8c4e..3c00be421 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -17,13 +17,16 @@ import json import logging import os import re +import threading import time from typing import Any, Optional +import numpy as np from elasticsearch_dsl import Q, Search from pydantic import BaseModel from pymysql.converters import escape_string from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR +from pyobvector.client import ClusterVersionException from pyobvector.client.hybrid_search import HybridSearch from pyobvector.util import ObVersion from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table @@ -106,17 +109,6 @@ index_columns: list[str] = [ "removed_kwd", ] -fulltext_search_columns: list[str] = [ - "docnm_kwd", - "content_with_weight", - "title_tks", - "title_sm_tks", - "important_tks", - "question_tks", - "content_ltks", - "content_sm_ltks" -] - fts_columns_origin: list[str] = [ "docnm_kwd^10", "content_with_weight", @@ -138,7 +130,7 @@ fulltext_index_name_template = "fts_idx_%s" # MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607 fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)" # cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938 -vector_search_template = "cosine_distance(%s, %s)" +vector_search_template = "cosine_distance(%s, '%s')" class SearchResult(BaseModel): @@ -362,18 +354,28 @@ class OBConnection(DocStoreConnection): port = mysql_config.get("port", 2881) self.username = mysql_config.get("user", "root@test") self.password = mysql_config.get("password", "infini_rag_flow") + max_connections = mysql_config.get("max_connections", 300) else: logger.info("Use customized config to create OceanBase connection.") host = ob_config.get("host", "localhost") port = ob_config.get("port", 2881) self.username = ob_config.get("user", "root@test") self.password = ob_config.get("password", "infini_rag_flow") + max_connections = ob_config.get("max_connections", 300) self.db_name = ob_config.get("db_name", "test") self.uri = f"{host}:{port}" logger.info(f"Use OceanBase '{self.uri}' as the doc engine.") + # Set the maximum number of connections that can be created above the pool_size. + # By default, this is half of max_connections, but at least 10. + # This allows the pool to handle temporary spikes in demand without exhausting resources. + max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10))) + # Set the number of seconds to wait before giving up when trying to get a connection from the pool. + # Default is 30 seconds, but can be overridden with the OB_POOL_TIMEOUT environment variable. + pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30")) + for _ in range(ATTEMPT_TIME): try: self.client = ObVecClient( @@ -383,6 +385,9 @@ class OBConnection(DocStoreConnection): db_name=self.db_name, pool_pre_ping=True, pool_recycle=3600, + pool_size=max_connections, + max_overflow=max_overflow, + pool_timeout=pool_timeout, ) break except Exception as e: @@ -398,6 +403,37 @@ class OBConnection(DocStoreConnection): self._check_ob_version() self._try_to_update_ob_query_timeout() + self.es = None + if self.enable_hybrid_search: + try: + self.es = HybridSearch( + uri=self.uri, + user=self.username, + password=self.password, + db_name=self.db_name, + pool_pre_ping=True, + pool_recycle=3600, + pool_size=max_connections, + max_overflow=max_overflow, + pool_timeout=pool_timeout, + ) + logger.info("OceanBase Hybrid Search feature is enabled") + except ClusterVersionException as e: + logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e) + self.es = None + + if self.es is not None and self.search_original_content: + logger.info("HybridSearch is enabled, forcing search_original_content to False") + self.search_original_content = False + # Determine which columns to use for full-text search dynamically: + # If HybridSearch is enabled (self.es is not None), we must use tokenized columns (fts_columns_tks) + # for compatibility and performance with HybridSearch. Otherwise, we use the original content columns + # (fts_columns_origin), which may be controlled by an environment variable. + self.fulltext_search_columns = fts_columns_origin if self.search_original_content else fts_columns_tks + + self._table_exists_cache: set[str] = set() + self._table_exists_cache_lock = threading.RLock() + logger.info(f"OceanBase {self.uri} is healthy.") def _check_ob_version(self): @@ -417,18 +453,6 @@ class OBConnection(DocStoreConnection): f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}" ) - self.es = None - if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search: - self.es = HybridSearch( - uri=self.uri, - user=self.username, - password=self.password, - db_name=self.db_name, - pool_pre_ping=True, - pool_recycle=3600, - ) - logger.info("OceanBase Hybrid Search feature is enabled") - def _try_to_update_ob_query_timeout(self): try: val = self._get_variable_value("ob_query_timeout") @@ -455,9 +479,19 @@ class OBConnection(DocStoreConnection): return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y'] self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true') + logger.info(f"ENABLE_FULLTEXT_SEARCH={self.enable_fulltext_search}") + self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true') + logger.info(f"USE_FULLTEXT_HINT={self.use_fulltext_hint}") + self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true') + logger.info(f"SEARCH_ORIGINAL_CONTENT={self.search_original_content}") + self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false') + logger.info(f"ENABLE_HYBRID_SEARCH={self.enable_hybrid_search}") + + self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true') + logger.info(f"USE_FULLTEXT_FIRST_FUSION_SEARCH={self.use_fulltext_first_fusion_search}") """ Database operations @@ -478,6 +512,43 @@ class OBConnection(DocStoreConnection): return row[1] raise Exception(f"Variable '{var_name}' not found.") + def _check_table_exists_cached(self, table_name: str) -> bool: + """ + Check table existence with cache to reduce INFORMATION_SCHEMA queries under high concurrency. + Only caches when table exists. Does not cache when table does not exist. + Thread-safe implementation: read operations are lock-free (GIL-protected), + write operations are protected by RLock to ensure cache consistency. + + Args: + table_name: Table name + + Returns: + Whether the table exists with all required indexes and columns + """ + if table_name in self._table_exists_cache: + return True + + try: + if not self.client.check_table_exists(table_name): + return False + for column_name in index_columns: + if not self._index_exists(table_name, index_name_template % (table_name, column_name)): + return False + for fts_column in self.fulltext_search_columns: + column_name = fts_column.split("^")[0] + if not self._index_exists(table_name, fulltext_index_name_template % column_name): + return False + for column in [column_order_id, column_group_id]: + if not self._column_exist(table_name, column.name): + return False + except Exception as e: + raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}") + + with self._table_exists_cache_lock: + if table_name not in self._table_exists_cache: + self._table_exists_cache.add(table_name) + return True + """ Table operations """ @@ -500,8 +571,7 @@ class OBConnection(DocStoreConnection): process_func=lambda: self._add_index(indexName, column_name), ) - fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks - for fts_column in fts_columns: + for fts_column in self.fulltext_search_columns: column_name = fts_column.split("^")[0] _try_with_lock( lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}", @@ -546,24 +616,7 @@ class OBConnection(DocStoreConnection): raise Exception(f"OBConnection.deleteIndex error: {str(e)}") def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool: - try: - if not self.client.check_table_exists(indexName): - return False - for column_name in index_columns: - if not self._index_exists(indexName, index_name_template % (indexName, column_name)): - return False - fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks - for fts_column in fts_columns: - column_name = fts_column.split("^")[0] - if not self._index_exists(indexName, fulltext_index_name_template % column_name): - return False - for column in [column_order_id, column_group_id]: - if not self._column_exist(indexName, column.name): - return False - except Exception as e: - raise Exception(f"OBConnection.indexExist error: {str(e)}") - - return True + return self._check_table_exists_cached(indexName) def _get_count(self, table_name: str, filter_list: list[str] = None) -> int: where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else "" @@ -853,10 +906,8 @@ class OBConnection(DocStoreConnection): fulltext_query = escape_string(fulltext_query.strip()) fulltext_topn = m.topn - fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks - # get fulltext match expression and weight values - for field in fts_columns: + for field in self.fulltext_search_columns: parts = field.split("^") column_name: str = parts[0] column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0 @@ -885,7 +936,8 @@ class OBConnection(DocStoreConnection): fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})" if vector_data: - vector_search_expr = vector_search_template % (vector_column_name, vector_data) + vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]" + vector_search_expr = vector_search_template % (vector_column_name, vector_data_str) # use (1 - cosine_distance) as score, which should be [-1, 1] # https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323 vector_search_score_expr = f"(1 - {vector_search_expr})" @@ -910,11 +962,15 @@ class OBConnection(DocStoreConnection): if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields: output_fields.append("_score") - group_results = kwargs.get("group_results", False) + if limit: + if vector_topn is not None: + limit = min(vector_topn, limit) + if fulltext_topn is not None: + limit = min(fulltext_topn, limit) for index_name in indexNames: - if not self.client.check_table_exists(index_name): + if not self._check_table_exists_cached(index_name): continue fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else "" @@ -922,29 +978,7 @@ class OBConnection(DocStoreConnection): if search_type == "fusion": # fusion search, usually for chat num_candidates = vector_topn + fulltext_topn - if group_results: - count_sql = ( - f"WITH fulltext_results AS (" - f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" - f" FROM {index_name}" - f" WHERE {filters_expr} AND {fulltext_search_filter}" - f" ORDER BY relevance DESC" - f" LIMIT {num_candidates}" - f")," - f" scored_results AS (" - f" SELECT *" - f" FROM fulltext_results" - f" WHERE {vector_search_filter}" - f")," - f" group_results AS (" - f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn" - f" FROM scored_results" - f")" - f" SELECT COUNT(*)" - f" FROM group_results" - f" WHERE rn = 1" - ) - else: + if self.use_fulltext_first_fusion_search: count_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" @@ -955,6 +989,22 @@ class OBConnection(DocStoreConnection): f")" f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}" ) + else: + count_sql = ( + f"WITH fulltext_results AS (" + f" SELECT {fulltext_search_hint} id FROM {index_name}" + f" WHERE {filters_expr} AND {fulltext_search_filter}" + f" ORDER BY {fulltext_search_score_expr}" + f" LIMIT {fulltext_topn}" + f")," + f"vector_results AS (" + f" SELECT id FROM {index_name}" + f" WHERE {filters_expr} AND {vector_search_filter}" + f" ORDER BY {vector_search_expr}" + f" APPROXIMATE LIMIT {vector_topn}" + f")" + f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id" + ) logger.debug("OBConnection.search with count sql: %s", count_sql) start_time = time.time() @@ -976,32 +1026,8 @@ class OBConnection(DocStoreConnection): if total_count == 0: continue - score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})" - if group_results: - fusion_sql = ( - f"WITH fulltext_results AS (" - f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" - f" FROM {index_name}" - f" WHERE {filters_expr} AND {fulltext_search_filter}" - f" ORDER BY relevance DESC" - f" LIMIT {num_candidates}" - f")," - f" scored_results AS (" - f" SELECT *, {score_expr} AS _score" - f" FROM fulltext_results" - f" WHERE {vector_search_filter}" - f")," - f" group_results AS (" - f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn" - f" FROM scored_results" - f")" - f" SELECT {fields_expr}, _score" - f" FROM group_results" - f" WHERE rn = 1" - f" ORDER BY _score DESC" - f" LIMIT {offset}, {limit}" - ) - else: + if self.use_fulltext_first_fusion_search: + score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})" fusion_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" @@ -1016,6 +1042,38 @@ class OBConnection(DocStoreConnection): f" ORDER BY _score DESC" f" LIMIT {offset}, {limit}" ) + else: + pagerank_score_expr = f"(CAST(IFNULL(f.{PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)" + score_expr = f"(f.relevance * {1 - vector_similarity_weight} + v.similarity * {vector_similarity_weight} + {pagerank_score_expr})" + fields_expr = ", ".join([f"t.{f} as {f}" for f in output_fields if f != "_score"]) + fusion_sql = ( + f"WITH fulltext_results AS (" + f" SELECT {fulltext_search_hint} id, pagerank_fea, {fulltext_search_score_expr} AS relevance" + f" FROM {index_name}" + f" WHERE {filters_expr} AND {fulltext_search_filter}" + f" ORDER BY relevance DESC" + f" LIMIT {fulltext_topn}" + f")," + f"vector_results AS (" + f" SELECT id, pagerank_fea, {vector_search_score_expr} AS similarity" + f" FROM {index_name}" + f" WHERE {filters_expr} AND {vector_search_filter}" + f" ORDER BY {vector_search_expr}" + f" APPROXIMATE LIMIT {vector_topn}" + f")," + f"combined_results AS (" + f" SELECT COALESCE(f.id, v.id) AS id, {score_expr} AS score" + f" FROM fulltext_results f" + f" FULL OUTER JOIN vector_results v" + f" ON f.id = v.id" + f")" + f" SELECT {fields_expr}, c.score as _score" + f" FROM combined_results c" + f" JOIN {index_name} t" + f" ON c.id = t.id" + f" ORDER BY score DESC" + f" LIMIT {offset}, {limit}" + ) logger.debug("OBConnection.search with fusion sql: %s", fusion_sql) start_time = time.time() @@ -1234,10 +1292,14 @@ class OBConnection(DocStoreConnection): for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) + + if result.total == 0: + result.total = len(result.chunks) + return result def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: - if not self.client.check_table_exists(indexName): + if not self._check_table_exists_cached(indexName): return None try: @@ -1336,7 +1398,7 @@ class OBConnection(DocStoreConnection): return res def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: - if not self.client.check_table_exists(indexName): + if not self._check_table_exists_cached(indexName): return True condition["kb_id"] = knowledgebaseId @@ -1387,7 +1449,7 @@ class OBConnection(DocStoreConnection): return False def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: - if not self.client.check_table_exists(indexName): + if not self._check_table_exists_cached(indexName): return 0 condition["kb_id"] = knowledgebaseId