feat: enhance OBConnection.search (#11876)

### What problem does this PR solve?

Enhance OBConnection.search for better performance. Main changes:
1. Use string type of vector array in distance func for better parsing
performance.
2. Manually set max_connections as pool size instead of using default
value.
3. Set 'fulltext_search_columns' when starting.
4. Cache the results of the table existence check (we will never drop
the table).
5. Remove unused 'group_results' logic.
6. Add the `USE_FULLTEXT_FIRST_FUSION_SEARCH` flag, and the
corresponding fusion search SQL when it's false.

### Type of change
- [x] Performance Improvement
This commit is contained in:
He Wang
2025-12-10 19:13:37 +08:00
committed by GitHub
parent 3cb72377d7
commit badf33e3b9

View File

@ -17,13 +17,16 @@ import json
import logging import logging
import os import os
import re import re
import threading
import time import time
from typing import Any, Optional from typing import Any, Optional
import numpy as np
from elasticsearch_dsl import Q, Search from elasticsearch_dsl import Q, Search
from pydantic import BaseModel from pydantic import BaseModel
from pymysql.converters import escape_string from pymysql.converters import escape_string
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR
from pyobvector.client import ClusterVersionException
from pyobvector.client.hybrid_search import HybridSearch from pyobvector.client.hybrid_search import HybridSearch
from pyobvector.util import ObVersion from pyobvector.util import ObVersion
from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table
@ -106,17 +109,6 @@ index_columns: list[str] = [
"removed_kwd", "removed_kwd",
] ]
fulltext_search_columns: list[str] = [
"docnm_kwd",
"content_with_weight",
"title_tks",
"title_sm_tks",
"important_tks",
"question_tks",
"content_ltks",
"content_sm_ltks"
]
fts_columns_origin: list[str] = [ fts_columns_origin: list[str] = [
"docnm_kwd^10", "docnm_kwd^10",
"content_with_weight", "content_with_weight",
@ -138,7 +130,7 @@ fulltext_index_name_template = "fts_idx_%s"
# MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607 # MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607
fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)" fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
# cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938 # cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938
vector_search_template = "cosine_distance(%s, %s)" vector_search_template = "cosine_distance(%s, '%s')"
class SearchResult(BaseModel): class SearchResult(BaseModel):
@ -362,18 +354,28 @@ class OBConnection(DocStoreConnection):
port = mysql_config.get("port", 2881) port = mysql_config.get("port", 2881)
self.username = mysql_config.get("user", "root@test") self.username = mysql_config.get("user", "root@test")
self.password = mysql_config.get("password", "infini_rag_flow") self.password = mysql_config.get("password", "infini_rag_flow")
max_connections = mysql_config.get("max_connections", 300)
else: else:
logger.info("Use customized config to create OceanBase connection.") logger.info("Use customized config to create OceanBase connection.")
host = ob_config.get("host", "localhost") host = ob_config.get("host", "localhost")
port = ob_config.get("port", 2881) port = ob_config.get("port", 2881)
self.username = ob_config.get("user", "root@test") self.username = ob_config.get("user", "root@test")
self.password = ob_config.get("password", "infini_rag_flow") self.password = ob_config.get("password", "infini_rag_flow")
max_connections = ob_config.get("max_connections", 300)
self.db_name = ob_config.get("db_name", "test") self.db_name = ob_config.get("db_name", "test")
self.uri = f"{host}:{port}" self.uri = f"{host}:{port}"
logger.info(f"Use OceanBase '{self.uri}' as the doc engine.") logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")
# Set the maximum number of connections that can be created above the pool_size.
# By default, this is half of max_connections, but at least 10.
# This allows the pool to handle temporary spikes in demand without exhausting resources.
max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10)))
# Set the number of seconds to wait before giving up when trying to get a connection from the pool.
# Default is 30 seconds, but can be overridden with the OB_POOL_TIMEOUT environment variable.
pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30"))
for _ in range(ATTEMPT_TIME): for _ in range(ATTEMPT_TIME):
try: try:
self.client = ObVecClient( self.client = ObVecClient(
@ -383,6 +385,9 @@ class OBConnection(DocStoreConnection):
db_name=self.db_name, db_name=self.db_name,
pool_pre_ping=True, pool_pre_ping=True,
pool_recycle=3600, pool_recycle=3600,
pool_size=max_connections,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
) )
break break
except Exception as e: except Exception as e:
@ -398,6 +403,37 @@ class OBConnection(DocStoreConnection):
self._check_ob_version() self._check_ob_version()
self._try_to_update_ob_query_timeout() self._try_to_update_ob_query_timeout()
self.es = None
if self.enable_hybrid_search:
try:
self.es = HybridSearch(
uri=self.uri,
user=self.username,
password=self.password,
db_name=self.db_name,
pool_pre_ping=True,
pool_recycle=3600,
pool_size=max_connections,
max_overflow=max_overflow,
pool_timeout=pool_timeout,
)
logger.info("OceanBase Hybrid Search feature is enabled")
except ClusterVersionException as e:
logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e)
self.es = None
if self.es is not None and self.search_original_content:
logger.info("HybridSearch is enabled, forcing search_original_content to False")
self.search_original_content = False
# Determine which columns to use for full-text search dynamically:
# If HybridSearch is enabled (self.es is not None), we must use tokenized columns (fts_columns_tks)
# for compatibility and performance with HybridSearch. Otherwise, we use the original content columns
# (fts_columns_origin), which may be controlled by an environment variable.
self.fulltext_search_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
self._table_exists_cache: set[str] = set()
self._table_exists_cache_lock = threading.RLock()
logger.info(f"OceanBase {self.uri} is healthy.") logger.info(f"OceanBase {self.uri} is healthy.")
def _check_ob_version(self): def _check_ob_version(self):
@ -417,18 +453,6 @@ class OBConnection(DocStoreConnection):
f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}" f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
) )
self.es = None
if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search:
self.es = HybridSearch(
uri=self.uri,
user=self.username,
password=self.password,
db_name=self.db_name,
pool_pre_ping=True,
pool_recycle=3600,
)
logger.info("OceanBase Hybrid Search feature is enabled")
def _try_to_update_ob_query_timeout(self): def _try_to_update_ob_query_timeout(self):
try: try:
val = self._get_variable_value("ob_query_timeout") val = self._get_variable_value("ob_query_timeout")
@ -455,9 +479,19 @@ class OBConnection(DocStoreConnection):
return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y'] return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']
self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true') self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
logger.info(f"ENABLE_FULLTEXT_SEARCH={self.enable_fulltext_search}")
self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true') self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
logger.info(f"USE_FULLTEXT_HINT={self.use_fulltext_hint}")
self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true') self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
logger.info(f"SEARCH_ORIGINAL_CONTENT={self.search_original_content}")
self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false') self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')
logger.info(f"ENABLE_HYBRID_SEARCH={self.enable_hybrid_search}")
self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true')
logger.info(f"USE_FULLTEXT_FIRST_FUSION_SEARCH={self.use_fulltext_first_fusion_search}")
""" """
Database operations Database operations
@ -478,6 +512,43 @@ class OBConnection(DocStoreConnection):
return row[1] return row[1]
raise Exception(f"Variable '{var_name}' not found.") raise Exception(f"Variable '{var_name}' not found.")
def _check_table_exists_cached(self, table_name: str) -> bool:
"""
Check table existence with cache to reduce INFORMATION_SCHEMA queries under high concurrency.
Only caches when table exists. Does not cache when table does not exist.
Thread-safe implementation: read operations are lock-free (GIL-protected),
write operations are protected by RLock to ensure cache consistency.
Args:
table_name: Table name
Returns:
Whether the table exists with all required indexes and columns
"""
if table_name in self._table_exists_cache:
return True
try:
if not self.client.check_table_exists(table_name):
return False
for column_name in index_columns:
if not self._index_exists(table_name, index_name_template % (table_name, column_name)):
return False
for fts_column in self.fulltext_search_columns:
column_name = fts_column.split("^")[0]
if not self._index_exists(table_name, fulltext_index_name_template % column_name):
return False
for column in [column_order_id, column_group_id]:
if not self._column_exist(table_name, column.name):
return False
except Exception as e:
raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}")
with self._table_exists_cache_lock:
if table_name not in self._table_exists_cache:
self._table_exists_cache.add(table_name)
return True
""" """
Table operations Table operations
""" """
@ -500,8 +571,7 @@ class OBConnection(DocStoreConnection):
process_func=lambda: self._add_index(indexName, column_name), process_func=lambda: self._add_index(indexName, column_name),
) )
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks for fts_column in self.fulltext_search_columns:
for fts_column in fts_columns:
column_name = fts_column.split("^")[0] column_name = fts_column.split("^")[0]
_try_with_lock( _try_with_lock(
lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}", lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}",
@ -546,24 +616,7 @@ class OBConnection(DocStoreConnection):
raise Exception(f"OBConnection.deleteIndex error: {str(e)}") raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool: def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
try: return self._check_table_exists_cached(indexName)
if not self.client.check_table_exists(indexName):
return False
for column_name in index_columns:
if not self._index_exists(indexName, index_name_template % (indexName, column_name)):
return False
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
for fts_column in fts_columns:
column_name = fts_column.split("^")[0]
if not self._index_exists(indexName, fulltext_index_name_template % column_name):
return False
for column in [column_order_id, column_group_id]:
if not self._column_exist(indexName, column.name):
return False
except Exception as e:
raise Exception(f"OBConnection.indexExist error: {str(e)}")
return True
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int: def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else "" where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else ""
@ -853,10 +906,8 @@ class OBConnection(DocStoreConnection):
fulltext_query = escape_string(fulltext_query.strip()) fulltext_query = escape_string(fulltext_query.strip())
fulltext_topn = m.topn fulltext_topn = m.topn
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
# get fulltext match expression and weight values # get fulltext match expression and weight values
for field in fts_columns: for field in self.fulltext_search_columns:
parts = field.split("^") parts = field.split("^")
column_name: str = parts[0] column_name: str = parts[0]
column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0 column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0
@ -885,7 +936,8 @@ class OBConnection(DocStoreConnection):
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})" fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
if vector_data: if vector_data:
vector_search_expr = vector_search_template % (vector_column_name, vector_data) vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]"
vector_search_expr = vector_search_template % (vector_column_name, vector_data_str)
# use (1 - cosine_distance) as score, which should be [-1, 1] # use (1 - cosine_distance) as score, which should be [-1, 1]
# https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323 # https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323
vector_search_score_expr = f"(1 - {vector_search_expr})" vector_search_score_expr = f"(1 - {vector_search_expr})"
@ -910,11 +962,15 @@ class OBConnection(DocStoreConnection):
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields: if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
output_fields.append("_score") output_fields.append("_score")
group_results = kwargs.get("group_results", False) if limit:
if vector_topn is not None:
limit = min(vector_topn, limit)
if fulltext_topn is not None:
limit = min(fulltext_topn, limit)
for index_name in indexNames: for index_name in indexNames:
if not self.client.check_table_exists(index_name): if not self._check_table_exists_cached(index_name):
continue continue
fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else "" fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else ""
@ -922,29 +978,7 @@ class OBConnection(DocStoreConnection):
if search_type == "fusion": if search_type == "fusion":
# fusion search, usually for chat # fusion search, usually for chat
num_candidates = vector_topn + fulltext_topn num_candidates = vector_topn + fulltext_topn
if group_results: if self.use_fulltext_first_fusion_search:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f"),"
f" scored_results AS ("
f" SELECT *"
f" FROM fulltext_results"
f" WHERE {vector_search_filter}"
f"),"
f" group_results AS ("
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn"
f" FROM scored_results"
f")"
f" SELECT COUNT(*)"
f" FROM group_results"
f" WHERE rn = 1"
)
else:
count_sql = ( count_sql = (
f"WITH fulltext_results AS (" f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
@ -955,6 +989,22 @@ class OBConnection(DocStoreConnection):
f")" f")"
f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}" f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}"
) )
else:
count_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} id FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY {fulltext_search_score_expr}"
f" LIMIT {fulltext_topn}"
f"),"
f"vector_results AS ("
f" SELECT id FROM {index_name}"
f" WHERE {filters_expr} AND {vector_search_filter}"
f" ORDER BY {vector_search_expr}"
f" APPROXIMATE LIMIT {vector_topn}"
f")"
f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id"
)
logger.debug("OBConnection.search with count sql: %s", count_sql) logger.debug("OBConnection.search with count sql: %s", count_sql)
start_time = time.time() start_time = time.time()
@ -976,32 +1026,8 @@ class OBConnection(DocStoreConnection):
if total_count == 0: if total_count == 0:
continue continue
if self.use_fulltext_first_fusion_search:
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})" score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
if group_results:
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {num_candidates}"
f"),"
f" scored_results AS ("
f" SELECT *, {score_expr} AS _score"
f" FROM fulltext_results"
f" WHERE {vector_search_filter}"
f"),"
f" group_results AS ("
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn"
f" FROM scored_results"
f")"
f" SELECT {fields_expr}, _score"
f" FROM group_results"
f" WHERE rn = 1"
f" ORDER BY _score DESC"
f" LIMIT {offset}, {limit}"
)
else:
fusion_sql = ( fusion_sql = (
f"WITH fulltext_results AS (" f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
@ -1016,6 +1042,38 @@ class OBConnection(DocStoreConnection):
f" ORDER BY _score DESC" f" ORDER BY _score DESC"
f" LIMIT {offset}, {limit}" f" LIMIT {offset}, {limit}"
) )
else:
pagerank_score_expr = f"(CAST(IFNULL(f.{PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)"
score_expr = f"(f.relevance * {1 - vector_similarity_weight} + v.similarity * {vector_similarity_weight} + {pagerank_score_expr})"
fields_expr = ", ".join([f"t.{f} as {f}" for f in output_fields if f != "_score"])
fusion_sql = (
f"WITH fulltext_results AS ("
f" SELECT {fulltext_search_hint} id, pagerank_fea, {fulltext_search_score_expr} AS relevance"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {fulltext_search_filter}"
f" ORDER BY relevance DESC"
f" LIMIT {fulltext_topn}"
f"),"
f"vector_results AS ("
f" SELECT id, pagerank_fea, {vector_search_score_expr} AS similarity"
f" FROM {index_name}"
f" WHERE {filters_expr} AND {vector_search_filter}"
f" ORDER BY {vector_search_expr}"
f" APPROXIMATE LIMIT {vector_topn}"
f"),"
f"combined_results AS ("
f" SELECT COALESCE(f.id, v.id) AS id, {score_expr} AS score"
f" FROM fulltext_results f"
f" FULL OUTER JOIN vector_results v"
f" ON f.id = v.id"
f")"
f" SELECT {fields_expr}, c.score as _score"
f" FROM combined_results c"
f" JOIN {index_name} t"
f" ON c.id = t.id"
f" ORDER BY score DESC"
f" LIMIT {offset}, {limit}"
)
logger.debug("OBConnection.search with fusion sql: %s", fusion_sql) logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
start_time = time.time() start_time = time.time()
@ -1234,10 +1292,14 @@ class OBConnection(DocStoreConnection):
for row in rows: for row in rows:
result.chunks.append(self._row_to_entity(row, output_fields)) result.chunks.append(self._row_to_entity(row, output_fields))
if result.total == 0:
result.total = len(result.chunks)
return result return result
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
if not self.client.check_table_exists(indexName): if not self._check_table_exists_cached(indexName):
return None return None
try: try:
@ -1336,7 +1398,7 @@ class OBConnection(DocStoreConnection):
return res return res
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
if not self.client.check_table_exists(indexName): if not self._check_table_exists_cached(indexName):
return True return True
condition["kb_id"] = knowledgebaseId condition["kb_id"] = knowledgebaseId
@ -1387,7 +1449,7 @@ class OBConnection(DocStoreConnection):
return False return False
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
if not self.client.check_table_exists(indexName): if not self._check_table_exists_cached(indexName):
return 0 return 0
condition["kb_id"] = knowledgebaseId condition["kb_id"] = knowledgebaseId