From 296476ab89a9d9141a82347d61a11806cfaab5ed Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Wed, 12 Nov 2025 19:00:15 +0800 Subject: [PATCH] Refactor function name (#11210) ### What problem does this PR solve? As title ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai --- api/apps/kb_app.py | 4 +-- api/db/services/document_service.py | 4 +-- graphrag/search.py | 6 ++--- graphrag/utils.py | 6 ++--- rag/nlp/query.py | 28 ++++++++++---------- rag/nlp/rag_tokenizer.py | 41 ++++++++++++++--------------- rag/nlp/search.py | 30 ++++++++++----------- rag/nlp/term_weight.py | 10 +++---- rag/svr/cache_file_svr.py | 2 +- rag/svr/task_executor.py | 7 ++--- rag/utils/azure_spn_conn.py | 6 +++-- rag/utils/doc_store_conn.py | 10 +++---- rag/utils/es_conn.py | 10 +++---- rag/utils/infinity_conn.py | 14 +++++----- rag/utils/minio_conn.py | 4 +-- rag/utils/opendal_conn.py | 3 +-- rag/utils/opensearch_conn.py | 10 +++---- rag/utils/oss_conn.py | 4 +-- rag/utils/redis_conn.py | 5 ++-- rag/utils/s3_conn.py | 4 +-- 20 files changed, 105 insertions(+), 103 deletions(-) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 7094c28d7..4546b2586 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -807,7 +807,7 @@ def check_embedding(): offset=0, limit=1, indexNames=index_nm, knowledgebaseIds=[kb_id] ) - total = docStoreConn.getTotal(res0) + total = docStoreConn.get_total(res0) if total <= 0: return [] @@ -824,7 +824,7 @@ def check_embedding(): offset=off, limit=1, indexNames=index_nm, knowledgebaseIds=[kb_id] ) - ids = docStoreConn.getChunkIds(res1) + ids = docStoreConn.get_chunk_ids(res1) if not ids: continue diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index a64ae16de..530133164 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -309,7 +309,7 @@ class DocumentService(CommonService): chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), page * page_size, page_size, search.index_name(tenant_id), [doc.kb_id]) - chunk_ids = settings.docStoreConn.getChunkIds(chunks) + chunk_ids = settings.docStoreConn.get_chunk_ids(chunks) if not chunk_ids: break all_chunk_ids.extend(chunk_ids) @@ -322,7 +322,7 @@ class DocumentService(CommonService): settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) - graph_source = settings.docStoreConn.getFields( + graph_source = settings.docStoreConn.get_fields( settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] ) if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: diff --git a/graphrag/search.py b/graphrag/search.py index 860f14bcb..b3a0104e1 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -69,7 +69,7 @@ class KGSearch(Dealer): def _ent_info_from_(self, es_res, sim_thr=0.3): res = {} flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"] - es_res = self.dataStore.getFields(es_res, flds) + es_res = self.dataStore.get_fields(es_res, flds) for _, ent in es_res.items(): for f in flds: if f in ent and ent[f] is None: @@ -88,7 +88,7 @@ class KGSearch(Dealer): def _relation_info_from_(self, es_res, sim_thr=0.3): res = {} - es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", + es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"]) for _, ent in es_res.items(): if get_float(ent["_score"]) < sim_thr: @@ -300,7 +300,7 @@ class KGSearch(Dealer): fltr["entities_kwd"] = entities comm_res = self.dataStore.search(fields, [], fltr, [], OrderByExpr(), 0, topn, idxnms, kb_ids) - comm_res_fields = self.dataStore.getFields(comm_res, fields) + comm_res_fields = self.dataStore.get_fields(comm_res, fields) txts = [] for ii, (_, row) in enumerate(comm_res_fields.items()): obj = json.loads(row["content_with_weight"]) diff --git a/graphrag/utils.py b/graphrag/utils.py index 6a8df1e40..51a9c1abc 100644 --- a/graphrag/utils.py +++ b/graphrag/utils.py @@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): "removed_kwd": "N", } res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) - fields2 = settings.docStoreConn.getFields(res, fields) + fields2 = settings.docStoreConn.get_fields(res, fields) graph_doc_ids = set() for chunk_id in fields2.keys(): graph_doc_ids = set(fields2[chunk_id]["source_id"]) @@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): es_res = await trio.to_thread.run_sync( lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]) ) - # tot = settings.docStoreConn.getTotal(es_res) - es_res = settings.docStoreConn.getFields(es_res, flds) + # tot = settings.docStoreConn.get_total(es_res) + es_res = settings.docStoreConn.get_fields(es_res, flds) if len(es_res) == 0: break diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 68d2d2979..ec3628525 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -38,11 +38,11 @@ class FulltextQueryer: ] @staticmethod - def subSpecialChar(line): + def sub_special_char(line): return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip() @staticmethod - def isChinese(line): + def is_chinese(line): arr = re.split(r"[ \t]+", line) if len(arr) <= 3: return True @@ -92,7 +92,7 @@ class FulltextQueryer: otxt = txt txt = FulltextQueryer.rmWWW(txt) - if not self.isChinese(txt): + if not self.is_chinese(txt): txt = FulltextQueryer.rmWWW(txt) tks = rag_tokenizer.tokenize(txt).split() keywords = [t for t in tks if t] @@ -163,7 +163,7 @@ class FulltextQueryer: ) for m in sm ] - sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1] + sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1] sm = [m for m in sm if len(m) > 1] if len(keywords) < 32: @@ -171,7 +171,7 @@ class FulltextQueryer: keywords.extend(sm) tk_syns = self.syn.lookup(tk) - tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] + tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns] if len(keywords) < 32: keywords.extend([s for s in tk_syns if s]) tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] @@ -180,7 +180,7 @@ class FulltextQueryer: if len(keywords) >= 32: break - tk = FulltextQueryer.subSpecialChar(tk) + tk = FulltextQueryer.sub_special_char(tk) if tk.find(" ") > 0: tk = '"%s"' % tk if tk_syns: @@ -198,7 +198,7 @@ class FulltextQueryer: syns = " OR ".join( [ '"%s"' - % rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s)) + % rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s)) for s in syns ] ) @@ -217,17 +217,17 @@ class FulltextQueryer: return None, keywords def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7): - from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity + from sklearn.metrics.pairwise import cosine_similarity import numpy as np - sims = CosineSimilarity([avec], bvecs) + sims = cosine_similarity([avec], bvecs) tksim = self.token_similarity(atks, btkss) if np.sum(sims[0]) == 0: return np.array(tksim), tksim, sims[0] return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] def token_similarity(self, atks, btkss): - def toDict(tks): + def to_dict(tks): if isinstance(tks, str): tks = tks.split() d = defaultdict(int) @@ -236,8 +236,8 @@ class FulltextQueryer: d[t] += c return d - atks = toDict(atks) - btkss = [toDict(tks) for tks in btkss] + atks = to_dict(atks) + btkss = [to_dict(tks) for tks in btkss] return [self.similarity(atks, btks) for btks in btkss] def similarity(self, qtwt, dtwt): @@ -262,10 +262,10 @@ class FulltextQueryer: keywords = [f'"{k.strip()}"' for k in keywords] for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]: tk_syns = self.syn.lookup(tk) - tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] + tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns] tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] - tk = FulltextQueryer.subSpecialChar(tk) + tk = FulltextQueryer.sub_special_char(tk) if tk.find(" ") > 0: tk = '"%s"' % tk if tk_syns: diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index 3c4b97833..c95c18e74 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -35,7 +35,7 @@ class RagTokenizer: def rkey_(self, line): return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1] - def loadDict_(self, fnm): + def _load_dict(self, fnm): logging.info(f"[HUQIE]:Build trie from {fnm}") try: of = open(fnm, "r", encoding='utf-8') @@ -85,18 +85,18 @@ class RagTokenizer: self.trie_ = datrie.Trie(string.printable) # load data from dict file and save to trie file - self.loadDict_(self.DIR_ + ".txt") + self._load_dict(self.DIR_ + ".txt") - def loadUserDict(self, fnm): + def load_user_dict(self, fnm): try: self.trie_ = datrie.Trie.load(fnm + ".trie") return except Exception: self.trie_ = datrie.Trie(string.printable) - self.loadDict_(fnm) + self._load_dict(fnm) - def addUserDict(self, fnm): - self.loadDict_(fnm) + def add_user_dict(self, fnm): + self._load_dict(fnm) def _strQ2B(self, ustring): """Convert full-width characters to half-width characters""" @@ -221,7 +221,7 @@ class RagTokenizer: logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F)) return tks, B / len(tks) + L + F - def sortTks_(self, tkslist): + def _sort_tokens(self, tkslist): res = [] for tfts in tkslist: tks, s = self.score_(tfts) @@ -246,7 +246,7 @@ class RagTokenizer: return " ".join(res) - def maxForward_(self, line): + def _max_forward(self, line): res = [] s = 0 while s < len(line): @@ -270,7 +270,7 @@ class RagTokenizer: return self.score_(res) - def maxBackward_(self, line): + def _max_backward(self, line): res = [] s = len(line) - 1 while s >= 0: @@ -336,8 +336,8 @@ class RagTokenizer: continue # use maxforward for the first time - tks, s = self.maxForward_(L) - tks1, s1 = self.maxBackward_(L) + tks, s = self._max_forward(L) + tks1, s1 = self._max_backward(L) if self.DEBUG: logging.debug("[FW] {} {}".format(tks, s)) logging.debug("[BW] {} {}".format(tks1, s1)) @@ -369,7 +369,7 @@ class RagTokenizer: # backward tokens from_i to i are different from forward tokens from _j to j. tkslist = [] self.dfs_("".join(tks[_j:j]), 0, [], tkslist) - res.append(" ".join(self.sortTks_(tkslist)[0][0])) + res.append(" ".join(self._sort_tokens(tkslist)[0][0])) same = 1 while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: @@ -385,7 +385,7 @@ class RagTokenizer: assert "".join(tks1[_i:]) == "".join(tks[_j:]) tkslist = [] self.dfs_("".join(tks[_j:]), 0, [], tkslist) - res.append(" ".join(self.sortTks_(tkslist)[0][0])) + res.append(" ".join(self._sort_tokens(tkslist)[0][0])) res = " ".join(res) logging.debug("[TKS] {}".format(self.merge_(res))) @@ -413,7 +413,7 @@ class RagTokenizer: if len(tkslist) < 2: res.append(tk) continue - stk = self.sortTks_(tkslist)[1][0] + stk = self._sort_tokens(tkslist)[1][0] if len(stk) == len(tk): stk = tk else: @@ -447,14 +447,13 @@ def is_number(s): def is_alphabet(s): - if (s >= u'\u0041' and s <= u'\u005a') or ( - s >= u'\u0061' and s <= u'\u007a'): + if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'): return True else: return False -def naiveQie(txt): +def naive_qie(txt): tks = [] for t in txt.split(): if tks and re.match(r".*[a-zA-Z]$", tks[-1] @@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize fine_grained_tokenize = tokenizer.fine_grained_tokenize tag = tokenizer.tag freq = tokenizer.freq -loadUserDict = tokenizer.loadUserDict -addUserDict = tokenizer.addUserDict +load_user_dict = tokenizer.load_user_dict +add_user_dict = tokenizer.add_user_dict tradi2simp = tokenizer._tradi2simp strQ2B = tokenizer._strQ2B if __name__ == '__main__': tknzr = RagTokenizer(debug=True) - # huqie.addUserDict("/tmp/tmp.new.tks.dict") + # huqie.add_user_dict("/tmp/tmp.new.tks.dict") tks = tknzr.tokenize( "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈") logging.info(tknzr.fine_grained_tokenize(tks)) @@ -506,7 +505,7 @@ if __name__ == '__main__': if len(sys.argv) < 2: sys.exit() tknzr.DEBUG = False - tknzr.loadUserDict(sys.argv[1]) + tknzr.load_user_dict(sys.argv[1]) of = open(sys.argv[2], "r") while True: line = of.readline() diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 1bf0abe04..f8b3d513f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -102,7 +102,7 @@ class Dealer: orderBy.asc("top_int") orderBy.desc("create_timestamp_flt") res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) else: highlightFields = ["content_ltks", "title_tks"] @@ -115,7 +115,7 @@ class Dealer: matchExprs = [matchText] res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) else: matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) @@ -127,20 +127,20 @@ class Dealer: res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) logging.debug("Dealer.search TOTAL: {}".format(total)) # If result is empty, try again with lower min_match if total == 0: if filters.get("doc_id"): res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) else: matchText, _ = self.qryr.question(qst, min_match=0.1) matchDense.extra_options["similarity"] = 0.17 res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature) - total = self.dataStore.getTotal(res) + total = self.dataStore.get_total(res) logging.debug("Dealer.search 2 TOTAL: {}".format(total)) for k in keywords: @@ -153,17 +153,17 @@ class Dealer: kwds.add(kk) logging.debug(f"TOTAL: {total}") - ids = self.dataStore.getChunkIds(res) + ids = self.dataStore.get_chunk_ids(res) keywords = list(kwds) - highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") - aggs = self.dataStore.getAggregation(res, "docnm_kwd") + highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight") + aggs = self.dataStore.get_aggregation(res, "docnm_kwd") return self.SearchResult( total=total, ids=ids, query_vector=q_vec, aggregation=aggs, highlight=highlight, - field=self.dataStore.getFields(res, src + ["_score"]), + field=self.dataStore.get_fields(res, src + ["_score"]), keywords=keywords ) @@ -488,7 +488,7 @@ class Dealer: for p in range(offset, max_count, bs): es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id), kb_ids) - dict_chunks = self.dataStore.getFields(es_res, fields) + dict_chunks = self.dataStore.get_fields(es_res, fields) for id, doc in dict_chunks.items(): doc["id"] = id if dict_chunks: @@ -501,11 +501,11 @@ class Dealer: if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]): return [] res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) - return self.dataStore.getAggregation(res, "tag_kwd") + return self.dataStore.get_aggregation(res, "tag_kwd") def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000): res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"]) - res = self.dataStore.getAggregation(res, "tag_kwd") + res = self.dataStore.get_aggregation(res, "tag_kwd") total = np.sum([c for _, c in res]) return {t: (c + 1) / (total + S) for t, c in res} @@ -513,7 +513,7 @@ class Dealer: idx_nm = index_name(tenant_id) match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"]) - aggs = self.dataStore.getAggregation(res, "tag_kwd") + aggs = self.dataStore.get_aggregation(res, "tag_kwd") if not aggs: return False cnt = np.sum([c for _, c in aggs]) @@ -529,7 +529,7 @@ class Dealer: idx_nms = [index_name(tid) for tid in tenant_ids] match_txt, _ = self.qryr.question(question, min_match=0.0) res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"]) - aggs = self.dataStore.getAggregation(res, "tag_kwd") + aggs = self.dataStore.get_aggregation(res, "tag_kwd") if not aggs: return {} cnt = np.sum([c for _, c in aggs]) @@ -552,7 +552,7 @@ class Dealer: es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms, kb_ids) toc = [] - dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"]) + dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"]) for _, doc in dict_chunks.items(): try: toc.extend(json.loads(doc["content_with_weight"])) diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py index 392117c18..28ed585ee 100644 --- a/rag/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -113,20 +113,20 @@ class Dealer: res.append(tk) return res - def tokenMerge(self, tks): - def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) + def token_merge(self, tks): + def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) res, i = [], 0 while i < len(tks): j = i - if i == 0 and oneTerm(tks[i]) and len( + if i == 0 and one_term(tks[i]) and len( tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位 res.append(" ".join(tks[0:2])) i = 2 continue while j < len( - tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]): + tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]): j += 1 if j - i > 1: if j - i < 5: @@ -232,7 +232,7 @@ class Dealer: tw = list(zip(tks, wts)) else: for tk in tks: - tt = self.tokenMerge(self.pretoken(tk, True)) + tt = self.token_merge(self.pretoken(tk, True)) idf1 = np.array([idf(freq(t), 10000000) for t in tt]) idf2 = np.array([idf(df(t), 1000000000) for t in tt]) wts = (0.3 * idf1 + 0.7 * idf2) * \ diff --git a/rag/svr/cache_file_svr.py b/rag/svr/cache_file_svr.py index 89ab8b75f..3744c04ea 100644 --- a/rag/svr/cache_file_svr.py +++ b/rag/svr/cache_file_svr.py @@ -28,7 +28,7 @@ def collect(): logging.debug(doc_locations) if len(doc_locations) == 0: time.sleep(1) - return + return None return doc_locations diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index af8dfc186..d926415e5 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback): task_canceled = has_canceled(task["id"]) if task_canceled: progress_callback(-1, msg="Task has been canceled.") - return + return None if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0: examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]}) else: @@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback): d["page_num_int"] = [100000000] d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest() return d + return None def init_kb(row, vector_size: int): @@ -719,7 +720,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c task_canceled = has_canceled(task_id) if task_canceled: progress_callback(-1, msg="Task has been canceled.") - return + return False if b % 128 == 0: progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") if doc_store_result: @@ -737,7 +738,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c for chunk_id in chunk_ids: nursery.start_soon(delete_image, task_dataset_id, chunk_id) progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.") - return + return False return True diff --git a/rag/utils/azure_spn_conn.py b/rag/utils/azure_spn_conn.py index f47470d67..005d3ba6b 100644 --- a/rag/utils/azure_spn_conn.py +++ b/rag/utils/azure_spn_conn.py @@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob: logging.exception(f"Fail put {bucket}/{fnm}") self.__open__() time.sleep(1) + return None + return None def rm(self, bucket, fnm): try: @@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob: logging.exception(f"fail get {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None def obj_exist(self, bucket, fnm): try: @@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob: logging.exception(f"fail get {bucket}/{fnm}") self.__open__() time.sleep(1) - return \ No newline at end of file + return None \ No newline at end of file diff --git a/rag/utils/doc_store_conn.py b/rag/utils/doc_store_conn.py index c3fa61b0c..33f030011 100644 --- a/rag/utils/doc_store_conn.py +++ b/rag/utils/doc_store_conn.py @@ -241,23 +241,23 @@ class DocStoreConnection(ABC): """ @abstractmethod - def getTotal(self, res): + def get_total(self, res): raise NotImplementedError("Not implemented") @abstractmethod - def getChunkIds(self, res): + def get_chunk_ids(self, res): raise NotImplementedError("Not implemented") @abstractmethod - def getFields(self, res, fields: list[str]) -> dict[str, dict]: + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: raise NotImplementedError("Not implemented") @abstractmethod - def getHighlight(self, res, keywords: list[str], fieldnm: str): + def get_highlight(self, res, keywords: list[str], fieldnm: str): raise NotImplementedError("Not implemented") @abstractmethod - def getAggregation(self, res, fieldnm: str): + def get_aggregation(self, res, fieldnm: str): raise NotImplementedError("Not implemented") """ diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index e99ee1375..5971950cf 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection): Helper functions for search result """ - def getTotal(self, res): + def get_total(self, res): if isinstance(res["hits"]["total"], type({})): return res["hits"]["total"]["value"] return res["hits"]["total"] - def getChunkIds(self, res): + def get_chunk_ids(self, res): return [d["_id"] for d in res["hits"]["hits"]] def __getSource(self, res): @@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection): rr.append(d["_source"]) return rr - def getFields(self, res, fields: list[str]) -> dict[str, dict]: + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} @@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection): res_fields[d["id"]] = m return res_fields - def getHighlight(self, res, keywords: list[str], fieldnm: str): + def get_highlight(self, res, keywords: list[str], fieldnm: str): ans = {} for d in res["hits"]["hits"]: hlts = d.get("highlight") @@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection): return ans - def getAggregation(self, res, fieldnm: str): + def get_aggregation(self, res, fieldnm: str): agg_field = "aggs_" + fieldnm if "aggregations" not in res or agg_field not in res["aggregations"]: return list() diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 03251e72c..ab575f9bc 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection): df_list.append(kb_res) self.connPool.release_conn(inf_conn) res = concat_dataframes(df_list, ["id"]) - res_fields = self.getFields(res, res.columns.tolist()) + res_fields = self.get_fields(res, res.columns.tolist()) return res_fields.get(chunkId, None) def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: @@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection): col_to_remove = list(removeValue.keys()) row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df() logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}") - row_to_opt = self.getFields(row_to_opt, col_to_remove) + row_to_opt = self.get_fields(row_to_opt, col_to_remove) for id, old_v in row_to_opt.items(): for k, remove_v in removeValue.items(): if remove_v in old_v[k]: @@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection): Helper functions for search result """ - def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: + def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: if isinstance(res, tuple): return res[1] return len(res) - def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: + def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: if isinstance(res, tuple): res = res[0] return list(res["id"]) - def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: + def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: if isinstance(res, tuple): res = res[0] if not fields: @@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection): return res2.set_index("id").to_dict(orient="index") - def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str): + def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str): if isinstance(res, tuple): res = res[0] ans = {} @@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection): ans[id] = txt return ans - def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str): + def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str): """ Manual aggregation for tag fields since Infinity doesn't provide native aggregation """ diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index 75cd2725b..e0913e98b 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -92,7 +92,7 @@ class RAGFlowMinio: logging.exception(f"Fail to get {bucket}/{filename}") self.__open__() time.sleep(1) - return + return None def obj_exist(self, bucket, filename, tenant_id=None): try: @@ -130,7 +130,7 @@ class RAGFlowMinio: logging.exception(f"Fail to get_presigned {bucket}/{fnm}:") self.__open__() time.sleep(1) - return + return None def remove_bucket(self, bucket): try: diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index 54650b54b..c6cebf9ca 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -62,8 +62,7 @@ class OpenDALStorage: def health(self): bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1" - r = self._operator.write(f"{bucket}/{fnm}", binary) - return r + return self._operator.write(f"{bucket}/{fnm}", binary) def put(self, bucket, fnm, binary, tenant_id=None): self._operator.write(f"{bucket}/{fnm}", binary) diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py index c862b52e9..2df1d65ee 100644 --- a/rag/utils/opensearch_conn.py +++ b/rag/utils/opensearch_conn.py @@ -455,12 +455,12 @@ class OSConnection(DocStoreConnection): Helper functions for search result """ - def getTotal(self, res): + def get_total(self, res): if isinstance(res["hits"]["total"], type({})): return res["hits"]["total"]["value"] return res["hits"]["total"] - def getChunkIds(self, res): + def get_chunk_ids(self, res): return [d["_id"] for d in res["hits"]["hits"]] def __getSource(self, res): @@ -471,7 +471,7 @@ class OSConnection(DocStoreConnection): rr.append(d["_source"]) return rr - def getFields(self, res, fields: list[str]) -> dict[str, dict]: + def get_fields(self, res, fields: list[str]) -> dict[str, dict]: res_fields = {} if not fields: return {} @@ -490,7 +490,7 @@ class OSConnection(DocStoreConnection): res_fields[d["id"]] = m return res_fields - def getHighlight(self, res, keywords: list[str], fieldnm: str): + def get_highlight(self, res, keywords: list[str], fieldnm: str): ans = {} for d in res["hits"]["hits"]: hlts = d.get("highlight") @@ -515,7 +515,7 @@ class OSConnection(DocStoreConnection): return ans - def getAggregation(self, res, fieldnm: str): + def get_aggregation(self, res, fieldnm: str): agg_field = "aggs_" + fieldnm if "aggregations" not in res or agg_field not in res["aggregations"]: return list() diff --git a/rag/utils/oss_conn.py b/rag/utils/oss_conn.py index 20cea0b94..b0114f668 100644 --- a/rag/utils/oss_conn.py +++ b/rag/utils/oss_conn.py @@ -141,7 +141,7 @@ class RAGFlowOSS: logging.exception(f"fail get {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None @use_prefix_path @use_default_bucket @@ -170,5 +170,5 @@ class RAGFlowOSS: logging.exception(f"fail get url {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 3c6565230..58b0fe15b 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -104,6 +104,7 @@ class RedisDB: if self.REDIS.get(a) == b: return True + return False def info(self): info = self.REDIS.info() @@ -124,7 +125,7 @@ class RedisDB: def exist(self, k): if not self.REDIS: - return + return None try: return self.REDIS.exists(k) except Exception as e: @@ -133,7 +134,7 @@ class RedisDB: def get(self, k): if not self.REDIS: - return + return None try: return self.REDIS.get(k) except Exception as e: diff --git a/rag/utils/s3_conn.py b/rag/utils/s3_conn.py index 9006fa586..11ac65cee 100644 --- a/rag/utils/s3_conn.py +++ b/rag/utils/s3_conn.py @@ -164,7 +164,7 @@ class RAGFlowS3: logging.exception(f"fail get {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None @use_prefix_path @use_default_bucket @@ -193,7 +193,7 @@ class RAGFlowS3: logging.exception(f"fail get url {bucket}/{fnm}") self.__open__() time.sleep(1) - return + return None @use_default_bucket def rm_bucket(self, bucket, *args, **kwargs):