mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-03 17:15:08 +08:00
feat: add OceanBase memory store (#12955)
### What problem does this PR solve? Add OceanBase memory store and extracting base class `OBConnectionBase`. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
739
common/doc_store/ob_conn_base.py
Normal file
739
common/doc_store/ob_conn_base.py
Normal file
@ -0,0 +1,739 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pymysql.converters import escape_string
|
||||||
|
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, VECTOR
|
||||||
|
from sqlalchemy import Column, Table
|
||||||
|
|
||||||
|
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
|
||||||
|
|
||||||
|
ATTEMPT_TIME = 2
|
||||||
|
|
||||||
|
# Common templates for OceanBase
|
||||||
|
index_name_template = "ix_%s_%s"
|
||||||
|
fulltext_index_name_template = "fts_idx_%s"
|
||||||
|
fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
|
||||||
|
vector_search_template = "cosine_distance(%s, '%s')"
|
||||||
|
vector_column_pattern = re.compile(r"q_(?P<vector_size>\d+)_vec")
|
||||||
|
|
||||||
|
|
||||||
|
def get_value_str(value: Any) -> str:
|
||||||
|
"""Convert value to SQL string representation."""
|
||||||
|
if isinstance(value, str):
|
||||||
|
# escape_string already handles all necessary escaping for MySQL/OceanBase
|
||||||
|
# including backslashes, quotes, newlines, etc.
|
||||||
|
return f"'{escape_string(value)}'"
|
||||||
|
elif isinstance(value, bool):
|
||||||
|
return "true" if value else "false"
|
||||||
|
elif value is None:
|
||||||
|
return "NULL"
|
||||||
|
elif isinstance(value, (list, dict)):
|
||||||
|
json_str = json.dumps(value, ensure_ascii=False)
|
||||||
|
return f"'{escape_string(json_str)}'"
|
||||||
|
else:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None):
|
||||||
|
"""Execute function with distributed lock."""
|
||||||
|
if not timeout:
|
||||||
|
timeout = int(os.environ.get("OB_DDL_TIMEOUT", "60"))
|
||||||
|
|
||||||
|
if not check_func():
|
||||||
|
from rag.utils.redis_conn import RedisDistributedLock
|
||||||
|
lock = RedisDistributedLock(lock_name)
|
||||||
|
if lock.acquire():
|
||||||
|
try:
|
||||||
|
process_func()
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
if "Duplicate" in str(e):
|
||||||
|
return
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
if not check_func():
|
||||||
|
time.sleep(1)
|
||||||
|
count = 1
|
||||||
|
while count < timeout and not check_func():
|
||||||
|
count += 1
|
||||||
|
time.sleep(1)
|
||||||
|
if count >= timeout and not check_func():
|
||||||
|
raise Exception(f"Timeout to wait for process complete for {lock_name}.")
|
||||||
|
|
||||||
|
|
||||||
|
class OBConnectionBase(DocStoreConnection):
|
||||||
|
"""Base class for OceanBase document store connections."""
|
||||||
|
|
||||||
|
def __init__(self, logger_name: str = 'ragflow.ob_conn'):
|
||||||
|
from common.doc_store.ob_conn_pool import OB_CONN
|
||||||
|
|
||||||
|
self.logger = logging.getLogger(logger_name)
|
||||||
|
self.client: ObVecClient = OB_CONN.get_client()
|
||||||
|
self.es = OB_CONN.get_hybrid_search_client()
|
||||||
|
self.db_name = OB_CONN.get_db_name()
|
||||||
|
self.uri = OB_CONN.get_uri()
|
||||||
|
|
||||||
|
self._load_env_vars()
|
||||||
|
|
||||||
|
self._table_exists_cache: set[str] = set()
|
||||||
|
self._table_exists_cache_lock = threading.RLock()
|
||||||
|
|
||||||
|
# Cache for vector columns: stores (table_name, vector_size) tuples
|
||||||
|
self._vector_column_cache: set[tuple[str, int]] = set()
|
||||||
|
self._vector_column_cache_lock = threading.RLock()
|
||||||
|
|
||||||
|
self.logger.info(f"OceanBase {self.uri} connection initialized.")
|
||||||
|
|
||||||
|
def _load_env_vars(self):
|
||||||
|
def is_true(var: str, default: str) -> bool:
|
||||||
|
return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']
|
||||||
|
|
||||||
|
self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
|
||||||
|
self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
|
||||||
|
self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
|
||||||
|
self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')
|
||||||
|
self.use_fulltext_first_fusion_search = is_true('USE_FULLTEXT_FIRST_FUSION_SEARCH', 'true')
|
||||||
|
|
||||||
|
# Adjust settings based on hybrid search availability
|
||||||
|
if self.es is not None and self.search_original_content:
|
||||||
|
self.logger.info("HybridSearch is enabled, forcing search_original_content to False")
|
||||||
|
self.search_original_content = False
|
||||||
|
|
||||||
|
"""
|
||||||
|
Template methods - must be implemented by subclasses
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_index_columns(self) -> list[str]:
|
||||||
|
"""Return list of column names that need regular indexes."""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_fulltext_columns(self) -> list[str]:
|
||||||
|
"""Return list of column names that need fulltext indexes (without weight suffix)."""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_column_definitions(self) -> list[Column]:
|
||||||
|
"""Return list of column definitions for table creation."""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
def get_extra_columns(self) -> list[Column]:
|
||||||
|
"""Return list of extra columns to add after table creation. Override if needed."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_table_name(self, index_name: str, dataset_id: str) -> str:
|
||||||
|
"""Return the actual table name given index_name and dataset_id."""
|
||||||
|
return index_name
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_lock_prefix(self) -> str:
|
||||||
|
"""Return the lock name prefix for distributed locking."""
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Database operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def db_type(self) -> str:
|
||||||
|
return "oceanbase"
|
||||||
|
|
||||||
|
def health(self) -> dict:
|
||||||
|
return {
|
||||||
|
"uri": self.uri,
|
||||||
|
"version_comment": self._get_variable_value("version_comment")
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_variable_value(self, var_name: str) -> Any:
|
||||||
|
rows = self.client.perform_raw_text_sql(f"SHOW VARIABLES LIKE '{var_name}'")
|
||||||
|
for row in rows:
|
||||||
|
return row[1]
|
||||||
|
raise Exception(f"Variable '{var_name}' not found.")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Table operations - common implementation using template methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _check_table_exists_cached(self, table_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check table existence with cache to reduce INFORMATION_SCHEMA queries.
|
||||||
|
Thread-safe implementation using RLock.
|
||||||
|
"""
|
||||||
|
if table_name in self._table_exists_cache:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.client.check_table_exists(table_name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check regular indexes
|
||||||
|
for column_name in self.get_index_columns():
|
||||||
|
if not self._index_exists(table_name, index_name_template % (table_name, column_name)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check fulltext indexes
|
||||||
|
for column_name in self.get_fulltext_columns():
|
||||||
|
if not self._index_exists(table_name, fulltext_index_name_template % column_name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check extra columns
|
||||||
|
for column in self.get_extra_columns():
|
||||||
|
if not self._column_exist(table_name, column.name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"OBConnection._check_table_exists_cached error: {str(e)}")
|
||||||
|
|
||||||
|
with self._table_exists_cache_lock:
|
||||||
|
if table_name not in self._table_exists_cache:
|
||||||
|
self._table_exists_cache.add(table_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _create_table(self, table_name: str):
|
||||||
|
"""Create table using column definitions from subclass."""
|
||||||
|
self._create_table_with_columns(table_name, self.get_column_definitions())
|
||||||
|
|
||||||
|
def create_idx(self, index_name: str, dataset_id: str, vector_size: int, parser_id: str = None):
|
||||||
|
"""Create index/table with all necessary indexes."""
|
||||||
|
table_name = self.get_table_name(index_name, dataset_id)
|
||||||
|
lock_prefix = self.get_lock_prefix()
|
||||||
|
|
||||||
|
try:
|
||||||
|
_try_with_lock(
|
||||||
|
lock_name=f"{lock_prefix}create_table_{table_name}",
|
||||||
|
check_func=lambda: self.client.check_table_exists(table_name),
|
||||||
|
process_func=lambda: self._create_table(table_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
for column_name in self.get_index_columns():
|
||||||
|
_try_with_lock(
|
||||||
|
lock_name=f"{lock_prefix}add_idx_{table_name}_{column_name}",
|
||||||
|
check_func=lambda cn=column_name: self._index_exists(table_name,
|
||||||
|
index_name_template % (table_name, cn)),
|
||||||
|
process_func=lambda cn=column_name: self._add_index(table_name, cn),
|
||||||
|
)
|
||||||
|
|
||||||
|
for column_name in self.get_fulltext_columns():
|
||||||
|
_try_with_lock(
|
||||||
|
lock_name=f"{lock_prefix}add_fulltext_idx_{table_name}_{column_name}",
|
||||||
|
check_func=lambda cn=column_name: self._index_exists(table_name, fulltext_index_name_template % cn),
|
||||||
|
process_func=lambda cn=column_name: self._add_fulltext_index(table_name, cn),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add vector column and index (skip metadata refresh, will be done in finally)
|
||||||
|
self._ensure_vector_column_exists(table_name, vector_size, refresh_metadata=False)
|
||||||
|
|
||||||
|
# Add extra columns if any
|
||||||
|
for column in self.get_extra_columns():
|
||||||
|
_try_with_lock(
|
||||||
|
lock_name=f"{lock_prefix}add_{column.name}_{table_name}",
|
||||||
|
check_func=lambda c=column: self._column_exist(table_name, c.name),
|
||||||
|
process_func=lambda c=column: self._add_column(table_name, c),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"OBConnection.create_idx error: {str(e)}")
|
||||||
|
finally:
|
||||||
|
self.client.refresh_metadata([table_name])
|
||||||
|
|
||||||
|
def create_doc_meta_idx(self, index_name: str):
|
||||||
|
"""
|
||||||
|
Create a document metadata table.
|
||||||
|
|
||||||
|
Table name pattern: ragflow_doc_meta_{tenant_id}
|
||||||
|
- Per-tenant metadata table for storing document metadata fields
|
||||||
|
"""
|
||||||
|
from sqlalchemy import JSON
|
||||||
|
from sqlalchemy.dialects.mysql import VARCHAR
|
||||||
|
|
||||||
|
table_name = index_name
|
||||||
|
lock_prefix = self.get_lock_prefix()
|
||||||
|
|
||||||
|
# Define columns for document metadata table
|
||||||
|
doc_meta_columns = [
|
||||||
|
Column("id", VARCHAR(256), primary_key=True, comment="document id"),
|
||||||
|
Column("kb_id", VARCHAR(256), nullable=False, comment="knowledge base id"),
|
||||||
|
Column("meta_fields", JSON, nullable=True, comment="document metadata fields"),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create table with distributed lock
|
||||||
|
_try_with_lock(
|
||||||
|
lock_name=f"{lock_prefix}create_doc_meta_table_{table_name}",
|
||||||
|
check_func=lambda: self.client.check_table_exists(table_name),
|
||||||
|
process_func=lambda: self._create_table_with_columns(table_name, doc_meta_columns),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create index on kb_id for better query performance
|
||||||
|
_try_with_lock(
|
||||||
|
lock_name=f"{lock_prefix}add_idx_{table_name}_kb_id",
|
||||||
|
check_func=lambda: self._index_exists(table_name, index_name_template % (table_name, "kb_id")),
|
||||||
|
process_func=lambda: self._add_index(table_name, "kb_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(f"Created document metadata table '{table_name}'.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"OBConnection.create_doc_meta_idx error: {str(e)}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self.client.refresh_metadata([table_name])
|
||||||
|
|
||||||
|
def delete_idx(self, index_name: str, dataset_id: str):
|
||||||
|
"""Delete index/table."""
|
||||||
|
# For doc_meta tables, use index_name directly as table name
|
||||||
|
if index_name.startswith("ragflow_doc_meta_"):
|
||||||
|
table_name = index_name
|
||||||
|
else:
|
||||||
|
table_name = self.get_table_name(index_name, dataset_id)
|
||||||
|
try:
|
||||||
|
if self.client.check_table_exists(table_name=table_name):
|
||||||
|
self.client.drop_table_if_exist(table_name)
|
||||||
|
self.logger.info(f"Dropped table '{table_name}'.")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"OBConnection.delete_idx error: {str(e)}")
|
||||||
|
|
||||||
|
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
|
||||||
|
"""Check if index/table exists."""
|
||||||
|
# For doc_meta tables, use index_name directly as table name
|
||||||
|
if index_name.startswith("ragflow_doc_meta_"):
|
||||||
|
table_name = index_name
|
||||||
|
else:
|
||||||
|
table_name = self.get_table_name(index_name, dataset_id) if dataset_id else index_name
|
||||||
|
return self._check_table_exists_cached(table_name)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Table operations - helper methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
|
||||||
|
where_clause = "WHERE " + " AND ".join(filter_list) if filter_list and len(filter_list) > 0 else ""
|
||||||
|
(count,) = self.client.perform_raw_text_sql(
|
||||||
|
f"SELECT COUNT(*) FROM {table_name} {where_clause}"
|
||||||
|
).fetchone()
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _column_exist(self, table_name: str, column_name: str) -> bool:
|
||||||
|
return self._get_count(
|
||||||
|
table_name="INFORMATION_SCHEMA.COLUMNS",
|
||||||
|
filter_list=[
|
||||||
|
f"TABLE_SCHEMA = '{self.db_name}'",
|
||||||
|
f"TABLE_NAME = '{table_name}'",
|
||||||
|
f"COLUMN_NAME = '{column_name}'",
|
||||||
|
]) > 0
|
||||||
|
|
||||||
|
def _index_exists(self, table_name: str, idx_name: str) -> bool:
|
||||||
|
return self._get_count(
|
||||||
|
table_name="INFORMATION_SCHEMA.STATISTICS",
|
||||||
|
filter_list=[
|
||||||
|
f"TABLE_SCHEMA = '{self.db_name}'",
|
||||||
|
f"TABLE_NAME = '{table_name}'",
|
||||||
|
f"INDEX_NAME = '{idx_name}'",
|
||||||
|
]) > 0
|
||||||
|
|
||||||
|
def _create_table_with_columns(self, table_name: str, columns: list[Column]):
|
||||||
|
"""Create table with specified columns."""
|
||||||
|
if table_name in self.client.metadata_obj.tables:
|
||||||
|
self.client.metadata_obj.remove(Table(table_name, self.client.metadata_obj))
|
||||||
|
|
||||||
|
table_options = {
|
||||||
|
"mysql_charset": "utf8mb4",
|
||||||
|
"mysql_collate": "utf8mb4_unicode_ci",
|
||||||
|
"mysql_organization": "heap",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.client.create_table(
|
||||||
|
table_name=table_name,
|
||||||
|
columns=[c.copy() for c in columns],
|
||||||
|
**table_options,
|
||||||
|
)
|
||||||
|
self.logger.info(f"Created table '{table_name}'.")
|
||||||
|
|
||||||
|
def _add_index(self, table_name: str, column_name: str):
|
||||||
|
idx_name = index_name_template % (table_name, column_name)
|
||||||
|
self.client.create_index(
|
||||||
|
table_name=table_name,
|
||||||
|
is_vec_index=False,
|
||||||
|
index_name=idx_name,
|
||||||
|
column_names=[column_name],
|
||||||
|
)
|
||||||
|
self.logger.info(f"Created index '{idx_name}' on table '{table_name}'.")
|
||||||
|
|
||||||
|
def _add_fulltext_index(self, table_name: str, column_name: str):
|
||||||
|
fulltext_idx_name = fulltext_index_name_template % column_name
|
||||||
|
self.client.create_fts_idx_with_fts_index_param(
|
||||||
|
table_name=table_name,
|
||||||
|
fts_idx_param=FtsIndexParam(
|
||||||
|
index_name=fulltext_idx_name,
|
||||||
|
field_names=[column_name],
|
||||||
|
parser_type=FtsParser.IK,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.logger.info(f"Created full text index '{fulltext_idx_name}' on table '{table_name}'.")
|
||||||
|
|
||||||
|
def _add_vector_column(self, table_name: str, vector_size: int):
|
||||||
|
vector_field_name = f"q_{vector_size}_vec"
|
||||||
|
self.client.add_columns(
|
||||||
|
table_name=table_name,
|
||||||
|
columns=[Column(vector_field_name, VECTOR(vector_size), nullable=True)],
|
||||||
|
)
|
||||||
|
self.logger.info(f"Added vector column '{vector_field_name}' to table '{table_name}'.")
|
||||||
|
|
||||||
|
def _add_vector_index(self, table_name: str, vector_field_name: str):
|
||||||
|
vector_idx_name = f"{vector_field_name}_idx"
|
||||||
|
self.client.create_index(
|
||||||
|
table_name=table_name,
|
||||||
|
is_vec_index=True,
|
||||||
|
index_name=vector_idx_name,
|
||||||
|
column_names=[vector_field_name],
|
||||||
|
vidx_params="distance=cosine, type=hnsw, lib=vsag",
|
||||||
|
)
|
||||||
|
self.logger.info(
|
||||||
|
f"Created vector index '{vector_idx_name}' on table '{table_name}' with column '{vector_field_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_column(self, table_name: str, column: Column):
|
||||||
|
try:
|
||||||
|
self.client.add_columns(
|
||||||
|
table_name=table_name,
|
||||||
|
columns=[column.copy()],
|
||||||
|
)
|
||||||
|
self.logger.info(f"Added column '{column.name}' to table '{table_name}'.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to add column '{column.name}' to table '{table_name}': {str(e)}")
|
||||||
|
|
||||||
|
def _ensure_vector_column_exists(self, table_name: str, vector_size: int, refresh_metadata: bool = True):
|
||||||
|
"""
|
||||||
|
Ensure vector column and index exist for the given vector size.
|
||||||
|
This method is safe to call multiple times - it will skip if already exists.
|
||||||
|
Uses cache to avoid repeated INFORMATION_SCHEMA queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Name of the table
|
||||||
|
vector_size: Size of the vector column
|
||||||
|
refresh_metadata: Whether to refresh SQLAlchemy metadata after changes (default True)
|
||||||
|
"""
|
||||||
|
if vector_size <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache_key = (table_name, vector_size)
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
if cache_key in self._vector_column_cache:
|
||||||
|
return
|
||||||
|
|
||||||
|
lock_prefix = self.get_lock_prefix()
|
||||||
|
vector_field_name = f"q_{vector_size}_vec"
|
||||||
|
vector_index_name = f"{vector_field_name}_idx"
|
||||||
|
|
||||||
|
# Check if already exists (may have been created by another process)
|
||||||
|
column_exists = self._column_exist(table_name, vector_field_name)
|
||||||
|
index_exists = self._index_exists(table_name, vector_index_name)
|
||||||
|
|
||||||
|
if column_exists and index_exists:
|
||||||
|
# Already exists, add to cache and return
|
||||||
|
with self._vector_column_cache_lock:
|
||||||
|
self._vector_column_cache.add(cache_key)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create column if needed
|
||||||
|
if not column_exists:
|
||||||
|
_try_with_lock(
|
||||||
|
lock_name=f"{lock_prefix}add_vector_column_{table_name}_{vector_field_name}",
|
||||||
|
check_func=lambda: self._column_exist(table_name, vector_field_name),
|
||||||
|
process_func=lambda: self._add_vector_column(table_name, vector_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create index if needed
|
||||||
|
if not index_exists:
|
||||||
|
_try_with_lock(
|
||||||
|
lock_name=f"{lock_prefix}add_vector_idx_{table_name}_{vector_field_name}",
|
||||||
|
check_func=lambda: self._index_exists(table_name, vector_index_name),
|
||||||
|
process_func=lambda: self._add_vector_index(table_name, vector_field_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
if refresh_metadata:
|
||||||
|
self.client.refresh_metadata([table_name])
|
||||||
|
|
||||||
|
# Add to cache after successful creation
|
||||||
|
with self._vector_column_cache_lock:
|
||||||
|
self._vector_column_cache.add(cache_key)
|
||||||
|
|
||||||
|
def _execute_search_sql(self, sql: str) -> tuple[list, float]:
|
||||||
|
start_time = time.time()
|
||||||
|
res = self.client.perform_raw_text_sql(sql)
|
||||||
|
rows = res.fetchall()
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
return rows, elapsed_time
|
||||||
|
|
||||||
|
def _parse_fulltext_columns(
|
||||||
|
self,
|
||||||
|
fulltext_query: str,
|
||||||
|
fulltext_columns: list[str]
|
||||||
|
) -> tuple[dict[str, str], dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Parse fulltext search columns with optional weight suffix and build search expressions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fulltext_query: The escaped fulltext query string
|
||||||
|
fulltext_columns: List of column names, optionally with weight suffix (e.g., "col^0.5")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (fulltext_search_expr dict, fulltext_search_weight dict)
|
||||||
|
where weights are normalized to 0~1
|
||||||
|
"""
|
||||||
|
fulltext_search_expr: dict[str, str] = {}
|
||||||
|
fulltext_search_weight: dict[str, float] = {}
|
||||||
|
|
||||||
|
# get fulltext match expression and weight values
|
||||||
|
for field in fulltext_columns:
|
||||||
|
parts = field.split("^")
|
||||||
|
column_name: str = parts[0]
|
||||||
|
column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0
|
||||||
|
|
||||||
|
fulltext_search_weight[column_name] = column_weight
|
||||||
|
fulltext_search_expr[column_name] = fulltext_search_template % (column_name, fulltext_query)
|
||||||
|
|
||||||
|
# adjust the weight to 0~1
|
||||||
|
weight_sum = sum(fulltext_search_weight.values())
|
||||||
|
n = len(fulltext_search_weight)
|
||||||
|
if weight_sum <= 0 < n:
|
||||||
|
# All weights are 0 (e.g. "col^0"); use equal weights to avoid ZeroDivisionError
|
||||||
|
for column_name in fulltext_search_weight:
|
||||||
|
fulltext_search_weight[column_name] = 1.0 / n
|
||||||
|
else:
|
||||||
|
for column_name in fulltext_search_weight:
|
||||||
|
fulltext_search_weight[column_name] = fulltext_search_weight[column_name] / weight_sum
|
||||||
|
|
||||||
|
return fulltext_search_expr, fulltext_search_weight
|
||||||
|
|
||||||
|
def _build_vector_search_sql(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
fields_expr: str,
|
||||||
|
vector_search_score_expr: str,
|
||||||
|
filters_expr: str,
|
||||||
|
vector_search_filter: str,
|
||||||
|
vector_search_expr: str,
|
||||||
|
limit: int,
|
||||||
|
vector_topn: int,
|
||||||
|
offset: int = 0
|
||||||
|
) -> str:
|
||||||
|
sql = (
|
||||||
|
f"SELECT {fields_expr}, {vector_search_score_expr} AS _score"
|
||||||
|
f" FROM {table_name}"
|
||||||
|
f" WHERE {filters_expr} AND {vector_search_filter}"
|
||||||
|
f" ORDER BY {vector_search_expr}"
|
||||||
|
f" APPROXIMATE LIMIT {limit if limit != 0 else vector_topn}"
|
||||||
|
)
|
||||||
|
if offset != 0:
|
||||||
|
sql += f" OFFSET {offset}"
|
||||||
|
return sql
|
||||||
|
|
||||||
|
def _build_fulltext_search_sql(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
fields_expr: str,
|
||||||
|
fulltext_search_score_expr: str,
|
||||||
|
filters_expr: str,
|
||||||
|
fulltext_search_filter: str,
|
||||||
|
offset: int,
|
||||||
|
limit: int,
|
||||||
|
fulltext_topn: int,
|
||||||
|
hint: str = ""
|
||||||
|
) -> str:
|
||||||
|
hint_expr = f"{hint} " if hint else ""
|
||||||
|
return (
|
||||||
|
f"SELECT {hint_expr}{fields_expr}, {fulltext_search_score_expr} AS _score"
|
||||||
|
f" FROM {table_name}"
|
||||||
|
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||||
|
f" ORDER BY _score DESC"
|
||||||
|
f" LIMIT {offset}, {limit if limit != 0 else fulltext_topn}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_filter_search_sql(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
fields_expr: str,
|
||||||
|
filters_expr: str,
|
||||||
|
order_by_expr: str = "",
|
||||||
|
limit_expr: str = ""
|
||||||
|
) -> str:
|
||||||
|
return (
|
||||||
|
f"SELECT {fields_expr}"
|
||||||
|
f" FROM {table_name}"
|
||||||
|
f" WHERE {filters_expr}"
|
||||||
|
f" {order_by_expr} {limit_expr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_count_sql(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
filters_expr: str,
|
||||||
|
extra_filter: str = "",
|
||||||
|
hint: str = ""
|
||||||
|
) -> str:
|
||||||
|
hint_expr = f"{hint} " if hint else ""
|
||||||
|
where_clause = f"{filters_expr} AND {extra_filter}" if extra_filter else filters_expr
|
||||||
|
return f"SELECT {hint_expr}COUNT(id) FROM {table_name} WHERE {where_clause}"
|
||||||
|
|
||||||
|
def _row_to_entity(self, data, fields: list[str]) -> dict:
|
||||||
|
entity = {}
|
||||||
|
for i, field in enumerate(fields):
|
||||||
|
value = data[i]
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
entity[field] = value
|
||||||
|
return entity
|
||||||
|
|
||||||
|
def _get_dataset_id_field(self) -> str:
|
||||||
|
return "kb_id"
|
||||||
|
|
||||||
|
def _get_filters(self, condition: dict) -> list[str]:
|
||||||
|
filters: list[str] = []
|
||||||
|
for k, v in condition.items():
|
||||||
|
if not v:
|
||||||
|
continue
|
||||||
|
if k == "exists":
|
||||||
|
filters.append(f"{v} IS NOT NULL")
|
||||||
|
elif k == "must_not" and isinstance(v, dict) and "exists" in v:
|
||||||
|
filters.append(f"{v.get('exists')} IS NULL")
|
||||||
|
elif isinstance(v, list):
|
||||||
|
values: list[str] = []
|
||||||
|
for item in v:
|
||||||
|
values.append(get_value_str(item))
|
||||||
|
value = ", ".join(values)
|
||||||
|
filters.append(f"{k} IN ({value})")
|
||||||
|
else:
|
||||||
|
filters.append(f"{k} = {get_value_str(v)}")
|
||||||
|
return filters
|
||||||
|
|
||||||
|
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
|
||||||
|
if not self._check_table_exists_cached(index_name):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
res = self.client.get(
|
||||||
|
table_name=index_name,
|
||||||
|
ids=[doc_id],
|
||||||
|
)
|
||||||
|
row = res.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._row_to_entity(row, fields=list(res.keys()))
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"OBConnectionBase.get({doc_id}) got exception")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
|
||||||
|
if not self._check_table_exists_cached(index_name):
|
||||||
|
return 0
|
||||||
|
# For doc_meta tables, don't add dataset_id to condition
|
||||||
|
if not index_name.startswith("ragflow_doc_meta_"):
|
||||||
|
condition[self._get_dataset_id_field()] = dataset_id
|
||||||
|
try:
|
||||||
|
from sqlalchemy import text
|
||||||
|
res = self.client.get(
|
||||||
|
table_name=index_name,
|
||||||
|
ids=None,
|
||||||
|
where_clause=[text(f) for f in self._get_filters(condition)],
|
||||||
|
output_column_name=["id"],
|
||||||
|
)
|
||||||
|
rows = res.fetchall()
|
||||||
|
if len(rows) == 0:
|
||||||
|
return 0
|
||||||
|
ids = [row[0] for row in rows]
|
||||||
|
self.logger.debug(f"OBConnection.delete, filters: {condition}, ids: {ids}")
|
||||||
|
self.client.delete(
|
||||||
|
table_name=index_name,
|
||||||
|
ids=ids,
|
||||||
|
)
|
||||||
|
return len(ids)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"OBConnection.delete error: {str(e)}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
"""
|
||||||
|
Abstract CRUD methods that must be implemented by subclasses
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
select_fields: list[str],
|
||||||
|
highlight_fields: list[str],
|
||||||
|
condition: dict,
|
||||||
|
match_expressions: list[MatchExpr],
|
||||||
|
order_by: OrderByExpr,
|
||||||
|
offset: int,
|
||||||
|
limit: int,
|
||||||
|
index_names: str | list[str],
|
||||||
|
knowledgebase_ids: list[str],
|
||||||
|
agg_fields: list[str] | None = None,
|
||||||
|
rank_feature: dict | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helper functions for search result - abstract methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_total(self, res) -> int:
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_doc_ids(self, res) -> list[str]:
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_highlight(self, res, keywords: list[str], field_name: str):
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_aggregation(self, res, field_name: str):
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
"""
|
||||||
|
SQL - can be overridden by subclasses
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sql(self, sql: str, fetch_size: int, format: str):
|
||||||
|
"""Execute SQL query - default implementation."""
|
||||||
|
return None
|
||||||
191
common/doc_store/ob_conn_pool.py
Normal file
191
common/doc_store/ob_conn_pool.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pyobvector import ObVecClient
|
||||||
|
from pyobvector.client import ClusterVersionException
|
||||||
|
from pyobvector.client.hybrid_search import HybridSearch
|
||||||
|
from pyobvector.util import ObVersion
|
||||||
|
|
||||||
|
from common import settings
|
||||||
|
from common.decorator import singleton
|
||||||
|
|
||||||
|
ATTEMPT_TIME = 2
|
||||||
|
OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))
|
||||||
|
|
||||||
|
logger = logging.getLogger('ragflow.ob_conn_pool')
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class OceanBaseConnectionPool:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.client = None
|
||||||
|
self.es = None # HybridSearch client
|
||||||
|
|
||||||
|
if hasattr(settings, "OB"):
|
||||||
|
self.OB_CONFIG = settings.OB
|
||||||
|
else:
|
||||||
|
self.OB_CONFIG = settings.get_base_config("oceanbase", {})
|
||||||
|
|
||||||
|
scheme = self.OB_CONFIG.get("scheme")
|
||||||
|
ob_config = self.OB_CONFIG.get("config", {})
|
||||||
|
|
||||||
|
if scheme and scheme.lower() == "mysql":
|
||||||
|
mysql_config = settings.get_base_config("mysql", {})
|
||||||
|
logger.info("Use MySQL scheme to create OceanBase connection.")
|
||||||
|
host = mysql_config.get("host", "localhost")
|
||||||
|
port = mysql_config.get("port", 2881)
|
||||||
|
self.username = mysql_config.get("user", "root@test")
|
||||||
|
self.password = mysql_config.get("password", "infini_rag_flow")
|
||||||
|
max_connections = mysql_config.get("max_connections", 300)
|
||||||
|
else:
|
||||||
|
logger.info("Use customized config to create OceanBase connection.")
|
||||||
|
host = ob_config.get("host", "localhost")
|
||||||
|
port = ob_config.get("port", 2881)
|
||||||
|
self.username = ob_config.get("user", "root@test")
|
||||||
|
self.password = ob_config.get("password", "infini_rag_flow")
|
||||||
|
max_connections = ob_config.get("max_connections", 300)
|
||||||
|
|
||||||
|
self.db_name = ob_config.get("db_name", "test")
|
||||||
|
self.uri = f"{host}:{port}"
|
||||||
|
|
||||||
|
logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")
|
||||||
|
|
||||||
|
max_overflow = int(os.environ.get("OB_MAX_OVERFLOW", max(max_connections // 2, 10)))
|
||||||
|
pool_timeout = int(os.environ.get("OB_POOL_TIMEOUT", "30"))
|
||||||
|
|
||||||
|
for _ in range(ATTEMPT_TIME):
|
||||||
|
try:
|
||||||
|
self.client = ObVecClient(
|
||||||
|
uri=self.uri,
|
||||||
|
user=self.username,
|
||||||
|
password=self.password,
|
||||||
|
db_name=self.db_name,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=3600,
|
||||||
|
pool_size=max_connections,
|
||||||
|
max_overflow=max_overflow,
|
||||||
|
pool_timeout=pool_timeout,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"{str(e)}. Waiting OceanBase {self.uri} to be healthy.")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
if self.client is None:
|
||||||
|
msg = f"OceanBase {self.uri} connection failed after {ATTEMPT_TIME} attempts."
|
||||||
|
logger.error(msg)
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
self._check_ob_version()
|
||||||
|
self._try_to_update_ob_query_timeout()
|
||||||
|
self._init_hybrid_search(max_connections, max_overflow, pool_timeout)
|
||||||
|
|
||||||
|
logger.info(f"OceanBase {self.uri} is healthy.")
|
||||||
|
|
||||||
|
def _check_ob_version(self):
|
||||||
|
try:
|
||||||
|
res = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL").fetchone()
|
||||||
|
version_str = res[0] if res else None
|
||||||
|
logger.info(f"OceanBase {self.uri} version is {version_str}")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to get OceanBase version from {self.uri}, error: {str(e)}")
|
||||||
|
|
||||||
|
if not version_str:
|
||||||
|
raise Exception(f"Failed to get OceanBase version from {self.uri}.")
|
||||||
|
|
||||||
|
ob_version = ObVersion.from_db_version_string(version_str)
|
||||||
|
if ob_version < ObVersion.from_db_version_nums(4, 3, 5, 1):
|
||||||
|
raise Exception(
|
||||||
|
f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _try_to_update_ob_query_timeout(self):
|
||||||
|
try:
|
||||||
|
rows = self.client.perform_raw_text_sql("SHOW VARIABLES LIKE 'ob_query_timeout'")
|
||||||
|
for row in rows:
|
||||||
|
val = row[1]
|
||||||
|
if val and int(val) >= OB_QUERY_TIMEOUT:
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to get 'ob_query_timeout' variable: %s", str(e))
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.perform_raw_text_sql(f"SET GLOBAL ob_query_timeout={OB_QUERY_TIMEOUT}")
|
||||||
|
logger.info("Set GLOBAL variable 'ob_query_timeout' to %d.", OB_QUERY_TIMEOUT)
|
||||||
|
self.client.engine.dispose()
|
||||||
|
logger.info("Disposed all connections in engine pool to refresh connection pool")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to set 'ob_query_timeout' variable: {str(e)}")
|
||||||
|
|
||||||
|
def _init_hybrid_search(self, max_connections, max_overflow, pool_timeout):
|
||||||
|
enable_hybrid_search = os.getenv('ENABLE_HYBRID_SEARCH', 'false').lower() in ['true', '1', 'yes', 'y']
|
||||||
|
if enable_hybrid_search:
|
||||||
|
try:
|
||||||
|
self.es = HybridSearch(
|
||||||
|
uri=self.uri,
|
||||||
|
user=self.username,
|
||||||
|
password=self.password,
|
||||||
|
db_name=self.db_name,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=3600,
|
||||||
|
pool_size=max_connections,
|
||||||
|
max_overflow=max_overflow,
|
||||||
|
pool_timeout=pool_timeout,
|
||||||
|
)
|
||||||
|
logger.info("OceanBase Hybrid Search feature is enabled")
|
||||||
|
except ClusterVersionException as e:
|
||||||
|
logger.info("Failed to initialize HybridSearch client, fallback to use SQL", exc_info=e)
|
||||||
|
self.es = None
|
||||||
|
|
||||||
|
def get_client(self) -> ObVecClient:
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
def get_hybrid_search_client(self) -> HybridSearch | None:
|
||||||
|
return self.es
|
||||||
|
|
||||||
|
def get_db_name(self) -> str:
|
||||||
|
return self.db_name
|
||||||
|
|
||||||
|
def get_uri(self) -> str:
|
||||||
|
return self.uri
|
||||||
|
|
||||||
|
def refresh_client(self) -> ObVecClient:
|
||||||
|
try:
|
||||||
|
self.client.perform_raw_text_sql("SELECT 1 FROM DUAL")
|
||||||
|
return self.client
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"OceanBase connection unhealthy: {str(e)}, refreshing...")
|
||||||
|
self.client.engine.dispose()
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if hasattr(self, "client") and self.client:
|
||||||
|
try:
|
||||||
|
self.client.engine.dispose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if hasattr(self, "es") and self.es:
|
||||||
|
try:
|
||||||
|
self.es.engine.dispose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
OB_CONN = OceanBaseConnectionPool()
|
||||||
@ -41,6 +41,7 @@ from rag.nlp import search
|
|||||||
|
|
||||||
import memory.utils.es_conn as memory_es_conn
|
import memory.utils.es_conn as memory_es_conn
|
||||||
import memory.utils.infinity_conn as memory_infinity_conn
|
import memory.utils.infinity_conn as memory_infinity_conn
|
||||||
|
import memory.utils.ob_conn as memory_ob_conn
|
||||||
|
|
||||||
LLM = None
|
LLM = None
|
||||||
LLM_FACTORY = None
|
LLM_FACTORY = None
|
||||||
@ -281,6 +282,8 @@ def init_settings():
|
|||||||
"db_name": "default_db"
|
"db_name": "default_db"
|
||||||
})
|
})
|
||||||
msgStoreConn = memory_infinity_conn.InfinityConnection()
|
msgStoreConn = memory_infinity_conn.InfinityConnection()
|
||||||
|
elif lower_case_doc_engine in ["oceanbase", "seekdb"]:
|
||||||
|
msgStoreConn = memory_ob_conn.OBConnection()
|
||||||
|
|
||||||
global AZURE, S3, MINIO, OSS, GCS
|
global AZURE, S3, MINIO, OSS, GCS
|
||||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||||
|
|||||||
613
memory/utils/ob_conn.py
Normal file
613
memory/utils/ob_conn.py
Normal file
@ -0,0 +1,613 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pymysql.converters import escape_string
|
||||||
|
from sqlalchemy import Column, String, Integer
|
||||||
|
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||||
|
|
||||||
|
from common.decorator import singleton
|
||||||
|
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
|
||||||
|
from common.doc_store.ob_conn_base import OBConnectionBase, get_value_str, vector_search_template
|
||||||
|
from common.float_utils import get_float
|
||||||
|
from rag.nlp.rag_tokenizer import tokenize, fine_grained_tokenize
|
||||||
|
|
||||||
|
# Column definitions for memory message table
|
||||||
|
COLUMN_DEFINITIONS: list[Column] = [
|
||||||
|
Column("id", String(256), primary_key=True, comment="unique record id"),
|
||||||
|
Column("message_id", String(256), nullable=False, index=True, comment="message id"),
|
||||||
|
Column("message_type_kwd", String(64), nullable=True, comment="message type"),
|
||||||
|
Column("source_id", String(256), nullable=True, comment="source message id"),
|
||||||
|
Column("memory_id", String(256), nullable=False, index=True, comment="memory id"),
|
||||||
|
Column("user_id", String(256), nullable=True, comment="user id"),
|
||||||
|
Column("agent_id", String(256), nullable=True, comment="agent id"),
|
||||||
|
Column("session_id", String(256), nullable=True, comment="session id"),
|
||||||
|
Column("zone_id", Integer, nullable=True, server_default="0", comment="zone id"),
|
||||||
|
Column("valid_at", String(64), nullable=True, comment="valid at timestamp string"),
|
||||||
|
Column("invalid_at", String(64), nullable=True, comment="invalid at timestamp string"),
|
||||||
|
Column("forget_at", String(64), nullable=True, comment="forget at timestamp string"),
|
||||||
|
Column("status_int", Integer, nullable=False, server_default="1", comment="status: 1 for active, 0 for inactive"),
|
||||||
|
Column("content_ltks", LONGTEXT, nullable=True, comment="content with tokenization"),
|
||||||
|
Column("tokenized_content_ltks", LONGTEXT, nullable=True, comment="fine-grained tokenized content"),
|
||||||
|
]
|
||||||
|
|
||||||
|
COLUMN_NAMES: list[str] = [col.name for col in COLUMN_DEFINITIONS]
|
||||||
|
|
||||||
|
# Index columns for creating indexes
|
||||||
|
INDEX_COLUMNS: list[str] = [
|
||||||
|
"message_id",
|
||||||
|
"memory_id",
|
||||||
|
"status_int",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Full-text search columns
|
||||||
|
FTS_COLUMNS: list[str] = [
|
||||||
|
"content_ltks",
|
||||||
|
"tokenized_content_ltks",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
total: int
|
||||||
|
messages: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class OBConnection(OBConnectionBase):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(logger_name='ragflow.memory_ob_conn')
|
||||||
|
self._fulltext_search_columns = FTS_COLUMNS
|
||||||
|
|
||||||
|
"""
|
||||||
|
Template method implementations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_index_columns(self) -> list[str]:
|
||||||
|
return INDEX_COLUMNS
|
||||||
|
|
||||||
|
def get_fulltext_columns(self) -> list[str]:
|
||||||
|
"""Return list of column names that need fulltext indexes (without weight suffix)."""
|
||||||
|
return [col.split("^")[0] for col in self._fulltext_search_columns]
|
||||||
|
|
||||||
|
def get_column_definitions(self) -> list[Column]:
|
||||||
|
return COLUMN_DEFINITIONS
|
||||||
|
|
||||||
|
def get_lock_prefix(self) -> str:
|
||||||
|
return "ob_memory_"
|
||||||
|
|
||||||
|
def _get_dataset_id_field(self) -> str:
|
||||||
|
return "memory_id"
|
||||||
|
|
||||||
|
def _get_vector_column_name_from_table(self, table_name: str) -> Optional[str]:
|
||||||
|
"""Get the vector column name from the table (q_{size}_vec pattern)."""
|
||||||
|
sql = f"""
|
||||||
|
SELECT COLUMN_NAME
|
||||||
|
FROM INFORMATION_SCHEMA.COLUMNS
|
||||||
|
WHERE TABLE_SCHEMA = '{self.db_name}'
|
||||||
|
AND TABLE_NAME = '{table_name}'
|
||||||
|
AND COLUMN_NAME REGEXP '^q_[0-9]+_vec$'
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
res = self.client.perform_raw_text_sql(sql)
|
||||||
|
row = res.fetchone()
|
||||||
|
return row[0] if row else None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
"""
|
||||||
|
Field conversion methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_field_name(field_name: str, use_tokenized_content=False) -> str:
|
||||||
|
"""Convert message field name to database column name."""
|
||||||
|
match field_name:
|
||||||
|
case "message_type":
|
||||||
|
return "message_type_kwd"
|
||||||
|
case "status":
|
||||||
|
return "status_int"
|
||||||
|
case "content":
|
||||||
|
if use_tokenized_content:
|
||||||
|
return "tokenized_content_ltks"
|
||||||
|
return "content_ltks"
|
||||||
|
case _:
|
||||||
|
return field_name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def map_message_to_ob_fields(message: dict) -> dict:
|
||||||
|
"""Map message dictionary fields to OceanBase document fields."""
|
||||||
|
storage_doc = {
|
||||||
|
"id": message.get("id"),
|
||||||
|
"message_id": message["message_id"],
|
||||||
|
"message_type_kwd": message["message_type"],
|
||||||
|
"source_id": message.get("source_id"),
|
||||||
|
"memory_id": message["memory_id"],
|
||||||
|
"user_id": message.get("user_id", ""),
|
||||||
|
"agent_id": message["agent_id"],
|
||||||
|
"session_id": message["session_id"],
|
||||||
|
"valid_at": message["valid_at"],
|
||||||
|
"invalid_at": message.get("invalid_at"),
|
||||||
|
"forget_at": message.get("forget_at"),
|
||||||
|
"status_int": 1 if message["status"] else 0,
|
||||||
|
"zone_id": message.get("zone_id", 0),
|
||||||
|
"content_ltks": message["content"],
|
||||||
|
"tokenized_content_ltks": fine_grained_tokenize(tokenize(message["content"])),
|
||||||
|
}
|
||||||
|
# Handle vector embedding
|
||||||
|
content_embed = message.get("content_embed", [])
|
||||||
|
if len(content_embed) > 0:
|
||||||
|
storage_doc[f"q_{len(content_embed)}_vec"] = content_embed
|
||||||
|
return storage_doc
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_message_from_ob_doc(doc: dict) -> dict:
|
||||||
|
"""Convert an OceanBase document back to a message dictionary."""
|
||||||
|
embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None)
|
||||||
|
content_embed = doc.get(embd_field_name, []) if embd_field_name else []
|
||||||
|
if isinstance(content_embed, np.ndarray):
|
||||||
|
content_embed = content_embed.tolist()
|
||||||
|
message = {
|
||||||
|
"message_id": doc.get("message_id"),
|
||||||
|
"message_type": doc.get("message_type_kwd"),
|
||||||
|
"source_id": doc.get("source_id") if doc.get("source_id") else None,
|
||||||
|
"memory_id": doc.get("memory_id"),
|
||||||
|
"user_id": doc.get("user_id", ""),
|
||||||
|
"agent_id": doc.get("agent_id"),
|
||||||
|
"session_id": doc.get("session_id"),
|
||||||
|
"zone_id": doc.get("zone_id", 0),
|
||||||
|
"valid_at": doc.get("valid_at"),
|
||||||
|
"invalid_at": doc.get("invalid_at", "-"),
|
||||||
|
"forget_at": doc.get("forget_at", "-"),
|
||||||
|
"status": bool(int(doc.get("status_int", 0))),
|
||||||
|
"content": doc.get("content_ltks", ""),
|
||||||
|
"content_embed": content_embed,
|
||||||
|
}
|
||||||
|
if doc.get("id"):
|
||||||
|
message["id"] = doc["id"]
|
||||||
|
return message
|
||||||
|
|
||||||
|
"""
|
||||||
|
CRUD operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
select_fields: list[str],
|
||||||
|
highlight_fields: list[str],
|
||||||
|
condition: dict,
|
||||||
|
match_expressions: list[MatchExpr],
|
||||||
|
order_by: OrderByExpr,
|
||||||
|
offset: int,
|
||||||
|
limit: int,
|
||||||
|
index_names: str | list[str],
|
||||||
|
memory_ids: list[str],
|
||||||
|
agg_fields: list[str] | None = None,
|
||||||
|
rank_feature: dict | None = None,
|
||||||
|
hide_forgotten: bool = True
|
||||||
|
):
|
||||||
|
"""Search messages in memory storage."""
|
||||||
|
if isinstance(index_names, str):
|
||||||
|
index_names = index_names.split(",")
|
||||||
|
assert isinstance(index_names, list) and len(index_names) > 0
|
||||||
|
|
||||||
|
result: SearchResult = SearchResult(total=0, messages=[])
|
||||||
|
|
||||||
|
output_fields = select_fields.copy()
|
||||||
|
if "id" not in output_fields:
|
||||||
|
output_fields = ["id"] + output_fields
|
||||||
|
if "_score" in output_fields:
|
||||||
|
output_fields.remove("_score")
|
||||||
|
|
||||||
|
# Handle content_embed field - resolve to actual vector column name
|
||||||
|
has_content_embed = "content_embed" in output_fields
|
||||||
|
actual_vector_column: Optional[str] = None
|
||||||
|
if has_content_embed:
|
||||||
|
output_fields = [f for f in output_fields if f != "content_embed"]
|
||||||
|
# Try to get vector column name from first available table
|
||||||
|
for idx_name in index_names:
|
||||||
|
if self._check_table_exists_cached(idx_name):
|
||||||
|
actual_vector_column = self._get_vector_column_name_from_table(idx_name)
|
||||||
|
if actual_vector_column:
|
||||||
|
output_fields.append(actual_vector_column)
|
||||||
|
break
|
||||||
|
|
||||||
|
if highlight_fields:
|
||||||
|
for field in highlight_fields:
|
||||||
|
field_name = self.convert_field_name(field)
|
||||||
|
if field_name not in output_fields:
|
||||||
|
output_fields.append(field_name)
|
||||||
|
|
||||||
|
db_output_fields = [self.convert_field_name(f) for f in output_fields]
|
||||||
|
fields_expr = ", ".join(db_output_fields)
|
||||||
|
|
||||||
|
condition["memory_id"] = memory_ids
|
||||||
|
if hide_forgotten:
|
||||||
|
condition["must_not"] = {"exists": "forget_at"}
|
||||||
|
|
||||||
|
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||||
|
filters: list[str] = self._get_filters(condition_dict)
|
||||||
|
filters_expr = " AND ".join(filters) if filters else "1=1"
|
||||||
|
|
||||||
|
# Parse match expressions
|
||||||
|
fulltext_query: Optional[str] = None
|
||||||
|
fulltext_topn: Optional[int] = None
|
||||||
|
fulltext_search_expr: dict[str, str] = {}
|
||||||
|
fulltext_search_weight: dict[str, float] = {}
|
||||||
|
fulltext_search_filter: Optional[str] = None
|
||||||
|
fulltext_search_score_expr: Optional[str] = None
|
||||||
|
|
||||||
|
vector_column_name: Optional[str] = None
|
||||||
|
vector_data: Optional[list[float]] = None
|
||||||
|
vector_topn: Optional[int] = None
|
||||||
|
vector_similarity_threshold: Optional[float] = None
|
||||||
|
vector_similarity_weight: Optional[float] = None
|
||||||
|
vector_search_expr: Optional[str] = None
|
||||||
|
vector_search_score_expr: Optional[str] = None
|
||||||
|
vector_search_filter: Optional[str] = None
|
||||||
|
|
||||||
|
for m in match_expressions:
|
||||||
|
if isinstance(m, MatchTextExpr):
|
||||||
|
assert "original_query" in m.extra_options, "'original_query' is missing in extra_options."
|
||||||
|
fulltext_query = m.extra_options["original_query"]
|
||||||
|
fulltext_query = escape_string(fulltext_query.strip())
|
||||||
|
fulltext_topn = m.topn
|
||||||
|
|
||||||
|
fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns(
|
||||||
|
fulltext_query, self._fulltext_search_columns
|
||||||
|
)
|
||||||
|
elif isinstance(m, MatchDenseExpr):
|
||||||
|
vector_column_name = m.vector_column_name
|
||||||
|
vector_data = m.embedding_data
|
||||||
|
vector_topn = m.topn
|
||||||
|
vector_similarity_threshold = m.extra_options.get("similarity", 0.0) if m.extra_options else 0.0
|
||||||
|
elif isinstance(m, FusionExpr):
|
||||||
|
weights = m.fusion_params.get("weights", "0.5,0.5") if m.fusion_params else "0.5,0.5"
|
||||||
|
vector_similarity_weight = get_float(weights.split(",")[1])
|
||||||
|
|
||||||
|
if fulltext_query:
|
||||||
|
fulltext_search_filter = f"({' OR '.join([expr for expr in fulltext_search_expr.values()])})"
|
||||||
|
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
|
||||||
|
|
||||||
|
if vector_data:
|
||||||
|
vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]"
|
||||||
|
vector_search_expr = vector_search_template % (vector_column_name, vector_data_str)
|
||||||
|
vector_search_score_expr = f"(1 - {vector_search_expr})"
|
||||||
|
vector_search_filter = f"{vector_search_score_expr} >= {vector_similarity_threshold}"
|
||||||
|
|
||||||
|
# Determine search type
|
||||||
|
if fulltext_query and vector_data:
|
||||||
|
search_type = "fusion"
|
||||||
|
elif fulltext_query:
|
||||||
|
search_type = "fulltext"
|
||||||
|
elif vector_data:
|
||||||
|
search_type = "vector"
|
||||||
|
else:
|
||||||
|
search_type = "filter"
|
||||||
|
|
||||||
|
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
|
||||||
|
output_fields.append("_score")
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
if vector_topn is not None:
|
||||||
|
limit = min(vector_topn, limit)
|
||||||
|
if fulltext_topn is not None:
|
||||||
|
limit = min(fulltext_topn, limit)
|
||||||
|
|
||||||
|
for index_name in index_names:
|
||||||
|
table_name = index_name
|
||||||
|
|
||||||
|
if not self._check_table_exists_cached(table_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if search_type == "fusion":
|
||||||
|
num_candidates = (vector_topn or limit) + (fulltext_topn or limit)
|
||||||
|
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight})"
|
||||||
|
fusion_sql = (
|
||||||
|
f"WITH fulltext_results AS ("
|
||||||
|
f" SELECT *, {fulltext_search_score_expr} AS relevance"
|
||||||
|
f" FROM {table_name}"
|
||||||
|
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||||
|
f" ORDER BY relevance DESC"
|
||||||
|
f" LIMIT {num_candidates}"
|
||||||
|
f")"
|
||||||
|
f" SELECT {fields_expr}, {score_expr} AS _score"
|
||||||
|
f" FROM fulltext_results"
|
||||||
|
f" WHERE {vector_search_filter}"
|
||||||
|
f" ORDER BY _score DESC"
|
||||||
|
f" LIMIT {offset}, {limit}"
|
||||||
|
)
|
||||||
|
self.logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
|
||||||
|
rows, elapsed_time = self._execute_search_sql(fusion_sql)
|
||||||
|
self.logger.info(
|
||||||
|
f"OBConnection.search table {table_name}, search type: fusion, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"]))
|
||||||
|
result.total += 1
|
||||||
|
|
||||||
|
elif search_type == "vector":
|
||||||
|
vector_sql = self._build_vector_search_sql(
|
||||||
|
table_name, fields_expr, vector_search_score_expr, filters_expr,
|
||||||
|
vector_search_filter, vector_search_expr, limit, vector_topn, offset
|
||||||
|
)
|
||||||
|
self.logger.debug("OBConnection.search with vector sql: %s", vector_sql)
|
||||||
|
rows, elapsed_time = self._execute_search_sql(vector_sql)
|
||||||
|
self.logger.info(
|
||||||
|
f"OBConnection.search table {table_name}, search type: vector, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"]))
|
||||||
|
result.total += 1
|
||||||
|
|
||||||
|
elif search_type == "fulltext":
|
||||||
|
fulltext_sql = self._build_fulltext_search_sql(
|
||||||
|
table_name, fields_expr, fulltext_search_score_expr, filters_expr,
|
||||||
|
fulltext_search_filter, offset, limit, fulltext_topn
|
||||||
|
)
|
||||||
|
self.logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql)
|
||||||
|
rows, elapsed_time = self._execute_search_sql(fulltext_sql)
|
||||||
|
self.logger.info(
|
||||||
|
f"OBConnection.search table {table_name}, search type: fulltext, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"]))
|
||||||
|
result.total += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
orders: list[str] = []
|
||||||
|
if order_by and order_by.fields:
|
||||||
|
for field, order_dir in order_by.fields:
|
||||||
|
field_name = self.convert_field_name(field)
|
||||||
|
order_str = "ASC" if order_dir == 0 else "DESC"
|
||||||
|
orders.append(f"{field_name} {order_str}")
|
||||||
|
|
||||||
|
order_by_expr = ("ORDER BY " + ", ".join(orders)) if orders else ""
|
||||||
|
limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else ""
|
||||||
|
filter_sql = self._build_filter_search_sql(
|
||||||
|
table_name, fields_expr, filters_expr, order_by_expr, limit_expr
|
||||||
|
)
|
||||||
|
self.logger.debug("OBConnection.search with filter sql: %s", filter_sql)
|
||||||
|
rows, elapsed_time = self._execute_search_sql(filter_sql)
|
||||||
|
self.logger.info(
|
||||||
|
f"OBConnection.search table {table_name}, search type: filter, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
result.messages.append(self._row_to_entity(row, db_output_fields))
|
||||||
|
result.total += 1
|
||||||
|
|
||||||
|
if result.total == 0:
|
||||||
|
result.total = len(result.messages)
|
||||||
|
|
||||||
|
return result, result.total
|
||||||
|
|
||||||
|
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int = 512):
|
||||||
|
"""Get forgotten messages (messages with forget_at set)."""
|
||||||
|
if not self._check_table_exists_cached(index_name):
|
||||||
|
return None
|
||||||
|
|
||||||
|
db_output_fields = [self.convert_field_name(f) for f in select_fields]
|
||||||
|
fields_expr = ", ".join(db_output_fields)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
f"SELECT {fields_expr}"
|
||||||
|
f" FROM {index_name}"
|
||||||
|
f" WHERE memory_id = {get_value_str(memory_id)} AND forget_at IS NOT NULL"
|
||||||
|
f" ORDER BY forget_at ASC"
|
||||||
|
f" LIMIT {limit}"
|
||||||
|
)
|
||||||
|
self.logger.debug("OBConnection.get_forgotten_messages sql: %s", sql)
|
||||||
|
|
||||||
|
res = self.client.perform_raw_text_sql(sql)
|
||||||
|
rows = res.fetchall()
|
||||||
|
|
||||||
|
result = SearchResult(total=len(rows), messages=[])
|
||||||
|
for row in rows:
|
||||||
|
result.messages.append(self._row_to_entity(row, db_output_fields))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str,
|
||||||
|
limit: int = 512):
|
||||||
|
"""Get messages missing a specific field."""
|
||||||
|
if not self._check_table_exists_cached(index_name):
|
||||||
|
return None
|
||||||
|
|
||||||
|
db_field_name = self.convert_field_name(field_name)
|
||||||
|
db_output_fields = [self.convert_field_name(f) for f in select_fields]
|
||||||
|
fields_expr = ", ".join(db_output_fields)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
f"SELECT {fields_expr}"
|
||||||
|
f" FROM {index_name}"
|
||||||
|
f" WHERE memory_id = {get_value_str(memory_id)} AND {db_field_name} IS NULL"
|
||||||
|
f" ORDER BY valid_at ASC"
|
||||||
|
f" LIMIT {limit}"
|
||||||
|
)
|
||||||
|
self.logger.debug("OBConnection.get_missing_field_message sql: %s", sql)
|
||||||
|
|
||||||
|
res = self.client.perform_raw_text_sql(sql)
|
||||||
|
rows = res.fetchall()
|
||||||
|
|
||||||
|
result = SearchResult(total=len(rows), messages=[])
|
||||||
|
for row in rows:
|
||||||
|
result.messages.append(self._row_to_entity(row, db_output_fields))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||||
|
"""Get single message by id."""
|
||||||
|
doc = super().get(doc_id, index_name, memory_ids)
|
||||||
|
if doc is None:
|
||||||
|
return None
|
||||||
|
return self.get_message_from_ob_doc(doc)
|
||||||
|
|
||||||
|
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
|
||||||
|
"""Insert messages into memory storage."""
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
vector_size = len(documents[0].get("content_embed", [])) if "content_embed" in documents[0] else 0
|
||||||
|
|
||||||
|
if not self._check_table_exists_cached(index_name):
|
||||||
|
if vector_size == 0:
|
||||||
|
raise ValueError("Cannot infer vector size from documents")
|
||||||
|
self.create_idx(index_name, memory_id, vector_size)
|
||||||
|
elif vector_size > 0:
|
||||||
|
# Table exists but may not have the required vector column
|
||||||
|
self._ensure_vector_column_exists(index_name, vector_size)
|
||||||
|
|
||||||
|
docs: list[dict] = []
|
||||||
|
ids: list[str] = []
|
||||||
|
|
||||||
|
for document in documents:
|
||||||
|
d = self.map_message_to_ob_fields(document)
|
||||||
|
ids.append(d["id"])
|
||||||
|
|
||||||
|
for column_name in COLUMN_NAMES:
|
||||||
|
if column_name not in d:
|
||||||
|
d[column_name] = None
|
||||||
|
|
||||||
|
docs.append(d)
|
||||||
|
|
||||||
|
self.logger.debug("OBConnection.insert messages: %s", ids)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
try:
|
||||||
|
self.client.upsert(index_name, docs)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"OBConnection.insert error: {str(e)}")
|
||||||
|
res.append(str(e))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
|
||||||
|
"""Update messages with given condition."""
|
||||||
|
if not self._check_table_exists_cached(index_name):
|
||||||
|
return True
|
||||||
|
|
||||||
|
condition["memory_id"] = memory_id
|
||||||
|
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||||
|
filters = self._get_filters(condition_dict)
|
||||||
|
|
||||||
|
update_dict = {self.convert_field_name(k): v for k, v in new_value.items()}
|
||||||
|
if "content_ltks" in update_dict:
|
||||||
|
update_dict["tokenized_content_ltks"] = fine_grained_tokenize(tokenize(update_dict["content_ltks"]))
|
||||||
|
update_dict.pop("id", None)
|
||||||
|
|
||||||
|
set_values: list[str] = []
|
||||||
|
for k, v in update_dict.items():
|
||||||
|
if k == "remove":
|
||||||
|
if isinstance(v, str):
|
||||||
|
set_values.append(f"{v} = NULL")
|
||||||
|
elif k == "status":
|
||||||
|
set_values.append(f"status_int = {1 if v else 0}")
|
||||||
|
else:
|
||||||
|
set_values.append(f"{k} = {get_value_str(v)}")
|
||||||
|
|
||||||
|
if not set_values:
|
||||||
|
return True
|
||||||
|
|
||||||
|
update_sql = (
|
||||||
|
f"UPDATE {index_name}"
|
||||||
|
f" SET {', '.join(set_values)}"
|
||||||
|
f" WHERE {' AND '.join(filters)}"
|
||||||
|
)
|
||||||
|
self.logger.debug("OBConnection.update sql: %s", update_sql)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.perform_raw_text_sql(update_sql)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"OBConnection.update error: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete(self, condition: dict, index_name: str, memory_id: str) -> int:
|
||||||
|
"""Delete messages with given condition."""
|
||||||
|
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||||
|
return super().delete(condition_dict, index_name, memory_id)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helper functions for search result
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_total(self, res) -> int:
|
||||||
|
if isinstance(res, tuple):
|
||||||
|
return res[1]
|
||||||
|
if hasattr(res, 'total'):
|
||||||
|
return res.total
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_doc_ids(self, res) -> list[str]:
|
||||||
|
if isinstance(res, tuple):
|
||||||
|
res = res[0]
|
||||||
|
if hasattr(res, 'messages'):
|
||||||
|
return [row.get("id") for row in res.messages if row.get("id")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||||
|
"""Get fields from search result."""
|
||||||
|
if isinstance(res, tuple):
|
||||||
|
res = res[0]
|
||||||
|
|
||||||
|
res_fields = {}
|
||||||
|
if not fields:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
messages = res.messages if hasattr(res, 'messages') else []
|
||||||
|
|
||||||
|
for doc in messages:
|
||||||
|
message = self.get_message_from_ob_doc(doc)
|
||||||
|
m = {}
|
||||||
|
for n, v in message.items():
|
||||||
|
if n not in fields:
|
||||||
|
continue
|
||||||
|
if isinstance(v, list):
|
||||||
|
m[n] = v
|
||||||
|
continue
|
||||||
|
if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v,
|
||||||
|
(int,
|
||||||
|
float,
|
||||||
|
bool)):
|
||||||
|
m[n] = v
|
||||||
|
continue
|
||||||
|
if not isinstance(v, str):
|
||||||
|
m[n] = str(v) if v is not None else ""
|
||||||
|
else:
|
||||||
|
m[n] = v
|
||||||
|
|
||||||
|
doc_id = doc.get("id") or message.get("id")
|
||||||
|
if m and doc_id:
|
||||||
|
res_fields[doc_id] = m
|
||||||
|
|
||||||
|
return res_fields
|
||||||
|
|
||||||
|
def get_highlight(self, res, keywords: list[str], field_name: str):
|
||||||
|
"""Get highlighted text for search results."""
|
||||||
|
# TODO: Implement highlight functionality for OceanBase memory
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_aggregation(self, res, field_name: str):
|
||||||
|
"""Get aggregation for search results."""
|
||||||
|
# TODO: Implement aggregation functionality for OceanBase memory
|
||||||
|
return []
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user