mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 12:06:42 +08:00
feat: enhance OBConnection.search (#11876)
### What problem does this PR solve? Enhance OBConnection.search for better performance. Main changes: 1. Use string type of vector array in distance func for better parsing performance. 2. Manually set max_connections as pool size instead of using default value. 3. Set 'fulltext_search_columns' when starting. 4. Cache the results of the table existence check (we will never drop the table). 5. Remove unused 'group_results' logic. 6. Add the `USE_FULLTEXT_FIRST_FUSION_SEARCH` flag, and the corresponding fusion search SQL when it's false. ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -17,13 +17,16 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from elasticsearch_dsl import Q, Search
|
||||
from pydantic import BaseModel
|
||||
from pymysql.converters import escape_string
|
||||
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR
|
||||
from pyobvector.client import ClusterVersionException
|
||||
from pyobvector.client.hybrid_search import HybridSearch
|
||||
from pyobvector.util import ObVersion
|
||||
from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table
|
||||
@ -106,17 +109,6 @@ index_columns: list[str] = [
|
||||
"removed_kwd",
|
||||
]
|
||||
|
||||
fulltext_search_columns: list[str] = [
|
||||
"docnm_kwd",
|
||||
"content_with_weight",
|
||||
"title_tks",
|
||||
"title_sm_tks",
|
||||
"important_tks",
|
||||
"question_tks",
|
||||
"content_ltks",
|
||||
"content_sm_ltks"
|
||||
]
|
||||
|
||||
fts_columns_origin: list[str] = [
|
||||
"docnm_kwd^10",
|
||||
"content_with_weight",
|
||||
@ -138,7 +130,7 @@ fulltext_index_name_template = "fts_idx_%s"
|
||||
# MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607
|
||||
fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
|
||||
# cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938
|
||||
vector_search_template = "cosine_distance(%s, %s)"
|
||||
vector_search_template = "cosine_distance(%s, '%s')"
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
@ -362,18 +354,28 @@ class OBConnection(DocStoreConnection):
|
||||
port = mysql_config.get("port", 2881)
|
||||
self.username = mysql_config.get("user", "root@test")
|
||||
self.password = mysql_config.get("password", "infini_rag_flow")
|
||||
max_connections = mysql_config.get("max_connections", 300)
|
||||
else:
|
||||
logger.info("Use customized config to create OceanBase connection.")
|
||||
host = ob_config.get("host", "localhost")
|
||||
port = ob_config.get("port", 2881)
|
||||
self.username = ob_config.get("user", "root@test")
|
||||
self.password = ob_config.get("password", "infini_rag_flow")
|
||||
max_connections = ob_config.get("max_connections", 300)
|
||||
|
||||
self.db_name = ob_config.get("db_name", "test")
|
||||
self.uri = f"{host}:{port}"
|
||||
|
||||
logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")
|
||||
|
||||
# Set the maximum number of connections that can be created above the pool_size.
|
||||
# By default, this is half of max_connections, but at least 10.
|
||||
# This allows the pool to handle temporary spikes in demand without exhausting resources.
|
||||
max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10)))
|
||||
# Set the number of seconds to wait before giving up when trying to get a connection from the pool.
|
||||
# Default is 30 seconds, but can be overridden with the OB_POOL_TIMEOUT environment variable.
|
||||
pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30"))
|
||||
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
self.client = ObVecClient(
|
||||
@ -383,6 +385,9 @@ class OBConnection(DocStoreConnection):
|
||||
db_name=self.db_name,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
pool_size=max_connections,
|
||||
max_overflow=max_overflow,
|
||||
pool_timeout=pool_timeout,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
@ -398,6 +403,37 @@ class OBConnection(DocStoreConnection):
|
||||
self._check_ob_version()
|
||||
self._try_to_update_ob_query_timeout()
|
||||
|
||||
self.es = None
|
||||
if self.enable_hybrid_search:
|
||||
try:
|
||||
self.es = HybridSearch(
|
||||
uri=self.uri,
|
||||
user=self.username,
|
||||
password=self.password,
|
||||
db_name=self.db_name,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
pool_size=max_connections,
|
||||
max_overflow=max_overflow,
|
||||
pool_timeout=pool_timeout,
|
||||
)
|
||||
logger.info("OceanBase Hybrid Search feature is enabled")
|
||||
except ClusterVersionException as e:
|
||||
logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e)
|
||||
self.es = None
|
||||
|
||||
if self.es is not None and self.search_original_content:
|
||||
logger.info("HybridSearch is enabled, forcing search_original_content to False")
|
||||
self.search_original_content = False
|
||||
# Determine which columns to use for full-text search dynamically:
|
||||
# If HybridSearch is enabled (self.es is not None), we must use tokenized columns (fts_columns_tks)
|
||||
# for compatibility and performance with HybridSearch. Otherwise, we use the original content columns
|
||||
# (fts_columns_origin), which may be controlled by an environment variable.
|
||||
self.fulltext_search_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
|
||||
|
||||
self._table_exists_cache: set[str] = set()
|
||||
self._table_exists_cache_lock = threading.RLock()
|
||||
|
||||
logger.info(f"OceanBase {self.uri} is healthy.")
|
||||
|
||||
def _check_ob_version(self):
|
||||
@ -417,18 +453,6 @@ class OBConnection(DocStoreConnection):
|
||||
f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
|
||||
)
|
||||
|
||||
self.es = None
|
||||
if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search:
|
||||
self.es = HybridSearch(
|
||||
uri=self.uri,
|
||||
user=self.username,
|
||||
password=self.password,
|
||||
db_name=self.db_name,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
logger.info("OceanBase Hybrid Search feature is enabled")
|
||||
|
||||
def _try_to_update_ob_query_timeout(self):
|
||||
try:
|
||||
val = self._get_variable_value("ob_query_timeout")
|
||||
@ -455,9 +479,19 @@ class OBConnection(DocStoreConnection):
|
||||
return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']
|
||||
|
||||
self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
|
||||
logger.info(f"ENABLE_FULLTEXT_SEARCH={self.enable_fulltext_search}")
|
||||
|
||||
self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
|
||||
logger.info(f"USE_FULLTEXT_HINT={self.use_fulltext_hint}")
|
||||
|
||||
self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
|
||||
logger.info(f"SEARCH_ORIGINAL_CONTENT={self.search_original_content}")
|
||||
|
||||
self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')
|
||||
logger.info(f"ENABLE_HYBRID_SEARCH={self.enable_hybrid_search}")
|
||||
|
||||
self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true')
|
||||
logger.info(f"USE_FULLTEXT_FIRST_FUSION_SEARCH={self.use_fulltext_first_fusion_search}")
|
||||
|
||||
"""
|
||||
Database operations
|
||||
@ -478,6 +512,43 @@ class OBConnection(DocStoreConnection):
|
||||
return row[1]
|
||||
raise Exception(f"Variable '{var_name}' not found.")
|
||||
|
||||
def _check_table_exists_cached(self, table_name: str) -> bool:
|
||||
"""
|
||||
Check table existence with cache to reduce INFORMATION_SCHEMA queries under high concurrency.
|
||||
Only caches when table exists. Does not cache when table does not exist.
|
||||
Thread-safe implementation: read operations are lock-free (GIL-protected),
|
||||
write operations are protected by RLock to ensure cache consistency.
|
||||
|
||||
Args:
|
||||
table_name: Table name
|
||||
|
||||
Returns:
|
||||
Whether the table exists with all required indexes and columns
|
||||
"""
|
||||
if table_name in self._table_exists_cache:
|
||||
return True
|
||||
|
||||
try:
|
||||
if not self.client.check_table_exists(table_name):
|
||||
return False
|
||||
for column_name in index_columns:
|
||||
if not self._index_exists(table_name, index_name_template % (table_name, column_name)):
|
||||
return False
|
||||
for fts_column in self.fulltext_search_columns:
|
||||
column_name = fts_column.split("^")[0]
|
||||
if not self._index_exists(table_name, fulltext_index_name_template % column_name):
|
||||
return False
|
||||
for column in [column_order_id, column_group_id]:
|
||||
if not self._column_exist(table_name, column.name):
|
||||
return False
|
||||
except Exception as e:
|
||||
raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}")
|
||||
|
||||
with self._table_exists_cache_lock:
|
||||
if table_name not in self._table_exists_cache:
|
||||
self._table_exists_cache.add(table_name)
|
||||
return True
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
@ -500,8 +571,7 @@ class OBConnection(DocStoreConnection):
|
||||
process_func=lambda: self._add_index(indexName, column_name),
|
||||
)
|
||||
|
||||
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
|
||||
for fts_column in fts_columns:
|
||||
for fts_column in self.fulltext_search_columns:
|
||||
column_name = fts_column.split("^")[0]
|
||||
_try_with_lock(
|
||||
lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}",
|
||||
@ -546,24 +616,7 @@ class OBConnection(DocStoreConnection):
|
||||
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
try:
|
||||
if not self.client.check_table_exists(indexName):
|
||||
return False
|
||||
for column_name in index_columns:
|
||||
if not self._index_exists(indexName, index_name_template % (indexName, column_name)):
|
||||
return False
|
||||
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
|
||||
for fts_column in fts_columns:
|
||||
column_name = fts_column.split("^")[0]
|
||||
if not self._index_exists(indexName, fulltext_index_name_template % column_name):
|
||||
return False
|
||||
for column in [column_order_id, column_group_id]:
|
||||
if not self._column_exist(indexName, column.name):
|
||||
return False
|
||||
except Exception as e:
|
||||
raise Exception(f"OBConnection.indexExist error: {str(e)}")
|
||||
|
||||
return True
|
||||
return self._check_table_exists_cached(indexName)
|
||||
|
||||
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
|
||||
where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else ""
|
||||
@ -853,10 +906,8 @@ class OBConnection(DocStoreConnection):
|
||||
fulltext_query = escape_string(fulltext_query.strip())
|
||||
fulltext_topn = m.topn
|
||||
|
||||
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
|
||||
|
||||
# get fulltext match expression and weight values
|
||||
for field in fts_columns:
|
||||
for field in self.fulltext_search_columns:
|
||||
parts = field.split("^")
|
||||
column_name: str = parts[0]
|
||||
column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0
|
||||
@ -885,7 +936,8 @@ class OBConnection(DocStoreConnection):
|
||||
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
|
||||
|
||||
if vector_data:
|
||||
vector_search_expr = vector_search_template % (vector_column_name, vector_data)
|
||||
vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]"
|
||||
vector_search_expr = vector_search_template % (vector_column_name, vector_data_str)
|
||||
# use (1 - cosine_distance) as score, which should be [-1, 1]
|
||||
# https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323
|
||||
vector_search_score_expr = f"(1 - {vector_search_expr})"
|
||||
@ -910,11 +962,15 @@ class OBConnection(DocStoreConnection):
|
||||
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
|
||||
output_fields.append("_score")
|
||||
|
||||
group_results = kwargs.get("group_results", False)
|
||||
if limit:
|
||||
if vector_topn is not None:
|
||||
limit = min(vector_topn, limit)
|
||||
if fulltext_topn is not None:
|
||||
limit = min(fulltext_topn, limit)
|
||||
|
||||
for index_name in indexNames:
|
||||
|
||||
if not self.client.check_table_exists(index_name):
|
||||
if not self._check_table_exists_cached(index_name):
|
||||
continue
|
||||
|
||||
fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else ""
|
||||
@ -922,29 +978,7 @@ class OBConnection(DocStoreConnection):
|
||||
if search_type == "fusion":
|
||||
# fusion search, usually for chat
|
||||
num_candidates = vector_topn + fulltext_topn
|
||||
if group_results:
|
||||
count_sql = (
|
||||
f"WITH fulltext_results AS ("
|
||||
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
||||
f" FROM {index_name}"
|
||||
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||
f" ORDER BY relevance DESC"
|
||||
f" LIMIT {num_candidates}"
|
||||
f"),"
|
||||
f" scored_results AS ("
|
||||
f" SELECT *"
|
||||
f" FROM fulltext_results"
|
||||
f" WHERE {vector_search_filter}"
|
||||
f"),"
|
||||
f" group_results AS ("
|
||||
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn"
|
||||
f" FROM scored_results"
|
||||
f")"
|
||||
f" SELECT COUNT(*)"
|
||||
f" FROM group_results"
|
||||
f" WHERE rn = 1"
|
||||
)
|
||||
else:
|
||||
if self.use_fulltext_first_fusion_search:
|
||||
count_sql = (
|
||||
f"WITH fulltext_results AS ("
|
||||
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
||||
@ -955,6 +989,22 @@ class OBConnection(DocStoreConnection):
|
||||
f")"
|
||||
f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}"
|
||||
)
|
||||
else:
|
||||
count_sql = (
|
||||
f"WITH fulltext_results AS ("
|
||||
f" SELECT {fulltext_search_hint} id FROM {index_name}"
|
||||
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||
f" ORDER BY {fulltext_search_score_expr}"
|
||||
f" LIMIT {fulltext_topn}"
|
||||
f"),"
|
||||
f"vector_results AS ("
|
||||
f" SELECT id FROM {index_name}"
|
||||
f" WHERE {filters_expr} AND {vector_search_filter}"
|
||||
f" ORDER BY {vector_search_expr}"
|
||||
f" APPROXIMATE LIMIT {vector_topn}"
|
||||
f")"
|
||||
f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id"
|
||||
)
|
||||
logger.debug("OBConnection.search with count sql: %s", count_sql)
|
||||
|
||||
start_time = time.time()
|
||||
@ -976,32 +1026,8 @@ class OBConnection(DocStoreConnection):
|
||||
if total_count == 0:
|
||||
continue
|
||||
|
||||
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
|
||||
if group_results:
|
||||
fusion_sql = (
|
||||
f"WITH fulltext_results AS ("
|
||||
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
||||
f" FROM {index_name}"
|
||||
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||
f" ORDER BY relevance DESC"
|
||||
f" LIMIT {num_candidates}"
|
||||
f"),"
|
||||
f" scored_results AS ("
|
||||
f" SELECT *, {score_expr} AS _score"
|
||||
f" FROM fulltext_results"
|
||||
f" WHERE {vector_search_filter}"
|
||||
f"),"
|
||||
f" group_results AS ("
|
||||
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn"
|
||||
f" FROM scored_results"
|
||||
f")"
|
||||
f" SELECT {fields_expr}, _score"
|
||||
f" FROM group_results"
|
||||
f" WHERE rn = 1"
|
||||
f" ORDER BY _score DESC"
|
||||
f" LIMIT {offset}, {limit}"
|
||||
)
|
||||
else:
|
||||
if self.use_fulltext_first_fusion_search:
|
||||
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
|
||||
fusion_sql = (
|
||||
f"WITH fulltext_results AS ("
|
||||
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
||||
@ -1016,6 +1042,38 @@ class OBConnection(DocStoreConnection):
|
||||
f" ORDER BY _score DESC"
|
||||
f" LIMIT {offset}, {limit}"
|
||||
)
|
||||
else:
|
||||
pagerank_score_expr = f"(CAST(IFNULL(f.{PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)"
|
||||
score_expr = f"(f.relevance * {1 - vector_similarity_weight} + v.similarity * {vector_similarity_weight} + {pagerank_score_expr})"
|
||||
fields_expr = ", ".join([f"t.{f} as {f}" for f in output_fields if f != "_score"])
|
||||
fusion_sql = (
|
||||
f"WITH fulltext_results AS ("
|
||||
f" SELECT {fulltext_search_hint} id, pagerank_fea, {fulltext_search_score_expr} AS relevance"
|
||||
f" FROM {index_name}"
|
||||
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||
f" ORDER BY relevance DESC"
|
||||
f" LIMIT {fulltext_topn}"
|
||||
f"),"
|
||||
f"vector_results AS ("
|
||||
f" SELECT id, pagerank_fea, {vector_search_score_expr} AS similarity"
|
||||
f" FROM {index_name}"
|
||||
f" WHERE {filters_expr} AND {vector_search_filter}"
|
||||
f" ORDER BY {vector_search_expr}"
|
||||
f" APPROXIMATE LIMIT {vector_topn}"
|
||||
f"),"
|
||||
f"combined_results AS ("
|
||||
f" SELECT COALESCE(f.id, v.id) AS id, {score_expr} AS score"
|
||||
f" FROM fulltext_results f"
|
||||
f" FULL OUTER JOIN vector_results v"
|
||||
f" ON f.id = v.id"
|
||||
f")"
|
||||
f" SELECT {fields_expr}, c.score as _score"
|
||||
f" FROM combined_results c"
|
||||
f" JOIN {index_name} t"
|
||||
f" ON c.id = t.id"
|
||||
f" ORDER BY score DESC"
|
||||
f" LIMIT {offset}, {limit}"
|
||||
)
|
||||
logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
|
||||
|
||||
start_time = time.time()
|
||||
@ -1234,10 +1292,14 @@ class OBConnection(DocStoreConnection):
|
||||
|
||||
for row in rows:
|
||||
result.chunks.append(self._row_to_entity(row, output_fields))
|
||||
|
||||
if result.total == 0:
|
||||
result.total = len(result.chunks)
|
||||
|
||||
return result
|
||||
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
if not self.client.check_table_exists(indexName):
|
||||
if not self._check_table_exists_cached(indexName):
|
||||
return None
|
||||
|
||||
try:
|
||||
@ -1336,7 +1398,7 @@ class OBConnection(DocStoreConnection):
|
||||
return res
|
||||
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
if not self.client.check_table_exists(indexName):
|
||||
if not self._check_table_exists_cached(indexName):
|
||||
return True
|
||||
|
||||
condition["kb_id"] = knowledgebaseId
|
||||
@ -1387,7 +1449,7 @@ class OBConnection(DocStoreConnection):
|
||||
return False
|
||||
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
if not self.client.check_table_exists(indexName):
|
||||
if not self._check_table_exists_cached(indexName):
|
||||
return 0
|
||||
|
||||
condition["kb_id"] = knowledgebaseId
|
||||
|
||||
Reference in New Issue
Block a user