# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import logging import re import time from typing import Any, Optional import numpy as np from elasticsearch_dsl import Q, Search from pydantic import BaseModel from pymysql.converters import escape_string from pyobvector import ARRAY from sqlalchemy import Column, String, Integer, JSON, Double, Row from sqlalchemy.dialects.mysql import LONGTEXT, TEXT from sqlalchemy.sql.type_api import TypeEngine from common.constants import PAGERANK_FLD, TAG_FLD from common.decorator import singleton from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr from common.doc_store.ob_conn_base import ( OBConnectionBase, get_value_str, vector_search_template, vector_column_pattern, fulltext_index_name_template, ) from common.float_utils import get_float from rag.nlp import rag_tokenizer logger = logging.getLogger('ragflow.ob_conn') column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk order id for maintaining sequence") column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval") column_mom_id = Column("mom_id", String(256), nullable=True, comment="parent chunk id") column_chunk_data = Column("chunk_data", JSON, nullable=True, comment="table parser row data") column_definitions: list[Column] = [ Column("id", String(256), primary_key=True, comment="chunk id"), Column("kb_id", String(256), nullable=False, index=True, comment="knowledge base id"), Column("doc_id", String(256), nullable=True, index=True, comment="document id"), Column("docnm_kwd", String(256), nullable=True, comment="document name"), Column("doc_type_kwd", String(256), nullable=True, comment="document type"), Column("title_tks", String(256), nullable=True, comment="title tokens"), Column("title_sm_tks", String(256), nullable=True, comment="fine-grained (small) title tokens"), Column("content_with_weight", LONGTEXT, nullable=True, comment="the original content"), Column("content_ltks", LONGTEXT, nullable=True, comment="long text tokens derived from content_with_weight"), Column("content_sm_ltks", LONGTEXT, nullable=True, comment="fine-grained (small) tokens derived from content_ltks"), Column("pagerank_fea", Integer, nullable=True, comment="page rank priority, usually set in kb level"), Column("important_kwd", ARRAY(String(256)), nullable=True, comment="keywords"), Column("important_tks", TEXT, nullable=True, comment="keyword tokens"), Column("question_kwd", ARRAY(String(1024)), nullable=True, comment="questions"), Column("question_tks", TEXT, nullable=True, comment="question tokens"), Column("tag_kwd", ARRAY(String(256)), nullable=True, comment="tags"), Column("tag_feas", JSON, nullable=True, comment="tag features used for 'rank_feature', format: [tag -> relevance score]"), Column("available_int", Integer, nullable=False, index=True, server_default="1", comment="status of availability, 0 for unavailable, 1 for available"), Column("create_time", String(19), nullable=True, comment="creation time in YYYY-MM-DD HH:MM:SS format"), Column("create_timestamp_flt", Double, nullable=True, comment="creation timestamp in float format"), Column("img_id", String(128), nullable=True, comment="image id"), Column("position_int", ARRAY(ARRAY(Integer)), nullable=True, comment="position"), Column("page_num_int", ARRAY(Integer), nullable=True, comment="page number"), Column("top_int", ARRAY(Integer), nullable=True, comment="rank from the top"), Column("knowledge_graph_kwd", String(256), nullable=True, index=True, comment="knowledge graph chunk type"), Column("source_id", ARRAY(String(256)), nullable=True, comment="source document id"), Column("entity_kwd", String(256), nullable=True, comment="entity name"), Column("entity_type_kwd", String(256), nullable=True, index=True, comment="entity type"), Column("from_entity_kwd", String(256), nullable=True, comment="the source entity of this edge"), Column("to_entity_kwd", String(256), nullable=True, comment="the target entity of this edge"), Column("weight_int", Integer, nullable=True, comment="the weight of this edge"), Column("weight_flt", Double, nullable=True, comment="the weight of community report"), Column("entities_kwd", ARRAY(String(256)), nullable=True, comment="node ids of entities"), Column("rank_flt", Double, nullable=True, comment="rank of this entity"), Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'", comment="whether it has been deleted"), column_chunk_data, Column("metadata", JSON, nullable=True, comment="metadata for this chunk"), Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"), column_order_id, column_group_id, column_mom_id, ] column_names: list[str] = [col.name for col in column_definitions] column_types: dict[str, TypeEngine] = {col.name: col.type for col in column_definitions} array_columns: list[str] = [col.name for col in column_definitions if isinstance(col.type, ARRAY)] # Index columns for RAG chunk table INDEX_COLUMNS: list[str] = [ "kb_id", "doc_id", "available_int", "knowledge_graph_kwd", "entity_type_kwd", "removed_kwd", ] # Full-text search columns (with weight) - original content FTS_COLUMNS_ORIGIN: list[str] = [ "docnm_kwd^10", "content_with_weight", "important_tks^20", "question_tks^20", ] # Full-text search columns (with weight) - tokenized content FTS_COLUMNS_TKS: list[str] = [ "title_tks^10", "title_sm_tks^5", "important_tks^20", "question_tks^20", "content_ltks^2", "content_sm_ltks", ] # Extra columns to add after table creation (for migration) EXTRA_COLUMNS: list[Column] = [column_order_id, column_group_id, column_mom_id] class SearchResult(BaseModel): total: int chunks: list[dict] def get_column_value(column_name: str, value: Any) -> Any: if column_name in column_types: column_type = column_types[column_name] if isinstance(column_type, String): return str(value) elif isinstance(column_type, Integer): return int(value) elif isinstance(column_type, Double): return float(value) elif isinstance(column_type, ARRAY) or isinstance(column_type, JSON): if isinstance(value, str): try: return json.loads(value) except json.JSONDecodeError: return value else: return value else: raise ValueError(f"Unsupported column type for column '{column_name}': {column_type}") elif vector_column_pattern.match(column_name): if isinstance(value, str): try: return json.loads(value) except json.JSONDecodeError: return value else: return value elif column_name == "_score": return float(value) else: raise ValueError(f"Unknown column '{column_name}' with value '{value}'.") def get_default_value(column_name: str) -> Any: if column_name == "available_int": return 1 elif column_name == "removed_kwd": return "N" elif column_name == "_order_id": return 0 else: return None def get_metadata_filter_expression(metadata_filtering_conditions: dict) -> str: """ Convert metadata filtering conditions to MySQL JSON path expression. Args: metadata_filtering_conditions: dict with 'conditions' and 'logical_operator' keys Returns: MySQL JSON path expression string """ if not metadata_filtering_conditions: return "" conditions = metadata_filtering_conditions.get("conditions", []) logical_operator = metadata_filtering_conditions.get("logical_operator", "and").upper() if not conditions: return "" if logical_operator not in ["AND", "OR"]: raise ValueError(f"Unsupported logical operator: {logical_operator}. Only 'and' and 'or' are supported.") metadata_filters = [] for condition in conditions: name = condition.get("name") comparison_operator = condition.get("comparison_operator") value = condition.get("value") if not all([name, comparison_operator]): continue expr = f"JSON_EXTRACT(metadata, '$.{name}')" value_str = get_value_str(value) # Convert comparison operator to MySQL JSON path syntax if comparison_operator == "is": # JSON_EXTRACT(metadata, '$.field_name') = 'value' metadata_filters.append(f"{expr} = {value_str}") elif comparison_operator == "is not": metadata_filters.append(f"{expr} != {value_str}") elif comparison_operator == "contains": metadata_filters.append(f"JSON_CONTAINS({expr}, {value_str})") elif comparison_operator == "not contains": metadata_filters.append(f"NOT JSON_CONTAINS({expr}, {value_str})") elif comparison_operator == "start with": metadata_filters.append(f"{expr} LIKE CONCAT({value_str}, '%')") elif comparison_operator == "end with": metadata_filters.append(f"{expr} LIKE CONCAT('%', {value_str})") elif comparison_operator == "empty": metadata_filters.append(f"({expr} IS NULL OR {expr} = '' OR {expr} = '[]' OR {expr} = '{{}}')") elif comparison_operator == "not empty": metadata_filters.append(f"({expr} IS NOT NULL AND {expr} != '' AND {expr} != '[]' AND {expr} != '{{}}')") # Number operators elif comparison_operator == "=": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) = {value_str}") elif comparison_operator == "≠": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) != {value_str}") elif comparison_operator == ">": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) > {value_str}") elif comparison_operator == "<": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) < {value_str}") elif comparison_operator == "≥": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) >= {value_str}") elif comparison_operator == "≤": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) <= {value_str}") # Time operators elif comparison_operator == "before": metadata_filters.append(f"CAST({expr} AS DATETIME) < {value_str}") elif comparison_operator == "after": metadata_filters.append(f"CAST({expr} AS DATETIME) > {value_str}") else: logger.warning(f"Unsupported comparison operator: {comparison_operator}") continue if not metadata_filters: return "" return f"({f' {logical_operator} '.join(metadata_filters)})" def get_filters(condition: dict) -> list[str]: filters: list[str] = [] for k, v in condition.items(): if not v: continue if k == "exists": filters.append(f"{v} IS NOT NULL") elif k == "must_not" and isinstance(v, dict) and "exists" in v: filters.append(f"{v.get('exists')} IS NULL") elif k == "metadata_filtering_conditions": # Handle metadata filtering conditions metadata_filter = get_metadata_filter_expression(v) if metadata_filter: filters.append(metadata_filter) elif k in array_columns: if isinstance(v, list): array_filters = [] for vv in v: array_filters.append(f"array_contains({k}, {get_value_str(vv)})") array_filter = " OR ".join(array_filters) filters.append(f"({array_filter})") else: filters.append(f"array_contains({k}, {get_value_str(v)})") elif isinstance(v, list): values: list[str] = [] for item in v: values.append(get_value_str(item)) value = ", ".join(values) filters.append(f"{k} IN ({value})") else: filters.append(f"{k} = {get_value_str(v)}") return filters @singleton class OBConnection(OBConnectionBase): def __init__(self): super().__init__(logger_name='ragflow.ob_conn') # Determine which columns to use for full-text search dynamically self._fulltext_search_columns = FTS_COLUMNS_ORIGIN if self.search_original_content else FTS_COLUMNS_TKS """ Template method implementations """ def get_index_columns(self) -> list[str]: return INDEX_COLUMNS def get_column_definitions(self) -> list[Column]: return column_definitions def get_extra_columns(self) -> list[Column]: return EXTRA_COLUMNS def get_lock_prefix(self) -> str: return "ob_" def _get_filters(self, condition: dict) -> list[str]: return get_filters(condition) def get_fulltext_columns(self) -> list[str]: """Return list of column names that need fulltext indexes (without weight suffix).""" return [col.split("^")[0] for col in self._fulltext_search_columns] def delete_idx(self, index_name: str, dataset_id: str): if dataset_id: # The index need to be alive after any kb deletion since all kb under this tenant are in one index. return super().delete_idx(index_name, dataset_id) """ Performance monitoring """ def get_performance_metrics(self) -> dict: """ Get comprehensive performance metrics for OceanBase. Returns: dict: Performance metrics including latency, storage, QPS, and slow queries """ metrics = { "connection": "connected", "latency_ms": 0.0, "storage_used": "0B", "storage_total": "0B", "query_per_second": 0, "slow_queries": 0, "active_connections": 0, "max_connections": 0 } try: # Measure connection latency start_time = time.time() self.client.perform_raw_text_sql("SELECT 1").fetchone() metrics["latency_ms"] = round((time.time() - start_time) * 1000, 2) # Get storage information try: storage_info = self._get_storage_info() metrics.update(storage_info) except Exception as e: logger.warning(f"Failed to get storage info: {str(e)}") # Get connection pool statistics try: pool_stats = self._get_connection_pool_stats() metrics.update(pool_stats) except Exception as e: logger.warning(f"Failed to get connection pool stats: {str(e)}") # Get slow query statistics try: slow_queries = self._get_slow_query_count() metrics["slow_queries"] = slow_queries except Exception as e: logger.warning(f"Failed to get slow query count: {str(e)}") # Get QPS (Queries Per Second) - approximate from processlist try: qps = self._estimate_qps() metrics["query_per_second"] = qps except Exception as e: logger.warning(f"Failed to estimate QPS: {str(e)}") except Exception as e: metrics["connection"] = "disconnected" metrics["error"] = str(e) logger.error(f"Failed to get OceanBase performance metrics: {str(e)}") return metrics def _get_storage_info(self) -> dict: """ Get storage space usage information. Returns: dict: Storage information with used and total space """ try: # Get database size result = self.client.perform_raw_text_sql( f"SELECT ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) AS 'size_mb' " f"FROM information_schema.tables WHERE table_schema = '{self.db_name}'" ).fetchone() size_mb = float(result[0]) if result and result[0] else 0.0 # Try to get total available space (may not be available in all OceanBase versions) try: result = self.client.perform_raw_text_sql( "SELECT ROUND(SUM(total_size) / 1024 / 1024 / 1024, 2) AS 'total_gb' " "FROM oceanbase.__all_disk_stat" ).fetchone() total_gb = float(result[0]) if result and result[0] else None except Exception: # Fallback: estimate total space (100GB default if not available) total_gb = 100.0 return { "storage_used": f"{size_mb:.2f}MB", "storage_total": f"{total_gb:.2f}GB" if total_gb else "N/A" } except Exception as e: logger.warning(f"Failed to get storage info: {str(e)}") return { "storage_used": "N/A", "storage_total": "N/A" } def _get_connection_pool_stats(self) -> dict: """ Get connection pool statistics. Returns: dict: Connection pool statistics """ try: # Get active connections from processlist result = self.client.perform_raw_text_sql("SHOW PROCESSLIST") active_connections = len(list(result.fetchall())) # Get max_connections setting max_conn_result = self.client.perform_raw_text_sql( "SHOW VARIABLES LIKE 'max_connections'" ).fetchone() max_connections = int(max_conn_result[1]) if max_conn_result and max_conn_result[1] else 0 # Get pool size from client if available pool_size = getattr(self.client, 'pool_size', None) or 0 return { "active_connections": active_connections, "max_connections": max_connections if max_connections > 0 else pool_size, "pool_size": pool_size } except Exception as e: logger.warning(f"Failed to get connection pool stats: {str(e)}") return { "active_connections": 0, "max_connections": 0, "pool_size": 0 } def _get_slow_query_count(self, threshold_seconds: int = 1) -> int: """ Get count of slow queries (queries taking longer than threshold). Args: threshold_seconds: Threshold in seconds for slow queries (default: 1) Returns: int: Number of slow queries """ try: result = self.client.perform_raw_text_sql( f"SELECT COUNT(*) FROM information_schema.processlist " f"WHERE time > {threshold_seconds} AND command != 'Sleep'" ).fetchone() return int(result[0]) if result and result[0] else 0 except Exception as e: logger.warning(f"Failed to get slow query count: {str(e)}") return 0 def _estimate_qps(self) -> int: """ Estimate queries per second from processlist. Returns: int: Estimated queries per second """ try: # Count active queries (non-Sleep commands) result = self.client.perform_raw_text_sql( "SELECT COUNT(*) FROM information_schema.processlist WHERE command != 'Sleep'" ).fetchone() active_queries = int(result[0]) if result and result[0] else 0 # Rough estimate: assume average query takes 0.1 seconds # This is a simplified estimation estimated_qps = max(0, active_queries * 10) return estimated_qps except Exception as e: logger.warning(f"Failed to estimate QPS: {str(e)}") return 0 """ CRUD operations """ def search( self, select_fields: list[str], highlight_fields: list[str], condition: dict, match_expressions: list[MatchExpr], order_by: OrderByExpr, offset: int, limit: int, index_names: str | list[str], knowledgebase_ids: list[str], agg_fields: list[str] = [], rank_feature: dict | None = None, **kwargs, ): if isinstance(index_names, str): index_names = index_names.split(",") assert isinstance(index_names, list) and len(index_names) > 0 index_names = list(set(index_names)) if len(match_expressions) == 3: if not self.enable_fulltext_search: # disable fulltext search in fusion search, which means fallback to vector search match_expressions = [m for m in match_expressions if isinstance(m, MatchDenseExpr)] else: for m in match_expressions: if isinstance(m, FusionExpr): weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) # skip the search if its weight is zero if vector_similarity_weight <= 0.0: match_expressions = [m for m in match_expressions if isinstance(m, MatchTextExpr)] elif vector_similarity_weight >= 1.0: match_expressions = [m for m in match_expressions if isinstance(m, MatchDenseExpr)] result: SearchResult = SearchResult( total=0, chunks=[], ) # copied from es_conn.py if len(match_expressions) == 3 and self.es: bqry = Q("bool", must=[]) condition["kb_id"] = knowledgebase_ids for k, v in condition.items(): if k == "available_int": if v == 0: bqry.filter.append(Q("range", available_int={"lt": 1})) else: bqry.filter.append( Q("bool", must_not=Q("range", available_int={"lt": 1}))) continue if not v: continue if isinstance(v, list): bqry.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bqry.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") s = Search() vector_similarity_weight = 0.5 for m in match_expressions: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance( match_expressions[1], MatchDenseExpr) and isinstance( match_expressions[2], FusionExpr) weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) for m in match_expressions: if isinstance(m, MatchTextExpr): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" bqry.must.append(Q("query_string", fields=FTS_COLUMNS_TKS, type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) bqry.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): assert (bqry is not None) similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] s = s.knn(m.vector_column_name, m.topn, m.topn * 2, query_vector=list(m.embedding_data), filter=bqry.to_dict(), similarity=similarity, ) if bqry and rank_feature: for fld, sc in rank_feature.items(): if fld != PAGERANK_FLD: fld = f"{TAG_FLD}.{fld}" bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) if bqry: s = s.query(bqry) # for field in highlightFields: # s = s.highlight(field) if order_by: orders = list() for field, order in order_by.fields: order = "asc" if order == 0 else "desc" if field in ["page_num_int", "top_int"]: order_info = {"order": order, "unmapped_type": "float", "mode": "avg", "numeric_type": "double"} elif field.endswith("_int") or field.endswith("_flt"): order_info = {"order": order, "unmapped_type": "float"} else: order_info = {"order": order, "unmapped_type": "text"} orders.append({field: order_info}) s = s.sort(*orders) for fld in agg_fields: s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) if limit > 0: s = s[offset:offset + limit] q = s.to_dict() logger.debug(f"OBConnection.hybrid_search {str(index_names)} query: " + json.dumps(q)) for index_name in index_names: start_time = time.time() res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True) elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: hybrid, elapsed time: {elapsed_time:.3f} seconds," f" got count: {len(res)}" ) for chunk in res: result.chunks.append(self._es_row_to_entity(chunk)) result.total = result.total + 1 return result output_fields = select_fields.copy() if "id" not in output_fields: output_fields = ["id"] + output_fields if "_score" in output_fields: output_fields.remove("_score") if highlight_fields: for field in highlight_fields: if field not in output_fields: output_fields.append(field) fields_expr = ", ".join(output_fields) condition["kb_id"] = knowledgebase_ids filters: list[str] = get_filters(condition) filters_expr = " AND ".join(filters) fulltext_query: Optional[str] = None fulltext_topn: Optional[int] = None fulltext_search_weight: dict[str, float] = {} fulltext_search_expr: dict[str, str] = {} fulltext_search_idx_list: list[str] = [] fulltext_search_score_expr: Optional[str] = None fulltext_search_filter: Optional[str] = None vector_column_name: Optional[str] = None vector_data: Optional[list[float]] = None vector_topn: Optional[int] = None vector_similarity_threshold: Optional[float] = None vector_similarity_weight: Optional[float] = None vector_search_expr: Optional[str] = None vector_search_score_expr: Optional[str] = None vector_search_filter: Optional[str] = None for m in match_expressions: if isinstance(m, MatchTextExpr): assert "original_query" in m.extra_options, "'original_query' is missing in extra_options." fulltext_query = m.extra_options["original_query"] fulltext_query = escape_string(fulltext_query.strip()) fulltext_topn = m.topn fulltext_search_expr, fulltext_search_weight = self._parse_fulltext_columns( fulltext_query, self._fulltext_search_columns ) for column_name in fulltext_search_expr.keys(): fulltext_search_idx_list.append(fulltext_index_name_template % column_name) elif isinstance(m, MatchDenseExpr): assert m.embedding_data_type == "float", f"embedding data type '{m.embedding_data_type}' is not float." vector_column_name = m.vector_column_name vector_data = m.embedding_data vector_topn = m.topn vector_similarity_threshold = m.extra_options.get("similarity", 0.0) elif isinstance(m, FusionExpr): weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) if fulltext_query: fulltext_search_filter = f"({' OR '.join([expr for expr in fulltext_search_expr.values()])})" fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})" if vector_data: vector_data_str = "[" + ",".join([str(np.float32(v)) for v in vector_data]) + "]" vector_search_expr = vector_search_template % (vector_column_name, vector_data_str) # use (1 - cosine_distance) as score, which should be [-1, 1] # https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323 vector_search_score_expr = f"(1 - {vector_search_expr})" vector_search_filter = f"{vector_search_score_expr} >= {vector_similarity_threshold}" pagerank_score_expr = f"(CAST(IFNULL({PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)" # TODO use tag rank_feature in sorting # tag_rank_fea = {k: float(v) for k, v in (rank_feature or {}).items() if k != PAGERANK_FLD} if fulltext_query and vector_data: search_type = "fusion" elif fulltext_query: search_type = "fulltext" elif vector_data: search_type = "vector" elif len(agg_fields) > 0: search_type = "aggregation" else: search_type = "filter" if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields: output_fields.append("_score") if limit: if vector_topn is not None: limit = min(vector_topn, limit) if fulltext_topn is not None: limit = min(fulltext_topn, limit) for index_name in index_names: if not self._check_table_exists_cached(index_name): continue fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else "" if search_type == "fusion": # fusion search, usually for chat num_candidates = vector_topn + fulltext_topn if self.use_fulltext_first_fusion_search: count_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" f" FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY relevance DESC" f" LIMIT {num_candidates}" f")" f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}" ) else: count_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} id FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY {fulltext_search_score_expr}" f" LIMIT {fulltext_topn}" f")," f"vector_results AS (" f" SELECT id FROM {index_name}" f" WHERE {filters_expr} AND {vector_search_filter}" f" ORDER BY {vector_search_expr}" f" APPROXIMATE LIMIT {vector_topn}" f")" f" SELECT COUNT(*) FROM fulltext_results f FULL OUTER JOIN vector_results v ON f.id = v.id" ) logger.debug("OBConnection.search with count sql: %s", count_sql) rows, elapsed_time = self._execute_search_sql(count_sql) total_count = rows[0][0] if rows else 0 result.total += total_count logger.info( f"OBConnection.search table {index_name}, search type: fusion, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" vector column: '{vector_column_name}'," f" query text: '{fulltext_query}'," f" condition: '{condition}'," f" vector_similarity_threshold: {vector_similarity_threshold}," f" got count: {total_count}" ) if total_count == 0: continue if self.use_fulltext_first_fusion_search: score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})" fusion_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" f" FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY relevance DESC" f" LIMIT {num_candidates}" f")" f" SELECT {fields_expr}, {score_expr} AS _score" f" FROM fulltext_results" f" WHERE {vector_search_filter}" f" ORDER BY _score DESC" f" LIMIT {offset}, {limit}" ) else: pagerank_score_expr = f"(CAST(IFNULL(f.{PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)" score_expr = f"(f.relevance * {1 - vector_similarity_weight} + v.similarity * {vector_similarity_weight} + {pagerank_score_expr})" fields_expr = ", ".join([f"t.{f} as {f}" for f in output_fields if f != "_score"]) fusion_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} id, pagerank_fea, {fulltext_search_score_expr} AS relevance" f" FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY relevance DESC" f" LIMIT {fulltext_topn}" f")," f"vector_results AS (" f" SELECT id, pagerank_fea, {vector_search_score_expr} AS similarity" f" FROM {index_name}" f" WHERE {filters_expr} AND {vector_search_filter}" f" ORDER BY {vector_search_expr}" f" APPROXIMATE LIMIT {vector_topn}" f")," f"combined_results AS (" f" SELECT COALESCE(f.id, v.id) AS id, {score_expr} AS score" f" FROM fulltext_results f" f" FULL OUTER JOIN vector_results v" f" ON f.id = v.id" f")" f" SELECT {fields_expr}, c.score as _score" f" FROM combined_results c" f" JOIN {index_name} t" f" ON c.id = t.id" f" ORDER BY score DESC" f" LIMIT {offset}, {limit}" ) logger.debug("OBConnection.search with fusion sql: %s", fusion_sql) rows, elapsed_time = self._execute_search_sql(fusion_sql) logger.info( f"OBConnection.search table {index_name}, search type: fusion, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," f" vector column: '{vector_column_name}'," f" query text: '{fulltext_query}'," f" condition: '{condition}'," f" vector_similarity_threshold: {vector_similarity_threshold}," f" vector_similarity_weight: {vector_similarity_weight}," f" return rows count: {len(rows)}" ) for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "vector": # vector search, usually used for graph search count_sql = self._build_count_sql(index_name, filters_expr, vector_search_filter) logger.debug("OBConnection.search with vector count sql: %s", count_sql) rows, elapsed_time = self._execute_search_sql(count_sql) total_count = rows[0][0] if rows else 0 result.total += total_count logger.info( f"OBConnection.search table {index_name}, search type: vector, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" vector column: '{vector_column_name}'," f" condition: '{condition}'," f" vector_similarity_threshold: {vector_similarity_threshold}," f" got count: {total_count}" ) if total_count == 0: continue vector_sql = self._build_vector_search_sql( index_name, fields_expr, vector_search_score_expr, filters_expr, vector_search_filter, vector_search_expr, limit, vector_topn, offset ) logger.debug("OBConnection.search with vector sql: %s", vector_sql) rows, elapsed_time = self._execute_search_sql(vector_sql) logger.info( f"OBConnection.search table {index_name}, search type: vector, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," f" vector column: '{vector_column_name}'," f" condition: '{condition}'," f" vector_similarity_threshold: {vector_similarity_threshold}," f" return rows count: {len(rows)}" ) for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "fulltext": # fulltext search, usually used to search chunks in one dataset count_sql = self._build_count_sql(index_name, filters_expr, fulltext_search_filter, fulltext_search_hint) logger.debug("OBConnection.search with fulltext count sql: %s", count_sql) rows, elapsed_time = self._execute_search_sql(count_sql) total_count = rows[0][0] if rows else 0 result.total += total_count logger.info( f"OBConnection.search table {index_name}, search type: fulltext, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" query text: '{fulltext_query}'," f" condition: '{condition}'," f" got count: {total_count}" ) if total_count == 0: continue fulltext_sql = self._build_fulltext_search_sql( index_name, fields_expr, fulltext_search_score_expr, filters_expr, fulltext_search_filter, offset, limit, fulltext_topn, fulltext_search_hint ) logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql) rows, elapsed_time = self._execute_search_sql(fulltext_sql) logger.info( f"OBConnection.search table {index_name}, search type: fulltext, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," f" query text: '{fulltext_query}'," f" condition: '{condition}'," f" return rows count: {len(rows)}" ) for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "aggregation": # aggregation search assert len(agg_fields) == 1, "Only one aggregation field is supported in OceanBase." agg_field = agg_fields[0] if agg_field in array_columns: res = self.client.perform_raw_text_sql( f"SELECT {agg_field} FROM {index_name}" f" WHERE {agg_field} IS NOT NULL AND {filters_expr}" ) counts = {} for row in res: if row[0]: if isinstance(row[0], str): try: arr = json.loads(row[0]) except json.JSONDecodeError: logger.warning(f"Failed to parse JSON array: {row[0]}") continue else: arr = row[0] if isinstance(arr, list): for v in arr: if isinstance(v, str) and v.strip(): counts[v] = counts.get(v, 0) + 1 for v, count in counts.items(): result.chunks.append({ "value": v, "count": count, }) result.total += len(counts) else: res = self.client.perform_raw_text_sql( f"SELECT {agg_field}, COUNT(*) as count FROM {index_name}" f" WHERE {agg_field} IS NOT NULL AND {filters_expr}" f" GROUP BY {agg_field}" ) for row in res: result.chunks.append({ "value": row[0], "count": int(row[1]), }) result.total += 1 else: # only filter orders: list[str] = [] if order_by: for field, order in order_by.fields: if isinstance(column_types[field], ARRAY): f = field + "_sort" fields_expr += f", array_to_string({field}, ',') AS {f}" field = f order = "ASC" if order == 0 else "DESC" orders.append(f"{field} {order}") count_sql = self._build_count_sql(index_name, filters_expr) logger.debug("OBConnection.search with normal count sql: %s", count_sql) rows, elapsed_time = self._execute_search_sql(count_sql) total_count = rows[0][0] if rows else 0 result.total += total_count logger.info( f"OBConnection.search table {index_name}, search type: normal, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" condition: '{condition}'," f" got count: {total_count}" ) if total_count == 0: continue order_by_expr = ("ORDER BY " + ", ".join(orders)) if len(orders) > 0 else "" limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else "" filter_sql = self._build_filter_search_sql( index_name, fields_expr, filters_expr, order_by_expr, limit_expr ) logger.debug("OBConnection.search with normal sql: %s", filter_sql) rows, elapsed_time = self._execute_search_sql(filter_sql) logger.info( f"OBConnection.search table {index_name}, search type: normal, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," f" condition: '{condition}'," f" return rows count: {len(rows)}" ) for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) if result.total == 0: result.total = len(result.chunks) return result def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None: try: doc = super().get(chunk_id, index_name, knowledgebase_ids) if doc is None: return None return doc except json.JSONDecodeError as e: logger.error(f"JSON decode error when getting chunk {chunk_id}: {str(e)}") return { "id": chunk_id, "error": f"Failed to parse chunk data due to invalid JSON: {str(e)}" } except Exception as e: logger.exception(f"OBConnection.get({chunk_id}) got exception") raise e def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: if not documents: return [] # For doc_meta tables, use simple insert without field transformation if index_name.startswith("ragflow_doc_meta_"): return self._insert_doc_meta(documents, index_name) docs: list[dict] = [] ids: list[str] = [] for document in documents: d: dict = {} for k, v in document.items(): if vector_column_pattern.match(k): d[k] = v continue if k not in column_names: if "extra" not in d: d["extra"] = {} d["extra"][k] = v continue if v is None: d[k] = get_default_value(k) continue if k == "kb_id" and isinstance(v, list): d[k] = v[0] elif k == "content_with_weight" and isinstance(v, dict): d[k] = json.dumps(v, ensure_ascii=False) elif k == "position_int": d[k] = json.dumps([list(vv) for vv in v], ensure_ascii=False) elif isinstance(v, list): # remove characters like '\t' for JSON dump and clean special characters cleaned_v = [] for vv in v: if isinstance(vv, str): cleaned_str = vv.strip() cleaned_str = cleaned_str.replace('\\', '\\\\') cleaned_str = cleaned_str.replace('\n', '\\n') cleaned_str = cleaned_str.replace('\r', '\\r') cleaned_str = cleaned_str.replace('\t', '\\t') cleaned_v.append(cleaned_str) else: cleaned_v.append(vv) d[k] = json.dumps(cleaned_v, ensure_ascii=False) else: d[k] = v ids.append(d["id"]) # this is to fix https://github.com/sqlalchemy/sqlalchemy/issues/9703 for column_name in column_names: if column_name not in d: d[column_name] = get_default_value(column_name) metadata = d.get("metadata", {}) if metadata is None: metadata = {} group_id = metadata.get("_group_id") title = metadata.get("_title") if d.get("doc_id"): if group_id: d["group_id"] = group_id else: d["group_id"] = d["doc_id"] if title: d["docnm_kwd"] = title docs.append(d) logger.debug("OBConnection.insert chunks: %s", docs) res = [] try: self.client.upsert(index_name, docs) except Exception as e: logger.error(f"OBConnection.insert error: {str(e)}") res.append(str(e)) return res def _insert_doc_meta(self, documents: list[dict], index_name: str) -> list[str]: """Insert documents into doc_meta table with simple field handling.""" docs: list[dict] = [] for document in documents: d = { "id": document.get("id"), "kb_id": document.get("kb_id"), } # Handle meta_fields - store as JSON meta_fields = document.get("meta_fields") if meta_fields is not None: if isinstance(meta_fields, dict): d["meta_fields"] = json.dumps(meta_fields, ensure_ascii=False) elif isinstance(meta_fields, str): d["meta_fields"] = meta_fields else: d["meta_fields"] = "{}" else: d["meta_fields"] = "{}" docs.append(d) logger.debug("OBConnection._insert_doc_meta: %s", docs) res = [] try: self.client.upsert(index_name, docs) except Exception as e: logger.error(f"OBConnection._insert_doc_meta error: {str(e)}") res.append(str(e)) return res def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool: if not self._check_table_exists_cached(index_name): return True # For doc_meta tables, don't force kb_id in condition if not index_name.startswith("ragflow_doc_meta_"): condition["kb_id"] = knowledgebase_id filters = get_filters(condition) set_values: list[str] = [] for k, v in new_value.items(): if k == "remove": if isinstance(v, str): set_values.append(f"{v} = NULL") else: assert isinstance(v, dict), f"Expected str or dict for 'remove', got {type(new_value[k])}." for kk, vv in v.items(): assert kk in array_columns, f"Column '{kk}' is not an array column." set_values.append(f"{kk} = array_remove({kk}, {get_value_str(vv)})") elif k == "add": assert isinstance(v, dict), f"Expected str or dict for 'add', got {type(new_value[k])}." for kk, vv in v.items(): assert kk in array_columns, f"Column '{kk}' is not an array column." set_values.append(f"{kk} = array_append({kk}, {get_value_str(vv)})") elif k == "metadata": assert isinstance(v, dict), f"Expected dict for 'metadata', got {type(new_value[k])}" set_values.append(f"{k} = {get_value_str(v)}") if v and "doc_id" in condition: group_id = v.get("_group_id") title = v.get("_title") if group_id: set_values.append(f"group_id = {get_value_str(group_id)}") if title: set_values.append(f"docnm_kwd = {get_value_str(title)}") else: set_values.append(f"{k} = {get_value_str(v)}") if not set_values: return True update_sql = ( f"UPDATE {index_name}" f" SET {', '.join(set_values)}" f" WHERE {' AND '.join(filters)}" ) logger.debug("OBConnection.update sql: %s", update_sql) try: self.client.perform_raw_text_sql(update_sql) return True except Exception as e: logger.error(f"OBConnection.update error: {str(e)}") return False def _row_to_entity(self, data: Row, fields: list[str]) -> dict: entity = {} for i, field in enumerate(fields): value = data[i] if value is None: continue entity[field] = get_column_value(field, value) return entity @staticmethod def _es_row_to_entity(data: dict) -> dict: entity = {} for k, v in data.items(): if v is None: continue entity[k] = get_column_value(k, v) return entity """ Helper functions for search result """ def get_total(self, res) -> int: return res.total def get_doc_ids(self, res) -> list[str]: return [row["id"] for row in res.chunks] def get_fields(self, res, fields: list[str]) -> dict[str, dict]: result = {} for row in res.chunks: data = {} for field in fields: v = row.get(field) if v is not None: data[field] = v result[row["id"]] = data return result # copied from query.FulltextQueryer def is_chinese(self, line): arr = re.split(r"[ \t]+", line) if len(arr) <= 3: return True e = 0 for t in arr: if not re.match(r"[a-zA-Z]+$", t): e += 1 return e * 1.0 / len(arr) >= 0.7 def highlight(self, txt: str, tks: str, question: str, keywords: list[str]) -> Optional[str]: if not txt or not keywords: return None highlighted_txt = txt if question and not self.is_chinese(question): highlighted_txt = re.sub( r"(^|\W)(%s)(\W|$)" % re.escape(question), r"\1\2\3", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE, ) if re.search(r"[^<>]+", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE): return highlighted_txt for keyword in keywords: highlighted_txt = re.sub( r"(^|\W)(%s)(\W|$)" % re.escape(keyword), r"\1\2\3", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE, ) if len(re.findall(r'', highlighted_txt)) > 0 or len( re.findall(r'\s*', highlighted_txt)) > 0: return highlighted_txt else: return None if not tks: tks = rag_tokenizer.tokenize(txt) tokens = tks.split() if not tokens: return None last_pos = len(txt) for i in range(len(tokens) - 1, -1, -1): token = tokens[i] token_pos = highlighted_txt.rfind(token, 0, last_pos) if token_pos != -1: if token in keywords: highlighted_txt = ( highlighted_txt[:token_pos] + f'{token}' + highlighted_txt[token_pos + len(token):] ) last_pos = token_pos return re.sub(r'', '', highlighted_txt) def get_highlight(self, res, keywords: list[str], fieldnm: str): ans = {} if len(res.chunks) == 0 or len(keywords) == 0: return ans for d in res.chunks: txt = d.get(fieldnm) if not txt: continue tks = d.get("content_ltks") if fieldnm == "content_with_weight" else "" highlighted_txt = self.highlight(txt, tks, " ".join(keywords), keywords) if highlighted_txt: ans[d["id"]] = highlighted_txt return ans def get_aggregation(self, res, fieldnm: str): if len(res.chunks) == 0: return [] counts = {} result = [] for d in res.chunks: if "value" in d and "count" in d: # directly use the aggregation result result.append((d["value"], d["count"])) elif fieldnm in d: # aggregate the values of specific field v = d[fieldnm] if isinstance(v, list): for vv in v: if isinstance(vv, str) and vv.strip(): counts[vv] = counts.get(vv, 0) + 1 elif isinstance(v, str) and v.strip(): counts[v] = counts.get(v, 0) + 1 if len(counts) > 0: for k, v in counts.items(): result.append((k, v)) return result """ SQL """ def sql(self, sql: str, fetch_size: int = 1024, format: str = "json"): logger.debug("OBConnection.sql get sql: %s", sql) def normalize_sql(sql_text: str) -> str: cleaned = sql_text.strip().rstrip(";") cleaned = re.sub(r"[`]+", "", cleaned) cleaned = re.sub( r"json_extract_string\s*\(\s*([^,]+?)\s*,\s*([^)]+?)\s*\)", r"JSON_UNQUOTE(JSON_EXTRACT(\1, \2))", cleaned, flags=re.IGNORECASE, ) cleaned = re.sub( r"json_extract_isnull\s*\(\s*([^,]+?)\s*,\s*([^)]+?)\s*\)", r"(JSON_EXTRACT(\1, \2) IS NULL)", cleaned, flags=re.IGNORECASE, ) return cleaned def coerce_value(value: Any) -> Any: if isinstance(value, np.generic): return value.item() if isinstance(value, bytes): return value.decode("utf-8", errors="ignore") return value sql_text = normalize_sql(sql) if fetch_size and fetch_size > 0: sql_lower = sql_text.lstrip().lower() if re.match(r"^(select|with)\b", sql_lower) and not re.search(r"\blimit\b", sql_lower): sql_text = f"{sql_text} LIMIT {int(fetch_size)}" logger.debug("OBConnection.sql to ob: %s", sql_text) try: res = self.client.perform_raw_text_sql(sql_text) except Exception: logger.exception("OBConnection.sql got exception") raise if res is None: return None columns = list(res.keys()) if hasattr(res, "keys") else [] try: rows = res.fetchmany(fetch_size) if fetch_size and fetch_size > 0 else res.fetchall() except Exception: rows = res.fetchall() rows_list = [[coerce_value(v) for v in list(row)] for row in rows] result = { "columns": [{"name": col, "type": "text"} for col in columns], "rows": rows_list, } if format == "markdown": header = "|" + "|".join(columns) + "|" if columns else "" separator = "|" + "|".join(["---" for _ in columns]) + "|" if columns else "" body = "\n".join(["|" + "|".join([str(v) for v in row]) + "|" for row in rows_list]) result["markdown"] = "\n".join([line for line in [header, separator, body] if line]) return result