Refactor function name (#11210)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-11-12 19:00:15 +08:00
committed by GitHub
parent a36a0fe71c
commit 296476ab89
20 changed files with 105 additions and 103 deletions

View File

@ -807,7 +807,7 @@ def check_embedding():
offset=0, limit=1, offset=0, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id] indexNames=index_nm, knowledgebaseIds=[kb_id]
) )
total = docStoreConn.getTotal(res0) total = docStoreConn.get_total(res0)
if total <= 0: if total <= 0:
return [] return []
@ -824,7 +824,7 @@ def check_embedding():
offset=off, limit=1, offset=off, limit=1,
indexNames=index_nm, knowledgebaseIds=[kb_id] indexNames=index_nm, knowledgebaseIds=[kb_id]
) )
ids = docStoreConn.getChunkIds(res1) ids = docStoreConn.get_chunk_ids(res1)
if not ids: if not ids:
continue continue

View File

@ -309,7 +309,7 @@ class DocumentService(CommonService):
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
page * page_size, page_size, search.index_name(tenant_id), page * page_size, page_size, search.index_name(tenant_id),
[doc.kb_id]) [doc.kb_id])
chunk_ids = settings.docStoreConn.getChunkIds(chunks) chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
if not chunk_ids: if not chunk_ids:
break break
all_chunk_ids.extend(chunk_ids) all_chunk_ids.extend(chunk_ids)
@ -322,7 +322,7 @@ class DocumentService(CommonService):
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
graph_source = settings.docStoreConn.getFields( graph_source = settings.docStoreConn.get_fields(
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
) )
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:

View File

@ -69,7 +69,7 @@ class KGSearch(Dealer):
def _ent_info_from_(self, es_res, sim_thr=0.3): def _ent_info_from_(self, es_res, sim_thr=0.3):
res = {} res = {}
flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"] flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
es_res = self.dataStore.getFields(es_res, flds) es_res = self.dataStore.get_fields(es_res, flds)
for _, ent in es_res.items(): for _, ent in es_res.items():
for f in flds: for f in flds:
if f in ent and ent[f] is None: if f in ent and ent[f] is None:
@ -88,7 +88,7 @@ class KGSearch(Dealer):
def _relation_info_from_(self, es_res, sim_thr=0.3): def _relation_info_from_(self, es_res, sim_thr=0.3):
res = {} res = {}
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
"weight_int"]) "weight_int"])
for _, ent in es_res.items(): for _, ent in es_res.items():
if get_float(ent["_score"]) < sim_thr: if get_float(ent["_score"]) < sim_thr:
@ -300,7 +300,7 @@ class KGSearch(Dealer):
fltr["entities_kwd"] = entities fltr["entities_kwd"] = entities
comm_res = self.dataStore.search(fields, [], fltr, [], comm_res = self.dataStore.search(fields, [], fltr, [],
OrderByExpr(), 0, topn, idxnms, kb_ids) OrderByExpr(), 0, topn, idxnms, kb_ids)
comm_res_fields = self.dataStore.getFields(comm_res, fields) comm_res_fields = self.dataStore.get_fields(comm_res, fields)
txts = [] txts = []
for ii, (_, row) in enumerate(comm_res_fields.items()): for ii, (_, row) in enumerate(comm_res_fields.items()):
obj = json.loads(row["content_with_weight"]) obj = json.loads(row["content_with_weight"])

View File

@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
"removed_kwd": "N", "removed_kwd": "N",
} }
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields) fields2 = settings.docStoreConn.get_fields(res, fields)
graph_doc_ids = set() graph_doc_ids = set()
for chunk_id in fields2.keys(): for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"]) graph_doc_ids = set(fields2[chunk_id]["source_id"])
@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
es_res = await trio.to_thread.run_sync( es_res = await trio.to_thread.run_sync(
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
) )
# tot = settings.docStoreConn.getTotal(es_res) # tot = settings.docStoreConn.get_total(es_res)
es_res = settings.docStoreConn.getFields(es_res, flds) es_res = settings.docStoreConn.get_fields(es_res, flds)
if len(es_res) == 0: if len(es_res) == 0:
break break

View File

@ -38,11 +38,11 @@ class FulltextQueryer:
] ]
@staticmethod @staticmethod
def subSpecialChar(line): def sub_special_char(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip() return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod @staticmethod
def isChinese(line): def is_chinese(line):
arr = re.split(r"[ \t]+", line) arr = re.split(r"[ \t]+", line)
if len(arr) <= 3: if len(arr) <= 3:
return True return True
@ -92,7 +92,7 @@ class FulltextQueryer:
otxt = txt otxt = txt
txt = FulltextQueryer.rmWWW(txt) txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt): if not self.is_chinese(txt):
txt = FulltextQueryer.rmWWW(txt) txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split() tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t] keywords = [t for t in tks if t]
@ -163,7 +163,7 @@ class FulltextQueryer:
) )
for m in sm for m in sm
] ]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1] sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1] sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32: if len(keywords) < 32:
@ -171,7 +171,7 @@ class FulltextQueryer:
keywords.extend(sm) keywords.extend(sm)
tk_syns = self.syn.lookup(tk) tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
if len(keywords) < 32: if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s]) keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
@ -180,7 +180,7 @@ class FulltextQueryer:
if len(keywords) >= 32: if len(keywords) >= 32:
break break
tk = FulltextQueryer.subSpecialChar(tk) tk = FulltextQueryer.sub_special_char(tk)
if tk.find(" ") > 0: if tk.find(" ") > 0:
tk = '"%s"' % tk tk = '"%s"' % tk
if tk_syns: if tk_syns:
@ -198,7 +198,7 @@ class FulltextQueryer:
syns = " OR ".join( syns = " OR ".join(
[ [
'"%s"' '"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s)) % rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
for s in syns for s in syns
] ]
) )
@ -217,17 +217,17 @@ class FulltextQueryer:
return None, keywords return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7): def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity from sklearn.metrics.pairwise import cosine_similarity
import numpy as np import numpy as np
sims = CosineSimilarity([avec], bvecs) sims = cosine_similarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss) tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0: if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0] return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss): def token_similarity(self, atks, btkss):
def toDict(tks): def to_dict(tks):
if isinstance(tks, str): if isinstance(tks, str):
tks = tks.split() tks = tks.split()
d = defaultdict(int) d = defaultdict(int)
@ -236,8 +236,8 @@ class FulltextQueryer:
d[t] += c d[t] += c
return d return d
atks = toDict(atks) atks = to_dict(atks)
btkss = [toDict(tks) for tks in btkss] btkss = [to_dict(tks) for tks in btkss]
return [self.similarity(atks, btks) for btks in btkss] return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt): def similarity(self, qtwt, dtwt):
@ -262,10 +262,10 @@ class FulltextQueryer:
keywords = [f'"{k.strip()}"' for k in keywords] keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]: for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk) tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk) tk = FulltextQueryer.sub_special_char(tk)
if tk.find(" ") > 0: if tk.find(" ") > 0:
tk = '"%s"' % tk tk = '"%s"' % tk
if tk_syns: if tk_syns:

View File

@ -35,7 +35,7 @@ class RagTokenizer:
def rkey_(self, line): def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1] return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm): def _load_dict(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}") logging.info(f"[HUQIE]:Build trie from {fnm}")
try: try:
of = open(fnm, "r", encoding='utf-8') of = open(fnm, "r", encoding='utf-8')
@ -85,18 +85,18 @@ class RagTokenizer:
self.trie_ = datrie.Trie(string.printable) self.trie_ = datrie.Trie(string.printable)
# load data from dict file and save to trie file # load data from dict file and save to trie file
self.loadDict_(self.DIR_ + ".txt") self._load_dict(self.DIR_ + ".txt")
def loadUserDict(self, fnm): def load_user_dict(self, fnm):
try: try:
self.trie_ = datrie.Trie.load(fnm + ".trie") self.trie_ = datrie.Trie.load(fnm + ".trie")
return return
except Exception: except Exception:
self.trie_ = datrie.Trie(string.printable) self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm) self._load_dict(fnm)
def addUserDict(self, fnm): def add_user_dict(self, fnm):
self.loadDict_(fnm) self._load_dict(fnm)
def _strQ2B(self, ustring): def _strQ2B(self, ustring):
"""Convert full-width characters to half-width characters""" """Convert full-width characters to half-width characters"""
@ -221,7 +221,7 @@ class RagTokenizer:
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F)) logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F return tks, B / len(tks) + L + F
def sortTks_(self, tkslist): def _sort_tokens(self, tkslist):
res = [] res = []
for tfts in tkslist: for tfts in tkslist:
tks, s = self.score_(tfts) tks, s = self.score_(tfts)
@ -246,7 +246,7 @@ class RagTokenizer:
return " ".join(res) return " ".join(res)
def maxForward_(self, line): def _max_forward(self, line):
res = [] res = []
s = 0 s = 0
while s < len(line): while s < len(line):
@ -270,7 +270,7 @@ class RagTokenizer:
return self.score_(res) return self.score_(res)
def maxBackward_(self, line): def _max_backward(self, line):
res = [] res = []
s = len(line) - 1 s = len(line) - 1
while s >= 0: while s >= 0:
@ -336,8 +336,8 @@ class RagTokenizer:
continue continue
# use maxforward for the first time # use maxforward for the first time
tks, s = self.maxForward_(L) tks, s = self._max_forward(L)
tks1, s1 = self.maxBackward_(L) tks1, s1 = self._max_backward(L)
if self.DEBUG: if self.DEBUG:
logging.debug("[FW] {} {}".format(tks, s)) logging.debug("[FW] {} {}".format(tks, s))
logging.debug("[BW] {} {}".format(tks1, s1)) logging.debug("[BW] {} {}".format(tks1, s1))
@ -369,7 +369,7 @@ class RagTokenizer:
# backward tokens from_i to i are different from forward tokens from _j to j. # backward tokens from_i to i are different from forward tokens from _j to j.
tkslist = [] tkslist = []
self.dfs_("".join(tks[_j:j]), 0, [], tkslist) self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0])) res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
same = 1 same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
@ -385,7 +385,7 @@ class RagTokenizer:
assert "".join(tks1[_i:]) == "".join(tks[_j:]) assert "".join(tks1[_i:]) == "".join(tks[_j:])
tkslist = [] tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist) self.dfs_("".join(tks[_j:]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0])) res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
res = " ".join(res) res = " ".join(res)
logging.debug("[TKS] {}".format(self.merge_(res))) logging.debug("[TKS] {}".format(self.merge_(res)))
@ -413,7 +413,7 @@ class RagTokenizer:
if len(tkslist) < 2: if len(tkslist) < 2:
res.append(tk) res.append(tk)
continue continue
stk = self.sortTks_(tkslist)[1][0] stk = self._sort_tokens(tkslist)[1][0]
if len(stk) == len(tk): if len(stk) == len(tk):
stk = tk stk = tk
else: else:
@ -447,14 +447,13 @@ def is_number(s):
def is_alphabet(s): def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or ( if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'):
s >= u'\u0061' and s <= u'\u007a'):
return True return True
else: else:
return False return False
def naiveQie(txt): def naive_qie(txt):
tks = [] tks = []
for t in txt.split(): for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1] if tks and re.match(r".*[a-zA-Z]$", tks[-1]
@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize
fine_grained_tokenize = tokenizer.fine_grained_tokenize fine_grained_tokenize = tokenizer.fine_grained_tokenize
tag = tokenizer.tag tag = tokenizer.tag
freq = tokenizer.freq freq = tokenizer.freq
loadUserDict = tokenizer.loadUserDict load_user_dict = tokenizer.load_user_dict
addUserDict = tokenizer.addUserDict add_user_dict = tokenizer.add_user_dict
tradi2simp = tokenizer._tradi2simp tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B strQ2B = tokenizer._strQ2B
if __name__ == '__main__': if __name__ == '__main__':
tknzr = RagTokenizer(debug=True) tknzr = RagTokenizer(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict") # huqie.add_user_dict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize( tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈") "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
logging.info(tknzr.fine_grained_tokenize(tks)) logging.info(tknzr.fine_grained_tokenize(tks))
@ -506,7 +505,7 @@ if __name__ == '__main__':
if len(sys.argv) < 2: if len(sys.argv) < 2:
sys.exit() sys.exit()
tknzr.DEBUG = False tknzr.DEBUG = False
tknzr.loadUserDict(sys.argv[1]) tknzr.load_user_dict(sys.argv[1])
of = open(sys.argv[2], "r") of = open(sys.argv[2], "r")
while True: while True:
line = of.readline() line = of.readline()

View File

@ -102,7 +102,7 @@ class Dealer:
orderBy.asc("top_int") orderBy.asc("top_int")
orderBy.desc("create_timestamp_flt") orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res) total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
else: else:
highlightFields = ["content_ltks", "title_tks"] highlightFields = ["content_ltks", "title_tks"]
@ -115,7 +115,7 @@ class Dealer:
matchExprs = [matchText] matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature) idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res) total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
else: else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
@ -127,20 +127,20 @@ class Dealer:
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
idx_names, kb_ids, rank_feature=rank_feature) idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res) total = self.dataStore.get_total(res)
logging.debug("Dealer.search TOTAL: {}".format(total)) logging.debug("Dealer.search TOTAL: {}".format(total))
# If result is empty, try again with lower min_match # If result is empty, try again with lower min_match
if total == 0: if total == 0:
if filters.get("doc_id"): if filters.get("doc_id"):
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total = self.dataStore.getTotal(res) total = self.dataStore.get_total(res)
else: else:
matchText, _ = self.qryr.question(qst, min_match=0.1) matchText, _ = self.qryr.question(qst, min_match=0.1)
matchDense.extra_options["similarity"] = 0.17 matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
total = self.dataStore.getTotal(res) total = self.dataStore.get_total(res)
logging.debug("Dealer.search 2 TOTAL: {}".format(total)) logging.debug("Dealer.search 2 TOTAL: {}".format(total))
for k in keywords: for k in keywords:
@ -153,17 +153,17 @@ class Dealer:
kwds.add(kk) kwds.add(kk)
logging.debug(f"TOTAL: {total}") logging.debug(f"TOTAL: {total}")
ids = self.dataStore.getChunkIds(res) ids = self.dataStore.get_chunk_ids(res)
keywords = list(kwds) keywords = list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
aggs = self.dataStore.getAggregation(res, "docnm_kwd") aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
return self.SearchResult( return self.SearchResult(
total=total, total=total,
ids=ids, ids=ids,
query_vector=q_vec, query_vector=q_vec,
aggregation=aggs, aggregation=aggs,
highlight=highlight, highlight=highlight,
field=self.dataStore.getFields(res, src + ["_score"]), field=self.dataStore.get_fields(res, src + ["_score"]),
keywords=keywords keywords=keywords
) )
@ -488,7 +488,7 @@ class Dealer:
for p in range(offset, max_count, bs): for p in range(offset, max_count, bs):
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id), es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
kb_ids) kb_ids)
dict_chunks = self.dataStore.getFields(es_res, fields) dict_chunks = self.dataStore.get_fields(es_res, fields)
for id, doc in dict_chunks.items(): for id, doc in dict_chunks.items():
doc["id"] = id doc["id"] = id
if dict_chunks: if dict_chunks:
@ -501,11 +501,11 @@ class Dealer:
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]): if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
return [] return []
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
return self.dataStore.getAggregation(res, "tag_kwd") return self.dataStore.get_aggregation(res, "tag_kwd")
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000): def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
res = self.dataStore.getAggregation(res, "tag_kwd") res = self.dataStore.get_aggregation(res, "tag_kwd")
total = np.sum([c for _, c in res]) total = np.sum([c for _, c in res])
return {t: (c + 1) / (total + S) for t, c in res} return {t: (c + 1) / (total + S) for t, c in res}
@ -513,7 +513,7 @@ class Dealer:
idx_nm = index_name(tenant_id) idx_nm = index_name(tenant_id)
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn) match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"]) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd") aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs: if not aggs:
return False return False
cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
@ -529,7 +529,7 @@ class Dealer:
idx_nms = [index_name(tid) for tid in tenant_ids] idx_nms = [index_name(tid) for tid in tenant_ids]
match_txt, _ = self.qryr.question(question, min_match=0.0) match_txt, _ = self.qryr.question(question, min_match=0.0)
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"]) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
aggs = self.dataStore.getAggregation(res, "tag_kwd") aggs = self.dataStore.get_aggregation(res, "tag_kwd")
if not aggs: if not aggs:
return {} return {}
cnt = np.sum([c for _, c in aggs]) cnt = np.sum([c for _, c in aggs])
@ -552,7 +552,7 @@ class Dealer:
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms, es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
kb_ids) kb_ids)
toc = [] toc = []
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"]) dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
for _, doc in dict_chunks.items(): for _, doc in dict_chunks.items():
try: try:
toc.extend(json.loads(doc["content_with_weight"])) toc.extend(json.loads(doc["content_with_weight"]))

View File

@ -113,20 +113,20 @@ class Dealer:
res.append(tk) res.append(tk)
return res return res
def tokenMerge(self, tks): def token_merge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0 res, i = [], 0
while i < len(tks): while i < len(tks):
j = i j = i
if i == 0 and oneTerm(tks[i]) and len( if i == 0 and one_term(tks[i]) and len(
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位 tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2])) res.append(" ".join(tks[0:2]))
i = 2 i = 2
continue continue
while j < len( while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]): tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]):
j += 1 j += 1
if j - i > 1: if j - i > 1:
if j - i < 5: if j - i < 5:
@ -232,7 +232,7 @@ class Dealer:
tw = list(zip(tks, wts)) tw = list(zip(tks, wts))
else: else:
for tk in tks: for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True)) tt = self.token_merge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt]) idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt]) idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \ wts = (0.3 * idf1 + 0.7 * idf2) * \

View File

@ -28,7 +28,7 @@ def collect():
logging.debug(doc_locations) logging.debug(doc_locations)
if len(doc_locations) == 0: if len(doc_locations) == 0:
time.sleep(1) time.sleep(1)
return return None
return doc_locations return doc_locations

View File

@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback):
task_canceled = has_canceled(task["id"]) task_canceled = has_canceled(task["id"])
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return None
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
else: else:
@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback):
d["page_num_int"] = [100000000] d["page_num_int"] = [100000000]
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
return d return d
return None
def init_kb(row, vector_size: int): def init_kb(row, vector_size: int):
@ -719,7 +720,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
task_canceled = has_canceled(task_id) task_canceled = has_canceled(task_id)
if task_canceled: if task_canceled:
progress_callback(-1, msg="Task has been canceled.") progress_callback(-1, msg="Task has been canceled.")
return return False
if b % 128 == 0: if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
if doc_store_result: if doc_store_result:
@ -737,7 +738,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
for chunk_id in chunk_ids: for chunk_id in chunk_ids:
nursery.start_soon(delete_image, task_dataset_id, chunk_id) nursery.start_soon(delete_image, task_dataset_id, chunk_id)
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
return return False
return True return True

View File

@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"Fail put {bucket}/{fnm}") logging.exception(f"Fail put {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return None
return None
def rm(self, bucket, fnm): def rm(self, bucket, fnm):
try: try:
@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}") logging.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return None
def obj_exist(self, bucket, fnm): def obj_exist(self, bucket, fnm):
try: try:
@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob:
logging.exception(f"fail get {bucket}/{fnm}") logging.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return None

View File

@ -241,23 +241,23 @@ class DocStoreConnection(ABC):
""" """
@abstractmethod @abstractmethod
def getTotal(self, res): def get_total(self, res):
raise NotImplementedError("Not implemented") raise NotImplementedError("Not implemented")
@abstractmethod @abstractmethod
def getChunkIds(self, res): def get_chunk_ids(self, res):
raise NotImplementedError("Not implemented") raise NotImplementedError("Not implemented")
@abstractmethod @abstractmethod
def getFields(self, res, fields: list[str]) -> dict[str, dict]: def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented") raise NotImplementedError("Not implemented")
@abstractmethod @abstractmethod
def getHighlight(self, res, keywords: list[str], fieldnm: str): def get_highlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented") raise NotImplementedError("Not implemented")
@abstractmethod @abstractmethod
def getAggregation(self, res, fieldnm: str): def get_aggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented") raise NotImplementedError("Not implemented")
""" """

View File

@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection):
Helper functions for search result Helper functions for search result
""" """
def getTotal(self, res): def get_total(self, res):
if isinstance(res["hits"]["total"], type({})): if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"] return res["hits"]["total"]["value"]
return res["hits"]["total"] return res["hits"]["total"]
def getChunkIds(self, res): def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]] return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res): def __getSource(self, res):
@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection):
rr.append(d["_source"]) rr.append(d["_source"])
return rr return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]: def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {} res_fields = {}
if not fields: if not fields:
return {} return {}
@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection):
res_fields[d["id"]] = m res_fields[d["id"]] = m
return res_fields return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str): def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {} ans = {}
for d in res["hits"]["hits"]: for d in res["hits"]["hits"]:
hlts = d.get("highlight") hlts = d.get("highlight")
@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection):
return ans return ans
def getAggregation(self, res, fieldnm: str): def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]: if "aggregations" not in res or agg_field not in res["aggregations"]:
return list() return list()

View File

@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection):
df_list.append(kb_res) df_list.append(kb_res)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = concat_dataframes(df_list, ["id"]) res = concat_dataframes(df_list, ["id"])
res_fields = self.getFields(res, res.columns.tolist()) res_fields = self.get_fields(res, res.columns.tolist())
return res_fields.get(chunkId, None) return res_fields.get(chunkId, None)
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection):
col_to_remove = list(removeValue.keys()) col_to_remove = list(removeValue.keys())
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df() row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}") logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
row_to_opt = self.getFields(row_to_opt, col_to_remove) row_to_opt = self.get_fields(row_to_opt, col_to_remove)
for id, old_v in row_to_opt.items(): for id, old_v in row_to_opt.items():
for k, remove_v in removeValue.items(): for k, remove_v in removeValue.items():
if remove_v in old_v[k]: if remove_v in old_v[k]:
@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection):
Helper functions for search result Helper functions for search result
""" """
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
if isinstance(res, tuple): if isinstance(res, tuple):
return res[1] return res[1]
return len(res) return len(res)
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
if isinstance(res, tuple): if isinstance(res, tuple):
res = res[0] res = res[0]
return list(res["id"]) return list(res["id"])
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
if isinstance(res, tuple): if isinstance(res, tuple):
res = res[0] res = res[0]
if not fields: if not fields:
@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection):
return res2.set_index("id").to_dict(orient="index") return res2.set_index("id").to_dict(orient="index")
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str): def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
if isinstance(res, tuple): if isinstance(res, tuple):
res = res[0] res = res[0]
ans = {} ans = {}
@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection):
ans[id] = txt ans[id] = txt
return ans return ans
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str): def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
""" """
Manual aggregation for tag fields since Infinity doesn't provide native aggregation Manual aggregation for tag fields since Infinity doesn't provide native aggregation
""" """

View File

@ -92,7 +92,7 @@ class RAGFlowMinio:
logging.exception(f"Fail to get {bucket}/{filename}") logging.exception(f"Fail to get {bucket}/{filename}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return None
def obj_exist(self, bucket, filename, tenant_id=None): def obj_exist(self, bucket, filename, tenant_id=None):
try: try:
@ -130,7 +130,7 @@ class RAGFlowMinio:
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:") logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return None
def remove_bucket(self, bucket): def remove_bucket(self, bucket):
try: try:

View File

@ -62,8 +62,7 @@ class OpenDALStorage:
def health(self): def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
r = self._operator.write(f"{bucket}/{fnm}", binary) return self._operator.write(f"{bucket}/{fnm}", binary)
return r
def put(self, bucket, fnm, binary, tenant_id=None): def put(self, bucket, fnm, binary, tenant_id=None):
self._operator.write(f"{bucket}/{fnm}", binary) self._operator.write(f"{bucket}/{fnm}", binary)

View File

@ -455,12 +455,12 @@ class OSConnection(DocStoreConnection):
Helper functions for search result Helper functions for search result
""" """
def getTotal(self, res): def get_total(self, res):
if isinstance(res["hits"]["total"], type({})): if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"] return res["hits"]["total"]["value"]
return res["hits"]["total"] return res["hits"]["total"]
def getChunkIds(self, res): def get_chunk_ids(self, res):
return [d["_id"] for d in res["hits"]["hits"]] return [d["_id"] for d in res["hits"]["hits"]]
def __getSource(self, res): def __getSource(self, res):
@ -471,7 +471,7 @@ class OSConnection(DocStoreConnection):
rr.append(d["_source"]) rr.append(d["_source"])
return rr return rr
def getFields(self, res, fields: list[str]) -> dict[str, dict]: def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
res_fields = {} res_fields = {}
if not fields: if not fields:
return {} return {}
@ -490,7 +490,7 @@ class OSConnection(DocStoreConnection):
res_fields[d["id"]] = m res_fields[d["id"]] = m
return res_fields return res_fields
def getHighlight(self, res, keywords: list[str], fieldnm: str): def get_highlight(self, res, keywords: list[str], fieldnm: str):
ans = {} ans = {}
for d in res["hits"]["hits"]: for d in res["hits"]["hits"]:
hlts = d.get("highlight") hlts = d.get("highlight")
@ -515,7 +515,7 @@ class OSConnection(DocStoreConnection):
return ans return ans
def getAggregation(self, res, fieldnm: str): def get_aggregation(self, res, fieldnm: str):
agg_field = "aggs_" + fieldnm agg_field = "aggs_" + fieldnm
if "aggregations" not in res or agg_field not in res["aggregations"]: if "aggregations" not in res or agg_field not in res["aggregations"]:
return list() return list()

View File

@ -141,7 +141,7 @@ class RAGFlowOSS:
logging.exception(f"fail get {bucket}/{fnm}") logging.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return None
@use_prefix_path @use_prefix_path
@use_default_bucket @use_default_bucket
@ -170,5 +170,5 @@ class RAGFlowOSS:
logging.exception(f"fail get url {bucket}/{fnm}") logging.exception(f"fail get url {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return None

View File

@ -104,6 +104,7 @@ class RedisDB:
if self.REDIS.get(a) == b: if self.REDIS.get(a) == b:
return True return True
return False
def info(self): def info(self):
info = self.REDIS.info() info = self.REDIS.info()
@ -124,7 +125,7 @@ class RedisDB:
def exist(self, k): def exist(self, k):
if not self.REDIS: if not self.REDIS:
return return None
try: try:
return self.REDIS.exists(k) return self.REDIS.exists(k)
except Exception as e: except Exception as e:
@ -133,7 +134,7 @@ class RedisDB:
def get(self, k): def get(self, k):
if not self.REDIS: if not self.REDIS:
return return None
try: try:
return self.REDIS.get(k) return self.REDIS.get(k)
except Exception as e: except Exception as e:

View File

@ -164,7 +164,7 @@ class RAGFlowS3:
logging.exception(f"fail get {bucket}/{fnm}") logging.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return None
@use_prefix_path @use_prefix_path
@use_default_bucket @use_default_bucket
@ -193,7 +193,7 @@ class RAGFlowS3:
logging.exception(f"fail get url {bucket}/{fnm}") logging.exception(f"fail get url {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return None
@use_default_bucket @use_default_bucket
def rm_bucket(self, bucket, *args, **kwargs): def rm_bucket(self, bucket, *args, **kwargs):