# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import logging import os import re import time from typing import Any, Optional from elasticsearch_dsl import Q, Search from pydantic import BaseModel from pymysql.converters import escape_string from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR from pyobvector.client.hybrid_search import HybridSearch from pyobvector.util import ObVersion from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table from sqlalchemy.dialects.mysql import LONGTEXT, TEXT from sqlalchemy.sql.type_api import TypeEngine from common import settings from common.constants import PAGERANK_FLD, TAG_FLD from common.decorator import singleton from common.float_utils import get_float from rag.nlp import rag_tokenizer from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \ MatchDenseExpr ATTEMPT_TIME = 2 OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000")) logger = logging.getLogger('ragflow.ob_conn') column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk order id for maintaining sequence") column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval") column_definitions: list[Column] = [ Column("id", String(256), primary_key=True, comment="chunk id"), Column("kb_id", String(256), nullable=False, index=True, comment="knowledge base id"), Column("doc_id", String(256), nullable=True, index=True, comment="document id"), Column("docnm_kwd", String(256), nullable=True, comment="document name"), Column("doc_type_kwd", String(256), nullable=True, comment="document type"), Column("title_tks", String(256), nullable=True, comment="title tokens"), Column("title_sm_tks", String(256), nullable=True, comment="fine-grained (small) title tokens"), Column("content_with_weight", LONGTEXT, nullable=True, comment="the original content"), Column("content_ltks", LONGTEXT, nullable=True, comment="long text tokens derived from content_with_weight"), Column("content_sm_ltks", LONGTEXT, nullable=True, comment="fine-grained (small) tokens derived from content_ltks"), Column("pagerank_fea", Integer, nullable=True, comment="page rank priority, usually set in kb level"), Column("important_kwd", ARRAY(String(256)), nullable=True, comment="keywords"), Column("important_tks", TEXT, nullable=True, comment="keyword tokens"), Column("question_kwd", ARRAY(String(1024)), nullable=True, comment="questions"), Column("question_tks", TEXT, nullable=True, comment="question tokens"), Column("tag_kwd", ARRAY(String(256)), nullable=True, comment="tags"), Column("tag_feas", JSON, nullable=True, comment="tag features used for 'rank_feature', format: [tag -> relevance score]"), Column("available_int", Integer, nullable=False, index=True, server_default="1", comment="status of availability, 0 for unavailable, 1 for available"), Column("create_time", String(19), nullable=True, comment="creation time in YYYY-MM-DD HH:MM:SS format"), Column("create_timestamp_flt", Double, nullable=True, comment="creation timestamp in float format"), Column("img_id", String(128), nullable=True, comment="image id"), Column("position_int", ARRAY(ARRAY(Integer)), nullable=True, comment="position"), Column("page_num_int", ARRAY(Integer), nullable=True, comment="page number"), Column("top_int", ARRAY(Integer), nullable=True, comment="rank from the top"), Column("knowledge_graph_kwd", String(256), nullable=True, index=True, comment="knowledge graph chunk type"), Column("source_id", ARRAY(String(256)), nullable=True, comment="source document id"), Column("entity_kwd", String(256), nullable=True, comment="entity name"), Column("entity_type_kwd", String(256), nullable=True, index=True, comment="entity type"), Column("from_entity_kwd", String(256), nullable=True, comment="the source entity of this edge"), Column("to_entity_kwd", String(256), nullable=True, comment="the target entity of this edge"), Column("weight_int", Integer, nullable=True, comment="the weight of this edge"), Column("weight_flt", Double, nullable=True, comment="the weight of community report"), Column("entities_kwd", ARRAY(String(256)), nullable=True, comment="node ids of entities"), Column("rank_flt", Double, nullable=True, comment="rank of this entity"), Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'", comment="whether it has been deleted"), Column("metadata", JSON, nullable=True, comment="metadata for this chunk"), Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"), column_order_id, column_group_id, ] column_names: list[str] = [col.name for col in column_definitions] column_types: dict[str, TypeEngine] = {col.name: col.type for col in column_definitions} array_columns: list[str] = [col.name for col in column_definitions if isinstance(col.type, ARRAY)] vector_column_pattern = re.compile(r"q_(?P\d+)_vec") index_columns: list[str] = [ "kb_id", "doc_id", "available_int", "knowledge_graph_kwd", "entity_type_kwd", "removed_kwd", ] fulltext_search_columns: list[str] = [ "docnm_kwd", "content_with_weight", "title_tks", "title_sm_tks", "important_tks", "question_tks", "content_ltks", "content_sm_ltks" ] fts_columns_origin: list[str] = [ "docnm_kwd^10", "content_with_weight", "important_tks^20", "question_tks^20", ] fts_columns_tks: list[str] = [ "title_tks^10", "title_sm_tks^5", "important_tks^20", "question_tks^20", "content_ltks^2", "content_sm_ltks", ] index_name_template = "ix_%s_%s" fulltext_index_name_template = "fts_idx_%s" # MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607 fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)" # cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938 vector_search_template = "cosine_distance(%s, %s)" class SearchResult(BaseModel): total: int chunks: list[dict] def get_column_value(column_name: str, value: Any) -> Any: if column_name in column_types: column_type = column_types[column_name] if isinstance(column_type, String): return str(value) elif isinstance(column_type, Integer): return int(value) elif isinstance(column_type, Double): return float(value) elif isinstance(column_type, ARRAY) or isinstance(column_type, JSON): if isinstance(value, str): try: return json.loads(value) except json.JSONDecodeError: return value else: return value else: raise ValueError(f"Unsupported column type for column '{column_name}': {column_type}") elif vector_column_pattern.match(column_name): if isinstance(value, str): try: return json.loads(value) except json.JSONDecodeError: return value else: return value elif column_name == "_score": return float(value) else: raise ValueError(f"Unknown column '{column_name}' with value '{value}'.") def get_default_value(column_name: str) -> Any: if column_name == "available_int": return 1 elif column_name == "removed_kwd": return "N" elif column_name == "_order_id": return 0 else: return None def get_value_str(value: Any) -> str: if isinstance(value, str): cleaned_str = value.replace('\\', '\\\\') cleaned_str = cleaned_str.replace('\n', '\\n') cleaned_str = cleaned_str.replace('\r', '\\r') cleaned_str = cleaned_str.replace('\t', '\\t') return f"'{escape_string(cleaned_str)}'" elif isinstance(value, bool): return "true" if value else "false" elif value is None: return "NULL" elif isinstance(value, (list, dict)): json_str = json.dumps(value, ensure_ascii=False) return f"'{escape_string(json_str)}'" else: return str(value) def get_metadata_filter_expression(metadata_filtering_conditions: dict) -> str: """ Convert metadata filtering conditions to MySQL JSON path expression. Args: metadata_filtering_conditions: dict with 'conditions' and 'logical_operator' keys Returns: MySQL JSON path expression string """ if not metadata_filtering_conditions: return "" conditions = metadata_filtering_conditions.get("conditions", []) logical_operator = metadata_filtering_conditions.get("logical_operator", "and").upper() if not conditions: return "" if logical_operator not in ["AND", "OR"]: raise ValueError(f"Unsupported logical operator: {logical_operator}. Only 'and' and 'or' are supported.") metadata_filters = [] for condition in conditions: name = condition.get("name") comparison_operator = condition.get("comparison_operator") value = condition.get("value") if not all([name, comparison_operator]): continue expr = f"JSON_EXTRACT(metadata, '$.{name}')" value_str = get_value_str(value) if value else "" # Convert comparison operator to MySQL JSON path syntax if comparison_operator == "is": # JSON_EXTRACT(metadata, '$.field_name') = 'value' metadata_filters.append(f"{expr} = {value_str}") elif comparison_operator == "is not": metadata_filters.append(f"{expr} != {value_str}") elif comparison_operator == "contains": metadata_filters.append(f"JSON_CONTAINS({expr}, {value_str})") elif comparison_operator == "not contains": metadata_filters.append(f"NOT JSON_CONTAINS({expr}, {value_str})") elif comparison_operator == "start with": metadata_filters.append(f"{expr} LIKE CONCAT({value_str}, '%')") elif comparison_operator == "end with": metadata_filters.append(f"{expr} LIKE CONCAT('%', {value_str})") elif comparison_operator == "empty": metadata_filters.append(f"({expr} IS NULL OR {expr} = '' OR {expr} = '[]' OR {expr} = '{{}}')") elif comparison_operator == "not empty": metadata_filters.append(f"({expr} IS NOT NULL AND {expr} != '' AND {expr} != '[]' AND {expr} != '{{}}')") # Number operators elif comparison_operator == "=": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) = {value_str}") elif comparison_operator == "≠": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) != {value_str}") elif comparison_operator == ">": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) > {value_str}") elif comparison_operator == "<": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) < {value_str}") elif comparison_operator == "≥": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) >= {value_str}") elif comparison_operator == "≤": metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) <= {value_str}") # Time operators elif comparison_operator == "before": metadata_filters.append(f"CAST({expr} AS DATETIME) < {value_str}") elif comparison_operator == "after": metadata_filters.append(f"CAST({expr} AS DATETIME) > {value_str}") else: logger.warning(f"Unsupported comparison operator: {comparison_operator}") continue if not metadata_filters: return "" return f"({f' {logical_operator} '.join(metadata_filters)})" def get_filters(condition: dict) -> list[str]: filters: list[str] = [] for k, v in condition.items(): if not v: continue if k == "exists": filters.append(f"{v} IS NOT NULL") elif k == "must_not" and isinstance(v, dict) and "exists" in v: filters.append(f"{v.get('exists')} IS NULL") elif k == "metadata_filtering_conditions": # Handle metadata filtering conditions metadata_filter = get_metadata_filter_expression(v) if metadata_filter: filters.append(metadata_filter) elif k in array_columns: if isinstance(v, list): array_filters = [] for vv in v: array_filters.append(f"array_contains({k}, {get_value_str(vv)})") array_filter = " OR ".join(array_filters) filters.append(f"({array_filter})") else: filters.append(f"array_contains({k}, {get_value_str(v)})") elif isinstance(v, list): values: list[str] = [] for item in v: values.append(get_value_str(item)) value = ", ".join(values) filters.append(f"{k} IN ({value})") else: filters.append(f"{k} = {get_value_str(v)}") return filters def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None): if not timeout: timeout = int(os.environ.get("OB_DDL_TIMEOUT", "60")) if not check_func(): from rag.utils.redis_conn import RedisDistributedLock lock = RedisDistributedLock(lock_name) if lock.acquire(): logger.info(f"acquired lock success: {lock_name}, start processing.") try: process_func() return finally: lock.release() if not check_func(): logger.info(f"Waiting for process complete for {lock_name} on other task executors.") time.sleep(1) count = 1 while count < timeout and not check_func(): count += 1 time.sleep(1) if count >= timeout and not check_func(): raise Exception(f"Timeout to wait for process complete for {lock_name}.") @singleton class OBConnection(DocStoreConnection): def __init__(self): scheme: str = settings.OB.get("scheme") ob_config = settings.OB.get("config", {}) if scheme and scheme.lower() == "mysql": mysql_config = settings.get_base_config("mysql", {}) logger.info("Use MySQL scheme to create OceanBase connection.") host = mysql_config.get("host", "localhost") port = mysql_config.get("port", 2881) self.username = mysql_config.get("user", "root@test") self.password = mysql_config.get("password", "infini_rag_flow") else: logger.info("Use customized config to create OceanBase connection.") host = ob_config.get("host", "localhost") port = ob_config.get("port", 2881) self.username = ob_config.get("user", "root@test") self.password = ob_config.get("password", "infini_rag_flow") self.db_name = ob_config.get("db_name", "test") self.uri = f"{host}:{port}" logger.info(f"Use OceanBase '{self.uri}' as the doc engine.") for _ in range(ATTEMPT_TIME): try: self.client = ObVecClient( uri=self.uri, user=self.username, password=self.password, db_name=self.db_name, pool_pre_ping=True, pool_recycle=3600, ) break except Exception as e: logger.warning(f"{str(e)}. Waiting OceanBase {self.uri} to be healthy.") time.sleep(5) if self.client is None: msg = f"OceanBase {self.uri} connection failed after {ATTEMPT_TIME} attempts." logger.error(msg) raise Exception(msg) self._load_env_vars() self._check_ob_version() self._try_to_update_ob_query_timeout() logger.info(f"OceanBase {self.uri} is healthy.") def _check_ob_version(self): try: res = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL").fetchone() version_str = res[0] if res else None logger.info(f"OceanBase {self.uri} version is {version_str}") except Exception as e: raise Exception(f"Failed to get OceanBase version from {self.uri}, error: {str(e)}") if not version_str: raise Exception(f"Failed to get OceanBase version from {self.uri}.") ob_version = ObVersion.from_db_version_string(version_str) if ob_version < ObVersion.from_db_version_nums(4, 3, 5, 1): raise Exception( f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}" ) self.es = None if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search: self.es = HybridSearch( uri=self.uri, user=self.username, password=self.password, db_name=self.db_name, pool_pre_ping=True, pool_recycle=3600, ) logger.info("OceanBase Hybrid Search feature is enabled") def _try_to_update_ob_query_timeout(self): try: val = self._get_variable_value("ob_query_timeout") if val and int(val) >= OB_QUERY_TIMEOUT: return except Exception as e: logger.warning("Failed to get 'ob_query_timeout' variable: %s", str(e)) try: self.client.perform_raw_text_sql(f"SET GLOBAL ob_query_timeout={OB_QUERY_TIMEOUT}") logger.info("Set GLOBAL variable 'ob_query_timeout' to %d.", OB_QUERY_TIMEOUT) # refresh connection pool to ensure 'ob_query_timeout' has taken effect self.client.engine.dispose() if self.es is not None: self.es.engine.dispose() logger.info("Disposed all connections in engine pool to refresh connection pool") except Exception as e: logger.warning(f"Failed to set 'ob_query_timeout' variable: {str(e)}") def _load_env_vars(self): def is_true(var: str, default: str) -> bool: return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y'] self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true') self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true') self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true') self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false') """ Database operations """ def dbType(self) -> str: return "oceanbase" def health(self) -> dict: return { "uri": self.uri, "version_comment": self._get_variable_value("version_comment") } def _get_variable_value(self, var_name: str) -> Any: rows = self.client.perform_raw_text_sql(f"SHOW VARIABLES LIKE '{var_name}'") for row in rows: return row[1] raise Exception(f"Variable '{var_name}' not found.") """ Table operations """ def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): vector_field_name = f"q_{vectorSize}_vec" vector_index_name = f"{vector_field_name}_idx" try: _try_with_lock( lock_name=f"ob_create_table_{indexName}", check_func=lambda: self.client.check_table_exists(indexName), process_func=lambda: self._create_table(indexName), ) for column_name in index_columns: _try_with_lock( lock_name=f"ob_add_idx_{indexName}_{column_name}", check_func=lambda: self._index_exists(indexName, index_name_template % (indexName, column_name)), process_func=lambda: self._add_index(indexName, column_name), ) fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks for fts_column in fts_columns: column_name = fts_column.split("^")[0] _try_with_lock( lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}", check_func=lambda: self._index_exists(indexName, fulltext_index_name_template % column_name), process_func=lambda: self._add_fulltext_index(indexName, column_name), ) _try_with_lock( lock_name=f"ob_add_vector_column_{indexName}_{vector_field_name}", check_func=lambda: self._column_exist(indexName, vector_field_name), process_func=lambda: self._add_vector_column(indexName, vectorSize), ) _try_with_lock( lock_name=f"ob_add_vector_idx_{indexName}_{vector_field_name}", check_func=lambda: self._index_exists(indexName, vector_index_name), process_func=lambda: self._add_vector_index(indexName, vector_field_name), ) # new columns migration for column in [column_order_id, column_group_id]: _try_with_lock( lock_name=f"ob_add_{column.name}_{indexName}", check_func=lambda: self._column_exist(indexName, column.name), process_func=lambda: self._add_column(indexName, column), ) except Exception as e: raise Exception(f"OBConnection.createIndex error: {str(e)}") finally: # always refresh metadata to make sure it contains the latest table structure self.client.refresh_metadata([indexName]) def deleteIdx(self, indexName: str, knowledgebaseId: str): if len(knowledgebaseId) > 0: # The index need to be alive after any kb deletion since all kb under this tenant are in one index. return try: if self.client.check_table_exists(table_name=indexName): self.client.drop_table_if_exist(indexName) logger.info(f"Dropped table '{indexName}'.") except Exception as e: raise Exception(f"OBConnection.deleteIndex error: {str(e)}") def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool: try: if not self.client.check_table_exists(indexName): return False for column_name in index_columns: if not self._index_exists(indexName, index_name_template % (indexName, column_name)): return False fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks for fts_column in fts_columns: column_name = fts_column.split("^")[0] if not self._index_exists(indexName, fulltext_index_name_template % column_name): return False for column in [column_order_id, column_group_id]: if not self._column_exist(indexName, column.name): return False except Exception as e: raise Exception(f"OBConnection.indexExist error: {str(e)}") return True def _get_count(self, table_name: str, filter_list: list[str] = None) -> int: where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else "" (count,) = self.client.perform_raw_text_sql( f"SELECT COUNT(*) FROM {table_name} {where_clause}" ).fetchone() return count def _column_exist(self, table_name: str, column_name: str) -> bool: return self._get_count( table_name="INFORMATION_SCHEMA.COLUMNS", filter_list=[ f"TABLE_SCHEMA = '{self.db_name}'", f"TABLE_NAME = '{table_name}'", f"COLUMN_NAME = '{column_name}'", ]) > 0 def _index_exists(self, table_name: str, index_name: str) -> bool: return self._get_count( table_name="INFORMATION_SCHEMA.STATISTICS", filter_list=[ f"TABLE_SCHEMA = '{self.db_name}'", f"TABLE_NAME = '{table_name}'", f"INDEX_NAME = '{index_name}'", ]) > 0 def _create_table(self, table_name: str): # remove outdated metadata for external changes if table_name in self.client.metadata_obj.tables: self.client.metadata_obj.remove(Table(table_name, self.client.metadata_obj)) table_options = { "mysql_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci", "mysql_organization": "heap", } self.client.create_table( table_name=table_name, columns=column_definitions, **table_options, ) logger.info(f"Created table '{table_name}'.") def _add_index(self, table_name: str, column_name: str): index_name = index_name_template % (table_name, column_name) self.client.create_index( table_name=table_name, is_vec_index=False, index_name=index_name, column_names=[column_name], ) logger.info(f"Created index '{index_name}' on table '{table_name}'.") def _add_fulltext_index(self, table_name: str, column_name: str): fulltext_index_name = fulltext_index_name_template % column_name self.client.create_fts_idx_with_fts_index_param( table_name=table_name, fts_idx_param=FtsIndexParam( index_name=fulltext_index_name, field_names=[column_name], parser_type=FtsParser.IK, ), ) logger.info(f"Created full text index '{fulltext_index_name}' on table '{table_name}'.") def _add_vector_column(self, table_name: str, vector_size: int): vector_field_name = f"q_{vector_size}_vec" self.client.add_columns( table_name=table_name, columns=[Column(vector_field_name, VECTOR(vector_size), nullable=True)], ) logger.info(f"Added vector column '{vector_field_name}' to table '{table_name}'.") def _add_vector_index(self, table_name: str, vector_field_name: str): vector_index_name = f"{vector_field_name}_idx" self.client.create_index( table_name=table_name, is_vec_index=True, index_name=vector_index_name, column_names=[vector_field_name], vidx_params="distance=cosine, type=hnsw, lib=vsag", ) logger.info( f"Created vector index '{vector_index_name}' on table '{table_name}' with column '{vector_field_name}'." ) def _add_column(self, table_name: str, column: Column): try: self.client.add_columns( table_name=table_name, columns=[column], ) logger.info(f"Added column '{column.name}' to table '{table_name}'.") except Exception as e: logger.warning(f"Failed to add column '{column.name}' to table '{table_name}': {str(e)}") """ CRUD operations """ def search( self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str], knowledgebaseIds: list[str], aggFields: list[str] = [], rank_feature: dict | None = None, **kwargs, ): if isinstance(indexNames, str): indexNames = indexNames.split(",") assert isinstance(indexNames, list) and len(indexNames) > 0 indexNames = list(set(indexNames)) if len(matchExprs) == 3: if not self.enable_fulltext_search: # disable fulltext search in fusion search, which means fallback to vector search matchExprs = [m for m in matchExprs if isinstance(m, MatchDenseExpr)] else: for m in matchExprs: if isinstance(m, FusionExpr): weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) # skip the search if its weight is zero if vector_similarity_weight <= 0.0: matchExprs = [m for m in matchExprs if isinstance(m, MatchTextExpr)] elif vector_similarity_weight >= 1.0: matchExprs = [m for m in matchExprs if isinstance(m, MatchDenseExpr)] result: SearchResult = SearchResult( total=0, chunks=[], ) # copied from es_conn.py if len(matchExprs) == 3 and self.es: bqry = Q("bool", must=[]) condition["kb_id"] = knowledgebaseIds for k, v in condition.items(): if k == "available_int": if v == 0: bqry.filter.append(Q("range", available_int={"lt": 1})) else: bqry.filter.append( Q("bool", must_not=Q("range", available_int={"lt": 1}))) continue if not v: continue if isinstance(v, list): bqry.filter.append(Q("terms", **{k: v})) elif isinstance(v, str) or isinstance(v, int): bqry.filter.append(Q("term", **{k: v})) else: raise Exception( f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") s = Search() vector_similarity_weight = 0.5 for m in matchExprs: if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance( matchExprs[1], MatchDenseExpr) and isinstance( matchExprs[2], FusionExpr) weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) for m in matchExprs: if isinstance(m, MatchTextExpr): minimum_should_match = m.extra_options.get("minimum_should_match", 0.0) if isinstance(minimum_should_match, float): minimum_should_match = str(int(minimum_should_match * 100)) + "%" bqry.must.append(Q("query_string", fields=fts_columns_tks, type="best_fields", query=m.matching_text, minimum_should_match=minimum_should_match, boost=1)) bqry.boost = 1.0 - vector_similarity_weight elif isinstance(m, MatchDenseExpr): assert (bqry is not None) similarity = 0.0 if "similarity" in m.extra_options: similarity = m.extra_options["similarity"] s = s.knn(m.vector_column_name, m.topn, m.topn * 2, query_vector=list(m.embedding_data), filter=bqry.to_dict(), similarity=similarity, ) if bqry and rank_feature: for fld, sc in rank_feature.items(): if fld != PAGERANK_FLD: fld = f"{TAG_FLD}.{fld}" bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc)) if bqry: s = s.query(bqry) # for field in highlightFields: # s = s.highlight(field) if orderBy: orders = list() for field, order in orderBy.fields: order = "asc" if order == 0 else "desc" if field in ["page_num_int", "top_int"]: order_info = {"order": order, "unmapped_type": "float", "mode": "avg", "numeric_type": "double"} elif field.endswith("_int") or field.endswith("_flt"): order_info = {"order": order, "unmapped_type": "float"} else: order_info = {"order": order, "unmapped_type": "text"} orders.append({field: order_info}) s = s.sort(*orders) for fld in aggFields: s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000) if limit > 0: s = s[offset:offset + limit] q = s.to_dict() logger.debug(f"OBConnection.hybrid_search {str(indexNames)} query: " + json.dumps(q)) for index_name in indexNames: start_time = time.time() res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True) elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: hybrid, elapsed time: {elapsed_time:.3f} seconds," f" got count: {len(res)}" ) for chunk in res: result.chunks.append(self._es_row_to_entity(chunk)) result.total = result.total + 1 return result output_fields = selectFields.copy() if "id" not in output_fields: output_fields = ["id"] + output_fields if "_score" in output_fields: output_fields.remove("_score") if highlightFields: for field in highlightFields: if field not in output_fields: output_fields.append(field) fields_expr = ", ".join(output_fields) condition["kb_id"] = knowledgebaseIds filters: list[str] = get_filters(condition) filters_expr = " AND ".join(filters) fulltext_query: Optional[str] = None fulltext_topn: Optional[int] = None fulltext_search_weight: dict[str, float] = {} fulltext_search_expr: dict[str, str] = {} fulltext_search_idx_list: list[str] = [] fulltext_search_score_expr: Optional[str] = None fulltext_search_filter: Optional[str] = None vector_column_name: Optional[str] = None vector_data: Optional[list[float]] = None vector_topn: Optional[int] = None vector_similarity_threshold: Optional[float] = None vector_similarity_weight: Optional[float] = None vector_search_expr: Optional[str] = None vector_search_score_expr: Optional[str] = None vector_search_filter: Optional[str] = None for m in matchExprs: if isinstance(m, MatchTextExpr): assert "original_query" in m.extra_options, "'original_query' is missing in extra_options." fulltext_query = m.extra_options["original_query"] fulltext_query = escape_string(fulltext_query.strip()) fulltext_topn = m.topn fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks # get fulltext match expression and weight values for field in fts_columns: parts = field.split("^") column_name: str = parts[0] column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0 fulltext_search_weight[column_name] = column_weight fulltext_search_expr[column_name] = fulltext_search_template % (column_name, fulltext_query) fulltext_search_idx_list.append(fulltext_index_name_template % column_name) # adjust the weight to 0~1 weight_sum = sum(fulltext_search_weight.values()) for column_name in fulltext_search_weight.keys(): fulltext_search_weight[column_name] = fulltext_search_weight[column_name] / weight_sum elif isinstance(m, MatchDenseExpr): assert m.embedding_data_type == "float", f"embedding data type '{m.embedding_data_type}' is not float." vector_column_name = m.vector_column_name vector_data = m.embedding_data vector_topn = m.topn vector_similarity_threshold = m.extra_options.get("similarity", 0.0) elif isinstance(m, FusionExpr): weights = m.fusion_params["weights"] vector_similarity_weight = get_float(weights.split(",")[1]) if fulltext_query: fulltext_search_filter = f"({' OR '.join([expr for expr in fulltext_search_expr.values()])})" fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})" if vector_data: vector_search_expr = vector_search_template % (vector_column_name, vector_data) # use (1 - cosine_distance) as score, which should be [-1, 1] # https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323 vector_search_score_expr = f"(1 - {vector_search_expr})" vector_search_filter = f"{vector_search_score_expr} >= {vector_similarity_threshold}" pagerank_score_expr = f"(CAST(IFNULL({PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)" # TODO use tag rank_feature in sorting # tag_rank_fea = {k: float(v) for k, v in (rank_feature or {}).items() if k != PAGERANK_FLD} if fulltext_query and vector_data: search_type = "fusion" elif fulltext_query: search_type = "fulltext" elif vector_data: search_type = "vector" elif len(aggFields) > 0: search_type = "aggregation" else: search_type = "filter" if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields: output_fields.append("_score") group_results = kwargs.get("group_results", False) for index_name in indexNames: if not self.client.check_table_exists(index_name): continue fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else "" if search_type == "fusion": # fusion search, usually for chat num_candidates = vector_topn + fulltext_topn if group_results: count_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" f" FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY relevance DESC" f" LIMIT {num_candidates}" f")," f" scored_results AS (" f" SELECT *" f" FROM fulltext_results" f" WHERE {vector_search_filter}" f")," f" group_results AS (" f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn" f" FROM scored_results" f")" f" SELECT COUNT(*)" f" FROM group_results" f" WHERE rn = 1" ) else: count_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" f" FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY relevance DESC" f" LIMIT {num_candidates}" f")" f" SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}" ) logger.debug("OBConnection.search with count sql: %s", count_sql) start_time = time.time() res = self.client.perform_raw_text_sql(count_sql) total_count = res.fetchone()[0] if res else 0 result.total += total_count elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: fusion, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" vector column: '{vector_column_name}'," f" query text: '{fulltext_query}'," f" condition: '{condition}'," f" vector_similarity_threshold: {vector_similarity_threshold}," f" got count: {total_count}" ) if total_count == 0: continue score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})" if group_results: fusion_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" f" FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY relevance DESC" f" LIMIT {num_candidates}" f")," f" scored_results AS (" f" SELECT *, {score_expr} AS _score" f" FROM fulltext_results" f" WHERE {vector_search_filter}" f")," f" group_results AS (" f" SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn" f" FROM scored_results" f")" f" SELECT {fields_expr}, _score" f" FROM group_results" f" WHERE rn = 1" f" ORDER BY _score DESC" f" LIMIT {offset}, {limit}" ) else: fusion_sql = ( f"WITH fulltext_results AS (" f" SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance" f" FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY relevance DESC" f" LIMIT {num_candidates}" f")" f" SELECT {fields_expr}, {score_expr} AS _score" f" FROM fulltext_results" f" WHERE {vector_search_filter}" f" ORDER BY _score DESC" f" LIMIT {offset}, {limit}" ) logger.debug("OBConnection.search with fusion sql: %s", fusion_sql) start_time = time.time() res = self.client.perform_raw_text_sql(fusion_sql) rows = res.fetchall() elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: fusion, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," f" vector column: '{vector_column_name}'," f" query text: '{fulltext_query}'," f" condition: '{condition}'," f" vector_similarity_threshold: {vector_similarity_threshold}," f" vector_similarity_weight: {vector_similarity_weight}," f" return rows count: {len(rows)}" ) for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "vector": # vector search, usually used for graph search count_sql = f"SELECT COUNT(id) FROM {index_name} WHERE {filters_expr} AND {vector_search_filter}" logger.debug("OBConnection.search with vector count sql: %s", count_sql) start_time = time.time() res = self.client.perform_raw_text_sql(count_sql) total_count = res.fetchone()[0] if res else 0 result.total += total_count elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: vector, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" vector column: '{vector_column_name}'," f" condition: '{condition}'," f" vector_similarity_threshold: {vector_similarity_threshold}," f" got count: {total_count}" ) if total_count == 0: continue vector_sql = ( f"SELECT {fields_expr}, {vector_search_score_expr} AS _score" f" FROM {index_name}" f" WHERE {filters_expr} AND {vector_search_filter}" f" ORDER BY {vector_search_expr}" f" APPROXIMATE LIMIT {limit if limit != 0 else vector_topn}" ) if offset != 0: vector_sql += f" OFFSET {offset}" logger.debug("OBConnection.search with vector sql: %s", vector_sql) start_time = time.time() res = self.client.perform_raw_text_sql(vector_sql) rows = res.fetchall() elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: vector, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," f" vector column: '{vector_column_name}'," f" condition: '{condition}'," f" vector_similarity_threshold: {vector_similarity_threshold}," f" return rows count: {len(rows)}" ) for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "fulltext": # fulltext search, usually used to search chunks in one dataset count_sql = f"SELECT {fulltext_search_hint} COUNT(id) FROM {index_name} WHERE {filters_expr} AND {fulltext_search_filter}" logger.debug("OBConnection.search with fulltext count sql: %s", count_sql) start_time = time.time() res = self.client.perform_raw_text_sql(count_sql) total_count = res.fetchone()[0] if res else 0 result.total += total_count elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: fulltext, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" query text: '{fulltext_query}'," f" condition: '{condition}'," f" got count: {total_count}" ) if total_count == 0: continue fulltext_sql = ( f"SELECT {fulltext_search_hint} {fields_expr}, {fulltext_search_score_expr} AS _score" f" FROM {index_name}" f" WHERE {filters_expr} AND {fulltext_search_filter}" f" ORDER BY _score DESC" f" LIMIT {offset}, {limit if limit != 0 else fulltext_topn}" ) logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql) start_time = time.time() res = self.client.perform_raw_text_sql(fulltext_sql) rows = res.fetchall() elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: fulltext, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," f" query text: '{fulltext_query}'," f" condition: '{condition}'," f" return rows count: {len(rows)}" ) for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) elif search_type == "aggregation": # aggregation search assert len(aggFields) == 1, "Only one aggregation field is supported in OceanBase." agg_field = aggFields[0] if agg_field in array_columns: res = self.client.perform_raw_text_sql( f"SELECT {agg_field} FROM {index_name}" f" WHERE {agg_field} IS NOT NULL AND {filters_expr}" ) counts = {} for row in res: if row[0]: if isinstance(row[0], str): try: arr = json.loads(row[0]) except json.JSONDecodeError: logger.warning(f"Failed to parse JSON array: {row[0]}") continue else: arr = row[0] if isinstance(arr, list): for v in arr: if isinstance(v, str) and v.strip(): counts[v] = counts.get(v, 0) + 1 for v, count in counts.items(): result.chunks.append({ "value": v, "count": count, }) result.total += len(counts) else: res = self.client.perform_raw_text_sql( f"SELECT {agg_field}, COUNT(*) as count FROM {index_name}" f" WHERE {agg_field} IS NOT NULL AND {filters_expr}" f" GROUP BY {agg_field}" ) for row in res: result.chunks.append({ "value": row[0], "count": int(row[1]), }) result.total += 1 else: # only filter orders: list[str] = [] if orderBy: for field, order in orderBy.fields: if isinstance(column_types[field], ARRAY): f = field + "_sort" fields_expr += f", array_to_string({field}, ',') AS {f}" field = f order = "ASC" if order == 0 else "DESC" orders.append(f"{field} {order}") count_sql = f"SELECT COUNT(id) FROM {index_name} WHERE {filters_expr}" logger.debug("OBConnection.search with normal count sql: %s", count_sql) start_time = time.time() res = self.client.perform_raw_text_sql(count_sql) total_count = res.fetchone()[0] if res else 0 result.total += total_count elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: normal, step: 1-count, elapsed time: {elapsed_time:.3f} seconds," f" condition: '{condition}'," f" got count: {total_count}" ) if total_count == 0: continue order_by_expr = ("ORDER BY " + ", ".join(orders)) if len(orders) > 0 else "" limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else "" filter_sql = ( f"SELECT {fields_expr}" f" FROM {index_name}" f" WHERE {filters_expr}" f" {order_by_expr} {limit_expr}" ) logger.debug("OBConnection.search with normal sql: %s", filter_sql) start_time = time.time() res = self.client.perform_raw_text_sql(filter_sql) rows = res.fetchall() elapsed_time = time.time() - start_time logger.info( f"OBConnection.search table {index_name}, search type: normal, step: 2-query, elapsed time: {elapsed_time:.3f} seconds," f" select fields: '{output_fields}'," f" condition: '{condition}'," f" return rows count: {len(rows)}" ) for row in rows: result.chunks.append(self._row_to_entity(row, output_fields)) return result def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: if not self.client.check_table_exists(indexName): return None try: res = self.client.get( table_name=indexName, ids=[chunkId], ) row = res.fetchone() if row is None: raise Exception(f"ChunkId {chunkId} not found in index {indexName}.") return self._row_to_entity(row, fields=list(res.keys())) except json.JSONDecodeError as e: logger.error(f"JSON decode error when getting chunk {chunkId}: {str(e)}") return { "id": chunkId, "error": f"Failed to parse chunk data due to invalid JSON: {str(e)}" } except Exception as e: logger.error(f"Error getting chunk {chunkId}: {str(e)}") raise def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: if not documents: return [] docs: list[dict] = [] ids: list[str] = [] for document in documents: d: dict = {} for k, v in document.items(): if vector_column_pattern.match(k): d[k] = v continue if k not in column_names: if "extra" not in d: d["extra"] = {} d["extra"][k] = v continue if v is None: d[k] = get_default_value(k) continue if k == "kb_id" and isinstance(v, list): d[k] = v[0] elif k == "content_with_weight" and isinstance(v, dict): d[k] = json.dumps(v, ensure_ascii=False) elif k == "position_int": d[k] = json.dumps([list(vv) for vv in v], ensure_ascii=False) elif isinstance(v, list): # remove characters like '\t' for JSON dump and clean special characters cleaned_v = [] for vv in v: if isinstance(vv, str): cleaned_str = vv.strip() cleaned_str = cleaned_str.replace('\\', '\\\\') cleaned_str = cleaned_str.replace('\n', '\\n') cleaned_str = cleaned_str.replace('\r', '\\r') cleaned_str = cleaned_str.replace('\t', '\\t') cleaned_v.append(cleaned_str) else: cleaned_v.append(vv) d[k] = json.dumps(cleaned_v, ensure_ascii=False) else: d[k] = v ids.append(d["id"]) # this is to fix https://github.com/sqlalchemy/sqlalchemy/issues/9703 for column_name in column_names: if column_name not in d: d[column_name] = get_default_value(column_name) metadata = d.get("metadata", {}) if metadata is None: metadata = {} group_id = metadata.get("_group_id") title = metadata.get("_title") if d.get("doc_id"): if group_id: d["group_id"] = group_id else: d["group_id"] = d["doc_id"] if title: d["docnm_kwd"] = title docs.append(d) logger.debug("OBConnection.insert chunks: %s", docs) res = [] try: self.client.upsert(indexName, docs) except Exception as e: logger.error(f"OBConnection.insert error: {str(e)}") res.append(str(e)) return res def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: if not self.client.check_table_exists(indexName): return True condition["kb_id"] = knowledgebaseId filters = get_filters(condition) set_values: list[str] = [] for k, v in newValue.items(): if k == "remove": if isinstance(v, str): set_values.append(f"{v} = NULL") else: assert isinstance(v, dict), f"Expected str or dict for 'remove', got {type(newValue[k])}." for kk, vv in v.items(): assert kk in array_columns, f"Column '{kk}' is not an array column." set_values.append(f"{kk} = array_remove({kk}, {get_value_str(vv)})") elif k == "add": assert isinstance(v, dict), f"Expected str or dict for 'add', got {type(newValue[k])}." for kk, vv in v.items(): assert kk in array_columns, f"Column '{kk}' is not an array column." set_values.append(f"{kk} = array_append({kk}, {get_value_str(vv)})") elif k == "metadata": assert isinstance(v, dict), f"Expected dict for 'metadata', got {type(newValue[k])}" set_values.append(f"{k} = {get_value_str(v)}") if v and "doc_id" in condition: group_id = v.get("_group_id") title = v.get("_title") if group_id: set_values.append(f"group_id = {get_value_str(group_id)}") if title: set_values.append(f"docnm_kwd = {get_value_str(title)}") else: set_values.append(f"{k} = {get_value_str(v)}") if not set_values: return True update_sql = ( f"UPDATE {indexName}" f" SET {', '.join(set_values)}" f" WHERE {' AND '.join(filters)}" ) logger.debug("OBConnection.update sql: %s", update_sql) try: self.client.perform_raw_text_sql(update_sql) return True except Exception as e: logger.error(f"OBConnection.update error: {str(e)}") return False def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: if not self.client.check_table_exists(indexName): return 0 condition["kb_id"] = knowledgebaseId try: res = self.client.get( table_name=indexName, ids=None, where_clause=[text(f) for f in get_filters(condition)], output_column_name=["id"], ) rows = res.fetchall() if len(rows) == 0: return 0 ids = [row[0] for row in rows] logger.debug(f"OBConnection.delete chunks, filters: {condition}, ids: {ids}") self.client.delete( table_name=indexName, ids=ids, ) return len(ids) except Exception as e: logger.error(f"OBConnection.delete error: {str(e)}") return 0 @staticmethod def _row_to_entity(data: Row, fields: list[str]) -> dict: entity = {} for i, field in enumerate(fields): value = data[i] if value is None: continue entity[field] = get_column_value(field, value) return entity @staticmethod def _es_row_to_entity(data: dict) -> dict: entity = {} for k, v in data.items(): if v is None: continue entity[k] = get_column_value(k, v) return entity """ Helper functions for search result """ def get_total(self, res) -> int: return res.total def get_chunk_ids(self, res) -> list[str]: return [row["id"] for row in res.chunks] def get_fields(self, res, fields: list[str]) -> dict[str, dict]: result = {} for row in res.chunks: data = {} for field in fields: v = row.get(field) if v is not None: data[field] = v result[row["id"]] = data return result # copied from query.FulltextQueryer def is_chinese(self, line): arr = re.split(r"[ \t]+", line) if len(arr) <= 3: return True e = 0 for t in arr: if not re.match(r"[a-zA-Z]+$", t): e += 1 return e * 1.0 / len(arr) >= 0.7 def highlight(self, txt: str, tks: str, question: str, keywords: list[str]) -> Optional[str]: if not txt or not keywords: return None highlighted_txt = txt if question and not self.is_chinese(question): highlighted_txt = re.sub( r"(^|\W)(%s)(\W|$)" % re.escape(question), r"\1\2\3", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE, ) if re.search(r"[^<>]+", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE): return highlighted_txt for keyword in keywords: highlighted_txt = re.sub( r"(^|\W)(%s)(\W|$)" % re.escape(keyword), r"\1\2\3", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE, ) if len(re.findall(r'', highlighted_txt)) > 0 or len( re.findall(r'\s*', highlighted_txt)) > 0: return highlighted_txt else: return None if not tks: tks = rag_tokenizer.tokenize(txt) tokens = tks.split() if not tokens: return None last_pos = len(txt) for i in range(len(tokens) - 1, -1, -1): token = tokens[i] token_pos = highlighted_txt.rfind(token, 0, last_pos) if token_pos != -1: if token in keywords: highlighted_txt = ( highlighted_txt[:token_pos] + f'{token}' + highlighted_txt[token_pos + len(token):] ) last_pos = token_pos return re.sub(r'', '', highlighted_txt) def get_highlight(self, res, keywords: list[str], fieldnm: str): ans = {} if len(res.chunks) == 0 or len(keywords) == 0: return ans for d in res.chunks: txt = d.get(fieldnm) if not txt: continue tks = d.get("content_ltks") if fieldnm == "content_with_weight" else "" highlighted_txt = self.highlight(txt, tks, " ".join(keywords), keywords) if highlighted_txt: ans[d["id"]] = highlighted_txt return ans def get_aggregation(self, res, fieldnm: str): if len(res.chunks) == 0: return [] counts = {} result = [] for d in res.chunks: if "value" in d and "count" in d: # directly use the aggregation result result.append((d["value"], d["count"])) elif fieldnm in d: # aggregate the values of specific field v = d[fieldnm] if isinstance(v, list): for vv in v: if isinstance(vv, str) and vv.strip(): counts[vv] = counts.get(vv, 0) + 1 elif isinstance(v, str) and v.strip(): counts[v] = counts.get(v, 0) + 1 if len(counts) > 0: for k, v in counts.items(): result.append((k, v)) return result """ SQL """ def sql(sql: str, fetch_size: int, format: str): # TODO: execute the sql generated by text-to-sql return None