Use Infinity single-field-multi-index (#11444)

### What problem does this PR solve?

Use Infinity single-field-multi-index

### Type of change

- [x] Refactoring
- [x] Performance Improvement
This commit is contained in:
Zhichang Yu
2025-11-26 11:06:37 +08:00
committed by GitHub
parent a28c672695
commit 40e84ca41a
22 changed files with 577 additions and 140 deletions

View File

@ -26,6 +26,7 @@ from hanziconv import HanziConv
from nltk import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer
from common.file_utils import get_project_base_directory
from common import settings
class RagTokenizer:
@ -38,7 +39,7 @@ class RagTokenizer:
def _load_dict(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}")
try:
of = open(fnm, "r", encoding='utf-8')
of = open(fnm, "r", encoding="utf-8")
while True:
line = of.readline()
if not line:
@ -46,7 +47,7 @@ class RagTokenizer:
line = re.sub(r"[\r\n]+", "", line)
line = re.split(r"[ \t]", line)
k = self.key_(line[0])
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
if k not in self.trie_ or self.trie_[k][0] < F:
self.trie_[self.key_(line[0])] = (F, line[2])
self.trie_[self.rkey_(line[0])] = 1
@ -106,8 +107,8 @@ class RagTokenizer:
if inside_code == 0x3000:
inside_code = 0x0020
else:
inside_code -= 0xfee0
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
inside_code -= 0xFEE0
if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
rstring += uchar
else:
rstring += chr(inside_code)
@ -124,14 +125,14 @@ class RagTokenizer:
if s < len(chars):
copy_pretks = copy.deepcopy(preTks)
remaining = "".join(chars[s:])
copy_pretks.append((remaining, (-12, '')))
copy_pretks.append((remaining, (-12, "")))
tkslist.append(copy_pretks)
return s
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
if state_key in _memo:
return _memo[state_key]
res = s
if s >= len(chars):
tkslist.append(preTks)
@ -155,23 +156,23 @@ class RagTokenizer:
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
copy_pretks.append((t, (-12, "")))
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
res = max(res, next_res)
_memo[state_key] = res
return res
S = s + 1
if s + 2 <= len(chars):
t1 = "".join(chars[s:s + 1])
t2 = "".join(chars[s:s + 2])
t1 = "".join(chars[s : s + 1])
t2 = "".join(chars[s : s + 2])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
S = s + 2
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2
for e in range(S, len(chars) + 1):
t = "".join(chars[s:e])
k = self.key_(t)
@ -181,18 +182,18 @@ class RagTokenizer:
pretks = copy.deepcopy(preTks)
pretks.append((t, self.trie_[k]))
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
if res > s:
_memo[state_key] = res
return res
t = "".join(chars[s:s + 1])
t = "".join(chars[s : s + 1])
k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
copy_pretks.append((t, (-12, "")))
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
_memo[state_key] = result
return result
@ -216,7 +217,7 @@ class RagTokenizer:
F += freq
L += 0 if len(tk) < 2 else 1
tks.append(tk)
#F /= len(tks)
# F /= len(tks)
L /= len(tks)
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F
@ -252,8 +253,7 @@ class RagTokenizer:
while s < len(line):
e = s + 1
t = line[s:e]
while e < len(line) and self.trie_.has_keys_with_prefix(
self.key_(t)):
while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
e += 1
t = line[s:e]
@ -264,7 +264,7 @@ class RagTokenizer:
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
res.append((t, (0, "")))
s = e
@ -287,7 +287,7 @@ class RagTokenizer:
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
res.append((t, (0, "")))
s -= 1
@ -310,28 +310,29 @@ class RagTokenizer:
if _zh == zh:
e += 1
continue
txt_lang_pairs.append((a[s: e], zh))
txt_lang_pairs.append((a[s:e], zh))
s = e
e = s + 1
zh = _zh
if s >= len(a):
continue
txt_lang_pairs.append((a[s: e], zh))
txt_lang_pairs.append((a[s:e], zh))
return txt_lang_pairs
def tokenize(self, line):
def tokenize(self, line: str) -> str:
if settings.DOC_ENGINE_INFINITY:
return line
line = re.sub(r"\W+", " ", line)
line = self._strQ2B(line).lower()
line = self._tradi2simp(line)
arr = self._split_by_lang(line)
res = []
for L,lang in arr:
for L, lang in arr:
if not lang:
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
continue
if len(L) < 2 or re.match(
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
res.append(L)
continue
@ -347,7 +348,7 @@ class RagTokenizer:
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0:
res.append(" ".join(tks[j: j + same]))
res.append(" ".join(tks[j : j + same]))
_i = i + same
_j = j + same
j = _j + 1
@ -374,7 +375,7 @@ class RagTokenizer:
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
res.append(" ".join(tks[j: j + same]))
res.append(" ".join(tks[j : j + same]))
_i = i + same
_j = j + same
j = _j + 1
@ -391,7 +392,9 @@ class RagTokenizer:
logging.debug("[TKS] {}".format(self.merge_(res)))
return self.merge_(res)
def fine_grained_tokenize(self, tks):
def fine_grained_tokenize(self, tks: str) -> str:
if settings.DOC_ENGINE_INFINITY:
return tks
tks = tks.split()
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
if zh_num < len(tks) * 0.2:
@ -433,21 +436,21 @@ class RagTokenizer:
def is_chinese(s):
if s >= u'\u4e00' and s <= u'\u9fa5':
if s >= "\u4e00" and s <= "\u9fa5":
return True
else:
return False
def is_number(s):
if s >= u'\u0030' and s <= u'\u0039':
if s >= "\u0030" and s <= "\u0039":
return True
else:
return False
def is_alphabet(s):
if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'):
if ("\u0041" <= s <= "\u005a") or ("\u0061" <= s <= "\u007a"):
return True
else:
return False
@ -456,8 +459,7 @@ def is_alphabet(s):
def naive_qie(txt):
tks = []
for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
) and re.match(r".*[a-zA-Z]$", t):
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t):
tks.append(" ")
tks.append(t)
return tks
@ -473,43 +475,35 @@ add_user_dict = tokenizer.add_user_dict
tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B
if __name__ == '__main__':
if __name__ == "__main__":
tknzr = RagTokenizer(debug=True)
# huqie.add_user_dict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("虽然我不怎么玩")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
logging.info(tknzr.fine_grained_tokenize(tks))
texts = [
"over_the_past.pdf",
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈",
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。",
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥",
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa",
"虽然我不怎么玩",
"蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的",
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了",
"这周日你去吗?这周日你有空吗?",
"Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ",
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-",
]
for text in texts:
print(text)
tks1 = tknzr.tokenize(text)
tks2 = tknzr.fine_grained_tokenize(tks1)
print(tks1)
print(tks2)
if len(sys.argv) < 2:
sys.exit()
tknzr.DEBUG = False
tknzr.load_user_dict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()
if not line:
break
logging.info(tknzr.tokenize(line))
print(tknzr.tokenize(line))
of.close()

View File

@ -17,7 +17,6 @@ import json
import logging
import re
import math
import os
from collections import OrderedDict
from dataclasses import dataclass
@ -28,6 +27,7 @@ from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionE
from common.string_utils import remove_redundant_spaces
from common.float_utils import get_float
from common.constants import PAGERANK_FLD, TAG_FLD
from common import settings
def index_name(uid): return f"ragflow_{uid}"
@ -120,7 +120,8 @@ class Dealer:
else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data
src.append(f"q_{len(q_vec)}_vec")
if not settings.DOC_ENGINE_INFINITY:
src.append(f"q_{len(q_vec)}_vec")
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
matchExprs = [matchText, matchDense, fusionExpr]
@ -405,8 +406,13 @@ class Dealer:
rank_feature=rank_feature,
)
else:
lower_case_doc_engine = os.getenv("DOC_ENGINE", "elasticsearch")
if lower_case_doc_engine in ["elasticsearch", "opensearch"]:
if settings.DOC_ENGINE_INFINITY:
# Don't need rerank here since Infinity normalizes each way score before fusion.
sim = [sres.field[id].get("_score", 0.0) for id in sres.ids]
sim = [s if s is not None else 0.0 for s in sim]
tsim = sim
vsim = sim
else:
# ElasticSearch doesn't normalize each way score before fusion.
sim, tsim, vsim = self.rerank(
sres,
@ -415,12 +421,6 @@ class Dealer:
vector_similarity_weight,
rank_feature=rank_feature,
)
else:
# Don't need rerank here since Infinity normalizes each way score before fusion.
sim = [sres.field[id].get("_score", 0.0) for id in sres.ids]
sim = [s if s is not None else 0.0 for s in sim]
tsim = sim
vsim = sim
sim_np = np.array(sim, dtype=np.float64)
if sim_np.size == 0:

View File

@ -44,11 +44,56 @@ logger = logging.getLogger("ragflow.infinity_conn")
def field_keyword(field_name: str):
# The "docnm_kwd" field is always a string, not list.
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name != "docnm_kwd" and field_name != "knowledge_graph_kwd"):
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", "question_kwd"]):
return True
return False
def convert_select_fields(output_fields: list[str]) -> list[str]:
for i, field in enumerate(output_fields):
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
output_fields[i] = "docnm"
elif field in ["important_kwd", "important_tks"]:
output_fields[i] = "important_keywords"
elif field in ["question_kwd", "question_tks"]:
output_fields[i] = "questions"
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
output_fields[i] = "content"
elif field in ["authors_tks", "authors_sm_tks"]:
output_fields[i] = "authors"
return list(set(output_fields))
def convert_matching_field(field_weightstr: str) -> str:
tokens = field_weightstr.split("^")
field = tokens[0]
if field == "docnm_kwd" or field == "title_tks":
field = "docnm@ft_docnm_rag_coarse"
elif field == "title_sm_tks":
field = "docnm@ft_title_rag_fine"
elif field == "important_kwd":
field = "important_keywords@ft_important_keywords_rag_coarse"
elif field == "important_tks":
field = "important_keywords@ft_important_keywords_rag_fine"
elif field == "question_kwd":
field = "questions@ft_questions_rag_coarse"
elif field == "question_tks":
field = "questions@ft_questions_rag_fine"
elif field == "content_with_weight" or field == "content_ltks":
field = "content@ft_content_rag_coarse"
elif field == "content_sm_ltks":
field = "content@ft_content_rag_fine"
elif field == "authors_tks":
field = "authors@ft_authors_rag_coarse"
elif field == "authors_sm_tks":
field = "authors@ft_authors_rag_fine"
tokens[0] = field
return "^".join(tokens)
def list2str(lst: str|list, sep: str = " ") -> str:
if isinstance(lst, str):
return lst
return sep.join(lst)
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
assert "_id" not in condition
@ -77,13 +122,13 @@ def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | N
for item in v:
if isinstance(item, str):
item = item.replace("'", "''")
inCond.append(f"filter_fulltext('{k}', '{item}')")
inCond.append(f"filter_fulltext('{convert_matching_field(k)}', '{item}')")
if inCond:
strInCond = " or ".join(inCond)
strInCond = f"({strInCond})"
cond.append(strInCond)
else:
cond.append(f"filter_fulltext('{k}', '{v}')")
cond.append(f"filter_fulltext('{convert_matching_field(k)}', '{v}')")
elif isinstance(v, list):
inCond = list()
for item in v:
@ -181,11 +226,15 @@ class InfinityConnection(DocStoreConnection):
logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
inf_table.create_index(
f"text_idx_{field_name}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}),
ConflictType.Ignore,
)
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
"""
Database operations
@ -245,11 +294,15 @@ class InfinityConnection(DocStoreConnection):
for field_name, field_info in schema.items():
if field_info["type"] != "varchar" or "analyzer" not in field_info:
continue
inf_table.create_index(
f"text_idx_{field_name}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}),
ConflictType.Ignore,
)
analyzers = field_info["analyzer"]
if isinstance(analyzers, str):
analyzers = [analyzers]
for analyzer in analyzers:
inf_table.create_index(
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
ConflictType.Ignore,
)
self.connPool.release_conn(inf_conn)
logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}")
@ -302,6 +355,7 @@ class InfinityConnection(DocStoreConnection):
df_list = list()
table_list = list()
output = selectFields.copy()
output = convert_select_fields(output)
for essential_field in ["id"] + aggFields:
if essential_field not in output:
output.append(essential_field)
@ -352,6 +406,7 @@ class InfinityConnection(DocStoreConnection):
if isinstance(matchExpr, MatchTextExpr):
if filter_cond and "filter" not in matchExpr.extra_options:
matchExpr.extra_options.update({"filter": filter_cond})
matchExpr.fields = [convert_matching_field(field) for field in matchExpr.fields]
fields = ",".join(matchExpr.fields)
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
if filter_cond:
@ -470,7 +525,10 @@ class InfinityConnection(DocStoreConnection):
df_list.append(kb_res)
self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"])
res_fields = self.get_fields(res, res.columns.tolist())
fields = set(res.columns.tolist())
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
fields.add(field)
res_fields = self.get_fields(res, list(fields))
return res_fields.get(chunkId, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
@ -508,8 +566,39 @@ class InfinityConnection(DocStoreConnection):
for d in docs:
assert "_id" not in d
assert "id" in d
for k, v in d.items():
if field_keyword(k):
for k, v in list(d.items()):
if k == "docnm_kwd":
d["docnm"] = v
elif k == "title_kwd":
if not d.get("docnm_kwd"):
d["docnm"] = list2str(v)
elif k == "title_sm_tks":
if not d.get("docnm_kwd"):
d["docnm"] = list2str(v)
elif k == "important_kwd":
d["important_keywords"] = list2str(v)
elif k == "important_tks":
if not d.get("important_kwd"):
d["important_keywords"] = v
elif k == "content_with_weight":
d["content"] = v
elif k == "content_ltks":
if not d.get("content_with_weight"):
d["content"] = v
elif k == "content_sm_ltks":
if not d.get("content_with_weight"):
d["content"] = v
elif k == "authors_tks":
d["authors"] = v
elif k == "authors_sm_tks":
if not d.get("authors_tks"):
d["authors"] = v
elif k == "question_kwd":
d["questions"] = list2str(v, "\n")
elif k == "question_tks":
if not d.get("question_kwd"):
d["questions"] = list2str(v)
elif field_keyword(k):
if isinstance(v, list):
d[k] = "###".join(v)
else:
@ -528,6 +617,9 @@ class InfinityConnection(DocStoreConnection):
d[k] = "_".join(f"{num:08x}" for num in v)
else:
d[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in d:
del d[k]
for n, vs in embedding_clmns:
if n in d:
@ -562,7 +654,38 @@ class InfinityConnection(DocStoreConnection):
filter = equivalent_condition_to_str(condition, table_instance)
removeValue = {}
for k, v in list(newValue.items()):
if field_keyword(k):
if k == "docnm_kwd":
newValue["docnm"] = list2str(v)
elif k == "title_kwd":
if not newValue.get("docnm_kwd"):
newValue["docnm"] = list2str(v)
elif k == "title_sm_tks":
if not newValue.get("docnm_kwd"):
newValue["docnm"] = v
elif k == "important_kwd":
newValue["important_keywords"] = list2str(v)
elif k == "important_tks":
if not newValue.get("important_kwd"):
newValue["important_keywords"] = v
elif k == "content_with_weight":
newValue["content"] = v
elif k == "content_ltks":
if not newValue.get("content_with_weight"):
newValue["content"] = v
elif k == "content_sm_ltks":
if not newValue.get("content_with_weight"):
newValue["content"] = v
elif k == "authors_tks":
newValue["authors"] = v
elif k == "authors_sm_tks":
if not newValue.get("authors_tks"):
newValue["authors"] = v
elif k == "question_kwd":
newValue["questions"] = "\n".join(v)
elif k == "question_tks":
if not newValue.get("question_kwd"):
newValue["questions"] = list2str(v)
elif field_keyword(k):
if isinstance(v, list):
newValue[k] = "###".join(v)
else:
@ -593,6 +716,9 @@ class InfinityConnection(DocStoreConnection):
del newValue[k]
else:
newValue[k] = v
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
if k in newValue:
del newValue[k]
remove_opt = {} # "[k,new_value]": [id_to_update, ...]
if removeValue:
@ -656,22 +782,45 @@ class InfinityConnection(DocStoreConnection):
return {}
fieldsAll = fields.copy()
fieldsAll.append("id")
fieldsAll = set(fieldsAll)
if "docnm" in res.columns:
for field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
if field in fieldsAll:
res[field] = res["docnm"]
if "important_keywords" in res.columns:
if "important_kwd" in fieldsAll:
res["important_kwd"] = res["important_keywords"].apply(lambda v: v.split())
if "important_tks" in fieldsAll:
res["important_tks"] = res["important_keywords"]
if "questions" in res.columns:
if "question_kwd" in fieldsAll:
res["question_kwd"] = res["questions"].apply(lambda v: v.splitlines())
if "question_tks" in fieldsAll:
res["question_tks"] = res["questions"]
if "content" in res.columns:
for field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
if field in fieldsAll:
res[field] = res["content"]
if "authors" in res.columns:
for field in ["authors_tks", "authors_sm_tks"]:
if field in fieldsAll:
res[field] = res["authors"]
column_map = {col.lower(): col for col in res.columns}
matched_columns = {column_map[col.lower()]: col for col in set(fieldsAll) if col.lower() in column_map}
none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map]
matched_columns = {column_map[col.lower()]: col for col in fieldsAll if col.lower() in column_map}
none_columns = [col for col in fieldsAll if col.lower() not in column_map]
res2 = res[matched_columns.keys()]
res2 = res2.rename(columns=matched_columns)
res2.drop_duplicates(subset=["id"], inplace=True)
for column in res2.columns:
for column in list(res2.columns):
k = column.lower()
if field_keyword(k):
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
elif re.search(r"_feas$", k):
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
elif k == "position_int":
def to_position_int(v):
if v:
arr = [int(hex_val, 16) for hex_val in v.split("_")]
@ -685,6 +834,9 @@ class InfinityConnection(DocStoreConnection):
res2[column] = res2[column].apply(lambda v: [int(hex_val, 16) for hex_val in v.split("_")] if v else [])
else:
pass
for column in ["docnm", "important_keywords", "questions", "content", "authors"]:
if column in res2:
del res2[column]
for column in none_columns:
res2[column] = None