mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 12:06:42 +08:00
feat: enhance OBConnection.search (#11876)
### What problem does this PR solve? Enhance OBConnection.search for better performance. Main changes: 1. Use string type of vector array in distance func for better parsing performance. 2. Manually set max_connections as pool size instead of using default value. 3. Set 'fulltext_search_columns' when starting. 4. Cache the results of the table existence check (we will never drop the table). 5. Remove unused 'group_results' logic. 6. Add the `USE_FULLTEXT_FIRST_FUSION_SEARCH` flag, and the corresponding fusion search SQL when it's false. ### Type of change - [x] Performance Improvement
This commit is contained in:
@ -17,13 +17,16 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from elasticsearch_dsl import Q, Search
|
from elasticsearch_dsl import Q, Search
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pymysql.converters import escape_string
|
from pymysql.converters import escape_string
|
||||||
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR
|
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR
|
||||||
|
from pyobvector.client import ClusterVersionException
|
||||||
from pyobvector.client.hybrid_search import HybridSearch
|
from pyobvector.client.hybrid_search import HybridSearch
|
||||||
from pyobvector.util import ObVersion
|
from pyobvector.util import ObVersion
|
||||||
from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table
|
from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table
|
||||||
@ -106,17 +109,6 @@ index_columns: list[str] = [
|
|||||||
"removed_kwd",
|
"removed_kwd",
|
||||||
]
|
]
|
||||||
|
|
||||||
fulltext_search_columns: list[str] = [
|
|
||||||
"docnm_kwd",
|
|
||||||
"content_with_weight",
|
|
||||||
"title_tks",
|
|
||||||
"title_sm_tks",
|
|
||||||
"important_tks",
|
|
||||||
"question_tks",
|
|
||||||
"content_ltks",
|
|
||||||
"content_sm_ltks"
|
|
||||||
]
|
|
||||||
|
|
||||||
fts_columns_origin: list[str] = [
|
fts_columns_origin: list[str] = [
|
||||||
"docnm_kwd^10",
|
"docnm_kwd^10",
|
||||||
"content_with_weight",
|
"content_with_weight",
|
||||||
@ -138,7 +130,7 @@ fulltext_index_name_template = "fts_idx_%s"
|
|||||||
# MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607
|
# MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607
|
||||||
fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
|
fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
|
||||||
# cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938
|
# cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938
|
||||||
vector_search_template = "cosine_distance(%s, %s)"
|
vector_search_template = "cosine_distance(%s, '%s')"
|
||||||
|
|
||||||
|
|
||||||
class SearchResult(BaseModel):
|
class SearchResult(BaseModel):
|
||||||
@ -362,18 +354,28 @@ class OBConnection(DocStoreConnection):
|
|||||||
port = mysql_config.get("port", 2881)
|
port = mysql_config.get("port", 2881)
|
||||||
self.username = mysql_config.get("user", "root@test")
|
self.username = mysql_config.get("user", "root@test")
|
||||||
self.password = mysql_config.get("password", "infini_rag_flow")
|
self.password = mysql_config.get("password", "infini_rag_flow")
|
||||||
|
max_connections = mysql_config.get("max_connections", 300)
|
||||||
else:
|
else:
|
||||||
logger.info("Use customized config to create OceanBase connection.")
|
logger.info("Use customized config to create OceanBase connection.")
|
||||||
host = ob_config.get("host", "localhost")
|
host = ob_config.get("host", "localhost")
|
||||||
port = ob_config.get("port", 2881)
|
port = ob_config.get("port", 2881)
|
||||||
self.username = ob_config.get("user", "root@test")
|
self.username = ob_config.get("user", "root@test")
|
||||||
self.password = ob_config.get("password", "infini_rag_flow")
|
self.password = ob_config.get("password", "infini_rag_flow")
|
||||||
|
max_connections = ob_config.get("max_connections", 300)
|
||||||
|
|
||||||
self.db_name = ob_config.get("db_name", "test")
|
self.db_name = ob_config.get("db_name", "test")
|
||||||
self.uri = f"{host}:{port}"
|
self.uri = f"{host}:{port}"
|
||||||
|
|
||||||
logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")
|
logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")
|
||||||
|
|
||||||
|
# Set the maximum number of connections that can be created above the pool_size.
|
||||||
|
# By default, this is half of max_connections, but at least 10.
|
||||||
|
# This allows the pool to handle temporary spikes in demand without exhausting resources.
|
||||||
|
max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10)))
|
||||||
|
# Set the number of seconds to wait before giving up when trying to get a connection from the pool.
|
||||||
|
# Default is 30 seconds, but can be overridden with the OB_POOL_TIMEOUT environment variable.
|
||||||
|
pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30"))
|
||||||
|
|
||||||
for _ in range(ATTEMPT_TIME):
|
for _ in range(ATTEMPT_TIME):
|
||||||
try:
|
try:
|
||||||
self.client = ObVecClient(
|
self.client = ObVecClient(
|
||||||
@ -383,6 +385,9 @@ class OBConnection(DocStoreConnection):
|
|||||||
db_name=self.db_name,
|
db_name=self.db_name,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
pool_recycle=3600,
|
pool_recycle=3600,
|
||||||
|
pool_size=max_connections,
|
||||||
|
max_overflow=max_overflow,
|
||||||
|
pool_timeout=pool_timeout,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -398,6 +403,37 @@ class OBConnection(DocStoreConnection):
|
|||||||
self._check_ob_version()
|
self._check_ob_version()
|
||||||
self._try_to_update_ob_query_timeout()
|
self._try_to_update_ob_query_timeout()
|
||||||
|
|
||||||
|
self.es = None
|
||||||
|
if self.enable_hybrid_search:
|
||||||
|
try:
|
||||||
|
self.es = HybridSearch(
|
||||||
|
uri=self.uri,
|
||||||
|
user=self.username,
|
||||||
|
password=self.password,
|
||||||
|
db_name=self.db_name,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=3600,
|
||||||
|
pool_size=max_connections,
|
||||||
|
max_overflow=max_overflow,
|
||||||
|
pool_timeout=pool_timeout,
|
||||||
|
)
|
||||||
|
logger.info("OceanBase Hybrid Search feature is enabled")
|
||||||
|
except ClusterVersionException as e:
|
||||||
|
logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e)
|
||||||
|
self.es = None
|
||||||
|
|
||||||
|
if self.es is not None and self.search_original_content:
|
||||||
|
logger.info("HybridSearch is enabled, forcing search_original_content to False")
|
||||||
|
self.search_original_content = False
|
||||||
|
# Determine which columns to use for full-text search dynamically:
|
||||||
|
# If HybridSearch is enabled (self.es is not None), we must use tokenized columns (fts_columns_tks)
|
||||||
|
# for compatibility and performance with HybridSearch. Otherwise, we use the original content columns
|
||||||
|
# (fts_columns_origin), which may be controlled by an environment variable.
|
||||||
|
self.fulltext_search_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
|
||||||
|
|
||||||
|
self._table_exists_cache: set[str] = set()
|
||||||
|
self._table_exists_cache_lock = threading.RLock()
|
||||||
|
|
||||||
logger.info(f"OceanBase {self.uri} is healthy.")
|
logger.info(f"OceanBase {self.uri} is healthy.")
|
||||||
|
|
||||||
def _check_ob_version(self):
|
def _check_ob_version(self):
|
||||||
@ -417,18 +453,6 @@ class OBConnection(DocStoreConnection):
|
|||||||
f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
|
f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.es = None
|
|
||||||
if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search:
|
|
||||||
self.es = HybridSearch(
|
|
||||||
uri=self.uri,
|
|
||||||
user=self.username,
|
|
||||||
password=self.password,
|
|
||||||
db_name=self.db_name,
|
|
||||||
pool_pre_ping=True,
|
|
||||||
pool_recycle=3600,
|
|
||||||
)
|
|
||||||
logger.info("OceanBase Hybrid Search feature is enabled")
|
|
||||||
|
|
||||||
def _try_to_update_ob_query_timeout(self):
|
def _try_to_update_ob_query_timeout(self):
|
||||||
try:
|
try:
|
||||||
val = self._get_variable_value("ob_query_timeout")
|
val = self._get_variable_value("ob_query_timeout")
|
||||||
@ -455,9 +479,19 @@ class OBConnection(DocStoreConnection):
|
|||||||
return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']
|
return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']
|
||||||
|
|
||||||
self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
|
self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
|
||||||
|
logger.info(f"ENABLE_FULLTEXT_SEARCH={self.enable_fulltext_search}")
|
||||||
|
|
||||||
self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
|
self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
|
||||||
|
logger.info(f"USE_FULLTEXT_HINT={self.use_fulltext_hint}")
|
||||||
|
|
||||||
self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
|
self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
|
||||||
|
logger.info(f"SEARCH_ORIGINAL_CONTENT={self.search_original_content}")
|
||||||
|
|
||||||
self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')
|
self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')
|
||||||
|
logger.info(f"ENABLE_HYBRID_SEARCH={self.enable_hybrid_search}")
|
||||||
|
|
||||||
|
self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true')
|
||||||
|
logger.info(f"USE_FULLTEXT_FIRST_FUSION_SEARCH={self.use_fulltext_first_fusion_search}")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Database operations
|
Database operations
|
||||||
@ -478,6 +512,43 @@ class OBConnection(DocStoreConnection):
|
|||||||
return row[1]
|
return row[1]
|
||||||
raise Exception(f"Variable '{var_name}' not found.")
|
raise Exception(f"Variable '{var_name}' not found.")
|
||||||
|
|
||||||
|
def _check_table_exists_cached(self, table_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check table existence with cache to reduce INFORMATION_SCHEMA queries under high concurrency.
|
||||||
|
Only caches when table exists. Does not cache when table does not exist.
|
||||||
|
Thread-safe implementation: read operations are lock-free (GIL-protected),
|
||||||
|
write operations are protected by RLock to ensure cache consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Table name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the table exists with all required indexes and columns
|
||||||
|
"""
|
||||||
|
if table_name in self._table_exists_cache:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.client.check_table_exists(table_name):
|
||||||
|
return False
|
||||||
|
for column_name in index_columns:
|
||||||
|
if not self._index_exists(table_name, index_name_template % (table_name, column_name)):
|
||||||
|
return False
|
||||||
|
for fts_column in self.fulltext_search_columns:
|
||||||
|
column_name = fts_column.split("^")[0]
|
||||||
|
if not self._index_exists(table_name, fulltext_index_name_template % column_name):
|
||||||
|
return False
|
||||||
|
for column in [column_order_id, column_group_id]:
|
||||||
|
if not self._column_exist(table_name, column.name):
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}")
|
||||||
|
|
||||||
|
with self._table_exists_cache_lock:
|
||||||
|
if table_name not in self._table_exists_cache:
|
||||||
|
self._table_exists_cache.add(table_name)
|
||||||
|
return True
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Table operations
|
Table operations
|
||||||
"""
|
"""
|
||||||
@ -500,8 +571,7 @@ class OBConnection(DocStoreConnection):
|
|||||||
process_func=lambda: self._add_index(indexName, column_name),
|
process_func=lambda: self._add_index(indexName, column_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
|
for fts_column in self.fulltext_search_columns:
|
||||||
for fts_column in fts_columns:
|
|
||||||
column_name = fts_column.split("^")[0]
|
column_name = fts_column.split("^")[0]
|
||||||
_try_with_lock(
|
_try_with_lock(
|
||||||
lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}",
|
lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}",
|
||||||
@ -546,24 +616,7 @@ class OBConnection(DocStoreConnection):
|
|||||||
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
|
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
|
||||||
|
|
||||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||||
try:
|
return self._check_table_exists_cached(indexName)
|
||||||
if not self.client.check_table_exists(indexName):
|
|
||||||
return False
|
|
||||||
for column_name in index_columns:
|
|
||||||
if not self._index_exists(indexName, index_name_template % (indexName, column_name)):
|
|
||||||
return False
|
|
||||||
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
|
|
||||||
for fts_column in fts_columns:
|
|
||||||
column_name = fts_column.split("^")[0]
|
|
||||||
if not self._index_exists(indexName, fulltext_index_name_template % column_name):
|
|
||||||
return False
|
|
||||||
for column in [column_order_id, column_group_id]:
|
|
||||||
if not self._column_exist(indexName, column.name):
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"OBConnection.indexExist error: {str(e)}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
|
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
|
||||||
where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else ""
|
where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else ""
|
||||||
@ -853,10 +906,8 @@ class OBConnection(DocStoreConnection):
|
|||||||
fulltext_query = escape_string(fulltext_query.strip())
|
fulltext_query = escape_string(fulltext_query.strip())
|
||||||
fulltext_topn = m.topn
|
fulltext_topn = m.topn
|
||||||
|
|
||||||
fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
|
|
||||||
|
|
||||||
# get fulltext match expression and weight values
|
# get fulltext match expression and weight values
|
||||||
for field in fts_columns:
|
for field in self.fulltext_search_columns:
|
||||||
parts = field.split("^")
|
parts = field.split("^")
|
||||||
column_name: str = parts[0]
|
column_name: str = parts[0]
|
||||||
column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0
|
column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0
|
||||||
@ -885,7 +936,8 @@ class OBConnection(DocStoreConnection):
|
|||||||
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
|
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
|
||||||
|
|
||||||
if vector_data:
|
if vector_data:
|
||||||
vector_search_expr = vector_search_template % (vector_column_name, vector_data)
|
vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]"
|
||||||
|
vector_search_expr = vector_search_template % (vector_column_name, vector_data_str)
|
||||||
# use (1 - cosine_distance) as score, which should be [-1, 1]
|
# use (1 - cosine_distance) as score, which should be [-1, 1]
|
||||||
# https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323
|
# https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323
|
||||||
vector_search_score_expr = f"(1 - {vector_search_expr})"
|
vector_search_score_expr = f"(1 - {vector_search_expr})"
|
||||||
@ -910,11 +962,15 @@ class OBConnection(DocStoreConnection):
|
|||||||
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
|
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
|
||||||
output_fields.append("_score")
|
output_fields.append("_score")
|
||||||
|
|
||||||
group_results = kwargs.get("group_results", False)
|
if limit:
|
||||||
|
if vector_topn is not None:
|
||||||
|
limit = min(vector_topn, limit)
|
||||||
|
if fulltext_topn is not None:
|
||||||
|
limit = min(fulltext_topn, limit)
|
||||||
|
|
||||||
for index_name in indexNames:
|
for index_name in indexNames:
|
||||||
|
|
||||||
if not self.client.check_table_exists(index_name):
|
if not self._check_table_exists_cached(index_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else ""
|
fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else ""
|
||||||
@ -922,29 +978,7 @@ class OBConnection(DocStoreConnection):
|
|||||||
if search_type == "fusion":
|
if search_type == "fusion":
|
||||||
# fusion search, usually for chat
|
# fusion search, usually for chat
|
||||||
num_candidates = vector_topn + fulltext_topn
|
num_candidates = vector_topn + fulltext_topn
|
||||||
if group_results:
|
if self.use_fulltext_first_fusion_search:
|
||||||
count_sql = (
|
|
||||||
f"WITH fulltext_results AS ("
|
|
||||||
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
|
||||||
f" FROM {index_name}"
|
|
||||||
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
|
||||||
f" ORDER BY relevance DESC"
|
|
||||||
f" LIMIT {num_candidates}"
|
|
||||||
f"),"
|
|
||||||
f" scored_results AS ("
|
|
||||||
f" SELECT *"
|
|
||||||
f" FROM fulltext_results"
|
|
||||||
f" WHERE {vector_search_filter}"
|
|
||||||
f"),"
|
|
||||||
f" group_results AS ("
|
|
||||||
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn"
|
|
||||||
f" FROM scored_results"
|
|
||||||
f")"
|
|
||||||
f" SELECT COUNT(*)"
|
|
||||||
f" FROM group_results"
|
|
||||||
f" WHERE rn = 1"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
count_sql = (
|
count_sql = (
|
||||||
f"WITH fulltext_results AS ("
|
f"WITH fulltext_results AS ("
|
||||||
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
||||||
@ -955,6 +989,22 @@ class OBConnection(DocStoreConnection):
|
|||||||
f")"
|
f")"
|
||||||
f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}"
|
f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
count_sql = (
|
||||||
|
f"WITH fulltext_results AS ("
|
||||||
|
f" SELECT {fulltext_search_hint} id FROM {index_name}"
|
||||||
|
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||||
|
f" ORDER BY {fulltext_search_score_expr}"
|
||||||
|
f" LIMIT {fulltext_topn}"
|
||||||
|
f"),"
|
||||||
|
f"vector_results AS ("
|
||||||
|
f" SELECT id FROM {index_name}"
|
||||||
|
f" WHERE {filters_expr} AND {vector_search_filter}"
|
||||||
|
f" ORDER BY {vector_search_expr}"
|
||||||
|
f" APPROXIMATE LIMIT {vector_topn}"
|
||||||
|
f")"
|
||||||
|
f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id"
|
||||||
|
)
|
||||||
logger.debug("OBConnection.search with count sql: %s", count_sql)
|
logger.debug("OBConnection.search with count sql: %s", count_sql)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -976,32 +1026,8 @@ class OBConnection(DocStoreConnection):
|
|||||||
if total_count == 0:
|
if total_count == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
|
if self.use_fulltext_first_fusion_search:
|
||||||
if group_results:
|
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
|
||||||
fusion_sql = (
|
|
||||||
f"WITH fulltext_results AS ("
|
|
||||||
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
|
||||||
f" FROM {index_name}"
|
|
||||||
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
|
||||||
f" ORDER BY relevance DESC"
|
|
||||||
f" LIMIT {num_candidates}"
|
|
||||||
f"),"
|
|
||||||
f" scored_results AS ("
|
|
||||||
f" SELECT *, {score_expr} AS _score"
|
|
||||||
f" FROM fulltext_results"
|
|
||||||
f" WHERE {vector_search_filter}"
|
|
||||||
f"),"
|
|
||||||
f" group_results AS ("
|
|
||||||
f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn"
|
|
||||||
f" FROM scored_results"
|
|
||||||
f")"
|
|
||||||
f" SELECT {fields_expr}, _score"
|
|
||||||
f" FROM group_results"
|
|
||||||
f" WHERE rn = 1"
|
|
||||||
f" ORDER BY _score DESC"
|
|
||||||
f" LIMIT {offset}, {limit}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fusion_sql = (
|
fusion_sql = (
|
||||||
f"WITH fulltext_results AS ("
|
f"WITH fulltext_results AS ("
|
||||||
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
|
||||||
@ -1016,6 +1042,38 @@ class OBConnection(DocStoreConnection):
|
|||||||
f" ORDER BY _score DESC"
|
f" ORDER BY _score DESC"
|
||||||
f" LIMIT {offset}, {limit}"
|
f" LIMIT {offset}, {limit}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
pagerank_score_expr = f"(CAST(IFNULL(f.{PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)"
|
||||||
|
score_expr = f"(f.relevance * {1 - vector_similarity_weight} + v.similarity * {vector_similarity_weight} + {pagerank_score_expr})"
|
||||||
|
fields_expr = ", ".join([f"t.{f} as {f}" for f in output_fields if f != "_score"])
|
||||||
|
fusion_sql = (
|
||||||
|
f"WITH fulltext_results AS ("
|
||||||
|
f" SELECT {fulltext_search_hint} id, pagerank_fea, {fulltext_search_score_expr} AS relevance"
|
||||||
|
f" FROM {index_name}"
|
||||||
|
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||||
|
f" ORDER BY relevance DESC"
|
||||||
|
f" LIMIT {fulltext_topn}"
|
||||||
|
f"),"
|
||||||
|
f"vector_results AS ("
|
||||||
|
f" SELECT id, pagerank_fea, {vector_search_score_expr} AS similarity"
|
||||||
|
f" FROM {index_name}"
|
||||||
|
f" WHERE {filters_expr} AND {vector_search_filter}"
|
||||||
|
f" ORDER BY {vector_search_expr}"
|
||||||
|
f" APPROXIMATE LIMIT {vector_topn}"
|
||||||
|
f"),"
|
||||||
|
f"combined_results AS ("
|
||||||
|
f" SELECT COALESCE(f.id, v.id) AS id, {score_expr} AS score"
|
||||||
|
f" FROM fulltext_results f"
|
||||||
|
f" FULL OUTER JOIN vector_results v"
|
||||||
|
f" ON f.id = v.id"
|
||||||
|
f")"
|
||||||
|
f" SELECT {fields_expr}, c.score as _score"
|
||||||
|
f" FROM combined_results c"
|
||||||
|
f" JOIN {index_name} t"
|
||||||
|
f" ON c.id = t.id"
|
||||||
|
f" ORDER BY score DESC"
|
||||||
|
f" LIMIT {offset}, {limit}"
|
||||||
|
)
|
||||||
logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
|
logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -1234,10 +1292,14 @@ class OBConnection(DocStoreConnection):
|
|||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
result.chunks.append(self._row_to_entity(row, output_fields))
|
result.chunks.append(self._row_to_entity(row, output_fields))
|
||||||
|
|
||||||
|
if result.total == 0:
|
||||||
|
result.total = len(result.chunks)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||||
if not self.client.check_table_exists(indexName):
|
if not self._check_table_exists_cached(indexName):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1336,7 +1398,7 @@ class OBConnection(DocStoreConnection):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||||
if not self.client.check_table_exists(indexName):
|
if not self._check_table_exists_cached(indexName):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
condition["kb_id"] = knowledgebaseId
|
condition["kb_id"] = knowledgebaseId
|
||||||
@ -1387,7 +1449,7 @@ class OBConnection(DocStoreConnection):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||||
if not self.client.check_table_exists(indexName):
|
if not self._check_table_exists_cached(indexName):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
condition["kb_id"] = knowledgebaseId
|
condition["kb_id"] = knowledgebaseId
|
||||||
|
|||||||
Reference in New Issue
Block a user