mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-25 08:06:48 +08:00
Feat: message manage (#12083)
### What problem does this PR solve? Message CRUD. Issue #4213 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -77,9 +77,9 @@ class Benchmark:
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
if settings.docStoreConn.index_exist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.delete_idx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.create_idx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
|
||||
@ -19,11 +19,12 @@ import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from rag.utils.doc_store_conn import MatchTextExpr
|
||||
from common.query_base import QueryBase
|
||||
from common.doc_store.doc_store_base import MatchTextExpr
|
||||
from rag.nlp import rag_tokenizer, term_weight, synonym
|
||||
|
||||
|
||||
class FulltextQueryer:
|
||||
class FulltextQueryer(QueryBase):
|
||||
def __init__(self):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.syn = synonym.Dealer()
|
||||
@ -37,64 +38,19 @@ class FulltextQueryer:
|
||||
"content_sm_ltks",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def sub_special_char(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def is_chinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1.0 / len(arr) >= 0.7
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(
|
||||
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(
|
||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||
" ")
|
||||
]
|
||||
otxt = txt
|
||||
for r, p in patts:
|
||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||
if not txt:
|
||||
txt = otxt
|
||||
return txt
|
||||
|
||||
@staticmethod
|
||||
def add_space_between_eng_zh(txt):
|
||||
# (ENG/ENG+NUM) + ZH
|
||||
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ENG + ZH
|
||||
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ZH + (ENG/ENG+NUM)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
|
||||
return txt
|
||||
|
||||
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
||||
original_query = txt
|
||||
txt = FulltextQueryer.add_space_between_eng_zh(txt)
|
||||
txt = self.add_space_between_eng_zh(txt)
|
||||
txt = re.sub(
|
||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
txt = self.rmWWW(txt)
|
||||
|
||||
if not self.is_chinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
txt = self.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
tks_w = self.tw.weights(tks, preprocess=False)
|
||||
@ -138,7 +94,7 @@ class FulltextQueryer:
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
txt = self.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt)[:256]: # .split():
|
||||
if not tt:
|
||||
@ -164,7 +120,7 @@ class FulltextQueryer:
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
@ -172,7 +128,7 @@ class FulltextQueryer:
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [self.sub_special_char(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
@ -181,7 +137,7 @@ class FulltextQueryer:
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
tk = self.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
@ -199,7 +155,7 @@ class FulltextQueryer:
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
|
||||
% rag_tokenizer.tokenize(self.sub_special_char(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
@ -264,10 +220,10 @@ class FulltextQueryer:
|
||||
keywords = [f'"{k.strip()}"' for k in keywords]
|
||||
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [self.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
tk = self.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
|
||||
@ -24,7 +24,7 @@ from dataclasses import dataclass
|
||||
from rag.prompts.generator import relevant_chunks_with_toc
|
||||
from rag.nlp import rag_tokenizer, query
|
||||
import numpy as np
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
from common.doc_store.doc_store_base import MatchDenseExpr, FusionExpr, OrderByExpr, DocStoreConnection
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
@ -155,7 +155,7 @@ class Dealer:
|
||||
kwds.add(kk)
|
||||
|
||||
logging.debug(f"TOTAL: {total}")
|
||||
ids = self.dataStore.get_chunk_ids(res)
|
||||
ids = self.dataStore.get_doc_ids(res)
|
||||
keywords = list(kwds)
|
||||
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
|
||||
@ -545,7 +545,7 @@ class Dealer:
|
||||
return res
|
||||
|
||||
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
|
||||
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
|
||||
if not self.dataStore.index_exist(index_name(tenant_id), kb_ids[0]):
|
||||
return []
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
return self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
|
||||
@ -136,6 +136,19 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
|
||||
return knowledges
|
||||
|
||||
|
||||
def memory_prompt(message_list, max_tokens):
|
||||
used_token_count = 0
|
||||
content_list = []
|
||||
for message in message_list:
|
||||
current_content_tokens = num_tokens_from_string(message["content"])
|
||||
if used_token_count + current_content_tokens > max_tokens * 0.97:
|
||||
logging.warning(f"Not all the retrieval into prompt: {len(content_list)}/{len(message_list)}")
|
||||
break
|
||||
content_list.append(message["content"])
|
||||
used_token_count += current_content_tokens
|
||||
return content_list
|
||||
|
||||
|
||||
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
|
||||
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
|
||||
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
|
||||
|
||||
@ -506,7 +506,7 @@ def build_TOC(task, docs, progress_callback):
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
@ -1,271 +0,0 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
DEFAULT_MATCH_VECTOR_TOPN = 10
|
||||
DEFAULT_MATCH_SPARSE_TOPN = 10
|
||||
VEC = list | np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseVector:
|
||||
indices: list[int]
|
||||
values: list[float] | list[int] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
assert (self.values is None) or (len(self.indices) == len(self.values))
|
||||
|
||||
def to_dict_old(self):
|
||||
d = {"indices": self.indices}
|
||||
if self.values is not None:
|
||||
d["values"] = self.values
|
||||
return d
|
||||
|
||||
def to_dict(self):
|
||||
if self.values is None:
|
||||
raise ValueError("SparseVector.values is None")
|
||||
result = {}
|
||||
for i, v in zip(self.indices, self.values):
|
||||
result[str(i)] = v
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d):
|
||||
return SparseVector(d["indices"], d.get("values"))
|
||||
|
||||
def __str__(self):
|
||||
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class MatchTextExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
fields: list[str],
|
||||
matching_text: str,
|
||||
topn: int,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.fields = fields
|
||||
self.matching_text = matching_text
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchDenseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
embedding_data: VEC,
|
||||
embedding_data_type: str,
|
||||
distance_type: str,
|
||||
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
||||
extra_options: dict = dict(),
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.embedding_data = embedding_data
|
||||
self.embedding_data_type = embedding_data_type
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchSparseExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
sparse_data: SparseVector | dict,
|
||||
distance_type: str,
|
||||
topn: int,
|
||||
opt_params: dict | None = None,
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.sparse_data = sparse_data
|
||||
self.distance_type = distance_type
|
||||
self.topn = topn
|
||||
self.opt_params = opt_params
|
||||
|
||||
|
||||
class MatchTensorExpr(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
column_name: str,
|
||||
query_data: VEC,
|
||||
query_data_type: str,
|
||||
topn: int,
|
||||
extra_option: dict | None = None,
|
||||
):
|
||||
self.column_name = column_name
|
||||
self.query_data = query_data
|
||||
self.query_data_type = query_data_type
|
||||
self.topn = topn
|
||||
self.extra_option = extra_option
|
||||
|
||||
|
||||
class FusionExpr(ABC):
|
||||
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
|
||||
self.method = method
|
||||
self.topn = topn
|
||||
self.fusion_params = fusion_params
|
||||
|
||||
|
||||
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
|
||||
|
||||
class OrderByExpr(ABC):
|
||||
def __init__(self):
|
||||
self.fields = list()
|
||||
def asc(self, field: str):
|
||||
self.fields.append((field, 0))
|
||||
return self
|
||||
def desc(self, field: str):
|
||||
self.fields.append((field, 1))
|
||||
return self
|
||||
def fields(self):
|
||||
return self.fields
|
||||
|
||||
class DocStoreConnection(ABC):
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dbType(self) -> str:
|
||||
"""
|
||||
Return the type of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
"""
|
||||
Delete an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Check if an index with given name exists
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str|list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
"""
|
||||
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
"""
|
||||
Get single chunk with given id
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
"""
|
||||
Update or insert a bulk of rows
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
"""
|
||||
Update rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
"""
|
||||
Delete rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_total(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_chunk_ids(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
@abstractmethod
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
"""
|
||||
Run the sql generated by text-to-sql
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
@ -14,194 +14,92 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
import copy
|
||||
from elasticsearch import Elasticsearch, NotFoundError
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import convert_bytes
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr
|
||||
from common.doc_store.es_conn_base import ESConnectionBase
|
||||
from common.float_utils import get_float
|
||||
from common import settings
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
logger = logging.getLogger('ragflow.es_conn')
|
||||
|
||||
|
||||
@singleton
|
||||
class ESConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
if self._connect():
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
|
||||
if not self.es.ping():
|
||||
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "8.11.3"})
|
||||
v = v["number"].split(".")[0]
|
||||
if int(v) < 8:
|
||||
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
|
||||
|
||||
def _connect(self):
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
basic_auth=(settings.ES["username"], settings.ES[
|
||||
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
verify_certs= settings.ES.get("verify_certs", False),
|
||||
timeout=600 )
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
return True
|
||||
return False
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
return "elasticsearch"
|
||||
|
||||
def health(self) -> dict:
|
||||
health_dict = dict(self.es.cluster.health())
|
||||
health_dict["type"] = "elasticsearch"
|
||||
return health_dict
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.indexExist(indexName, knowledgebaseId):
|
||||
return True
|
||||
try:
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=indexName,
|
||||
settings=self.mapping["settings"],
|
||||
mappings=self.mapping["mappings"])
|
||||
except Exception:
|
||||
logger.exception("ESConnection.createIndex error %s" % (indexName))
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
if len(knowledgebaseId) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
try:
|
||||
self.es.indices.delete(index=indexName, allow_no_indices=True)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("ESConnection.deleteIdx error %s" % (indexName))
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
s = Index(indexName, self.es)
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
return s.exists()
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
break
|
||||
return False
|
||||
class ESConnection(ESConnectionBase):
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self, selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
index_names: str | list[str],
|
||||
knowledgebase_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
"""
|
||||
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||
"""
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
assert "_id" not in condition
|
||||
|
||||
bqry = Q("bool", must=[])
|
||||
condition["kb_id"] = knowledgebaseIds
|
||||
bool_query = Q("bool", must=[])
|
||||
condition["kb_id"] = knowledgebase_ids
|
||||
for k, v in condition.items():
|
||||
if k == "available_int":
|
||||
if v == 0:
|
||||
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
||||
bool_query.filter.append(Q("range", available_int={"lt": 1}))
|
||||
else:
|
||||
bqry.filter.append(
|
||||
bool_query.filter.append(
|
||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
||||
continue
|
||||
if not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
bool_query.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
bool_query.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
|
||||
s = Search()
|
||||
vector_similarity_weight = 0.5
|
||||
for m in matchExprs:
|
||||
for m in match_expressions:
|
||||
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
matchExprs[2], FusionExpr)
|
||||
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
match_expressions[2], FusionExpr)
|
||||
weights = m.fusion_params["weights"]
|
||||
vector_similarity_weight = get_float(weights.split(",")[1])
|
||||
for m in matchExprs:
|
||||
for m in match_expressions:
|
||||
if isinstance(m, MatchTextExpr):
|
||||
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
bqry.must.append(Q("query_string", fields=m.fields,
|
||||
bool_query.must.append(Q("query_string", fields=m.fields,
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
bqry.boost = 1.0 - vector_similarity_weight
|
||||
bool_query.boost = 1.0 - vector_similarity_weight
|
||||
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
assert (bqry is not None)
|
||||
assert (bool_query is not None)
|
||||
similarity = 0.0
|
||||
if "similarity" in m.extra_options:
|
||||
similarity = m.extra_options["similarity"]
|
||||
@ -209,24 +107,24 @@ class ESConnection(DocStoreConnection):
|
||||
m.topn,
|
||||
m.topn * 2,
|
||||
query_vector=list(m.embedding_data),
|
||||
filter=bqry.to_dict(),
|
||||
filter=bool_query.to_dict(),
|
||||
similarity=similarity,
|
||||
)
|
||||
|
||||
if bqry and rank_feature:
|
||||
if bool_query and rank_feature:
|
||||
for fld, sc in rank_feature.items():
|
||||
if fld != PAGERANK_FLD:
|
||||
fld = f"{TAG_FLD}.{fld}"
|
||||
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
||||
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
||||
|
||||
if bqry:
|
||||
s = s.query(bqry)
|
||||
for field in highlightFields:
|
||||
if bool_query:
|
||||
s = s.query(bool_query)
|
||||
for field in highlight_fields:
|
||||
s = s.highlight(field)
|
||||
|
||||
if orderBy:
|
||||
if order_by:
|
||||
orders = list()
|
||||
for field, order in orderBy.fields:
|
||||
for field, order in order_by.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
if field in ["page_num_int", "top_int"]:
|
||||
order_info = {"order": order, "unmapped_type": "float",
|
||||
@ -237,19 +135,19 @@ class ESConnection(DocStoreConnection):
|
||||
order_info = {"order": order, "unmapped_type": "text"}
|
||||
orders.append({field: order_info})
|
||||
s = s.sort(*orders)
|
||||
|
||||
for fld in aggFields:
|
||||
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
||||
if agg_fields:
|
||||
for fld in agg_fields:
|
||||
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
||||
|
||||
if limit > 0:
|
||||
s = s[offset:offset + limit]
|
||||
q = s.to_dict()
|
||||
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
#print(json.dumps(q, ensure_ascii=False))
|
||||
res = self.es.search(index=indexNames,
|
||||
res = self.es.search(index=index_names,
|
||||
body=q,
|
||||
timeout="600s",
|
||||
# search_type="dfs_query_then_fetch",
|
||||
@ -257,55 +155,37 @@ class ESConnection(DocStoreConnection):
|
||||
_source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self.logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e))
|
||||
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.get(index=(indexName),
|
||||
id=chunkId, source=True, )
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
chunk = res["_source"]
|
||||
chunk["id"] = chunkId
|
||||
return chunk
|
||||
except NotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.get({chunkId}) got exception")
|
||||
raise e
|
||||
logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.get timeout.")
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
|
||||
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
||||
operations = []
|
||||
for d in documents:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
d_copy = copy.deepcopy(d)
|
||||
d_copy["kb_id"] = knowledgebaseId
|
||||
d_copy["kb_id"] = knowledgebase_id
|
||||
meta_id = d_copy.pop("id", "")
|
||||
operations.append(
|
||||
{"index": {"_index": indexName, "_id": meta_id}})
|
||||
{"index": {"_index": index_name, "_id": meta_id}})
|
||||
operations.append(d_copy)
|
||||
|
||||
res = []
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = []
|
||||
r = self.es.bulk(index=(indexName), operations=operations,
|
||||
r = self.es.bulk(index=index_name, operations=operations,
|
||||
refresh=False, timeout="60s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
@ -316,58 +196,58 @@ class ESConnection(DocStoreConnection):
|
||||
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
res.append(str(e))
|
||||
logger.warning("ESConnection.insert got exception: " + str(e))
|
||||
self.logger.warning("ESConnection.insert got exception: " + str(e))
|
||||
|
||||
return res
|
||||
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
doc = copy.deepcopy(newValue)
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
|
||||
doc = copy.deepcopy(new_value)
|
||||
doc.pop("id", None)
|
||||
condition["kb_id"] = knowledgebaseId
|
||||
condition["kb_id"] = knowledgebase_id
|
||||
if "id" in condition and isinstance(condition["id"], str):
|
||||
# update specific single document
|
||||
chunkId = condition["id"]
|
||||
chunk_id = condition["id"]
|
||||
for i in range(ATTEMPT_TIME):
|
||||
for k in doc.keys():
|
||||
if "feas" != k.split("_")[-1]:
|
||||
continue
|
||||
try:
|
||||
self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");")
|
||||
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
|
||||
except Exception:
|
||||
logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
try:
|
||||
self.es.update(index=indexName, id=chunkId, doc=doc)
|
||||
self.es.update(index=index_name, id=chunk_id, doc=doc)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e))
|
||||
self.logger.exception(
|
||||
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
|
||||
break
|
||||
return False
|
||||
|
||||
# update unspecific maybe-multiple documents
|
||||
bqry = Q("bool")
|
||||
bool_query = Q("bool")
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if k == "exists":
|
||||
bqry.filter.append(Q("exists", field=v))
|
||||
bool_query.filter.append(Q("exists", field=v))
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
bool_query.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
bool_query.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
scripts = []
|
||||
params = {}
|
||||
for k, v in newValue.items():
|
||||
for k, v in new_value.items():
|
||||
if k == "remove":
|
||||
if isinstance(v, str):
|
||||
scripts.append(f"ctx._source.remove('{v}');")
|
||||
@ -397,8 +277,8 @@ class ESConnection(DocStoreConnection):
|
||||
raise Exception(
|
||||
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||
ubq = UpdateByQuery(
|
||||
index=indexName).using(
|
||||
self.es).query(bqry)
|
||||
index=index_name).using(
|
||||
self.es).query(bool_query)
|
||||
ubq = ubq.script(source="".join(scripts), params=params)
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
@ -409,19 +289,18 @@ class ESConnection(DocStoreConnection):
|
||||
_ = ubq.execute()
|
||||
return True
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
|
||||
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
|
||||
break
|
||||
return False
|
||||
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
qry = None
|
||||
def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int:
|
||||
assert "_id" not in condition
|
||||
condition["kb_id"] = knowledgebaseId
|
||||
condition["kb_id"] = knowledgebase_id
|
||||
if "id" in condition:
|
||||
chunk_ids = condition["id"]
|
||||
if not isinstance(chunk_ids, list):
|
||||
@ -448,21 +327,21 @@ class ESConnection(DocStoreConnection):
|
||||
qry.must.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception("Condition value must be int, str or list.")
|
||||
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
|
||||
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.delete_by_query(
|
||||
index=indexName,
|
||||
index=index_name,
|
||||
body=Search().query(qry).to_dict(),
|
||||
refresh=True)
|
||||
return res["deleted"]
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("ESConnection.delete got exception: " + str(e))
|
||||
self.logger.warning("ESConnection.delete got exception: " + str(e))
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return 0
|
||||
return 0
|
||||
@ -471,27 +350,11 @@ class ESConnection(DocStoreConnection):
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def get_chunk_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
rr = []
|
||||
for d in res["hits"]["hits"]:
|
||||
d["_source"]["id"] = d["_id"]
|
||||
d["_source"]["_score"] = d["_score"]
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
for d in self.__getSource(res):
|
||||
for d in self._get_source(res):
|
||||
m = {n: d.get(n) for n in fields if d.get(n) is not None}
|
||||
for n, v in m.items():
|
||||
if isinstance(v, list):
|
||||
@ -508,124 +371,3 @@ class ESConnection(DocStoreConnection):
|
||||
if m:
|
||||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
if not hlts:
|
||||
continue
|
||||
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
||||
if not is_english(txt.split()):
|
||||
ans[d["_id"]] = txt
|
||||
continue
|
||||
|
||||
txt = d["_source"][fieldnm]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
||||
flags=re.IGNORECASE | re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txts.append(t)
|
||||
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
||||
|
||||
return ans
|
||||
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
bkts = res["aggregations"][agg_field]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
logger.debug(f"ESConnection.sql get sql: {sql}")
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
replaces = []
|
||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||
fld, v = r.group(1), r.group(3)
|
||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||
replaces.append(
|
||||
("{}{}'{}'".format(
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
logger.debug(f"ESConnection.sql to es: {sql}")
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
||||
request_timeout="2s")
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
|
||||
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
|
||||
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
|
||||
return None
|
||||
|
||||
def get_cluster_stats(self):
|
||||
"""
|
||||
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
|
||||
"""
|
||||
raw_stats = self.es.cluster.stats()
|
||||
logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
|
||||
try:
|
||||
res = {
|
||||
'cluster_name': raw_stats['cluster_name'],
|
||||
'status': raw_stats['status']
|
||||
}
|
||||
indices_status = raw_stats['indices']
|
||||
res.update({
|
||||
'indices': indices_status['count'],
|
||||
'indices_shards': indices_status['shards']['total']
|
||||
})
|
||||
doc_info = indices_status['docs']
|
||||
res.update({
|
||||
'docs': doc_info['count'],
|
||||
'docs_deleted': doc_info['deleted']
|
||||
})
|
||||
store_info = indices_status['store']
|
||||
res.update({
|
||||
'store_size': convert_bytes(store_info['size_in_bytes']),
|
||||
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
|
||||
})
|
||||
mappings_info = indices_status['mappings']
|
||||
res.update({
|
||||
'mappings_fields': mappings_info['total_field_count'],
|
||||
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
|
||||
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
|
||||
})
|
||||
node_info = raw_stats['nodes']
|
||||
res.update({
|
||||
'nodes': node_info['count']['total'],
|
||||
'nodes_version': node_info['versions'],
|
||||
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
|
||||
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
|
||||
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
|
||||
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
|
||||
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
|
||||
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
|
||||
})
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.get_cluster_stats: {e}")
|
||||
return None
|
||||
|
||||
@ -14,365 +14,125 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import copy
|
||||
import infinity
|
||||
from infinity.common import ConflictType, InfinityException, SortType
|
||||
from infinity.index import IndexInfo, IndexType
|
||||
from infinity.connection_pool import ConnectionPool
|
||||
from infinity.common import InfinityException, SortType
|
||||
from infinity.errors import ErrorCode
|
||||
from common.decorator import singleton
|
||||
import pandas as pd
|
||||
from common.file_utils import get_project_base_directory
|
||||
from rag.nlp import is_english
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
from rag.utils.doc_store_conn import (
|
||||
DocStoreConnection,
|
||||
MatchExpr,
|
||||
MatchTextExpr,
|
||||
MatchDenseExpr,
|
||||
FusionExpr,
|
||||
OrderByExpr,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("ragflow.infinity_conn")
|
||||
|
||||
|
||||
def field_keyword(field_name: str):
|
||||
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
|
||||
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", "question_kwd"]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def convert_select_fields(output_fields: list[str]) -> list[str]:
|
||||
for i, field in enumerate(output_fields):
|
||||
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
|
||||
output_fields[i] = "docnm"
|
||||
elif field in ["important_kwd", "important_tks"]:
|
||||
output_fields[i] = "important_keywords"
|
||||
elif field in ["question_kwd", "question_tks"]:
|
||||
output_fields[i] = "questions"
|
||||
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
|
||||
output_fields[i] = "content"
|
||||
elif field in ["authors_tks", "authors_sm_tks"]:
|
||||
output_fields[i] = "authors"
|
||||
return list(set(output_fields))
|
||||
|
||||
def convert_matching_field(field_weightstr: str) -> str:
|
||||
tokens = field_weightstr.split("^")
|
||||
field = tokens[0]
|
||||
if field == "docnm_kwd" or field == "title_tks":
|
||||
field = "docnm@ft_docnm_rag_coarse"
|
||||
elif field == "title_sm_tks":
|
||||
field = "docnm@ft_docnm_rag_fine"
|
||||
elif field == "important_kwd":
|
||||
field = "important_keywords@ft_important_keywords_rag_coarse"
|
||||
elif field == "important_tks":
|
||||
field = "important_keywords@ft_important_keywords_rag_fine"
|
||||
elif field == "question_kwd":
|
||||
field = "questions@ft_questions_rag_coarse"
|
||||
elif field == "question_tks":
|
||||
field = "questions@ft_questions_rag_fine"
|
||||
elif field == "content_with_weight" or field == "content_ltks":
|
||||
field = "content@ft_content_rag_coarse"
|
||||
elif field == "content_sm_ltks":
|
||||
field = "content@ft_content_rag_fine"
|
||||
elif field == "authors_tks":
|
||||
field = "authors@ft_authors_rag_coarse"
|
||||
elif field == "authors_sm_tks":
|
||||
field = "authors@ft_authors_rag_fine"
|
||||
tokens[0] = field
|
||||
return "^".join(tokens)
|
||||
|
||||
def list2str(lst: str|list, sep: str = " ") -> str:
|
||||
if isinstance(lst, str):
|
||||
return lst
|
||||
return sep.join(lst)
|
||||
|
||||
|
||||
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
|
||||
assert "_id" not in condition
|
||||
clmns = {}
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
clmns[n] = (ty, de)
|
||||
|
||||
def exists(cln):
|
||||
nonlocal clmns
|
||||
assert cln in clmns, f"'{cln}' should be in '{clmns}'."
|
||||
ty, de = clmns[cln]
|
||||
if ty.lower().find("cha"):
|
||||
if not de:
|
||||
de = ""
|
||||
return f" {cln}!='{de}' "
|
||||
return f"{cln}!={de}"
|
||||
|
||||
cond = list()
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"filter_fulltext('{convert_matching_field(k)}', '{item}')")
|
||||
if inCond:
|
||||
strInCond = " or ".join(inCond)
|
||||
strInCond = f"({strInCond})"
|
||||
cond.append(strInCond)
|
||||
else:
|
||||
cond.append(f"filter_fulltext('{convert_matching_field(k)}', '{v}')")
|
||||
elif isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"'{item}'")
|
||||
else:
|
||||
inCond.append(str(item))
|
||||
if inCond:
|
||||
strInCond = ", ".join(inCond)
|
||||
strInCond = f"{k} IN ({strInCond})"
|
||||
cond.append(strInCond)
|
||||
elif k == "must_not":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "exists":
|
||||
cond.append("NOT (%s)" % exists(vv))
|
||||
elif isinstance(v, str):
|
||||
cond.append(f"{k}='{v}'")
|
||||
elif k == "exists":
|
||||
cond.append(exists(v))
|
||||
else:
|
||||
cond.append(f"{k}={str(v)}")
|
||||
return " AND ".join(cond) if cond else "1=1"
|
||||
|
||||
|
||||
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
|
||||
df_list2 = [df for df in df_list if not df.empty]
|
||||
if df_list2:
|
||||
return pd.concat(df_list2, axis=0).reset_index(drop=True)
|
||||
|
||||
schema = []
|
||||
for field_name in selectFields:
|
||||
if field_name == "score()": # Workaround: fix schema is changed to score()
|
||||
schema.append("SCORE")
|
||||
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
|
||||
schema.append("SIMILARITY")
|
||||
else:
|
||||
schema.append(field_name)
|
||||
return pd.DataFrame(columns=schema)
|
||||
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
from common.doc_store.infinity_conn_base import InfinityConnectionBase
|
||||
|
||||
|
||||
@singleton
|
||||
class InfinityConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||
infinity_uri = settings.INFINITY["uri"]
|
||||
if ":" in infinity_uri:
|
||||
host, port = infinity_uri.split(":")
|
||||
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||
self.connPool = None
|
||||
logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
|
||||
for _ in range(24):
|
||||
try:
|
||||
connPool = ConnectionPool(infinity_uri, max_size=4)
|
||||
inf_conn = connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
|
||||
self._migrate_db(inf_conn)
|
||||
self.connPool = connPool
|
||||
connPool.release_conn(inf_conn)
|
||||
break
|
||||
connPool.release_conn(inf_conn)
|
||||
logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
if self.connPool is None:
|
||||
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
logger.info(f"Infinity {infinity_uri} is healthy.")
|
||||
|
||||
def _migrate_db(self, inf_conn):
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
table_names = inf_db.list_tables().table_names
|
||||
for table_name in table_names:
|
||||
inf_table = inf_db.get_table(table_name)
|
||||
index_names = inf_table.list_indexes().index_names
|
||||
if "q_vec_idx" not in index_names:
|
||||
# Skip tables not created by me
|
||||
continue
|
||||
column_names = inf_table.show_columns()["name"]
|
||||
column_names = set(column_names)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_name in column_names:
|
||||
continue
|
||||
res = inf_table.add_columns({field_name: field_info})
|
||||
assert res.error_code == infinity.ErrorCode.OK
|
||||
logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
analyzers = field_info["analyzer"]
|
||||
if isinstance(analyzers, str):
|
||||
analyzers = [analyzers]
|
||||
for analyzer in analyzers:
|
||||
inf_table.create_index(
|
||||
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
class InfinityConnection(InfinityConnectionBase):
|
||||
|
||||
"""
|
||||
Database operations
|
||||
Dataframe and fields convert
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
return "infinity"
|
||||
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
inf_conn = self.connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res2 = {
|
||||
"type": "infinity",
|
||||
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
|
||||
"error": res.error_msg,
|
||||
}
|
||||
return res2
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
vector_name = f"q_{vectorSize}_vec"
|
||||
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
|
||||
inf_table = inf_db.create_table(
|
||||
table_name,
|
||||
schema,
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
inf_table.create_index(
|
||||
"q_vec_idx",
|
||||
IndexInfo(
|
||||
vector_name,
|
||||
IndexType.Hnsw,
|
||||
{
|
||||
"M": "16",
|
||||
"ef_construction": "50",
|
||||
"metric": "cosine",
|
||||
"encode": "lvq",
|
||||
},
|
||||
),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
analyzers = field_info["analyzer"]
|
||||
if isinstance(analyzers, str):
|
||||
analyzers = [analyzers]
|
||||
for analyzer in analyzers:
|
||||
inf_table.create_index(
|
||||
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}")
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
db_instance.drop_table(table_name, ConflictType.Ignore)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.info(f"INFINITY dropped table {table_name}")
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
_ = db_instance.get_table(table_name)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
@staticmethod
|
||||
def field_keyword(field_name: str):
|
||||
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
|
||||
if field_name == "source_id" or (
|
||||
field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd",
|
||||
"question_kwd"]):
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"INFINITY indexExist {str(e)}")
|
||||
return False
|
||||
|
||||
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
|
||||
for i, field in enumerate(output_fields):
|
||||
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
|
||||
output_fields[i] = "docnm"
|
||||
elif field in ["important_kwd", "important_tks"]:
|
||||
output_fields[i] = "important_keywords"
|
||||
elif field in ["question_kwd", "question_tks"]:
|
||||
output_fields[i] = "questions"
|
||||
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
|
||||
output_fields[i] = "content"
|
||||
elif field in ["authors_tks", "authors_sm_tks"]:
|
||||
output_fields[i] = "authors"
|
||||
return list(set(output_fields))
|
||||
|
||||
@staticmethod
|
||||
def convert_matching_field(field_weight_str: str) -> str:
|
||||
tokens = field_weight_str.split("^")
|
||||
field = tokens[0]
|
||||
if field == "docnm_kwd" or field == "title_tks":
|
||||
field = "docnm@ft_docnm_rag_coarse"
|
||||
elif field == "title_sm_tks":
|
||||
field = "docnm@ft_docnm_rag_fine"
|
||||
elif field == "important_kwd":
|
||||
field = "important_keywords@ft_important_keywords_rag_coarse"
|
||||
elif field == "important_tks":
|
||||
field = "important_keywords@ft_important_keywords_rag_fine"
|
||||
elif field == "question_kwd":
|
||||
field = "questions@ft_questions_rag_coarse"
|
||||
elif field == "question_tks":
|
||||
field = "questions@ft_questions_rag_fine"
|
||||
elif field == "content_with_weight" or field == "content_ltks":
|
||||
field = "content@ft_content_rag_coarse"
|
||||
elif field == "content_sm_ltks":
|
||||
field = "content@ft_content_rag_fine"
|
||||
elif field == "authors_tks":
|
||||
field = "authors@ft_authors_rag_coarse"
|
||||
elif field == "authors_sm_tks":
|
||||
field = "authors@ft_authors_rag_fine"
|
||||
tokens[0] = field
|
||||
return "^".join(tokens)
|
||||
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
index_names: str | list[str],
|
||||
knowledgebase_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
"""
|
||||
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
|
||||
"""
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
table_list = list()
|
||||
output = selectFields.copy()
|
||||
output = convert_select_fields(output)
|
||||
for essential_field in ["id"] + aggFields:
|
||||
output = select_fields.copy()
|
||||
output = self.convert_select_fields(output)
|
||||
if agg_fields is None:
|
||||
agg_fields = []
|
||||
for essential_field in ["id"] + agg_fields:
|
||||
if essential_field not in output:
|
||||
output.append(essential_field)
|
||||
score_func = ""
|
||||
score_column = ""
|
||||
for matchExpr in matchExprs:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
score_func = "score()"
|
||||
score_column = "SCORE"
|
||||
break
|
||||
if not score_func:
|
||||
for matchExpr in matchExprs:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchDenseExpr):
|
||||
score_func = "similarity()"
|
||||
score_column = "SIMILARITY"
|
||||
break
|
||||
if matchExprs:
|
||||
if match_expressions:
|
||||
if score_func not in output:
|
||||
output.append(score_func)
|
||||
if PAGERANK_FLD not in output:
|
||||
@ -387,11 +147,11 @@ class InfinityConnection(DocStoreConnection):
|
||||
filter_fulltext = ""
|
||||
if condition:
|
||||
table_found = False
|
||||
for indexName in indexNames:
|
||||
for kb_id in knowledgebaseIds:
|
||||
for indexName in index_names:
|
||||
for kb_id in knowledgebase_ids:
|
||||
table_name = f"{indexName}_{kb_id}"
|
||||
try:
|
||||
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
table_found = True
|
||||
break
|
||||
except Exception:
|
||||
@ -399,14 +159,14 @@ class InfinityConnection(DocStoreConnection):
|
||||
if table_found:
|
||||
break
|
||||
if not table_found:
|
||||
logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}")
|
||||
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
|
||||
return pd.DataFrame(), 0
|
||||
|
||||
for matchExpr in matchExprs:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
if filter_cond and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_cond})
|
||||
matchExpr.fields = [convert_matching_field(field) for field in matchExpr.fields]
|
||||
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
|
||||
fields = ",".join(matchExpr.fields)
|
||||
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
|
||||
if filter_cond:
|
||||
@ -430,7 +190,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
if filter_fulltext and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||
@ -441,16 +201,16 @@ class InfinityConnection(DocStoreConnection):
|
||||
if similarity:
|
||||
matchExpr.extra_options["threshold"] = similarity
|
||||
del matchExpr.extra_options["similarity"]
|
||||
logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
if matchExpr.method == "weighted_sum":
|
||||
# The default is "minmax" which gives a zero score for the last doc.
|
||||
matchExpr.fusion_params["normalize"] = "atan"
|
||||
logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
|
||||
order_by_expr_list = list()
|
||||
if orderBy.fields:
|
||||
for order_field in orderBy.fields:
|
||||
if order_by.fields:
|
||||
for order_field in order_by.fields:
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field[0], SortType.Asc))
|
||||
else:
|
||||
@ -458,8 +218,8 @@ class InfinityConnection(DocStoreConnection):
|
||||
|
||||
total_hits_count = 0
|
||||
# Scatter search tables and gather the results
|
||||
for indexName in indexNames:
|
||||
for knowledgebaseId in knowledgebaseIds:
|
||||
for indexName in index_names:
|
||||
for knowledgebaseId in knowledgebase_ids:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
@ -467,8 +227,8 @@ class InfinityConnection(DocStoreConnection):
|
||||
continue
|
||||
table_list.append(table_name)
|
||||
builder = table_instance.output(output)
|
||||
if len(matchExprs) > 0:
|
||||
for matchExpr in matchExprs:
|
||||
if len(match_expressions) > 0:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
fields = ",".join(matchExpr.fields)
|
||||
builder = builder.match_text(
|
||||
@ -491,53 +251,52 @@ class InfinityConnection(DocStoreConnection):
|
||||
else:
|
||||
if filter_cond and len(filter_cond) > 0:
|
||||
builder.filter(filter_cond)
|
||||
if orderBy.fields:
|
||||
if order_by.fields:
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(offset).limit(limit)
|
||||
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
|
||||
if extra_result:
|
||||
total_hits_count += int(extra_result["total_hits_count"])
|
||||
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
||||
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, output)
|
||||
if matchExprs:
|
||||
res = self.concat_dataframes(df_list, output)
|
||||
if match_expressions:
|
||||
res["_score"] = res[score_column] + res[PAGERANK_FLD]
|
||||
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
|
||||
res = res.head(limit)
|
||||
logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
self.logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
return res, total_hits_count
|
||||
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
assert isinstance(knowledgebaseIds, list)
|
||||
assert isinstance(knowledgebase_ids, list)
|
||||
table_list = list()
|
||||
for knowledgebaseId in knowledgebaseIds:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
for knowledgebaseId in knowledgebase_ids:
|
||||
table_name = f"{index_name}_{knowledgebaseId}"
|
||||
table_list.append(table_name)
|
||||
table_instance = None
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
continue
|
||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
|
||||
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
|
||||
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, ["id"])
|
||||
res = self.concat_dataframes(df_list, ["id"])
|
||||
fields = set(res.columns.tolist())
|
||||
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
|
||||
fields.add(field)
|
||||
res_fields = self.get_fields(res, list(fields))
|
||||
return res_fields.get(chunkId, None)
|
||||
return res_fields.get(chunk_id, None)
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
table_name = f"{index_name}_{knowledgebase_id}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except InfinityException as e:
|
||||
@ -553,7 +312,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
break
|
||||
if vector_size == 0:
|
||||
raise ValueError("Cannot infer vector size from documents")
|
||||
self.createIdx(indexName, knowledgebaseId, vector_size)
|
||||
self.create_idx(index_name, knowledgebase_id, vector_size)
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
# embedding fields can't have a default value....
|
||||
@ -574,12 +333,12 @@ class InfinityConnection(DocStoreConnection):
|
||||
d["docnm"] = v
|
||||
elif k == "title_kwd":
|
||||
if not d.get("docnm_kwd"):
|
||||
d["docnm"] = list2str(v)
|
||||
d["docnm"] = self.list2str(v)
|
||||
elif k == "title_sm_tks":
|
||||
if not d.get("docnm_kwd"):
|
||||
d["docnm"] = list2str(v)
|
||||
d["docnm"] = self.list2str(v)
|
||||
elif k == "important_kwd":
|
||||
d["important_keywords"] = list2str(v)
|
||||
d["important_keywords"] = self.list2str(v)
|
||||
elif k == "important_tks":
|
||||
if not d.get("important_kwd"):
|
||||
d["important_keywords"] = v
|
||||
@ -597,11 +356,11 @@ class InfinityConnection(DocStoreConnection):
|
||||
if not d.get("authors_tks"):
|
||||
d["authors"] = v
|
||||
elif k == "question_kwd":
|
||||
d["questions"] = list2str(v, "\n")
|
||||
d["questions"] = self.list2str(v, "\n")
|
||||
elif k == "question_tks":
|
||||
if not d.get("question_kwd"):
|
||||
d["questions"] = list2str(v)
|
||||
elif field_keyword(k):
|
||||
d["questions"] = self.list2str(v)
|
||||
elif self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
d[k] = "###".join(v)
|
||||
else:
|
||||
@ -637,15 +396,15 @@ class InfinityConnection(DocStoreConnection):
|
||||
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
|
||||
table_instance.insert(docs)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
|
||||
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
|
||||
return []
|
||||
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
|
||||
# if 'position_int' in newValue:
|
||||
# logger.info(f"update position_int: {newValue['position_int']}")
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
table_name = f"{index_name}_{knowledgebase_id}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
# if "exists" in condition:
|
||||
# del condition["exists"]
|
||||
@ -654,57 +413,57 @@ class InfinityConnection(DocStoreConnection):
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
clmns[n] = (ty, de)
|
||||
filter = equivalent_condition_to_str(condition, table_instance)
|
||||
filter = self.equivalent_condition_to_str(condition, table_instance)
|
||||
removeValue = {}
|
||||
for k, v in list(newValue.items()):
|
||||
for k, v in list(new_value.items()):
|
||||
if k == "docnm_kwd":
|
||||
newValue["docnm"] = list2str(v)
|
||||
new_value["docnm"] = self.list2str(v)
|
||||
elif k == "title_kwd":
|
||||
if not newValue.get("docnm_kwd"):
|
||||
newValue["docnm"] = list2str(v)
|
||||
if not new_value.get("docnm_kwd"):
|
||||
new_value["docnm"] = self.list2str(v)
|
||||
elif k == "title_sm_tks":
|
||||
if not newValue.get("docnm_kwd"):
|
||||
newValue["docnm"] = v
|
||||
if not new_value.get("docnm_kwd"):
|
||||
new_value["docnm"] = v
|
||||
elif k == "important_kwd":
|
||||
newValue["important_keywords"] = list2str(v)
|
||||
new_value["important_keywords"] = self.list2str(v)
|
||||
elif k == "important_tks":
|
||||
if not newValue.get("important_kwd"):
|
||||
newValue["important_keywords"] = v
|
||||
if not new_value.get("important_kwd"):
|
||||
new_value["important_keywords"] = v
|
||||
elif k == "content_with_weight":
|
||||
newValue["content"] = v
|
||||
new_value["content"] = v
|
||||
elif k == "content_ltks":
|
||||
if not newValue.get("content_with_weight"):
|
||||
newValue["content"] = v
|
||||
if not new_value.get("content_with_weight"):
|
||||
new_value["content"] = v
|
||||
elif k == "content_sm_ltks":
|
||||
if not newValue.get("content_with_weight"):
|
||||
newValue["content"] = v
|
||||
if not new_value.get("content_with_weight"):
|
||||
new_value["content"] = v
|
||||
elif k == "authors_tks":
|
||||
newValue["authors"] = v
|
||||
new_value["authors"] = v
|
||||
elif k == "authors_sm_tks":
|
||||
if not newValue.get("authors_tks"):
|
||||
newValue["authors"] = v
|
||||
if not new_value.get("authors_tks"):
|
||||
new_value["authors"] = v
|
||||
elif k == "question_kwd":
|
||||
newValue["questions"] = "\n".join(v)
|
||||
new_value["questions"] = "\n".join(v)
|
||||
elif k == "question_tks":
|
||||
if not newValue.get("question_kwd"):
|
||||
newValue["questions"] = list2str(v)
|
||||
elif field_keyword(k):
|
||||
if not new_value.get("question_kwd"):
|
||||
new_value["questions"] = self.list2str(v)
|
||||
elif self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
newValue[k] = "###".join(v)
|
||||
new_value[k] = "###".join(v)
|
||||
else:
|
||||
newValue[k] = v
|
||||
new_value[k] = v
|
||||
elif re.search(r"_feas$", k):
|
||||
newValue[k] = json.dumps(v)
|
||||
new_value[k] = json.dumps(v)
|
||||
elif k == "kb_id":
|
||||
if isinstance(newValue[k], list):
|
||||
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
|
||||
if isinstance(new_value[k], list):
|
||||
new_value[k] = new_value[k][0] # since d[k] is a list, but we need a str
|
||||
elif k == "position_int":
|
||||
assert isinstance(v, list)
|
||||
arr = [num for row in v for num in row]
|
||||
newValue[k] = "_".join(f"{num:08x}" for num in arr)
|
||||
new_value[k] = "_".join(f"{num:08x}" for num in arr)
|
||||
elif k in ["page_num_int", "top_int"]:
|
||||
assert isinstance(v, list)
|
||||
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
||||
new_value[k] = "_".join(f"{num:08x}" for num in v)
|
||||
elif k == "remove":
|
||||
if isinstance(v, str):
|
||||
assert v in clmns, f"'{v}' should be in '{clmns}'."
|
||||
@ -712,22 +471,22 @@ class InfinityConnection(DocStoreConnection):
|
||||
if ty.lower().find("cha"):
|
||||
if not de:
|
||||
de = ""
|
||||
newValue[v] = de
|
||||
new_value[v] = de
|
||||
else:
|
||||
for kk, vv in v.items():
|
||||
removeValue[kk] = vv
|
||||
del newValue[k]
|
||||
del new_value[k]
|
||||
else:
|
||||
newValue[k] = v
|
||||
new_value[k] = v
|
||||
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
|
||||
if k in newValue:
|
||||
del newValue[k]
|
||||
if k in new_value:
|
||||
del new_value[k]
|
||||
|
||||
remove_opt = {} # "[k,new_value]": [id_to_update, ...]
|
||||
if removeValue:
|
||||
col_to_remove = list(removeValue.keys())
|
||||
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
|
||||
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
|
||||
self.logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
|
||||
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
|
||||
for id, old_v in row_to_opt.items():
|
||||
for k, remove_v in removeValue.items():
|
||||
@ -740,78 +499,53 @@ class InfinityConnection(DocStoreConnection):
|
||||
else:
|
||||
remove_opt[kv_key].append(id)
|
||||
|
||||
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
||||
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
|
||||
for update_kv, ids in remove_opt.items():
|
||||
k, v = json.loads(update_kv)
|
||||
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
|
||||
|
||||
table_instance.update(filter, newValue)
|
||||
table_instance.update(filter, new_value)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
|
||||
return 0
|
||||
filter = equivalent_condition_to_str(condition, table_instance)
|
||||
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
|
||||
res = table_instance.delete(filter)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res.deleted_rows
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if not fields:
|
||||
return {}
|
||||
fieldsAll = fields.copy()
|
||||
fieldsAll.append("id")
|
||||
fieldsAll = set(fieldsAll)
|
||||
fields_all = fields.copy()
|
||||
fields_all.append("id")
|
||||
fields_all = set(fields_all)
|
||||
if "docnm" in res.columns:
|
||||
for field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
|
||||
if field in fieldsAll:
|
||||
if field in fields_all:
|
||||
res[field] = res["docnm"]
|
||||
if "important_keywords" in res.columns:
|
||||
if "important_kwd" in fieldsAll:
|
||||
if "important_kwd" in fields_all:
|
||||
res["important_kwd"] = res["important_keywords"].apply(lambda v: v.split())
|
||||
if "important_tks" in fieldsAll:
|
||||
if "important_tks" in fields_all:
|
||||
res["important_tks"] = res["important_keywords"]
|
||||
if "questions" in res.columns:
|
||||
if "question_kwd" in fieldsAll:
|
||||
if "question_kwd" in fields_all:
|
||||
res["question_kwd"] = res["questions"].apply(lambda v: v.splitlines())
|
||||
if "question_tks" in fieldsAll:
|
||||
if "question_tks" in fields_all:
|
||||
res["question_tks"] = res["questions"]
|
||||
if "content" in res.columns:
|
||||
for field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
|
||||
if field in fieldsAll:
|
||||
if field in fields_all:
|
||||
res[field] = res["content"]
|
||||
if "authors" in res.columns:
|
||||
for field in ["authors_tks", "authors_sm_tks"]:
|
||||
if field in fieldsAll:
|
||||
if field in fields_all:
|
||||
res[field] = res["authors"]
|
||||
|
||||
column_map = {col.lower(): col for col in res.columns}
|
||||
matched_columns = {column_map[col.lower()]: col for col in fieldsAll if col.lower() in column_map}
|
||||
none_columns = [col for col in fieldsAll if col.lower() not in column_map]
|
||||
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
|
||||
none_columns = [col for col in fields_all if col.lower() not in column_map]
|
||||
|
||||
res2 = res[matched_columns.keys()]
|
||||
res2 = res2.rename(columns=matched_columns)
|
||||
@ -819,7 +553,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
|
||||
for column in list(res2.columns):
|
||||
k = column.lower()
|
||||
if field_keyword(k):
|
||||
if self.field_keyword(k):
|
||||
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
|
||||
elif re.search(r"_feas$", k):
|
||||
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
|
||||
@ -844,95 +578,3 @@ class InfinityConnection(DocStoreConnection):
|
||||
res2[column] = None
|
||||
|
||||
return res2.set_index("id").to_dict(orient="index")
|
||||
|
||||
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
if fieldnm not in res:
|
||||
return {}
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
txt = res[fieldnm][i]
|
||||
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
|
||||
ans[id] = txt
|
||||
continue
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
if is_english([t]):
|
||||
for w in keywords:
|
||||
t = re.sub(
|
||||
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
|
||||
r"\1<em>\2</em>\3",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
else:
|
||||
for w in sorted(keywords, key=len, reverse=True):
|
||||
t = re.sub(
|
||||
re.escape(w),
|
||||
f"<em>{w}</em>",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txts.append(t)
|
||||
if txts:
|
||||
ans[id] = "...".join(txts)
|
||||
else:
|
||||
ans[id] = txt
|
||||
return ans
|
||||
|
||||
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
"""
|
||||
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
|
||||
"""
|
||||
from collections import Counter
|
||||
|
||||
# Extract DataFrame from result
|
||||
if isinstance(res, tuple):
|
||||
df, _ = res
|
||||
else:
|
||||
df = res
|
||||
|
||||
if df.empty or fieldnm not in df.columns:
|
||||
return []
|
||||
|
||||
# Aggregate tag counts
|
||||
tag_counter = Counter()
|
||||
|
||||
for value in df[fieldnm]:
|
||||
if pd.isna(value) or not value:
|
||||
continue
|
||||
|
||||
# Handle different tag formats
|
||||
if isinstance(value, str):
|
||||
# Split by ### for tag_kwd field or comma for other formats
|
||||
if fieldnm == "tag_kwd" and "###" in value:
|
||||
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
|
||||
else:
|
||||
# Try comma separation as fallback
|
||||
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
|
||||
|
||||
for tag in tags:
|
||||
if tag: # Only count non-empty tags
|
||||
tag_counter[tag] += 1
|
||||
elif isinstance(value, list):
|
||||
# Handle list format
|
||||
for tag in value:
|
||||
if tag and isinstance(tag, str):
|
||||
tag_counter[tag.strip()] += 1
|
||||
|
||||
# Return as list of [tag, count] pairs, sorted by count descending
|
||||
return [[tag, count] for tag, count in tag_counter.most_common()]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@ -37,9 +37,8 @@ from common import settings
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common.decorator import singleton
|
||||
from common.float_utils import get_float
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
|
||||
MatchDenseExpr
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))
|
||||
@ -497,7 +496,7 @@ class OBConnection(DocStoreConnection):
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
def db_type(self) -> str:
|
||||
return "oceanbase"
|
||||
|
||||
def health(self) -> dict:
|
||||
@ -553,7 +552,7 @@ class OBConnection(DocStoreConnection):
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
vector_field_name = f"q_{vectorSize}_vec"
|
||||
vector_index_name = f"{vector_field_name}_idx"
|
||||
|
||||
@ -604,7 +603,7 @@ class OBConnection(DocStoreConnection):
|
||||
# always refresh metadata to make sure it contains the latest table structure
|
||||
self.client.refresh_metadata([indexName])
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
def delete_idx(self, indexName: str, knowledgebaseId: str):
|
||||
if len(knowledgebaseId) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
@ -615,7 +614,7 @@ class OBConnection(DocStoreConnection):
|
||||
except Exception as e:
|
||||
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
return self._check_table_exists_cached(indexName)
|
||||
|
||||
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
|
||||
@ -1500,7 +1499,7 @@ class OBConnection(DocStoreConnection):
|
||||
def get_total(self, res) -> int:
|
||||
return res.total
|
||||
|
||||
def get_chunk_ids(self, res) -> list[str]:
|
||||
def get_doc_ids(self, res) -> list[str]:
|
||||
return [row["id"] for row in res.chunks]
|
||||
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
|
||||
@ -26,8 +26,7 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
|
||||
from opensearchpy import ConnectionTimeout
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
@ -79,7 +78,7 @@ class OSConnection(DocStoreConnection):
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
def db_type(self) -> str:
|
||||
return "opensearch"
|
||||
|
||||
def health(self) -> dict:
|
||||
@ -91,8 +90,8 @@ class OSConnection(DocStoreConnection):
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.indexExist(indexName, knowledgebaseId):
|
||||
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.index_exist(indexName, knowledgebaseId):
|
||||
return True
|
||||
try:
|
||||
from opensearchpy.client import IndicesClient
|
||||
@ -101,7 +100,7 @@ class OSConnection(DocStoreConnection):
|
||||
except Exception:
|
||||
logger.exception("OSConnection.createIndex error %s" % (indexName))
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
def delete_idx(self, indexName: str, knowledgebaseId: str):
|
||||
if len(knowledgebaseId) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
@ -112,7 +111,7 @@ class OSConnection(DocStoreConnection):
|
||||
except Exception:
|
||||
logger.exception("OSConnection.deleteIdx error %s" % (indexName))
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
s = Index(indexName, self.os)
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
@ -460,7 +459,7 @@ class OSConnection(DocStoreConnection):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def get_chunk_ids(self, res):
|
||||
def get_doc_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
|
||||
@ -272,6 +272,49 @@ class RedisDB:
|
||||
self.__open__()
|
||||
return None
|
||||
|
||||
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int:
|
||||
redis_key = f"{key_prefix}:{namespace}"
|
||||
|
||||
try:
|
||||
# Use pipeline for atomicity
|
||||
pipe = self.REDIS.pipeline()
|
||||
|
||||
# Check if key exists
|
||||
pipe.exists(redis_key)
|
||||
|
||||
# Get/Increment
|
||||
if ensure_minimum is not None:
|
||||
# Ensure minimum value
|
||||
pipe.get(redis_key)
|
||||
results = pipe.execute()
|
||||
|
||||
if results[0] == 0: # Key doesn't exist
|
||||
start_id = max(1, ensure_minimum)
|
||||
pipe.set(redis_key, start_id)
|
||||
pipe.execute()
|
||||
return start_id
|
||||
else:
|
||||
current = int(results[1])
|
||||
if current < ensure_minimum:
|
||||
pipe.set(redis_key, ensure_minimum)
|
||||
pipe.execute()
|
||||
return ensure_minimum
|
||||
|
||||
# Increment operation
|
||||
next_id = self.REDIS.incrby(redis_key, increment)
|
||||
|
||||
# If it's the first time, set a reasonable initial value
|
||||
if next_id == increment:
|
||||
self.REDIS.set(redis_key, 1 + increment)
|
||||
return 1 + increment
|
||||
|
||||
return next_id
|
||||
|
||||
except Exception as e:
|
||||
logging.warning("RedisDB.generate_auto_increment_id got exception: " + str(e))
|
||||
self.__open__()
|
||||
return -1
|
||||
|
||||
def transaction(self, key, value, exp=3600):
|
||||
try:
|
||||
pipeline = self.REDIS.pipeline(transaction=True)
|
||||
|
||||
Reference in New Issue
Block a user