feat: enhance OBConnection.search (#11876)

### What problem does this PR solve?

Enhance OBConnection.search for better performance. Main changes:
1. Use string type of vector array in distance func for better parsing
performance.
2. Manually set max_connections as pool size instead of using default
value.
3. Set 'fulltext_search_columns' when starting.
4. Cache the results of the table existence check (we will never drop
the table).
5. Remove unused 'group_results' logic.
6. Add the `USE_FULLTEXT_FIRST_FUSION_SEARCH` flag, and the
corresponding fusion search SQL when it's false.

### Type of change
- [x] Performance Improvement
This commit is contained in:
He Wang
2025-12-10 19:13:37 +08:00
committed by GitHub
parent 3cb72377d7
commit badf33e3b9

View File

@ -17,13 +17,16 @@ import json
import logging
import os
import re
import threading
import time
from typing import Any, Optional
import numpy as np
from elasticsearch_dsl import Q, Search
from pydantic import BaseModel
from pymysql.converters import escape_string
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR
from pyobvector.client import ClusterVersionException
from pyobvector.client.hybrid_search import HybridSearch
from pyobvector.util import ObVersion
from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table
@ -106,17 +109,6 @@ index_columns: list[str] = [
"removed_kwd",
]
fulltext_search_columns: list[str] = [
"docnm_kwd",
"content_with_weight",
"title_tks",
"title_sm_tks",
"important_tks",
"question_tks",
"content_ltks",
"content_sm_ltks"
]
fts_columns_origin: list[str] = [
"docnm_kwd^10",
"content_with_weight",
@ -138,7 +130,7 @@ fulltext_index_name_template = "fts_idx_%s"
# MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607
fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
# cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938
vector_search_template = "cosine_distance(%s, %s)"
vector_search_template = "cosine_distance(%s, '%s')"
class SearchResult(BaseModel):
@ -362,18 +354,28 @@ class OBConnection(DocStoreConnection):
port = mysql_config.get("port", 2881)
self.username = mysql_config.get("user", "root@test")
self.password = mysql_config.get("password", "infini_rag_flow")
max_connections = mysql_config.get("max_connections", 300)
else:
logger.info("Use customized config to create OceanBase connection.")
host = ob_config.get("host", "localhost")
port = ob_config.get("port", 2881)
self.username = ob_config.get("user", "root@test")
self.password = ob_config.get("password", "infini_rag_flow")
max_connections = ob_config.get("max_connections", 300)
self.db_name = ob_config.get("db_name", "test")
self.uri = f"{host}:{port}"
logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")
# Set the maximum number of connections that can be created above the pool_size.
# By default, this is half of max_connections, but at least 10.
# This allows the pool to handle temporary spikes in demand without exhausting resources.
max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10)))
# Set the number of seconds to wait before giving up when trying to get a connection from the pool.
# Default is 30 seconds, but can be overridden with the OB_POOL_TIMEOUT environment variable.
pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30"))
for _ in range(ATTEMPT_TIME):
try:
self.client = ObVecClient(
@ -383,6 +385,9 @@ class OBConnection(DocStoreConnection):
db_name=self.db_name,
pool_pre_ping=True,
pool_recycle=3600,
pool_size=max_connections,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
)
break
except Exception as e:
@ -398,6 +403,37 @@ class OBConnection(DocStoreConnection):
self._check_ob_version()
self._try_to_update_ob_query_timeout()
self.es = None
if self.enable_hybrid_search:
try:
self.es = HybridSearch(
uri=self.uri,
user=self.username,
password=self.password,
db_name=self.db_name,
pool_pre_ping=True,
pool_recycle=3600,
pool_size=max_connections,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
)
logger.info("OceanBase Hybrid Search feature is enabled")
except ClusterVersionException as e:
logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e)
self.es = None
if self.es is not None and self.search_original_content:
logger.info("HybridSearch is enabled, forcing search_original_content to False")
self.search_original_content = False
# Determine which columns to use for full-text search dynamically:
# If HybridSearch is enabled (self.es is not None), we must use tokenized columns (fts_columns_tks)
# for compatibility and performance with HybridSearch. Otherwise, we use the original content columns
# (fts_columns_origin), which may be controlled by an environment variable.
self.fulltext_search_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
self._table_exists_cache: set[str] = set()
self._table_exists_cache_lock = threading.RLock()
logger.info(f"OceanBase {self.uri} is healthy.")
def _check_ob_version(self):
@ -417,18 +453,6 @@ class OBConnection(DocStoreConnection):
f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
)
self.es = None
if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search:
self.es = HybridSearch(
uri=self.uri,
user=self.username,
password=self.password,
db_name=self.db_name,
pool_pre_ping=True,
pool_recycle=3600,
)
logger.info("OceanBase Hybrid Search feature is enabled")
def _try_to_update_ob_query_timeout(self):
try:
val = self._get_variable_value("ob_query_timeout")
@ -455,9 +479,19 @@ class OBConnection(DocStoreConnection):
return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']
self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
logger.info(f"ENABLE_FULLTEXT_SEARCH={self.enable_fulltext_search}")
self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
logger.info(f"USE_FULLTEXT_HINT={self.use_fulltext_hint}")
self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
logger.info(f"SEARCH_ORIGINAL_CONTENT={self.search_original_content}")
self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')
logger.info(f"ENABLE_HYBRID_SEARCH={self.enable_hybrid_search}")
self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true')
logger.info(f"USE_FULLTEXT_FIRST_FUSION_SEARCH={self.use_fulltext_first_fusion_search}")
"""
Database operations
@ -478,6 +512,43 @@ class OBConnection(DocStoreConnection):
return row[1]
raise Exception(f"Variable '{var_name}' not found.")
def _check_table_exists_cached(self, table_name: str) -> bool:
"""
Check table existence with cache to reduce INFORMATION_SCHEMA queries under high concurrency.
Only caches when table exists. Does not cache when table does not exist.
Thread-safe implementation: read operations are lock-free (GIL-protected),
write operations are protected by RLock to ensure cache consistency.
Args:
table_name: Table name
Returns:
Whether the table exists with all required indexes and columns
"""
if table_name in self._table_exists_cache:
return True
try:
if not self.client.check_table_exists(table_name):
return False
for column_name in index_columns:
if not self._index_exists(table_name, index_name_template % (table_name, column_name)):
return False
for fts_column in self.fulltext_search_columns:
column_name = fts_column.split("^")[0]
if not self._index_exists(table_name, fulltext_index_name_template % column_name):
return False
for column in [column_order_id, column_group_id]:
if not self._column_exist(table_name, column.name):
return False
except Exception as e:
raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}")
with self._table_exists_cache_lock:
if table_name not in self._table_exists_cache:
self._table_exists_cache.add(table_name)
return True
"""
Table operations
"""
@ -500,8 +571,7 @@ class OBConnection(DocStoreConnection):
process_func=lambda: self._add_index(indexName, column_name),
)
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
for fts_column in fts_columns:
for fts_column in self.fulltext_search_columns:
column_name = fts_column.split("^")[0]
_try_with_lock(
lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}",
@ -546,24 +616,7 @@ class OBConnection(DocStoreConnection):
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
try:
if not self.client.check_table_exists(indexName):
return False
for column_name in index_columns:
if not self._index_exists(indexName, index_name_template % (indexName, column_name)):
return False
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
for fts_column in fts_columns:
column_name = fts_column.split("^")[0]
if not self._index_exists(indexName, fulltext_index_name_template % column_name):
return False
for column in [column_order_id, column_group_id]:
if not self._column_exist(indexName, column.name):
return False
except Exception as e:
raise Exception(f"OBConnection.indexExist error: {str(e)}")
return True
return self._check_table_exists_cached(indexName)
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else ""
@ -853,10 +906,8 @@ class OBConnection(DocStoreConnection):
fulltext_query = escape_string(fulltext_query.strip())
fulltext_topn = m.topn
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
# get fulltext match expression and weight values
for field in fts_columns:
for field in self.fulltext_search_columns:
parts = field.split("^")
column_name: str = parts[0]
column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0
@ -885,7 +936,8 @@ class OBConnection(DocStoreConnection):
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
if vector_data:
vector_search_expr = vector_search_template % (vector_column_name, vector_data)
vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]"
vector_search_expr = vector_search_template % (vector_column_name, vector_data_str)
# use (1 - cosine_distance) as score, which should be [-1, 1]
# https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323
vector_search_score_expr = f"(1 - {vector_search_expr})"
@ -910,11 +962,15 @@ class OBConnection(DocStoreConnection):
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
output_fields.append("_score")
group_results = kwargs.get("group_results", False)
if limit:
if vector_topn is not None:
limit = min(vector_topn, limit)
if fulltext_topn is not None:
limit = min(fulltext_topn, limit)
for index_name in indexNames:
if not self.client.check_table_exists(index_name):
if not self._check_table_exists_cached(index_name):
continue
fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else ""
@ -922,29 +978,7 @@ class OBConnection(DocStoreConnection):
if search_type == "fusion":
# fusion search, usually for chat
num_candidates = vector_topn + fulltext_topn
if group_results:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f"),"
f" scored_results AS ("
f" SELECT *"
f" FROM fulltext_results"
f" WHERE {vector_search_filter}"
f"),"
f" group_results AS ("
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn"
f" FROM scored_results"
f")"
f" SELECT COUNT(*)"
f" FROM group_results"
f" WHERE rn = 1"
)
else:
if self.use_fulltext_first_fusion_search:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
@ -955,6 +989,22 @@ class OBConnection(DocStoreConnection):
f")"
f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}"
)
else:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} id FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY {fulltext_search_score_expr}"
f" LIMIT {fulltext_topn}"
f"),"
f"vector_results AS ("
f" SELECT id FROM {index_name}"
f" WHERE {filters_expr} AND {vector_search_filter}"
f" ORDER BY {vector_search_expr}"
f" APPROXIMATE LIMIT {vector_topn}"
f")"
f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id"
)
logger.debug("OBConnection.search with count sql: %s", count_sql)
start_time = time.time()
@ -976,32 +1026,8 @@ class OBConnection(DocStoreConnection):
if total_count == 0:
continue
if self.use_fulltext_first_fusion_search:
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
if group_results:
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f"),"
f" scored_results AS ("
f" SELECT *, {score_expr} AS _score"
f" FROM fulltext_results"
f" WHERE {vector_search_filter}"
f"),"
f" group_results AS ("
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn"
f" FROM scored_results"
f")"
f" SELECT {fields_expr}, _score"
f" FROM group_results"
f" WHERE rn = 1"
f" ORDER BY _score DESC"
f" LIMIT {offset}, {limit}"
)
else:
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
@ -1016,6 +1042,38 @@ class OBConnection(DocStoreConnection):
f" ORDER BY _score DESC"
f" LIMIT {offset}, {limit}"
)
else:
pagerank_score_expr = f"(CAST(IFNULL(f.{PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)"
score_expr = f"(f.relevance * {1 - vector_similarity_weight} + v.similarity * {vector_similarity_weight} + {pagerank_score_expr})"
fields_expr = ", ".join([f"t.{f} as {f}" for f in output_fields if f != "_score"])
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} id, pagerank_fea, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {fulltext_topn}"
f"),"
f"vector_results AS ("
f" SELECT id, pagerank_fea, {vector_search_score_expr} AS similarity"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {vector_search_filter}"
f" ORDER BY {vector_search_expr}"
f" APPROXIMATE LIMIT {vector_topn}"
f"),"
f"combined_results AS ("
f" SELECT COALESCE(f.id, v.id) AS id, {score_expr} AS score"
f" FROM fulltext_results f"
f" FULL OUTER JOIN vector_results v"
f" ON f.id = v.id"
f")"
f" SELECT {fields_expr}, c.score as _score"
f" FROM combined_results c"
f" JOIN {index_name} t"
f" ON c.id = t.id"
f" ORDER BY score DESC"
f" LIMIT {offset}, {limit}"
)
logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
start_time = time.time()
@ -1234,10 +1292,14 @@ class OBConnection(DocStoreConnection):
for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields))
if result.total == 0:
result.total = len(result.chunks)
return result
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
if not self.client.check_table_exists(indexName):
if not self._check_table_exists_cached(indexName):
return None
try:
@ -1336,7 +1398,7 @@ class OBConnection(DocStoreConnection):
return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
if not self.client.check_table_exists(indexName):
if not self._check_table_exists_cached(indexName):
return True
condition["kb_id"] = knowledgebaseId
@ -1387,7 +1449,7 @@ class OBConnection(DocStoreConnection):
return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
if not self.client.check_table_exists(indexName):
if not self._check_table_exists_cached(indexName):
return 0
condition["kb_id"] = knowledgebaseId