build python version rag-flow (#21)

* clean rust version project

* clean rust version project

* build python version rag-flow
This commit is contained in:
KevinHuSh
2024-01-15 08:46:22 +08:00
committed by GitHub
parent db8cae3f1e
commit 30791976d5
123 changed files with 4985 additions and 4239 deletions

0
rag/nlp/__init__.py Normal file
View File

435
rag/nlp/huchunk.py Normal file
View File

@ -0,0 +1,435 @@
import re
import os
import copy
import base64
import magic
from dataclasses import dataclass
from typing import List
import numpy as np
from io import BytesIO
class HuChunker:
def __init__(self):
self.MAX_LVL = 12
self.proj_patt = [
(r"第[零一二三四五六七八九十百]+章", 1),
(r"第[零一二三四五六七八九十百]+[条节]", 2),
(r"[零一二三四五六七八九十百]+[、  ]", 3),
(r"[\(][零一二三四五六七八九十百]+[\)]", 4),
(r"[0-9]+(、|\.[  ]|\.[^0-9])", 5),
(r"[0-9]+\.[0-9]+(、|[  ]|[^0-9])", 6),
(r"[0-9]+\.[0-9]+\.[0-9]+(、|[  ]|[^0-9])", 7),
(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+(、|[  ]|[^0-9])", 8),
(r".{,48}[:?]@", 9),
(r"[0-9]+", 10),
(r"[\(][0-9]+[\)]", 11),
(r"[零一二三四五六七八九十百]+是", 12),
(r"[⚫•➢✓ ]", 12)
]
self.lines = []
def _garbage(self, txt):
patt = [
r"(在此保证|不得以任何形式翻版|请勿传阅|仅供内部使用|未经事先书面授权)",
r"(版权(归本公司)*所有|免责声明|保留一切权力|承担全部责任|特别声明|报告中涉及)",
r"(不承担任何责任|投资者的通知事项:|任何机构和个人|本报告仅为|不构成投资)",
r"(不构成对任何个人或机构投资建议|联系其所在国家|本报告由从事证券交易)",
r"(本研究报告由|「认可投资者」|所有研究报告均以|请发邮件至)",
r"(本报告仅供|市场有风险,投资需谨慎|本报告中提及的)",
r"(本报告反映|此信息仅供|证券分析师承诺|具备证券投资咨询业务资格)",
r"^(时间|签字|签章)[:]",
r"(参考文献|目录索引|图表索引)",
r"[ ]*年[ ]+月[ ]+日",
r"^(中国证券业协会|[0-9]+年[0-9]+月[0-9]+日)$",
r"\.{10,}",
r"(———————END|帮我转发|欢迎收藏|快来关注我吧)"
]
return any([re.search(p, txt) for p in patt])
def _proj_match(self, line):
for p, j in self.proj_patt:
if re.match(p, line):
return j
return
def _does_proj_match(self):
mat = [None for _ in range(len(self.lines))]
for i in range(len(self.lines)):
mat[i] = self._proj_match(self.lines[i])
return mat
def naive_text_chunk(self, text, ti="", MAX_LEN=612):
if text:
self.lines = [l.strip().replace(u'\u3000', u' ')
.replace(u'\xa0', u'')
for l in text.split("\n\n")]
self.lines = [l for l in self.lines if not self._garbage(l)]
self.lines = [re.sub(r"([ ]+| )", " ", l)
for l in self.lines if l]
if not self.lines:
return []
arr = self.lines
res = [""]
i = 0
while i < len(arr):
a = arr[i]
if not a:
i += 1
continue
if len(a) > MAX_LEN:
a_ = a.split("\n")
if len(a_) >= 2:
arr.pop(i)
for j in range(2, len(a_) + 1):
if len("\n".join(a_[:j])) >= MAX_LEN:
arr.insert(i, "\n".join(a_[:j - 1]))
arr.insert(i + 1, "\n".join(a_[j - 1:]))
break
else:
assert False, f"Can't split: {a}"
continue
if len(res[-1]) < MAX_LEN / 3:
res[-1] += "\n" + a
else:
res.append(a)
i += 1
if ti:
for i in range(len(res)):
if res[i].find("——来自") >= 0:
continue
res[i] += f"\t——来自“{ti}"
return res
def _merge(self):
# merge continuous same level text
lines = [self.lines[0]] if self.lines else []
for i in range(1, len(self.lines)):
if self.mat[i] == self.mat[i - 1] \
and len(lines[-1]) < 256 \
and len(self.lines[i]) < 256:
lines[-1] += "\n" + self.lines[i]
continue
lines.append(self.lines[i])
self.lines = lines
self.mat = self._does_proj_match()
return self.mat
def text_chunks(self, text):
if text:
self.lines = [l.strip().replace(u'\u3000', u' ')
.replace(u'\xa0', u'')
for l in re.split(r"[\r\n]", text)]
self.lines = [l for l in self.lines if not self._garbage(l)]
self.lines = [l for l in self.lines if l]
self.mat = self._does_proj_match()
mat = self._merge()
tree = []
for i in range(len(self.lines)):
tree.append({"proj": mat[i],
"children": [],
"read": False})
# find all children
for i in range(len(self.lines) - 1):
if tree[i]["proj"] is None:
continue
ed = i + 1
while ed < len(tree) and (tree[ed]["proj"] is None or
tree[ed]["proj"] > tree[i]["proj"]):
ed += 1
nxt = tree[i]["proj"] + 1
st = set([p["proj"] for p in tree[i + 1: ed] if p["proj"]])
while nxt not in st:
nxt += 1
if nxt > self.MAX_LVL:
break
if nxt <= self.MAX_LVL:
for j in range(i + 1, ed):
if tree[j]["proj"] is not None:
break
tree[i]["children"].append(j)
for j in range(i + 1, ed):
if tree[j]["proj"] != nxt:
continue
tree[i]["children"].append(j)
else:
for j in range(i + 1, ed):
tree[i]["children"].append(j)
# get DFS combinations, find all the paths to leaf
paths = []
def dfs(i, path):
nonlocal tree, paths
path.append(i)
tree[i]["read"] = True
if len(self.lines[i]) > 256:
paths.append(path)
return
if not tree[i]["children"]:
if len(path) > 1 or len(self.lines[i]) >= 32:
paths.append(path)
return
for j in tree[i]["children"]:
dfs(j, copy.deepcopy(path))
for i, t in enumerate(tree):
if t["read"]:
continue
dfs(i, [])
# concat txt on the path for all paths
res = []
lines = np.array(self.lines)
for p in paths:
if len(p) < 2:
tree[p[0]]["read"] = False
continue
txt = "\n".join(lines[p[:-1]]) + "\n" + lines[p[-1]]
res.append(txt)
# concat continuous orphans
assert len(tree) == len(lines)
ii = 0
while ii < len(tree):
if tree[ii]["read"]:
ii += 1
continue
txt = lines[ii]
e = ii + 1
while e < len(tree) and not tree[e]["read"] and len(txt) < 256:
txt += "\n" + lines[e]
e += 1
res.append(txt)
ii = e
# if the node has not been read, find its daddy
def find_daddy(st):
nonlocal lines, tree
proj = tree[st]["proj"]
if len(self.lines[st]) > 512:
return [st]
if proj is None:
proj = self.MAX_LVL + 1
for i in range(st - 1, -1, -1):
if tree[i]["proj"] and tree[i]["proj"] < proj:
a = [st] + find_daddy(i)
return a
return []
return res
class PdfChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, pdf_parser):
self.pdf = pdf_parser
super().__init__()
def tableHtmls(self, pdfnm):
_, tbls = self.pdf(pdfnm, return_html=True)
res = []
for img, arr in tbls:
if arr[0].find("<table>") < 0:
continue
buffered = BytesIO()
if img:
img.save(buffered, format="JPEG")
img_str = base64.b64encode(
buffered.getvalue()).decode('utf-8') if img else ""
res.append({"table": arr[0], "image": img_str})
return res
def html(self, pdfnm):
txts, tbls = self.pdf(pdfnm, return_html=True)
res = []
txt_cks = self.text_chunks(txts)
for txt, img in [(self.pdf.remove_tag(c), self.pdf.crop(c))
for c in txt_cks]:
buffered = BytesIO()
if img:
img.save(buffered, format="JPEG")
img_str = base64.b64encode(
buffered.getvalue()).decode('utf-8') if img else ""
res.append({"table": "<p>%s</p>" % txt.replace("\n", "<br/>"),
"image": img_str})
for img, arr in tbls:
if not arr:
continue
buffered = BytesIO()
if img:
img.save(buffered, format="JPEG")
img_str = base64.b64encode(
buffered.getvalue()).decode('utf-8') if img else ""
res.append({"table": arr[0], "image": img_str})
return res
def __call__(self, pdfnm, return_image=True, naive_chunk=False):
flds = self.Fields()
text, tbls = self.pdf(pdfnm)
fnm = pdfnm
txt_cks = self.text_chunks(text) if not naive_chunk else \
self.naive_text_chunk(text, ti=fnm if isinstance(fnm, str) else "")
flds.text_chunks = [(self.pdf.remove_tag(c),
self.pdf.crop(c) if return_image else None) for c in txt_cks]
flds.table_chunks = [(arr, img if return_image else None)
for img, arr in tbls]
return flds
class DocxChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, doc_parser):
self.doc = doc_parser
super().__init__()
def _does_proj_match(self):
mat = []
for s in self.styles:
s = s.split(" ")[-1]
try:
mat.append(int(s))
except Exception as e:
mat.append(None)
return mat
def _merge(self):
i = 1
while i < len(self.lines):
if self.mat[i] == self.mat[i - 1] \
and len(self.lines[i - 1]) < 256 \
and len(self.lines[i]) < 256:
self.lines[i - 1] += "\n" + self.lines[i]
self.styles.pop(i)
self.lines.pop(i)
self.mat.pop(i)
continue
i += 1
self.mat = self._does_proj_match()
return self.mat
def __call__(self, fnm):
flds = self.Fields()
flds.title = os.path.splitext(
os.path.basename(fnm))[0] if isinstance(
fnm, type("")) else ""
secs, tbls = self.doc(fnm)
self.lines = [l for l, s in secs]
self.styles = [s for l, s in secs]
txt_cks = self.text_chunks("")
flds.text_chunks = [(t, None) for t in txt_cks if not self._garbage(t)]
flds.table_chunks = [(tb, None) for tb in tbls for t in tb if t]
return flds
class ExcelChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, excel_parser):
self.excel = excel_parser
super().__init__()
def __call__(self, fnm):
flds = self.Fields()
flds.text_chunks = [(t, None) for t in self.excel(fnm)]
flds.table_chunks = []
return flds
class PptChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self):
super().__init__()
def __call__(self, fnm):
from pptx import Presentation
ppt = Presentation(fnm) if isinstance(
fnm, str) else Presentation(
BytesIO(fnm))
flds = self.Fields()
flds.text_chunks = []
for slide in ppt.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
flds.text_chunks.append((shape.text, None))
flds.table_chunks = []
return flds
class TextChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self):
super().__init__()
@staticmethod
def is_binary_file(file_path):
mime = magic.Magic(mime=True)
if isinstance(file_path, str):
file_type = mime.from_file(file_path)
else:
file_type = mime.from_buffer(file_path)
if 'text' in file_type:
return False
else:
return True
def __call__(self, fnm):
flds = self.Fields()
if self.is_binary_file(fnm):
return flds
with open(fnm, "r") as f:
txt = f.read()
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
flds.table_chunks = []
return flds
if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(__file__) + "/../")
if sys.argv[1].split(".")[-1].lower() == "pdf":
from parser import PdfParser
ckr = PdfChunker(PdfParser())
if sys.argv[1].split(".")[-1].lower().find("doc") >= 0:
from parser import DocxParser
ckr = DocxChunker(DocxParser())
if sys.argv[1].split(".")[-1].lower().find("xlsx") >= 0:
from parser import ExcelParser
ckr = ExcelChunker(ExcelParser())
# ckr.html(sys.argv[1])
print(ckr(sys.argv[1]))

406
rag/nlp/huqie.py Normal file
View File

@ -0,0 +1,406 @@
# -*- coding: utf-8 -*-
import copy
import datrie
import math
import os
import re
import string
import sys
from hanziconv import HanziConv
from web_server.utils.file_utils import get_project_base_directory
class Huqie:
def key_(self, line):
return str(line.lower().encode("utf-8"))[2:-1]
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
print("[HUQIE]:Build trie", fnm, file=sys.stderr)
try:
of = open(fnm, "r")
while True:
line = of.readline()
if not line:
break
line = re.sub(r"[\r\n]+", "", line)
line = re.split(r"[ \t]", line)
k = self.key_(line[0])
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
if k not in self.trie_ or self.trie_[k][0] < F:
self.trie_[self.key_(line[0])] = (F, line[2])
self.trie_[self.rkey_(line[0])] = 1
self.trie_.save(fnm + ".trie")
of.close()
except Exception as e:
print("[HUQIE]:Faild to build trie, ", fnm, e, file=sys.stderr)
def __init__(self, debug=False):
self.DEBUG = debug
self.DENOMINATOR = 1000000
self.trie_ = datrie.Trie(string.printable)
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
try:
self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
return
except Exception as e:
print("[HUQIE]:Build default trie", file=sys.stderr)
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(self.DIR_ + ".txt")
def loadUserDict(self, fnm):
try:
self.trie_ = datrie.Trie.load(fnm + ".trie")
return
except Exception as e:
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm)
def addUserDict(self, fnm):
self.loadDict_(fnm)
def _strQ2B(self, ustring):
"""把字符串全角转半角"""
rstring = ""
for uchar in ustring:
inside_code = ord(uchar)
if inside_code == 0x3000:
inside_code = 0x0020
else:
inside_code -= 0xfee0
if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
rstring += uchar
else:
rstring += chr(inside_code)
return rstring
def _tradi2simp(self, line):
return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist):
MAX_L = 10
res = s
# if s > MAX_L or s>= len(chars):
if s >= len(chars):
tkslist.append(preTks)
return res
# pruning
S = s + 1
if s + 2 <= len(chars):
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
self.key_(t2)):
S = s + 2
if len(preTks) > 2 and len(
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2
################
for e in range(S, len(chars) + 1):
t = "".join(chars[s:e])
k = self.key_(t)
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
break
if k in self.trie_:
pretks = copy.deepcopy(preTks)
if k in self.trie_:
pretks.append((t, self.trie_[k]))
else:
pretks.append((t, (-12, '')))
res = max(res, self.dfs_(chars, e, pretks, tkslist))
if res > s:
return res
t = "".join(chars[s:s + 1])
k = self.key_(t)
if k in self.trie_:
preTks.append((t, self.trie_[k]))
else:
preTks.append((t, (-12, '')))
return self.dfs_(chars, s + 1, preTks, tkslist)
def freq(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return 0
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
def tag(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return ""
return self.trie_[k][1]
def score_(self, tfts):
B = 30
F, L, tks = 0, 0, []
for tk, (freq, tag) in tfts:
F += freq
L += 0 if len(tk) < 2 else 1
tks.append(tk)
F /= len(tks)
L /= len(tks)
if self.DEBUG:
print("[SC]", tks, len(tks), L, F, B / len(tks) + L + F)
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
res.append((tks, s))
return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks):
patts = [
(r"[ ]+", " "),
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
]
# for p,s in patts: tks = re.sub(p, s, tks)
# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split(" ")
s = 0
while True:
if s >= len(tks):
break
E = s + 1
for e in range(s + 2, min(len(tks) + 2, s + 6)):
tk = "".join(tks[s:e])
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
E = e
res.append("".join(tks[s:E]))
s = E
return " ".join(res)
def maxForward_(self, line):
res = []
s = 0
while s < len(line):
e = s + 1
t = line[s:e]
while e < len(line) and self.trie_.has_keys_with_prefix(
self.key_(t)):
e += 1
t = line[s:e]
while e - 1 > s and self.key_(t) not in self.trie_:
e -= 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s = e
return self.score_(res)
def maxBackward_(self, line):
res = []
s = len(line) - 1
while s >= 0:
e = s + 1
t = line[s:e]
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
s -= 1
t = line[s:e]
while s + 1 < e and self.key_(t) not in self.trie_:
s += 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s -= 1
return self.score_(res[::-1])
def qie(self, line):
line = self._strQ2B(line).lower()
line = self._tradi2simp(line)
arr = re.split(self.SPLIT_CHAR, line)
res = []
for L in arr:
if len(L) < 2 or re.match(
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
res.append(L)
continue
# print(L)
# use maxforward for the first time
tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L)
if self.DEBUG:
print("[FW]", tks, s)
print("[BW]", tks1, s1)
diff = [0 for _ in range(max(len(tks1), len(tks)))]
for i in range(min(len(tks1), len(tks))):
if tks[i] != tks1[i]:
diff[i] = 1
if s1 > s:
tks = tks1
i = 0
while i < len(tks):
s = i
while s < len(tks) and diff[s] == 0:
s += 1
if s == len(tks):
res.append(" ".join(tks[i:]))
break
if s > i:
res.append(" ".join(tks[i:s]))
e = s
while e < len(tks) and e - s < 5 and diff[e] == 1:
e += 1
tkslist = []
self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
i = e + 1
res = " ".join(res)
if self.DEBUG:
print("[TKS]", self.merge_(res))
return self.merge_(res)
def qieqie(self, tks):
res = []
for tk in tks.split(" "):
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
res.append(tk)
continue
tkslist = []
if len(tk) > 10:
tkslist.append(tk)
else:
self.dfs_(tk, 0, [], tkslist)
if len(tkslist) < 2:
res.append(tk)
continue
stk = self.sortTks_(tkslist)[1][0]
if len(stk) == len(tk):
stk = tk
else:
if re.match(r"[a-z\.-]+$", tk):
for t in stk:
if len(t) < 3:
stk = tk
break
else:
stk = " ".join(stk)
else:
stk = " ".join(stk)
res.append(stk)
return " ".join(res)
def is_chinese(s):
if s >= u'\u4e00' and s <= u'\u9fa5':
return True
else:
return False
def is_number(s):
if s >= u'\u0030' and s <= u'\u0039':
return True
else:
return False
def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or (
s >= u'\u0061' and s <= u'\u007a'):
return True
else:
return False
def naiveQie(txt):
tks = []
for t in txt.split(" "):
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
) and re.match(r".*[a-zA-Z]$", t):
tks.append(" ")
tks.append(t)
return tks
hq = Huqie()
qie = hq.qie
qieqie = hq.qieqie
tag = hq.tag
freq = hq.freq
loadUserDict = hq.loadUserDict
addUserDict = hq.addUserDict
tradi2simp = hq._tradi2simp
strQ2B = hq._strQ2B
if __name__ == '__main__':
huqie = Huqie(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
tks = huqie.qie(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
print(huqie.qieqie(tks))
tks = huqie.qie(
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
print(huqie.qieqie(tks))
tks = huqie.qie(
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
print(huqie.qieqie(tks))
tks = huqie.qie(
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
print(huqie.qieqie(tks))
tks = huqie.qie("虽然我不怎么玩")
print(huqie.qieqie(tks))
tks = huqie.qie("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
print(huqie.qieqie(tks))
tks = huqie.qie(
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
print(huqie.qieqie(tks))
tks = huqie.qie("这周日你去吗?这周日你有空吗?")
print(huqie.qieqie(tks))
tks = huqie.qie("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
print(huqie.qieqie(tks))
tks = huqie.qie(
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
print(huqie.qieqie(tks))
if len(sys.argv) < 2:
sys.exit()
huqie.DEBUG = False
huqie.loadUserDict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()
if not line:
break
print(huqie.qie(line))
of.close()

167
rag/nlp/query.py Normal file
View File

@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
import json
import re
import logging
import copy
import math
from elasticsearch_dsl import Q, Search
from rag.nlp import huqie, term_weight, synonym
class EsQueryer:
def __init__(self, es):
self.tw = term_weight.Dealer()
self.es = es
self.syn = synonym.Dealer(None)
self.flds = ["ask_tks^10", "ask_small_tks"]
@staticmethod
def subSpecialChar(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1. / len(arr) >= 0.8
@staticmethod
def rmWWW(txt):
txt = re.sub(
r"是*(什么样的|哪家|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
"",
txt)
return re.sub(
r"(what|who|how|which|where|why|(is|are|were|was) there) (is|are|were|was)*", "", txt, re.IGNORECASE)
def question(self, txt, tbl="qa", min_match="60%"):
txt = re.sub(
r"[ \t,,。??/`!&]+",
" ",
huqie.tradi2simp(
huqie.strQ2B(
txt.lower()))).strip()
txt = EsQueryer.rmWWW(txt)
if not self.isChinese(txt):
tks = txt.split(" ")
q = []
for i in range(1, len(tks)):
q.append("\"%s %s\"~2" % (tks[i - 1], tks[i]))
if not q:
q.append(txt)
return Q("bool",
must=Q("query_string", fields=self.flds,
type="best_fields", query=" OR ".join(q),
boost=1, minimum_should_match="60%")
), txt.split(" ")
def needQieqie(tk):
if len(tk) < 4:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
qs, keywords = [], []
for tt in self.tw.split(txt): # .split(" "):
if not tt:
continue
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
logging.info(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = huqie.qieqie(tk).split(" ") if needQieqie(tk) else []
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m) for m in sm]
sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(sm) < 2:
sm = []
keywords.append(re.sub(r"[ \\\"']+", "", tk))
tk_syns = self.syn.lookup(tk)
tk = EsQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = "\"%s\"" % tk
if tk_syns:
tk = f"({tk} %s)" % " ".join(tk_syns)
if sm:
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
" ".join(sm), " ".join(sm))
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
if re.match(r"[0-9a-z ]+$", tt):
tms = f"(\"{tt}\" OR \"%s\")" % huqie.qie(tt)
syns = " OR ".join(
["\"%s\"^0.7" % EsQueryer.subSpecialChar(huqie.qie(s)) for s in syns])
if syns:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
flds = copy.deepcopy(self.flds)
mst = []
if qs:
mst.append(
Q("query_string", fields=flds, type="best_fields",
query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
)
return Q("bool",
must=mst,
), keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
def toDict(tks):
d = {}
if isinstance(tks, type("")):
tks = tks.split(" ")
for t, c in self.tw.weights(tks):
if t not in d:
d[t] = 0
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
tksim = [self.similarity(atks, btks) for btks in btkss]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight
def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")):
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt))}
if isinstance(qtwt, type("")):
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt))}
s = 1e-9
for k, v in qtwt.items():
if k in dtwt:
s += v * dtwt[k]
q = 1e-9
for k, v in qtwt.items():
q += v * v
d = 1e-9
for k, v in dtwt.items():
d += v * v
return s / math.sqrt(q) / math.sqrt(d)

250
rag/nlp/search.py Normal file
View File

@ -0,0 +1,250 @@
# -*- coding: utf-8 -*-
import re
from elasticsearch_dsl import Q, Search, A
from typing import List, Optional, Tuple, Dict, Union
from dataclasses import dataclass
from rag.utils import rmSpace
from rag.nlp import huqie, query
import numpy as np
def index_name(uid): return f"docgpt_{uid}"
class Dealer:
def __init__(self, es, emb_mdl):
self.qryr = query.EsQueryer(es)
self.qryr.flds = [
"title_tks^10",
"title_sm_tks^5",
"content_ltks^2",
"content_sm_ltks"]
self.es = es
self.emb_mdl = emb_mdl
@dataclass
class SearchResult:
total: int
ids: List[str]
query_vector: List[float] = None
field: Optional[Dict] = None
highlight: Optional[Dict] = None
aggregation: Union[List, Dict, None] = None
keywords: Optional[List[str]] = None
group_docs: List[List] = None
def _vector(self, txt, sim=0.8, topk=10):
return {
"field": "q_vec",
"k": topk,
"similarity": sim,
"num_candidates": 1000,
"query_vector": self.emb_mdl.encode_queries(txt)
}
def search(self, req, idxnm, tks_num=3):
keywords = []
qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst)
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
bqry.filter.append(Q("exists", field="q_tks"))
bqry.boost = 0.05
print(bqry)
s = Search()
pg = int(req.get("page", 1)) - 1
ps = int(req.get("size", 1000))
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
"image_id", "doc_id", "q_vec"])
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks")
s = s.highlight("title_ltks")
if not qst:
s = s.sort(
{"create_time": {"order": "desc", "unmapped_type": "date"}})
s = s.highlight_options(
fragment_size=120,
number_of_fragments=5,
boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、"
)
s = s.to_dict()
q_vec = []
if req.get("vector"):
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
s["knn"]["filter"] = bqry.to_dict()
del s["highlight"]
q_vec = s["knn"]["query_vector"]
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
print("TOTAL: ", self.es.getTotal(res))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.7
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
kwds = set([])
for k in keywords:
kwds.add(k)
for kk in huqie.qieqie(k).split(" "):
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
aggs = self.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total=self.es.getTotal(res),
ids=self.es.getDocIds(res),
query_vector=q_vec,
aggregation=aggs,
highlight=self.getHighlight(res),
field=self.getFields(res, ["docnm_kwd", "content_ltks",
"kb_id", "image_id", "doc_id", "q_vec"]),
keywords=list(kwds)
)
def getAggregation(self, res, g):
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
return
bkts = res["aggregations"]["aggs_" + g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
def getHighlight(self, res):
def rmspace(line):
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
r = []
for t in line.split(" "):
if not t:
continue
if len(r) > 0 and len(
t) > 0 and r[-1][-1] in eng and t[0] in eng:
r.append(" ")
r.append(t)
r = "".join(r)
return r
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
return ans
def getFields(self, sres, flds):
res = {}
if not flds:
return {}
for d in self.es.getSource(sres):
m = {n: d.get(n) for n in flds if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, type([])):
m[n] = "\t".join([str(vv) for vv in v])
continue
if not isinstance(v, type("")):
m[n] = str(m[n])
m[n] = rmSpace(m[n])
if m:
res[d["id"]] = m
return res
@staticmethod
def trans2floats(txt):
return [float(t) for t in txt.split("\t")]
def insert_citations(self, ans, top_idx, sres,
vfield="q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(
sres.field[sres.ids[i]][vfield]) for i in top_idx]
ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
s = 0
e = 0
res = ""
def citeit():
nonlocal s, e, ans, res
if not ins_embd:
return
embd = self.emb_mdl.encode(ans[s: e])
sim = self.qryr.hybrid_similarity(embd,
ins_embd,
huqie.qie(ans[s:e]).split(" "),
ins_tw)
print(ans[s: e], sim)
mx = np.max(sim) * 0.99
if mx < 0.55:
return
cita = list(set([top_idx[i]
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
for i in cita:
res += f"@?{i}?@"
return cita
punct = set(";。?!")
if not self.qryr.isChinese(ans):
punct.add("?")
punct.add(".")
while e < len(ans):
if e - s < 12 or ans[e] not in punct:
e += 1
continue
if ans[e] == "." and e + \
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
e += 1
continue
if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
e += 1
continue
res += ans[s: e]
citeit()
res += ans[e]
e += 1
s = e
if s < len(ans):
res += ans[s:]
citeit()
return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
vfield="q_vec", cfield="content_ltks"):
ins_embd = [
Dealer.trans2floats(
sres.field[i]["q_vec"]) for i in sres.ids]
if not ins_embd:
return []
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
# return CosineSimilarity([sres.query_vector], ins_embd)[0]
sim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
huqie.qie(query).split(" "),
ins_tw, tkweight, vtweight)
return sim
if __name__ == "__main__":
from util import es_conn
SE = Dealer(es_conn.HuEs("infiniflow"))
qs = [
"胡凯",
""
]
for q in qs:
print(">>>>>>>>>>>>>>>>>>>>", q)
print(SE.search(
{"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))

64
rag/nlp/synonym.py Normal file
View File

@ -0,0 +1,64 @@
import json
import os
import time
import logging
import re
from web_server.utils.file_utils import get_project_base_directory
class Dealer:
def __init__(self, redis=None):
self.lookup_num = 100000000
self.load_tm = time.time() - 1000000
self.dictionary = None
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
try:
self.dictionary = json.load(open(path, 'r'))
except Exception as e:
logging.warn("Miss synonym.json")
self.dictionary = {}
if not redis:
logging.warning(
"Realtime synonym is disabled, since no redis connection.")
if not len(self.dictionary.keys()):
logging.warning(f"Fail to load synonym")
self.redis = redis
self.load()
def load(self):
if not self.redis:
return
if self.lookup_num < 100:
return
tm = time.time()
if tm - self.load_tm < 3600:
return
self.load_tm = time.time()
self.lookup_num = 0
d = self.redis.get("kevin_synonyms")
if not d:
return
try:
d = json.loads(d)
self.dictionary = d
except Exception as e:
logging.error("Fail to load synonym!" + str(e))
def lookup(self, tk):
self.lookup_num += 1
self.load()
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
if isinstance(res, str):
res = [res]
return res
if __name__ == '__main__':
dl = Dealer()
print(dl.dictionary)

216
rag/nlp/term_weight.py Normal file
View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
import math
import json
import re
import os
import numpy as np
from rag.nlp import huqie
from web_server.utils.file_utils import get_project_base_directory
class Dealer:
def __init__(self):
self.stop_words = set(["请问",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"#",
"什么",
"怎么",
"哪个",
"哪些",
"",
"相关"])
def load_dict(fnm):
res = {}
f = open(fnm, "r")
while True:
l = f.readline()
if not l:
break
arr = l.replace("\n", "").split("\t")
if len(arr) < 2:
res[arr[0]] = 0
else:
res[arr[0]] = int(arr[1])
c = 0
for _, v in res.items():
c += v
if c == 0:
return set(res.keys())
return res
fnm = os.path.join(get_project_base_directory(), "res")
self.ne, self.df = {}, {}
try:
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
except Exception as e:
print("[WARNING] Load ner.json FAIL!")
try:
self.df = load_dict(os.path.join(fnm, "term.freq"))
except Exception as e:
print("[WARNING] Load term.freq FAIL!")
def pretoken(self, txt, num=False, stpwd=True):
patt = [
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
]
rewt = [
]
for p, r in rewt:
txt = re.sub(p, r, txt)
res = []
for t in huqie.qie(txt).split(" "):
tk = t
if (stpwd and tk in self.stop_words) or (
re.match(r"[0-9]$", tk) and not num):
continue
for p in patt:
if re.match(p, t):
tk = "#"
break
tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk:
res.append(tk)
return res
def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
j = i
if i == 0 and oneTerm(tks[i]) and len(
tks) > 1 and len(tks[i + 1]) > 1: # 多 工位
res.append(" ".join(tks[0:2]))
i = 2
continue
while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
j += 1
if j - i > 1:
if j - i < 5:
res.append(" ".join(tks[i:j]))
i = j
else:
res.append(" ".join(tks[i:i + 2]))
i = i + 2
else:
if len(tks[i]) > 0:
res.append(tks[i])
i += 1
return [t for t in res if t]
def ner(self, t):
if not self.ne:
return ""
res = self.ne.get(t, "")
if res:
return res
def split(self, txt):
tks = []
for t in re.sub(r"[ \t]+", " ", txt).split(" "):
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t
else:
tks.append(t)
return tks
def weights(self, tks):
def skill(t):
if t not in self.sk:
return 1
return 6
def ner(t):
if not self.ne or t not in self.ne:
return 1
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
"firstnm": 1}
return m[self.ne[t]]
def postag(t):
t = huqie.tag(t)
if t in set(["r", "c", "d"]):
return 0.3
if t in set(["ns", "nt"]):
return 3
if t in set(["n"]):
return 2
if re.match(r"[0-9-]+", t):
return 2
return 1
def freq(t):
if re.match(r"[0-9\. -]+$", t):
return 10000
s = huqie.freq(t)
if not s and re.match(r"[a-z\. -]+$", t):
return 10
if not s:
s = 0
if not s and len(t) >= 4:
s = [tt for tt in huqie.qieqie(t).split(" ") if len(tt) > 1]
if len(s) > 1:
s = np.min([freq(tt) for tt in s]) / 6.
else:
s = 0
return max(s, 10)
def df(t):
if re.match(r"[0-9\. -]+$", t):
return 100000
if t in self.df:
return self.df[t] + 3
elif re.match(r"[a-z\. -]+$", t):
return 3
elif len(t) >= 4:
s = [tt for tt in huqie.qieqie(t).split(" ") if len(tt) > 1]
if len(s) > 1:
return max(3, np.min([df(tt) for tt in s]) / 6.)
return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = []
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tt])
tw.extend(zip(tt, wts))
S = np.sum([s for _, s in tw])
return [(t, s / S) for t, s in tw]