diff --git a/rag/app/laws.py b/rag/app/laws.py index bbc99a925..947f913be 100644 --- a/rag/app/laws.py +++ b/rag/app/laws.py @@ -103,7 +103,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca if not l:break txt += l sections = txt.split("\n") - sections = txt.split("\n") sections = [l for l in sections if l] callback(0.8, "Finish parsing.") else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index b3738f7f6..00a41b21c 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -1,13 +1,14 @@ +import random +from rag.utils import num_tokens_from_string +from . import huqie +from nltk import word_tokenize +import re import copy from nltk.stem import PorterStemmer + stemmer = PorterStemmer() -import re -from nltk import word_tokenize -from . import huqie -from rag.utils import num_tokens_from_string -import random BULLET_PATTERN = [[ r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", @@ -54,7 +55,8 @@ def bullets_category(sections): maxium = 0 res = -1 for i, h in enumerate(hits): - if h <= maxium: continue + if h <= maxium: + continue res = i maxium = h return res @@ -74,7 +76,8 @@ def tokenize(d, t, eng): d["content_with_weight"] = t if eng: t = re.sub(r"([a-z])-([a-z])", r"\1\2", t) - d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(t)]) + d["content_ltks"] = " ".join([stemmer.stem(w) + for w in word_tokenize(t)]) else: d["content_ltks"] = huqie.qie(t) d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) @@ -84,7 +87,8 @@ def tokenize_table(tbls, doc, eng, batch_size=10): res = [] # add tables for (img, rows), poss in tbls: - if not rows:continue + if not rows: + continue if isinstance(rows, str): d = copy.deepcopy(doc) r = re.sub(r"<[^<>]{,12}>", "", rows) @@ -106,14 +110,15 @@ def tokenize_table(tbls, doc, eng, batch_size=10): def add_positions(d, poss): - if not poss:return + if not poss: + return d["page_num_int"] = [] d["position_int"] = [] d["top_int"] = [] for pn, left, right, top, bottom in poss: - d["page_num_int"].append(pn+1) + d["page_num_int"].append(pn + 1) d["top_int"].append(top) - d["position_int"].append((pn+1, left, right, top, bottom)) + d["position_int"].append((pn + 1, left, right, top, bottom)) d["top_int"] = d["top_int"][:1] @@ -122,31 +127,38 @@ def remove_contents_table(sections, eng=False): while i < len(sections): def get(i): nonlocal sections - return (sections[i] if type(sections[i]) == type("") else sections[i][0]).strip() + return (sections[i] if isinstance(sections[i], + type("")) else sections[i][0]).strip() if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)): i += 1 continue sections.pop(i) - if i >= len(sections): break + if i >= len(sections): + break prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) while not prefix: sections.pop(i) - if i >= len(sections): break + if i >= len(sections): + break prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2]) sections.pop(i) - if i >= len(sections) or not prefix: break + if i >= len(sections) or not prefix: + break for j in range(i, min(i + 128, len(sections))): if not re.match(prefix, get(j)): continue - for _ in range(i, j): sections.pop(i) + for _ in range(i, j): + sections.pop(i) break def make_colon_as_title(sections): - if not sections: return [] - if type(sections[0]) == type(""): return sections + if not sections: + return [] + if isinstance(sections[0], type("")): + return sections i = 0 while i < len(sections): txt, layout = sections[i] @@ -165,20 +177,25 @@ def make_colon_as_title(sections): def hierarchical_merge(bull, sections, depth): - if not sections or bull < 0: return [] - if type(sections[0]) == type(""): sections = [(s, "") for s in sections] - sections = [(t,o) for t, o in sections if t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] + if not sections or bull < 0: + return [] + if isinstance(sections[0], type("")): + sections = [(s, "") for s in sections] + sections = [(t, o) for t, o in sections if + t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] bullets_size = len(BULLET_PATTERN[bull]) levels = [[] for _ in range(bullets_size + 2)] def not_title(txt): - if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt): return False - if len(txt.split(" "))>12 or (txt.find(" ")<0 and len(txt)) >= 32: return True + if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt): + return False + if len(txt.split(" ")) > 12 or (txt.find(" ") < 0 and len(txt) >= 32): + return True return re.search(r"[,;,。;!!]", txt) for i, (txt, layout) in enumerate(sections): for j, p in enumerate(BULLET_PATTERN[bull]): - if re.match(p, txt.strip()) and not not_title(txt): + if re.match(p, txt.strip()): levels[j].append(i) break else: @@ -187,12 +204,16 @@ def hierarchical_merge(bull, sections, depth): else: levels[bullets_size + 1].append(i) sections = [t for t, _ in sections] - #for s in sections: print("--", s) + + # for s in sections: print("--", s) def binary_search(arr, target): - if not arr: return -1 - if target > arr[-1]: return len(arr) - 1 - if target < arr[0]: return -1 + if not arr: + return -1 + if target > arr[-1]: + return len(arr) - 1 + if target < arr[0]: + return -1 s, e = 0, len(arr) while e - s > 1: i = (e + s) // 2 @@ -211,18 +232,24 @@ def hierarchical_merge(bull, sections, depth): levels = levels[::-1] for i, arr in enumerate(levels[:depth]): for j in arr: - if readed[j]: continue + if readed[j]: + continue readed[j] = True cks.append([j]) - if i + 1 == len(levels) - 1: continue + if i + 1 == len(levels) - 1: + continue for ii in range(i + 1, len(levels)): jj = binary_search(levels[ii], j) - if jj < 0: continue - if jj > cks[-1][-1]: cks[-1].pop(-1) + if jj < 0: + continue + if jj > cks[-1][-1]: + cks[-1].pop(-1) cks[-1].append(levels[ii][jj]) - for ii in cks[-1]: readed[ii] = True + for ii in cks[-1]: + readed[ii] = True - if not cks:return cks + if not cks: + return cks for i in range(len(cks)): cks[i] = [sections[j] for j in cks[i][::-1]] @@ -247,20 +274,26 @@ def hierarchical_merge(bull, sections, depth): def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"): - if not sections: return [] - if type(sections[0]) == type(""): sections = [(s, "") for s in sections] + if not sections: + return [] + if isinstance(sections[0], type("")): + sections = [(s, "") for s in sections] cks = [""] tk_nums = [0] + def add_chunk(t, pos): nonlocal cks, tk_nums, delimiter tnum = num_tokens_from_string(t) - if tnum < 8: pos = "" + if tnum < 8: + pos = "" if tk_nums[-1] > chunk_token_num: - if t.find(pos) < 0: t += pos + if t.find(pos) < 0: + t += pos cks.append(t) tk_nums.append(tnum) else: - if cks[-1].find(pos) < 0: t += pos + if cks[-1].find(pos) < 0: + t += pos cks[-1] += t tk_nums[-1] += tnum @@ -270,12 +303,12 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"): s, e = 0, 1 while e < len(sec): if sec[e] in delimiter: - add_chunk(sec[s: e+1], pos) + add_chunk(sec[s: e + 1], pos) s = e + 1 e = s + 1 else: e += 1 - if s < e: add_chunk(sec[s: e], pos) + if s < e: + add_chunk(sec[s: e], pos) return cks - diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 04d4588b3..39ab8e038 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -82,8 +82,8 @@ class Dealer: ) else: s = s.sort( - {"page_num_int": {"order": "asc", "unmapped_type": "float"}}, - {"top_int": {"order": "asc", "unmapped_type": "float", "mode" : "avg"}}, + {"page_num_int": {"order": "asc", "unmapped_type": "float", "mode" : "avg"}}, + {"top_int": {"order": "asc", "unmapped_type": "float", "mode": "avg"}}, {"create_time": {"order": "desc", "unmapped_type": "date"}}, {"create_timestamp_flt": {"order": "desc", "unmapped_type": "float"}} )