mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-04 09:35:06 +08:00
feat: add OceanBase memory store (#12955)
### What problem does this PR solve? Add OceanBase memory store and extracting base class `OBConnectionBase`. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
613
memory/utils/ob_conn.py
Normal file
613
memory/utils/ob_conn.py
Normal file
@ -0,0 +1,613 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
from pymysql.converters import escape_string
|
||||
from sqlalchemy import Column, String, Integer
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
from common.decorator import singleton
|
||||
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
|
||||
from common.doc_store.ob_conn_base import OBConnectionBase, get_value_str, vector_search_template
|
||||
from common.float_utils import get_float
|
||||
from rag.nlp.rag_tokenizer import tokenize, fine_grained_tokenize
|
||||
|
||||
# Column definitions for memory message table
|
||||
COLUMN_DEFINITIONS: list[Column] = [
|
||||
Column("id", String(256), primary_key=True, comment="unique record id"),
|
||||
Column("message_id", String(256), nullable=False, index=True, comment="message id"),
|
||||
Column("message_type_kwd", String(64), nullable=True, comment="message type"),
|
||||
Column("source_id", String(256), nullable=True, comment="source message id"),
|
||||
Column("memory_id", String(256), nullable=False, index=True, comment="memory id"),
|
||||
Column("user_id", String(256), nullable=True, comment="user id"),
|
||||
Column("agent_id", String(256), nullable=True, comment="agent id"),
|
||||
Column("session_id", String(256), nullable=True, comment="session id"),
|
||||
Column("zone_id", Integer, nullable=True, server_default="0", comment="zone id"),
|
||||
Column("valid_at", String(64), nullable=True, comment="valid at timestamp string"),
|
||||
Column("invalid_at", String(64), nullable=True, comment="invalid at timestamp string"),
|
||||
Column("forget_at", String(64), nullable=True, comment="forget at timestamp string"),
|
||||
Column("status_int", Integer, nullable=False, server_default="1", comment="status: 1 for active, 0 for inactive"),
|
||||
Column("content_ltks", LONGTEXT, nullable=True, comment="content with tokenization"),
|
||||
Column("tokenized_content_ltks", LONGTEXT, nullable=True, comment="fine-grained tokenized content"),
|
||||
]
|
||||
|
||||
COLUMN_NAMES: list[str] = [col.name for col in COLUMN_DEFINITIONS]
|
||||
|
||||
# Index columns for creating indexes
|
||||
INDEX_COLUMNS: list[str] = [
|
||||
"message_id",
|
||||
"memory_id",
|
||||
"status_int",
|
||||
]
|
||||
|
||||
# Full-text search columns
|
||||
FTS_COLUMNS: list[str] = [
|
||||
"content_ltks",
|
||||
"tokenized_content_ltks",
|
||||
]
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
total: int
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
@singleton
|
||||
class OBConnection(OBConnectionBase):
|
||||
def __init__(self):
|
||||
super().__init__(logger_name='ragflow.memory_ob_conn')
|
||||
self._fulltext_search_columns = FTS_COLUMNS
|
||||
|
||||
"""
|
||||
Template method implementations
|
||||
"""
|
||||
|
||||
def get_index_columns(self) -> list[str]:
|
||||
return INDEX_COLUMNS
|
||||
|
||||
def get_fulltext_columns(self) -> list[str]:
|
||||
"""Return list of column names that need fulltext indexes (without weight suffix)."""
|
||||
return [col.split("^")[0] for col in self._fulltext_search_columns]
|
||||
|
||||
def get_column_definitions(self) -> list[Column]:
|
||||
return COLUMN_DEFINITIONS
|
||||
|
||||
def get_lock_prefix(self) -> str:
|
||||
return "ob_memory_"
|
||||
|
||||
def _get_dataset_id_field(self) -> str:
|
||||
return "memory_id"
|
||||
|
||||
def _get_vector_column_name_from_table(self, table_name: str) -> Optional[str]:
|
||||
"""Get the vector column name from the table (q_{size}_vec pattern)."""
|
||||
sql = f"""
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = '{self.db_name}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
AND COLUMN_NAME REGEXP '^q_[0-9]+_vec$'
|
||||
LIMIT 1
|
||||
"""
|
||||
try:
|
||||
res = self.client.perform_raw_text_sql(sql)
|
||||
row = res.fetchone()
|
||||
return row[0] if row else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
"""
|
||||
Field conversion methods
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def convert_field_name(field_name: str, use_tokenized_content=False) -> str:
|
||||
"""Convert message field name to database column name."""
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case "content":
|
||||
if use_tokenized_content:
|
||||
return "tokenized_content_ltks"
|
||||
return "content_ltks"
|
||||
case _:
|
||||
return field_name
|
||||
|
||||
@staticmethod
|
||||
def map_message_to_ob_fields(message: dict) -> dict:
|
||||
"""Map message dictionary fields to OceanBase document fields."""
|
||||
storage_doc = {
|
||||
"id": message.get("id"),
|
||||
"message_id": message["message_id"],
|
||||
"message_type_kwd": message["message_type"],
|
||||
"source_id": message.get("source_id"),
|
||||
"memory_id": message["memory_id"],
|
||||
"user_id": message.get("user_id", ""),
|
||||
"agent_id": message["agent_id"],
|
||||
"session_id": message["session_id"],
|
||||
"valid_at": message["valid_at"],
|
||||
"invalid_at": message.get("invalid_at"),
|
||||
"forget_at": message.get("forget_at"),
|
||||
"status_int": 1 if message["status"] else 0,
|
||||
"zone_id": message.get("zone_id", 0),
|
||||
"content_ltks": message["content"],
|
||||
"tokenized_content_ltks": fine_grained_tokenize(tokenize(message["content"])),
|
||||
}
|
||||
# Handle vector embedding
|
||||
content_embed = message.get("content_embed", [])
|
||||
if len(content_embed) > 0:
|
||||
storage_doc[f"q_{len(content_embed)}_vec"] = content_embed
|
||||
return storage_doc
|
||||
|
||||
@staticmethod
|
||||
def get_message_from_ob_doc(doc: dict) -> dict:
|
||||
"""Convert an OceanBase document back to a message dictionary."""
|
||||
embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None)
|
||||
content_embed = doc.get(embd_field_name, []) if embd_field_name else []
|
||||
if isinstance(content_embed, np.ndarray):
|
||||
content_embed = content_embed.tolist()
|
||||
message = {
|
||||
"message_id": doc.get("message_id"),
|
||||
"message_type": doc.get("message_type_kwd"),
|
||||
"source_id": doc.get("source_id") if doc.get("source_id") else None,
|
||||
"memory_id": doc.get("memory_id"),
|
||||
"user_id": doc.get("user_id", ""),
|
||||
"agent_id": doc.get("agent_id"),
|
||||
"session_id": doc.get("session_id"),
|
||||
"zone_id": doc.get("zone_id", 0),
|
||||
"valid_at": doc.get("valid_at"),
|
||||
"invalid_at": doc.get("invalid_at", "-"),
|
||||
"forget_at": doc.get("forget_at", "-"),
|
||||
"status": bool(int(doc.get("status_int", 0))),
|
||||
"content": doc.get("content_ltks", ""),
|
||||
"content_embed": content_embed,
|
||||
}
|
||||
if doc.get("id"):
|
||||
message["id"] = doc["id"]
|
||||
return message
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
memory_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
hide_forgotten: bool = True
|
||||
):
|
||||
"""Search messages in memory storage."""
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
|
||||
result: SearchResult = SearchResult(total=0, messages=[])
|
||||
|
||||
output_fields = select_fields.copy()
|
||||
if "id" not in output_fields:
|
||||
output_fields = ["id"] + output_fields
|
||||
if "_score" in output_fields:
|
||||
output_fields.remove("_score")
|
||||
|
||||
# Handle content_embed field - resolve to actual vector column name
|
||||
has_content_embed = "content_embed" in output_fields
|
||||
actual_vector_column: Optional[str] = None
|
||||
if has_content_embed:
|
||||
output_fields = [f for f in output_fields if f != "content_embed"]
|
||||
# Try to get vector column name from first available table
|
||||
for idx_name in index_names:
|
||||
if self._check_table_exists_cached(idx_name):
|
||||
actual_vector_column = self._get_vector_column_name_from_table(idx_name)
|
||||
if actual_vector_column:
|
||||
output_fields.append(actual_vector_column)
|
||||
break
|
||||
|
||||
if highlight_fields:
|
||||
for field in highlight_fields:
|
||||
field_name = self.convert_field_name(field)
|
||||
if field_name not in output_fields:
|
||||
output_fields.append(field_name)
|
||||
|
||||
db_output_fields = [self.convert_field_name(f) for f in output_fields]
|
||||
fields_expr = ", ".join(db_output_fields)
|
||||
|
||||
condition["memory_id"] = memory_ids
|
||||
if hide_forgotten:
|
||||
condition["must_not"] = {"exists": "forget_at"}
|
||||
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
filters: list[str] = self._get_filters(condition_dict)
|
||||
filters_expr = " AND ".join(filters) if filters else "1=1"
|
||||
|
||||
# Parse match expressions
|
||||
fulltext_query: Optional[str] = None
|
||||
fulltext_topn: Optional[int] = None
|
||||
fulltext_search_expr: dict[str, str] = {}
|
||||
fulltext_search_weight: dict[str, float] = {}
|
||||
fulltext_search_filter: Optional[str] = None
|
||||
fulltext_search_score_expr: Optional[str] = None
|
||||
|
||||
vector_column_name: Optional[str] = None
|
||||
vector_data: Optional[list[float]] = None
|
||||
vector_topn: Optional[int] = None
|
||||
vector_similarity_threshold: Optional[float] = None
|
||||
vector_similarity_weight: Optional[float] = None
|
||||
vector_search_expr: Optional[str] = None
|
||||
vector_search_score_expr: Optional[str] = None
|
||||
vector_search_filter: Optional[str] = None
|
||||
|
||||
for m in match_expressions:
|
||||
if isinstance(m, MatchTextExpr):
|
||||
assert "original_query" in m.extra_options, "'original_query' is missing in extra_options."
|
||||
fulltext_query = m.extra_options["original_query"]
|
||||
fulltext_query = escape_string(fulltext_query.strip())
|
||||
fulltext_topn = m.topn
|
||||
|
||||
fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns(
|
||||
fulltext_query, self._fulltext_search_columns
|
||||
)
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
vector_column_name = m.vector_column_name
|
||||
vector_data = m.embedding_data
|
||||
vector_topn = m.topn
|
||||
vector_similarity_threshold = m.extra_options.get("similarity", 0.0) if m.extra_options else 0.0
|
||||
elif isinstance(m, FusionExpr):
|
||||
weights = m.fusion_params.get("weights", "0.5,0.5") if m.fusion_params else "0.5,0.5"
|
||||
vector_similarity_weight = get_float(weights.split(",")[1])
|
||||
|
||||
if fulltext_query:
|
||||
fulltext_search_filter = f"({' OR '.join([expr for expr in fulltext_search_expr.values()])})"
|
||||
fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"
|
||||
|
||||
if vector_data:
|
||||
vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]"
|
||||
vector_search_expr = vector_search_template % (vector_column_name, vector_data_str)
|
||||
vector_search_score_expr = f"(1 - {vector_search_expr})"
|
||||
vector_search_filter = f"{vector_search_score_expr} >= {vector_similarity_threshold}"
|
||||
|
||||
# Determine search type
|
||||
if fulltext_query and vector_data:
|
||||
search_type = "fusion"
|
||||
elif fulltext_query:
|
||||
search_type = "fulltext"
|
||||
elif vector_data:
|
||||
search_type = "vector"
|
||||
else:
|
||||
search_type = "filter"
|
||||
|
||||
if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
|
||||
output_fields.append("_score")
|
||||
|
||||
if limit:
|
||||
if vector_topn is not None:
|
||||
limit = min(vector_topn, limit)
|
||||
if fulltext_topn is not None:
|
||||
limit = min(fulltext_topn, limit)
|
||||
|
||||
for index_name in index_names:
|
||||
table_name = index_name
|
||||
|
||||
if not self._check_table_exists_cached(table_name):
|
||||
continue
|
||||
|
||||
if search_type == "fusion":
|
||||
num_candidates = (vector_topn or limit) + (fulltext_topn or limit)
|
||||
score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight})"
|
||||
fusion_sql = (
|
||||
f"WITH fulltext_results AS ("
|
||||
f" SELECT *, {fulltext_search_score_expr} AS relevance"
|
||||
f" FROM {table_name}"
|
||||
f" WHERE {filters_expr} AND {fulltext_search_filter}"
|
||||
f" ORDER BY relevance DESC"
|
||||
f" LIMIT {num_candidates}"
|
||||
f")"
|
||||
f" SELECT {fields_expr}, {score_expr} AS _score"
|
||||
f" FROM fulltext_results"
|
||||
f" WHERE {vector_search_filter}"
|
||||
f" ORDER BY _score DESC"
|
||||
f" LIMIT {offset}, {limit}"
|
||||
)
|
||||
self.logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)
|
||||
rows, elapsed_time = self._execute_search_sql(fusion_sql)
|
||||
self.logger.info(
|
||||
f"OBConnection.search table {table_name}, search type: fusion, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}"
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"]))
|
||||
result.total += 1
|
||||
|
||||
elif search_type == "vector":
|
||||
vector_sql = self._build_vector_search_sql(
|
||||
table_name, fields_expr, vector_search_score_expr, filters_expr,
|
||||
vector_search_filter, vector_search_expr, limit, vector_topn, offset
|
||||
)
|
||||
self.logger.debug("OBConnection.search with vector sql: %s", vector_sql)
|
||||
rows, elapsed_time = self._execute_search_sql(vector_sql)
|
||||
self.logger.info(
|
||||
f"OBConnection.search table {table_name}, search type: vector, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}"
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"]))
|
||||
result.total += 1
|
||||
|
||||
elif search_type == "fulltext":
|
||||
fulltext_sql = self._build_fulltext_search_sql(
|
||||
table_name, fields_expr, fulltext_search_score_expr, filters_expr,
|
||||
fulltext_search_filter, offset, limit, fulltext_topn
|
||||
)
|
||||
self.logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql)
|
||||
rows, elapsed_time = self._execute_search_sql(fulltext_sql)
|
||||
self.logger.info(
|
||||
f"OBConnection.search table {table_name}, search type: fulltext, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}"
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
result.messages.append(self._row_to_entity(row, db_output_fields + ["_score"]))
|
||||
result.total += 1
|
||||
|
||||
else:
|
||||
orders: list[str] = []
|
||||
if order_by and order_by.fields:
|
||||
for field, order_dir in order_by.fields:
|
||||
field_name = self.convert_field_name(field)
|
||||
order_str = "ASC" if order_dir == 0 else "DESC"
|
||||
orders.append(f"{field_name} {order_str}")
|
||||
|
||||
order_by_expr = ("ORDER BY " + ", ".join(orders)) if orders else ""
|
||||
limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else ""
|
||||
filter_sql = self._build_filter_search_sql(
|
||||
table_name, fields_expr, filters_expr, order_by_expr, limit_expr
|
||||
)
|
||||
self.logger.debug("OBConnection.search with filter sql: %s", filter_sql)
|
||||
rows, elapsed_time = self._execute_search_sql(filter_sql)
|
||||
self.logger.info(
|
||||
f"OBConnection.search table {table_name}, search type: filter, elapsed time: {elapsed_time:.3f}s, rows: {len(rows)}"
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
result.messages.append(self._row_to_entity(row, db_output_fields))
|
||||
result.total += 1
|
||||
|
||||
if result.total == 0:
|
||||
result.total = len(result.messages)
|
||||
|
||||
return result, result.total
|
||||
|
||||
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int = 512):
|
||||
"""Get forgotten messages (messages with forget_at set)."""
|
||||
if not self._check_table_exists_cached(index_name):
|
||||
return None
|
||||
|
||||
db_output_fields = [self.convert_field_name(f) for f in select_fields]
|
||||
fields_expr = ", ".join(db_output_fields)
|
||||
|
||||
sql = (
|
||||
f"SELECT {fields_expr}"
|
||||
f" FROM {index_name}"
|
||||
f" WHERE memory_id = {get_value_str(memory_id)} AND forget_at IS NOT NULL"
|
||||
f" ORDER BY forget_at ASC"
|
||||
f" LIMIT {limit}"
|
||||
)
|
||||
self.logger.debug("OBConnection.get_forgotten_messages sql: %s", sql)
|
||||
|
||||
res = self.client.perform_raw_text_sql(sql)
|
||||
rows = res.fetchall()
|
||||
|
||||
result = SearchResult(total=len(rows), messages=[])
|
||||
for row in rows:
|
||||
result.messages.append(self._row_to_entity(row, db_output_fields))
|
||||
|
||||
return result
|
||||
|
||||
def get_missing_field_message(self, select_fields: list[str], index_name: str, memory_id: str, field_name: str,
|
||||
limit: int = 512):
|
||||
"""Get messages missing a specific field."""
|
||||
if not self._check_table_exists_cached(index_name):
|
||||
return None
|
||||
|
||||
db_field_name = self.convert_field_name(field_name)
|
||||
db_output_fields = [self.convert_field_name(f) for f in select_fields]
|
||||
fields_expr = ", ".join(db_output_fields)
|
||||
|
||||
sql = (
|
||||
f"SELECT {fields_expr}"
|
||||
f" FROM {index_name}"
|
||||
f" WHERE memory_id = {get_value_str(memory_id)} AND {db_field_name} IS NULL"
|
||||
f" ORDER BY valid_at ASC"
|
||||
f" LIMIT {limit}"
|
||||
)
|
||||
self.logger.debug("OBConnection.get_missing_field_message sql: %s", sql)
|
||||
|
||||
res = self.client.perform_raw_text_sql(sql)
|
||||
rows = res.fetchall()
|
||||
|
||||
result = SearchResult(total=len(rows), messages=[])
|
||||
for row in rows:
|
||||
result.messages.append(self._row_to_entity(row, db_output_fields))
|
||||
|
||||
return result
|
||||
|
||||
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
"""Get single message by id."""
|
||||
doc = super().get(doc_id, index_name, memory_ids)
|
||||
if doc is None:
|
||||
return None
|
||||
return self.get_message_from_ob_doc(doc)
|
||||
|
||||
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
|
||||
"""Insert messages into memory storage."""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
vector_size = len(documents[0].get("content_embed", [])) if "content_embed" in documents[0] else 0
|
||||
|
||||
if not self._check_table_exists_cached(index_name):
|
||||
if vector_size == 0:
|
||||
raise ValueError("Cannot infer vector size from documents")
|
||||
self.create_idx(index_name, memory_id, vector_size)
|
||||
elif vector_size > 0:
|
||||
# Table exists but may not have the required vector column
|
||||
self._ensure_vector_column_exists(index_name, vector_size)
|
||||
|
||||
docs: list[dict] = []
|
||||
ids: list[str] = []
|
||||
|
||||
for document in documents:
|
||||
d = self.map_message_to_ob_fields(document)
|
||||
ids.append(d["id"])
|
||||
|
||||
for column_name in COLUMN_NAMES:
|
||||
if column_name not in d:
|
||||
d[column_name] = None
|
||||
|
||||
docs.append(d)
|
||||
|
||||
self.logger.debug("OBConnection.insert messages: %s", ids)
|
||||
|
||||
res = []
|
||||
try:
|
||||
self.client.upsert(index_name, docs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"OBConnection.insert error: {str(e)}")
|
||||
res.append(str(e))
|
||||
return res
|
||||
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
|
||||
"""Update messages with given condition."""
|
||||
if not self._check_table_exists_cached(index_name):
|
||||
return True
|
||||
|
||||
condition["memory_id"] = memory_id
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
filters = self._get_filters(condition_dict)
|
||||
|
||||
update_dict = {self.convert_field_name(k): v for k, v in new_value.items()}
|
||||
if "content_ltks" in update_dict:
|
||||
update_dict["tokenized_content_ltks"] = fine_grained_tokenize(tokenize(update_dict["content_ltks"]))
|
||||
update_dict.pop("id", None)
|
||||
|
||||
set_values: list[str] = []
|
||||
for k, v in update_dict.items():
|
||||
if k == "remove":
|
||||
if isinstance(v, str):
|
||||
set_values.append(f"{v} = NULL")
|
||||
elif k == "status":
|
||||
set_values.append(f"status_int = {1 if v else 0}")
|
||||
else:
|
||||
set_values.append(f"{k} = {get_value_str(v)}")
|
||||
|
||||
if not set_values:
|
||||
return True
|
||||
|
||||
update_sql = (
|
||||
f"UPDATE {index_name}"
|
||||
f" SET {', '.join(set_values)}"
|
||||
f" WHERE {' AND '.join(filters)}"
|
||||
)
|
||||
self.logger.debug("OBConnection.update sql: %s", update_sql)
|
||||
|
||||
try:
|
||||
self.client.perform_raw_text_sql(update_sql)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"OBConnection.update error: {str(e)}")
|
||||
return False
|
||||
|
||||
def delete(self, condition: dict, index_name: str, memory_id: str) -> int:
|
||||
"""Delete messages with given condition."""
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
return super().delete(condition_dict, index_name, memory_id)
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
if hasattr(res, 'total'):
|
||||
return res.total
|
||||
return 0
|
||||
|
||||
def get_doc_ids(self, res) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if hasattr(res, 'messages'):
|
||||
return [row.get("id") for row in res.messages if row.get("id")]
|
||||
return []
|
||||
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
"""Get fields from search result."""
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
|
||||
messages = res.messages if hasattr(res, 'messages') else []
|
||||
|
||||
for doc in messages:
|
||||
message = self.get_message_from_ob_doc(doc)
|
||||
m = {}
|
||||
for n, v in message.items():
|
||||
if n not in fields:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
m[n] = v
|
||||
continue
|
||||
if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v,
|
||||
(int,
|
||||
float,
|
||||
bool)):
|
||||
m[n] = v
|
||||
continue
|
||||
if not isinstance(v, str):
|
||||
m[n] = str(v) if v is not None else ""
|
||||
else:
|
||||
m[n] = v
|
||||
|
||||
doc_id = doc.get("id") or message.get("id")
|
||||
if m and doc_id:
|
||||
res_fields[doc_id] = m
|
||||
|
||||
return res_fields
|
||||
|
||||
def get_highlight(self, res, keywords: list[str], field_name: str):
|
||||
"""Get highlighted text for search results."""
|
||||
# TODO: Implement highlight functionality for OceanBase memory
|
||||
return {}
|
||||
|
||||
def get_aggregation(self, res, field_name: str):
|
||||
"""Get aggregation for search results."""
|
||||
# TODO: Implement aggregation functionality for OceanBase memory
|
||||
return []
|
||||
Reference in New Issue
Block a user