mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
build python version rag-flow (#21)
* clean rust version project * clean rust version project * build python version rag-flow
This commit is contained in:
0
rag/__init__.py
Normal file
0
rag/__init__.py
Normal file
32
rag/llm/__init__.py
Normal file
32
rag/llm/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from .embedding_model import *
|
||||
from .chat_model import *
|
||||
from .cv_model import *
|
||||
|
||||
|
||||
EmbeddingModel = {
|
||||
"local": HuEmbedding,
|
||||
"OpenAI": OpenAIEmbed,
|
||||
"通义千问": QWenEmbed,
|
||||
}
|
||||
|
||||
|
||||
CvModel = {
|
||||
"OpenAI": GptV4,
|
||||
"通义千问": QWenCV,
|
||||
}
|
||||
|
||||
52
rag/llm/chat_model.py
Normal file
52
rag/llm/chat_model.py
Normal file
@ -0,0 +1,52 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import os
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def chat(self, system, history, gen_conf):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
def __init__(self):
|
||||
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
res = self.client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=history,
|
||||
**gen_conf)
|
||||
return res.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class QWenChat(Base):
|
||||
def chat(self, system, history, gen_conf):
|
||||
from http import HTTPStatus
|
||||
from dashscope import Generation
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
response = Generation.call(
|
||||
Generation.Models.qwen_turbo,
|
||||
messages=history,
|
||||
result_format='message'
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
return response.message
|
||||
89
rag/llm/cv_model.py
Normal file
89
rag/llm/cv_model.py
Normal file
@ -0,0 +1,89 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def image2base64(self, image):
|
||||
if isinstance(image, BytesIO):
|
||||
return base64.b64encode(image.getvalue()).decode("utf-8")
|
||||
buffered = BytesIO()
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception as e:
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview"):
|
||||
self.client = OpenAI(key)
|
||||
self.model_name = model_name
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
b64 = self.image2base64(image)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return res.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class QWenCV(Base):
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1"):
|
||||
import dashscope
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
from http import HTTPStatus
|
||||
from dashscope import MultiModalConversation
|
||||
response = MultiModalConversation.call(model=self.model_name,
|
||||
messages=self.prompt(self.image2base64(image)))
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
return response.message
|
||||
94
rag/llm/embedding_model.py
Normal file
94
rag/llm/embedding_model.py
Normal file
@ -0,0 +1,94 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
|
||||
import dashscope
|
||||
from openai import OpenAI
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class HuEmbedding(Base):
|
||||
def __init__(self):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
token_count = 0
|
||||
for t in texts: token_count += num_tokens_from_string(t)
|
||||
res = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
||||
return np.array(res), token_count
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
def __init__(self, key, model_name="text-embedding-ada-002"):
|
||||
self.client = OpenAI(key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
token_count = 0
|
||||
for t in texts: token_count += num_tokens_from_string(t)
|
||||
res = self.client.embeddings.create(input=texts,
|
||||
model=self.model_name)
|
||||
return [d["embedding"] for d in res["data"]], token_count
|
||||
|
||||
|
||||
class QWenEmbed(Base):
|
||||
def __init__(self, key, model_name="text_embedding_v2"):
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=32, text_type="document"):
|
||||
import dashscope
|
||||
res = []
|
||||
token_count = 0
|
||||
for txt in texts:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=self.model_name,
|
||||
input=txt[:2048],
|
||||
text_type=text_type
|
||||
)
|
||||
res.append(resp["output"]["embeddings"][0]["embedding"])
|
||||
token_count += resp["usage"]["total_tokens"]
|
||||
return res, token_count
|
||||
0
rag/nlp/__init__.py
Normal file
0
rag/nlp/__init__.py
Normal file
435
rag/nlp/huchunk.py
Normal file
435
rag/nlp/huchunk.py
Normal file
@ -0,0 +1,435 @@
|
||||
import re
|
||||
import os
|
||||
import copy
|
||||
import base64
|
||||
import magic
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuChunker:
|
||||
|
||||
def __init__(self):
|
||||
self.MAX_LVL = 12
|
||||
self.proj_patt = [
|
||||
(r"第[零一二三四五六七八九十百]+章", 1),
|
||||
(r"第[零一二三四五六七八九十百]+[条节]", 2),
|
||||
(r"[零一二三四五六七八九十百]+[、 ]", 3),
|
||||
(r"[\((][零一二三四五六七八九十百]+[)\)]", 4),
|
||||
(r"[0-9]+(、|\.[ ]|\.[^0-9])", 5),
|
||||
(r"[0-9]+\.[0-9]+(、|[ ]|[^0-9])", 6),
|
||||
(r"[0-9]+\.[0-9]+\.[0-9]+(、|[ ]|[^0-9])", 7),
|
||||
(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+(、|[ ]|[^0-9])", 8),
|
||||
(r".{,48}[::??]@", 9),
|
||||
(r"[0-9]+)", 10),
|
||||
(r"[\((][0-9]+[)\)]", 11),
|
||||
(r"[零一二三四五六七八九十百]+是", 12),
|
||||
(r"[⚫•➢✓ ]", 12)
|
||||
]
|
||||
self.lines = []
|
||||
|
||||
def _garbage(self, txt):
|
||||
patt = [
|
||||
r"(在此保证|不得以任何形式翻版|请勿传阅|仅供内部使用|未经事先书面授权)",
|
||||
r"(版权(归本公司)*所有|免责声明|保留一切权力|承担全部责任|特别声明|报告中涉及)",
|
||||
r"(不承担任何责任|投资者的通知事项:|任何机构和个人|本报告仅为|不构成投资)",
|
||||
r"(不构成对任何个人或机构投资建议|联系其所在国家|本报告由从事证券交易)",
|
||||
r"(本研究报告由|「认可投资者」|所有研究报告均以|请发邮件至)",
|
||||
r"(本报告仅供|市场有风险,投资需谨慎|本报告中提及的)",
|
||||
r"(本报告反映|此信息仅供|证券分析师承诺|具备证券投资咨询业务资格)",
|
||||
r"^(时间|签字|签章)[::]",
|
||||
r"(参考文献|目录索引|图表索引)",
|
||||
r"[ ]*年[ ]+月[ ]+日",
|
||||
r"^(中国证券业协会|[0-9]+年[0-9]+月[0-9]+日)$",
|
||||
r"\.{10,}",
|
||||
r"(———————END|帮我转发|欢迎收藏|快来关注我吧)"
|
||||
]
|
||||
return any([re.search(p, txt) for p in patt])
|
||||
|
||||
def _proj_match(self, line):
|
||||
for p, j in self.proj_patt:
|
||||
if re.match(p, line):
|
||||
return j
|
||||
return
|
||||
|
||||
def _does_proj_match(self):
|
||||
mat = [None for _ in range(len(self.lines))]
|
||||
for i in range(len(self.lines)):
|
||||
mat[i] = self._proj_match(self.lines[i])
|
||||
return mat
|
||||
|
||||
def naive_text_chunk(self, text, ti="", MAX_LEN=612):
|
||||
if text:
|
||||
self.lines = [l.strip().replace(u'\u3000', u' ')
|
||||
.replace(u'\xa0', u'')
|
||||
for l in text.split("\n\n")]
|
||||
self.lines = [l for l in self.lines if not self._garbage(l)]
|
||||
self.lines = [re.sub(r"([ ]+| )", " ", l)
|
||||
for l in self.lines if l]
|
||||
if not self.lines:
|
||||
return []
|
||||
arr = self.lines
|
||||
|
||||
res = [""]
|
||||
i = 0
|
||||
while i < len(arr):
|
||||
a = arr[i]
|
||||
if not a:
|
||||
i += 1
|
||||
continue
|
||||
if len(a) > MAX_LEN:
|
||||
a_ = a.split("\n")
|
||||
if len(a_) >= 2:
|
||||
arr.pop(i)
|
||||
for j in range(2, len(a_) + 1):
|
||||
if len("\n".join(a_[:j])) >= MAX_LEN:
|
||||
arr.insert(i, "\n".join(a_[:j - 1]))
|
||||
arr.insert(i + 1, "\n".join(a_[j - 1:]))
|
||||
break
|
||||
else:
|
||||
assert False, f"Can't split: {a}"
|
||||
continue
|
||||
|
||||
if len(res[-1]) < MAX_LEN / 3:
|
||||
res[-1] += "\n" + a
|
||||
else:
|
||||
res.append(a)
|
||||
i += 1
|
||||
|
||||
if ti:
|
||||
for i in range(len(res)):
|
||||
if res[i].find("——来自") >= 0:
|
||||
continue
|
||||
res[i] += f"\t——来自“{ti}”"
|
||||
|
||||
return res
|
||||
|
||||
def _merge(self):
|
||||
# merge continuous same level text
|
||||
lines = [self.lines[0]] if self.lines else []
|
||||
for i in range(1, len(self.lines)):
|
||||
if self.mat[i] == self.mat[i - 1] \
|
||||
and len(lines[-1]) < 256 \
|
||||
and len(self.lines[i]) < 256:
|
||||
lines[-1] += "\n" + self.lines[i]
|
||||
continue
|
||||
lines.append(self.lines[i])
|
||||
self.lines = lines
|
||||
self.mat = self._does_proj_match()
|
||||
return self.mat
|
||||
|
||||
def text_chunks(self, text):
|
||||
if text:
|
||||
self.lines = [l.strip().replace(u'\u3000', u' ')
|
||||
.replace(u'\xa0', u'')
|
||||
for l in re.split(r"[\r\n]", text)]
|
||||
self.lines = [l for l in self.lines if not self._garbage(l)]
|
||||
self.lines = [l for l in self.lines if l]
|
||||
self.mat = self._does_proj_match()
|
||||
mat = self._merge()
|
||||
|
||||
tree = []
|
||||
for i in range(len(self.lines)):
|
||||
tree.append({"proj": mat[i],
|
||||
"children": [],
|
||||
"read": False})
|
||||
# find all children
|
||||
for i in range(len(self.lines) - 1):
|
||||
if tree[i]["proj"] is None:
|
||||
continue
|
||||
ed = i + 1
|
||||
while ed < len(tree) and (tree[ed]["proj"] is None or
|
||||
tree[ed]["proj"] > tree[i]["proj"]):
|
||||
ed += 1
|
||||
|
||||
nxt = tree[i]["proj"] + 1
|
||||
st = set([p["proj"] for p in tree[i + 1: ed] if p["proj"]])
|
||||
while nxt not in st:
|
||||
nxt += 1
|
||||
if nxt > self.MAX_LVL:
|
||||
break
|
||||
if nxt <= self.MAX_LVL:
|
||||
for j in range(i + 1, ed):
|
||||
if tree[j]["proj"] is not None:
|
||||
break
|
||||
tree[i]["children"].append(j)
|
||||
for j in range(i + 1, ed):
|
||||
if tree[j]["proj"] != nxt:
|
||||
continue
|
||||
tree[i]["children"].append(j)
|
||||
else:
|
||||
for j in range(i + 1, ed):
|
||||
tree[i]["children"].append(j)
|
||||
|
||||
# get DFS combinations, find all the paths to leaf
|
||||
paths = []
|
||||
|
||||
def dfs(i, path):
|
||||
nonlocal tree, paths
|
||||
path.append(i)
|
||||
tree[i]["read"] = True
|
||||
if len(self.lines[i]) > 256:
|
||||
paths.append(path)
|
||||
return
|
||||
if not tree[i]["children"]:
|
||||
if len(path) > 1 or len(self.lines[i]) >= 32:
|
||||
paths.append(path)
|
||||
return
|
||||
for j in tree[i]["children"]:
|
||||
dfs(j, copy.deepcopy(path))
|
||||
|
||||
for i, t in enumerate(tree):
|
||||
if t["read"]:
|
||||
continue
|
||||
dfs(i, [])
|
||||
|
||||
# concat txt on the path for all paths
|
||||
res = []
|
||||
lines = np.array(self.lines)
|
||||
for p in paths:
|
||||
if len(p) < 2:
|
||||
tree[p[0]]["read"] = False
|
||||
continue
|
||||
txt = "\n".join(lines[p[:-1]]) + "\n" + lines[p[-1]]
|
||||
res.append(txt)
|
||||
# concat continuous orphans
|
||||
assert len(tree) == len(lines)
|
||||
ii = 0
|
||||
while ii < len(tree):
|
||||
if tree[ii]["read"]:
|
||||
ii += 1
|
||||
continue
|
||||
txt = lines[ii]
|
||||
e = ii + 1
|
||||
while e < len(tree) and not tree[e]["read"] and len(txt) < 256:
|
||||
txt += "\n" + lines[e]
|
||||
e += 1
|
||||
res.append(txt)
|
||||
ii = e
|
||||
|
||||
# if the node has not been read, find its daddy
|
||||
def find_daddy(st):
|
||||
nonlocal lines, tree
|
||||
proj = tree[st]["proj"]
|
||||
if len(self.lines[st]) > 512:
|
||||
return [st]
|
||||
if proj is None:
|
||||
proj = self.MAX_LVL + 1
|
||||
for i in range(st - 1, -1, -1):
|
||||
if tree[i]["proj"] and tree[i]["proj"] < proj:
|
||||
a = [st] + find_daddy(i)
|
||||
return a
|
||||
return []
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class PdfChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self, pdf_parser):
|
||||
self.pdf = pdf_parser
|
||||
super().__init__()
|
||||
|
||||
def tableHtmls(self, pdfnm):
|
||||
_, tbls = self.pdf(pdfnm, return_html=True)
|
||||
res = []
|
||||
for img, arr in tbls:
|
||||
if arr[0].find("<table>") < 0:
|
||||
continue
|
||||
buffered = BytesIO()
|
||||
if img:
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(
|
||||
buffered.getvalue()).decode('utf-8') if img else ""
|
||||
res.append({"table": arr[0], "image": img_str})
|
||||
return res
|
||||
|
||||
def html(self, pdfnm):
|
||||
txts, tbls = self.pdf(pdfnm, return_html=True)
|
||||
res = []
|
||||
txt_cks = self.text_chunks(txts)
|
||||
for txt, img in [(self.pdf.remove_tag(c), self.pdf.crop(c))
|
||||
for c in txt_cks]:
|
||||
buffered = BytesIO()
|
||||
if img:
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(
|
||||
buffered.getvalue()).decode('utf-8') if img else ""
|
||||
res.append({"table": "<p>%s</p>" % txt.replace("\n", "<br/>"),
|
||||
"image": img_str})
|
||||
|
||||
for img, arr in tbls:
|
||||
if not arr:
|
||||
continue
|
||||
buffered = BytesIO()
|
||||
if img:
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(
|
||||
buffered.getvalue()).decode('utf-8') if img else ""
|
||||
res.append({"table": arr[0], "image": img_str})
|
||||
|
||||
return res
|
||||
|
||||
def __call__(self, pdfnm, return_image=True, naive_chunk=False):
|
||||
flds = self.Fields()
|
||||
text, tbls = self.pdf(pdfnm)
|
||||
fnm = pdfnm
|
||||
txt_cks = self.text_chunks(text) if not naive_chunk else \
|
||||
self.naive_text_chunk(text, ti=fnm if isinstance(fnm, str) else "")
|
||||
flds.text_chunks = [(self.pdf.remove_tag(c),
|
||||
self.pdf.crop(c) if return_image else None) for c in txt_cks]
|
||||
|
||||
flds.table_chunks = [(arr, img if return_image else None)
|
||||
for img, arr in tbls]
|
||||
return flds
|
||||
|
||||
|
||||
class DocxChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self, doc_parser):
|
||||
self.doc = doc_parser
|
||||
super().__init__()
|
||||
|
||||
def _does_proj_match(self):
|
||||
mat = []
|
||||
for s in self.styles:
|
||||
s = s.split(" ")[-1]
|
||||
try:
|
||||
mat.append(int(s))
|
||||
except Exception as e:
|
||||
mat.append(None)
|
||||
return mat
|
||||
|
||||
def _merge(self):
|
||||
i = 1
|
||||
while i < len(self.lines):
|
||||
if self.mat[i] == self.mat[i - 1] \
|
||||
and len(self.lines[i - 1]) < 256 \
|
||||
and len(self.lines[i]) < 256:
|
||||
self.lines[i - 1] += "\n" + self.lines[i]
|
||||
self.styles.pop(i)
|
||||
self.lines.pop(i)
|
||||
self.mat.pop(i)
|
||||
continue
|
||||
i += 1
|
||||
self.mat = self._does_proj_match()
|
||||
return self.mat
|
||||
|
||||
def __call__(self, fnm):
|
||||
flds = self.Fields()
|
||||
flds.title = os.path.splitext(
|
||||
os.path.basename(fnm))[0] if isinstance(
|
||||
fnm, type("")) else ""
|
||||
secs, tbls = self.doc(fnm)
|
||||
self.lines = [l for l, s in secs]
|
||||
self.styles = [s for l, s in secs]
|
||||
|
||||
txt_cks = self.text_chunks("")
|
||||
flds.text_chunks = [(t, None) for t in txt_cks if not self._garbage(t)]
|
||||
flds.table_chunks = [(tb, None) for tb in tbls for t in tb if t]
|
||||
return flds
|
||||
|
||||
|
||||
class ExcelChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self, excel_parser):
|
||||
self.excel = excel_parser
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, fnm):
|
||||
flds = self.Fields()
|
||||
flds.text_chunks = [(t, None) for t in self.excel(fnm)]
|
||||
flds.table_chunks = []
|
||||
return flds
|
||||
|
||||
|
||||
class PptChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, fnm):
|
||||
from pptx import Presentation
|
||||
ppt = Presentation(fnm) if isinstance(
|
||||
fnm, str) else Presentation(
|
||||
BytesIO(fnm))
|
||||
flds = self.Fields()
|
||||
flds.text_chunks = []
|
||||
for slide in ppt.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
flds.text_chunks.append((shape.text, None))
|
||||
flds.table_chunks = []
|
||||
return flds
|
||||
|
||||
|
||||
class TextChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def is_binary_file(file_path):
|
||||
mime = magic.Magic(mime=True)
|
||||
if isinstance(file_path, str):
|
||||
file_type = mime.from_file(file_path)
|
||||
else:
|
||||
file_type = mime.from_buffer(file_path)
|
||||
if 'text' in file_type:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def __call__(self, fnm):
|
||||
flds = self.Fields()
|
||||
if self.is_binary_file(fnm):
|
||||
return flds
|
||||
with open(fnm, "r") as f:
|
||||
txt = f.read()
|
||||
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
||||
flds.table_chunks = []
|
||||
return flds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(__file__) + "/../")
|
||||
if sys.argv[1].split(".")[-1].lower() == "pdf":
|
||||
from parser import PdfParser
|
||||
ckr = PdfChunker(PdfParser())
|
||||
if sys.argv[1].split(".")[-1].lower().find("doc") >= 0:
|
||||
from parser import DocxParser
|
||||
ckr = DocxChunker(DocxParser())
|
||||
if sys.argv[1].split(".")[-1].lower().find("xlsx") >= 0:
|
||||
from parser import ExcelParser
|
||||
ckr = ExcelChunker(ExcelParser())
|
||||
|
||||
# ckr.html(sys.argv[1])
|
||||
print(ckr(sys.argv[1]))
|
||||
406
rag/nlp/huqie.py
Normal file
406
rag/nlp/huqie.py
Normal file
@ -0,0 +1,406 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import datrie
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from hanziconv import HanziConv
|
||||
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Huqie:
|
||||
def key_(self, line):
|
||||
return str(line.lower().encode("utf-8"))[2:-1]
|
||||
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
print("[HUQIE]:Build trie", fnm, file=sys.stderr)
|
||||
try:
|
||||
of = open(fnm, "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
line = re.sub(r"[\r\n]+", "", line)
|
||||
line = re.split(r"[ \t]", line)
|
||||
k = self.key_(line[0])
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
|
||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||
self.trie_[self.rkey_(line[0])] = 1
|
||||
self.trie_.save(fnm + ".trie")
|
||||
of.close()
|
||||
except Exception as e:
|
||||
print("[HUQIE]:Faild to build trie, ", fnm, e, file=sys.stderr)
|
||||
|
||||
def __init__(self, debug=False):
|
||||
self.DEBUG = debug
|
||||
self.DENOMINATOR = 1000000
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
|
||||
|
||||
self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
|
||||
return
|
||||
except Exception as e:
|
||||
print("[HUQIE]:Build default trie", file=sys.stderr)
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
self.loadDict_(self.DIR_ + ".txt")
|
||||
|
||||
def loadUserDict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception as e:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def addUserDict(self, fnm):
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""把字符串全角转半角"""
|
||||
rstring = ""
|
||||
for uchar in ustring:
|
||||
inside_code = ord(uchar)
|
||||
if inside_code == 0x3000:
|
||||
inside_code = 0x0020
|
||||
else:
|
||||
inside_code -= 0xfee0
|
||||
if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
|
||||
rstring += uchar
|
||||
else:
|
||||
rstring += chr(inside_code)
|
||||
return rstring
|
||||
|
||||
def _tradi2simp(self, line):
|
||||
return HanziConv.toSimplified(line)
|
||||
|
||||
def dfs_(self, chars, s, preTks, tkslist):
|
||||
MAX_L = 10
|
||||
res = s
|
||||
# if s > MAX_L or s>= len(chars):
|
||||
if s >= len(chars):
|
||||
tkslist.append(preTks)
|
||||
return res
|
||||
|
||||
# pruning
|
||||
S = s + 1
|
||||
if s + 2 <= len(chars):
|
||||
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
|
||||
self.key_(t2)):
|
||||
S = s + 2
|
||||
if len(preTks) > 2 and len(
|
||||
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||
S = s + 2
|
||||
|
||||
################
|
||||
for e in range(S, len(chars) + 1):
|
||||
t = "".join(chars[s:e])
|
||||
k = self.key_(t)
|
||||
|
||||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||
break
|
||||
|
||||
if k in self.trie_:
|
||||
pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
pretks.append((t, (-12, '')))
|
||||
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
||||
|
||||
if res > s:
|
||||
return res
|
||||
|
||||
t = "".join(chars[s:s + 1])
|
||||
k = self.key_(t)
|
||||
if k in self.trie_:
|
||||
preTks.append((t, self.trie_[k]))
|
||||
else:
|
||||
preTks.append((t, (-12, '')))
|
||||
|
||||
return self.dfs_(chars, s + 1, preTks, tkslist)
|
||||
|
||||
def freq(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return 0
|
||||
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||||
|
||||
def tag(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return ""
|
||||
return self.trie_[k][1]
|
||||
|
||||
def score_(self, tfts):
|
||||
B = 30
|
||||
F, L, tks = 0, 0, []
|
||||
for tk, (freq, tag) in tfts:
|
||||
F += freq
|
||||
L += 0 if len(tk) < 2 else 1
|
||||
tks.append(tk)
|
||||
F /= len(tks)
|
||||
L /= len(tks)
|
||||
if self.DEBUG:
|
||||
print("[SC]", tks, len(tks), L, F, B / len(tks) + L + F)
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
res.append((tks, s))
|
||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||
|
||||
def merge_(self, tks):
|
||||
patts = [
|
||||
(r"[ ]+", " "),
|
||||
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
|
||||
]
|
||||
# for p,s in patts: tks = re.sub(p, s, tks)
|
||||
|
||||
# if split chars is part of token
|
||||
res = []
|
||||
tks = re.sub(r"[ ]+", " ", tks).split(" ")
|
||||
s = 0
|
||||
while True:
|
||||
if s >= len(tks):
|
||||
break
|
||||
E = s + 1
|
||||
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||||
tk = "".join(tks[s:e])
|
||||
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||||
E = e
|
||||
res.append("".join(tks[s:E]))
|
||||
s = E
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(
|
||||
self.key_(t)):
|
||||
e += 1
|
||||
t = line[s:e]
|
||||
|
||||
while e - 1 > s and self.key_(t) not in self.trie_:
|
||||
e -= 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s = e
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||||
s -= 1
|
||||
t = line[s:e]
|
||||
|
||||
while s + 1 < e and self.key_(t) not in self.trie_:
|
||||
s += 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s -= 1
|
||||
|
||||
return self.score_(res[::-1])
|
||||
|
||||
def qie(self, line):
|
||||
line = self._strQ2B(line).lower()
|
||||
line = self._tradi2simp(line)
|
||||
arr = re.split(self.SPLIT_CHAR, line)
|
||||
res = []
|
||||
for L in arr:
|
||||
if len(L) < 2 or re.match(
|
||||
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
res.append(L)
|
||||
continue
|
||||
# print(L)
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self.maxForward_(L)
|
||||
tks1, s1 = self.maxBackward_(L)
|
||||
if self.DEBUG:
|
||||
print("[FW]", tks, s)
|
||||
print("[BW]", tks1, s1)
|
||||
|
||||
diff = [0 for _ in range(max(len(tks1), len(tks)))]
|
||||
for i in range(min(len(tks1), len(tks))):
|
||||
if tks[i] != tks1[i]:
|
||||
diff[i] = 1
|
||||
|
||||
if s1 > s:
|
||||
tks = tks1
|
||||
|
||||
i = 0
|
||||
while i < len(tks):
|
||||
s = i
|
||||
while s < len(tks) and diff[s] == 0:
|
||||
s += 1
|
||||
if s == len(tks):
|
||||
res.append(" ".join(tks[i:]))
|
||||
break
|
||||
if s > i:
|
||||
res.append(" ".join(tks[i:s]))
|
||||
|
||||
e = s
|
||||
while e < len(tks) and e - s < 5 and diff[e] == 1:
|
||||
e += 1
|
||||
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
i = e + 1
|
||||
|
||||
res = " ".join(res)
|
||||
if self.DEBUG:
|
||||
print("[TKS]", self.merge_(res))
|
||||
return self.merge_(res)
|
||||
|
||||
def qieqie(self, tks):
|
||||
res = []
|
||||
for tk in tks.split(" "):
|
||||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||||
res.append(tk)
|
||||
continue
|
||||
tkslist = []
|
||||
if len(tk) > 10:
|
||||
tkslist.append(tk)
|
||||
else:
|
||||
self.dfs_(tk, 0, [], tkslist)
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
if re.match(r"[a-z\.-]+$", tk):
|
||||
for t in stk:
|
||||
if len(t) < 3:
|
||||
stk = tk
|
||||
break
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
|
||||
res.append(stk)
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
if s >= u'\u4e00' and s <= u'\u9fa5':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_number(s):
|
||||
if s >= u'\u0030' and s <= u'\u0039':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naiveQie(txt):
|
||||
tks = []
|
||||
for t in txt.split(" "):
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
) and re.match(r".*[a-zA-Z]$", t):
|
||||
tks.append(" ")
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
|
||||
hq = Huqie()
|
||||
qie = hq.qie
|
||||
qieqie = hq.qieqie
|
||||
tag = hq.tag
|
||||
freq = hq.freq
|
||||
loadUserDict = hq.loadUserDict
|
||||
addUserDict = hq.addUserDict
|
||||
tradi2simp = hq._tradi2simp
|
||||
strQ2B = hq._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
huqie = Huqie(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
tks = huqie.qie(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie("虽然我不怎么玩")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie("这周日你去吗?这周日你有空吗?")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||||
print(huqie.qieqie(tks))
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
huqie.DEBUG = False
|
||||
huqie.loadUserDict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
print(huqie.qie(line))
|
||||
of.close()
|
||||
167
rag/nlp/query.py
Normal file
167
rag/nlp/query.py
Normal file
@ -0,0 +1,167 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
import copy
|
||||
import math
|
||||
from elasticsearch_dsl import Q, Search
|
||||
from rag.nlp import huqie, term_weight, synonym
|
||||
|
||||
|
||||
class EsQueryer:
|
||||
def __init__(self, es):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.es = es
|
||||
self.syn = synonym.Dealer(None)
|
||||
self.flds = ["ask_tks^10", "ask_small_tks"]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def isChinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1. / len(arr) >= 0.8
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
txt = re.sub(
|
||||
r"是*(什么样的|哪家|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
|
||||
"",
|
||||
txt)
|
||||
return re.sub(
|
||||
r"(what|who|how|which|where|why|(is|are|were|was) there) (is|are|were|was)*", "", txt, re.IGNORECASE)
|
||||
|
||||
def question(self, txt, tbl="qa", min_match="60%"):
|
||||
txt = re.sub(
|
||||
r"[ \t,,。??/`!!&]+",
|
||||
" ",
|
||||
huqie.tradi2simp(
|
||||
huqie.strQ2B(
|
||||
txt.lower()))).strip()
|
||||
txt = EsQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
tks = txt.split(" ")
|
||||
q = []
|
||||
for i in range(1, len(tks)):
|
||||
q.append("\"%s %s\"~2" % (tks[i - 1], tks[i]))
|
||||
if not q:
|
||||
q.append(txt)
|
||||
return Q("bool",
|
||||
must=Q("query_string", fields=self.flds,
|
||||
type="best_fields", query=" OR ".join(q),
|
||||
boost=1, minimum_should_match="60%")
|
||||
), txt.split(" ")
|
||||
|
||||
def needQieqie(tk):
|
||||
if len(tk) < 4:
|
||||
return False
|
||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||
return False
|
||||
return True
|
||||
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt): # .split(" "):
|
||||
if not tt:
|
||||
continue
|
||||
twts = self.tw.weights([tt])
|
||||
syns = self.syn.lookup(tt)
|
||||
logging.info(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
sm = huqie.qieqie(tk).split(" ") if needQieqie(tk) else []
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
"",
|
||||
m) for m in sm]
|
||||
sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
if len(sm) < 2:
|
||||
sm = []
|
||||
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk = EsQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = "\"%s\"" % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} %s)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
|
||||
" ".join(sm), " ".join(sm))
|
||||
tms.append((tk, w))
|
||||
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
if len(twts) > 1:
|
||||
tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
|
||||
if re.match(r"[0-9a-z ]+$", tt):
|
||||
tms = f"(\"{tt}\" OR \"%s\")" % huqie.qie(tt)
|
||||
|
||||
syns = " OR ".join(
|
||||
["\"%s\"^0.7" % EsQueryer.subSpecialChar(huqie.qie(s)) for s in syns])
|
||||
if syns:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
|
||||
qs.append(tms)
|
||||
|
||||
flds = copy.deepcopy(self.flds)
|
||||
mst = []
|
||||
if qs:
|
||||
mst.append(
|
||||
Q("query_string", fields=flds, type="best_fields",
|
||||
query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
|
||||
)
|
||||
|
||||
return Q("bool",
|
||||
must=mst,
|
||||
), keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
|
||||
vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
import numpy as np
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
|
||||
def toDict(tks):
|
||||
d = {}
|
||||
if isinstance(tks, type("")):
|
||||
tks = tks.split(" ")
|
||||
for t, c in self.tw.weights(tks):
|
||||
if t not in d:
|
||||
d[t] = 0
|
||||
d[t] += c
|
||||
return d
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
tksim = [self.similarity(atks, btks) for btks in btkss]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
if isinstance(dtwt, type("")):
|
||||
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt))}
|
||||
if isinstance(qtwt, type("")):
|
||||
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt))}
|
||||
s = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
if k in dtwt:
|
||||
s += v * dtwt[k]
|
||||
q = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
q += v * v
|
||||
d = 1e-9
|
||||
for k, v in dtwt.items():
|
||||
d += v * v
|
||||
return s / math.sqrt(q) / math.sqrt(d)
|
||||
250
rag/nlp/search.py
Normal file
250
rag/nlp/search.py
Normal file
@ -0,0 +1,250 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
from elasticsearch_dsl import Q, Search, A
|
||||
from typing import List, Optional, Tuple, Dict, Union
|
||||
from dataclasses import dataclass
|
||||
from rag.utils import rmSpace
|
||||
from rag.nlp import huqie, query
|
||||
import numpy as np
|
||||
|
||||
|
||||
def index_name(uid): return f"docgpt_{uid}"
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, es, emb_mdl):
|
||||
self.qryr = query.EsQueryer(es)
|
||||
self.qryr.flds = [
|
||||
"title_tks^10",
|
||||
"title_sm_tks^5",
|
||||
"content_ltks^2",
|
||||
"content_sm_ltks"]
|
||||
self.es = es
|
||||
self.emb_mdl = emb_mdl
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
total: int
|
||||
ids: List[str]
|
||||
query_vector: List[float] = None
|
||||
field: Optional[Dict] = None
|
||||
highlight: Optional[Dict] = None
|
||||
aggregation: Union[List, Dict, None] = None
|
||||
keywords: Optional[List[str]] = None
|
||||
group_docs: List[List] = None
|
||||
|
||||
def _vector(self, txt, sim=0.8, topk=10):
|
||||
return {
|
||||
"field": "q_vec",
|
||||
"k": topk,
|
||||
"similarity": sim,
|
||||
"num_candidates": 1000,
|
||||
"query_vector": self.emb_mdl.encode_queries(txt)
|
||||
}
|
||||
|
||||
def search(self, req, idxnm, tks_num=3):
|
||||
keywords = []
|
||||
qst = req.get("question", "")
|
||||
|
||||
bqry, keywords = self.qryr.question(qst)
|
||||
if req.get("kb_ids"):
|
||||
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
||||
bqry.filter.append(Q("exists", field="q_tks"))
|
||||
bqry.boost = 0.05
|
||||
print(bqry)
|
||||
|
||||
s = Search()
|
||||
pg = int(req.get("page", 1)) - 1
|
||||
ps = int(req.get("size", 1000))
|
||||
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
|
||||
"image_id", "doc_id", "q_vec"])
|
||||
|
||||
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
|
||||
s = s.highlight("content_ltks")
|
||||
s = s.highlight("title_ltks")
|
||||
if not qst:
|
||||
s = s.sort(
|
||||
{"create_time": {"order": "desc", "unmapped_type": "date"}})
|
||||
|
||||
s = s.highlight_options(
|
||||
fragment_size=120,
|
||||
number_of_fragments=5,
|
||||
boundary_scanner_locale="zh-CN",
|
||||
boundary_scanner="SENTENCE",
|
||||
boundary_chars=",./;:\\!(),。?:!……()——、"
|
||||
)
|
||||
s = s.to_dict()
|
||||
q_vec = []
|
||||
if req.get("vector"):
|
||||
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
del s["highlight"]
|
||||
q_vec = s["knn"]["query_vector"]
|
||||
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
||||
print("TOTAL: ", self.es.getTotal(res))
|
||||
if self.es.getTotal(res) == 0 and "knn" in s:
|
||||
bqry, _ = self.qryr.question(qst, min_match="10%")
|
||||
if req.get("kb_ids"):
|
||||
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
||||
s["query"] = bqry.to_dict()
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
s["knn"]["similarity"] = 0.7
|
||||
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
||||
|
||||
kwds = set([])
|
||||
for k in keywords:
|
||||
kwds.add(k)
|
||||
for kk in huqie.qieqie(k).split(" "):
|
||||
if len(kk) < 2:
|
||||
continue
|
||||
if kk in kwds:
|
||||
continue
|
||||
kwds.add(kk)
|
||||
|
||||
aggs = self.getAggregation(res, "docnm_kwd")
|
||||
|
||||
return self.SearchResult(
|
||||
total=self.es.getTotal(res),
|
||||
ids=self.es.getDocIds(res),
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=self.getHighlight(res),
|
||||
field=self.getFields(res, ["docnm_kwd", "content_ltks",
|
||||
"kb_id", "image_id", "doc_id", "q_vec"]),
|
||||
keywords=list(kwds)
|
||||
)
|
||||
|
||||
def getAggregation(self, res, g):
|
||||
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
|
||||
return
|
||||
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
def getHighlight(self, res):
|
||||
def rmspace(line):
|
||||
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
||||
r = []
|
||||
for t in line.split(" "):
|
||||
if not t:
|
||||
continue
|
||||
if len(r) > 0 and len(
|
||||
t) > 0 and r[-1][-1] in eng and t[0] in eng:
|
||||
r.append(" ")
|
||||
r.append(t)
|
||||
r = "".join(r)
|
||||
return r
|
||||
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
if not hlts:
|
||||
continue
|
||||
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
|
||||
return ans
|
||||
|
||||
def getFields(self, sres, flds):
|
||||
res = {}
|
||||
if not flds:
|
||||
return {}
|
||||
for d in self.es.getSource(sres):
|
||||
m = {n: d.get(n) for n in flds if d.get(n) is not None}
|
||||
for n, v in m.items():
|
||||
if isinstance(v, type([])):
|
||||
m[n] = "\t".join([str(vv) for vv in v])
|
||||
continue
|
||||
if not isinstance(v, type("")):
|
||||
m[n] = str(m[n])
|
||||
m[n] = rmSpace(m[n])
|
||||
|
||||
if m:
|
||||
res[d["id"]] = m
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def trans2floats(txt):
|
||||
return [float(t) for t in txt.split("\t")]
|
||||
|
||||
def insert_citations(self, ans, top_idx, sres,
|
||||
vfield="q_vec", cfield="content_ltks"):
|
||||
|
||||
ins_embd = [Dealer.trans2floats(
|
||||
sres.field[sres.ids[i]][vfield]) for i in top_idx]
|
||||
ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
|
||||
s = 0
|
||||
e = 0
|
||||
res = ""
|
||||
|
||||
def citeit():
|
||||
nonlocal s, e, ans, res
|
||||
if not ins_embd:
|
||||
return
|
||||
embd = self.emb_mdl.encode(ans[s: e])
|
||||
sim = self.qryr.hybrid_similarity(embd,
|
||||
ins_embd,
|
||||
huqie.qie(ans[s:e]).split(" "),
|
||||
ins_tw)
|
||||
print(ans[s: e], sim)
|
||||
mx = np.max(sim) * 0.99
|
||||
if mx < 0.55:
|
||||
return
|
||||
cita = list(set([top_idx[i]
|
||||
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
|
||||
for i in cita:
|
||||
res += f"@?{i}?@"
|
||||
|
||||
return cita
|
||||
|
||||
punct = set(";。?!!")
|
||||
if not self.qryr.isChinese(ans):
|
||||
punct.add("?")
|
||||
punct.add(".")
|
||||
while e < len(ans):
|
||||
if e - s < 12 or ans[e] not in punct:
|
||||
e += 1
|
||||
continue
|
||||
if ans[e] == "." and e + \
|
||||
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
|
||||
e += 1
|
||||
continue
|
||||
if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
|
||||
e += 1
|
||||
continue
|
||||
res += ans[s: e]
|
||||
citeit()
|
||||
res += ans[e]
|
||||
e += 1
|
||||
s = e
|
||||
|
||||
if s < len(ans):
|
||||
res += ans[s:]
|
||||
citeit()
|
||||
|
||||
return res
|
||||
|
||||
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
|
||||
vfield="q_vec", cfield="content_ltks"):
|
||||
ins_embd = [
|
||||
Dealer.trans2floats(
|
||||
sres.field[i]["q_vec"]) for i in sres.ids]
|
||||
if not ins_embd:
|
||||
return []
|
||||
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
|
||||
# return CosineSimilarity([sres.query_vector], ins_embd)[0]
|
||||
sim = self.qryr.hybrid_similarity(sres.query_vector,
|
||||
ins_embd,
|
||||
huqie.qie(query).split(" "),
|
||||
ins_tw, tkweight, vtweight)
|
||||
return sim
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from util import es_conn
|
||||
SE = Dealer(es_conn.HuEs("infiniflow"))
|
||||
qs = [
|
||||
"胡凯",
|
||||
""
|
||||
]
|
||||
for q in qs:
|
||||
print(">>>>>>>>>>>>>>>>>>>>", q)
|
||||
print(SE.search(
|
||||
{"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
|
||||
64
rag/nlp/synonym.py
Normal file
64
rag/nlp/synonym.py
Normal file
@ -0,0 +1,64 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, redis=None):
|
||||
|
||||
self.lookup_num = 100000000
|
||||
self.load_tm = time.time() - 1000000
|
||||
self.dictionary = None
|
||||
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
|
||||
try:
|
||||
self.dictionary = json.load(open(path, 'r'))
|
||||
except Exception as e:
|
||||
logging.warn("Miss synonym.json")
|
||||
self.dictionary = {}
|
||||
|
||||
if not redis:
|
||||
logging.warning(
|
||||
"Realtime synonym is disabled, since no redis connection.")
|
||||
if not len(self.dictionary.keys()):
|
||||
logging.warning(f"Fail to load synonym")
|
||||
|
||||
self.redis = redis
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
if self.lookup_num < 100:
|
||||
return
|
||||
tm = time.time()
|
||||
if tm - self.load_tm < 3600:
|
||||
return
|
||||
|
||||
self.load_tm = time.time()
|
||||
self.lookup_num = 0
|
||||
d = self.redis.get("kevin_synonyms")
|
||||
if not d:
|
||||
return
|
||||
try:
|
||||
d = json.loads(d)
|
||||
self.dictionary = d
|
||||
except Exception as e:
|
||||
logging.error("Fail to load synonym!" + str(e))
|
||||
|
||||
def lookup(self, tk):
|
||||
self.lookup_num += 1
|
||||
self.load()
|
||||
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
|
||||
if isinstance(res, str):
|
||||
res = [res]
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dl = Dealer()
|
||||
print(dl.dictionary)
|
||||
216
rag/nlp/term_weight.py
Normal file
216
rag/nlp/term_weight.py
Normal file
@ -0,0 +1,216 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
from rag.nlp import huqie
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self):
|
||||
self.stop_words = set(["请问",
|
||||
"您",
|
||||
"你",
|
||||
"我",
|
||||
"他",
|
||||
"是",
|
||||
"的",
|
||||
"就",
|
||||
"有",
|
||||
"于",
|
||||
"及",
|
||||
"即",
|
||||
"在",
|
||||
"为",
|
||||
"最",
|
||||
"有",
|
||||
"从",
|
||||
"以",
|
||||
"了",
|
||||
"将",
|
||||
"与",
|
||||
"吗",
|
||||
"吧",
|
||||
"中",
|
||||
"#",
|
||||
"什么",
|
||||
"怎么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"啥",
|
||||
"相关"])
|
||||
|
||||
def load_dict(fnm):
|
||||
res = {}
|
||||
f = open(fnm, "r")
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:
|
||||
break
|
||||
arr = l.replace("\n", "").split("\t")
|
||||
if len(arr) < 2:
|
||||
res[arr[0]] = 0
|
||||
else:
|
||||
res[arr[0]] = int(arr[1])
|
||||
|
||||
c = 0
|
||||
for _, v in res.items():
|
||||
c += v
|
||||
if c == 0:
|
||||
return set(res.keys())
|
||||
return res
|
||||
|
||||
fnm = os.path.join(get_project_base_directory(), "res")
|
||||
self.ne, self.df = {}, {}
|
||||
try:
|
||||
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
|
||||
except Exception as e:
|
||||
print("[WARNING] Load ner.json FAIL!")
|
||||
try:
|
||||
self.df = load_dict(os.path.join(fnm, "term.freq"))
|
||||
except Exception as e:
|
||||
print("[WARNING] Load term.freq FAIL!")
|
||||
|
||||
def pretoken(self, txt, num=False, stpwd=True):
|
||||
patt = [
|
||||
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
|
||||
]
|
||||
rewt = [
|
||||
]
|
||||
for p, r in rewt:
|
||||
txt = re.sub(p, r, txt)
|
||||
|
||||
res = []
|
||||
for t in huqie.qie(txt).split(" "):
|
||||
tk = t
|
||||
if (stpwd and tk in self.stop_words) or (
|
||||
re.match(r"[0-9]$", tk) and not num):
|
||||
continue
|
||||
for p in patt:
|
||||
if re.match(p, t):
|
||||
tk = "#"
|
||||
break
|
||||
tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
if tk != "#" and tk:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
tks) > 1 and len(tks[i + 1]) > 1: # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
res.append(" ".join(tks[i:j]))
|
||||
i = j
|
||||
else:
|
||||
res.append(" ".join(tks[i:i + 2]))
|
||||
i = i + 2
|
||||
else:
|
||||
if len(tks[i]) > 0:
|
||||
res.append(tks[i])
|
||||
i += 1
|
||||
return [t for t in res if t]
|
||||
|
||||
def ner(self, t):
|
||||
if not self.ne:
|
||||
return ""
|
||||
res = self.ne.get(t, "")
|
||||
if res:
|
||||
return res
|
||||
|
||||
def split(self, txt):
|
||||
tks = []
|
||||
for t in re.sub(r"[ \t]+", " ", txt).split(" "):
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
tks[-1] = tks[-1] + " " + t
|
||||
else:
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
def weights(self, tks):
|
||||
def skill(t):
|
||||
if t not in self.sk:
|
||||
return 1
|
||||
return 6
|
||||
|
||||
def ner(t):
|
||||
if not self.ne or t not in self.ne:
|
||||
return 1
|
||||
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
|
||||
"firstnm": 1}
|
||||
return m[self.ne[t]]
|
||||
|
||||
def postag(t):
|
||||
t = huqie.tag(t)
|
||||
if t in set(["r", "c", "d"]):
|
||||
return 0.3
|
||||
if t in set(["ns", "nt"]):
|
||||
return 3
|
||||
if t in set(["n"]):
|
||||
return 2
|
||||
if re.match(r"[0-9-]+", t):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def freq(t):
|
||||
if re.match(r"[0-9\. -]+$", t):
|
||||
return 10000
|
||||
s = huqie.freq(t)
|
||||
if not s and re.match(r"[a-z\. -]+$", t):
|
||||
return 10
|
||||
if not s:
|
||||
s = 0
|
||||
|
||||
if not s and len(t) >= 4:
|
||||
s = [tt for tt in huqie.qieqie(t).split(" ") if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
s = np.min([freq(tt) for tt in s]) / 6.
|
||||
else:
|
||||
s = 0
|
||||
|
||||
return max(s, 10)
|
||||
|
||||
def df(t):
|
||||
if re.match(r"[0-9\. -]+$", t):
|
||||
return 100000
|
||||
if t in self.df:
|
||||
return self.df[t] + 3
|
||||
elif re.match(r"[a-z\. -]+$", t):
|
||||
return 3
|
||||
elif len(t) >= 4:
|
||||
s = [tt for tt in huqie.qieqie(t).split(" ") if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
return max(3, np.min([df(tt) for tt in s]) / 6.)
|
||||
|
||||
return 3
|
||||
|
||||
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
|
||||
tw = []
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tt])
|
||||
|
||||
tw.extend(zip(tt, wts))
|
||||
|
||||
S = np.sum([s for _, s in tw])
|
||||
return [(t, s / S) for t, s in tw]
|
||||
3
rag/parser/__init__.py
Normal file
3
rag/parser/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .pdf_parser import HuParser as PdfParser
|
||||
from .docx_parser import HuDocxParser as DocxParser
|
||||
from .excel_parser import HuExcelParser as ExcelParser
|
||||
105
rag/parser/docx_parser.py
Normal file
105
rag/parser/docx_parser.py
Normal file
@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from docx import Document
|
||||
import re
|
||||
import pandas as pd
|
||||
from collections import Counter
|
||||
from rag.nlp import huqie
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuDocxParser:
|
||||
|
||||
def __extract_table_content(self, tb):
|
||||
df = []
|
||||
for row in tb.rows:
|
||||
df.append([c.text for c in row.cells])
|
||||
return self.__compose_table_content(pd.DataFrame(df))
|
||||
|
||||
def __compose_table_content(self, df):
|
||||
|
||||
def blockType(b):
|
||||
patt = [
|
||||
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[年/-][0-9]{1,2}月*$", "Dt"),
|
||||
("^[0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^第*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[ABCDE]$", "DT"),
|
||||
("^[0-9.,+%/ -]+$", "Nu"),
|
||||
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
||||
(r"^[A-Z]*[a-z' -]+$", "En"),
|
||||
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
||||
(r"^.{1}$", "Sg")
|
||||
]
|
||||
for p, n in patt:
|
||||
if re.search(p, b):
|
||||
return n
|
||||
tks = [t for t in huqie.qie(b).split(" ") if len(t) > 1]
|
||||
if len(tks) > 3:
|
||||
if len(tks) < 12:
|
||||
return "Tx"
|
||||
else:
|
||||
return "Lx"
|
||||
|
||||
if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
|
||||
return "Nr"
|
||||
|
||||
return "Ot"
|
||||
|
||||
if len(df) < 2:
|
||||
return []
|
||||
max_type = Counter([blockType(str(df.iloc[i, j])) for i in range(
|
||||
1, len(df)) for j in range(len(df.iloc[i, :]))])
|
||||
max_type = max(max_type.items(), key=lambda x: x[1])[0]
|
||||
|
||||
colnm = len(df.iloc[0, :])
|
||||
hdrows = [0] # header is not nessesarily appear in the first line
|
||||
if max_type == "Nu":
|
||||
for r in range(1, len(df)):
|
||||
tys = Counter([blockType(str(df.iloc[r, j]))
|
||||
for j in range(len(df.iloc[r, :]))])
|
||||
tys = max(tys.items(), key=lambda x: x[1])[0]
|
||||
if tys != max_type:
|
||||
hdrows.append(r)
|
||||
|
||||
lines = []
|
||||
for i in range(1, len(df)):
|
||||
if i in hdrows:
|
||||
continue
|
||||
hr = [r - i for r in hdrows]
|
||||
hr = [r for r in hr if r < 0]
|
||||
t = len(hr) - 1
|
||||
while t > 0:
|
||||
if hr[t] - hr[t - 1] > 1:
|
||||
hr = hr[t:]
|
||||
break
|
||||
t -= 1
|
||||
headers = []
|
||||
for j in range(len(df.iloc[i, :])):
|
||||
t = []
|
||||
for h in hr:
|
||||
x = str(df.iloc[i + h, j]).strip()
|
||||
if x in t:
|
||||
continue
|
||||
t.append(x)
|
||||
t = ",".join(t)
|
||||
if t:
|
||||
t += ": "
|
||||
headers.append(t)
|
||||
cells = []
|
||||
for j in range(len(df.iloc[i, :])):
|
||||
if not str(df.iloc[i, j]):
|
||||
continue
|
||||
cells.append(headers[j] + str(df.iloc[i, j]))
|
||||
lines.append(";".join(cells))
|
||||
|
||||
if colnm > 3:
|
||||
return lines
|
||||
return ["\n".join(lines)]
|
||||
|
||||
def __call__(self, fnm):
|
||||
self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm))
|
||||
secs = [(p.text, p.style.name) for p in self.doc.paragraphs]
|
||||
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
|
||||
return secs, tbls
|
||||
33
rag/parser/excel_parser.py
Normal file
33
rag/parser/excel_parser.py
Normal file
@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from openpyxl import load_workbook
|
||||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuExcelParser:
|
||||
def __call__(self, fnm):
|
||||
if isinstance(fnm, str):
|
||||
wb = load_workbook(fnm)
|
||||
else:
|
||||
wb = load_workbook(BytesIO(fnm))
|
||||
res = []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
rows = list(ws.rows)
|
||||
ti = list(rows[0])
|
||||
for r in list(rows[1:]):
|
||||
l = []
|
||||
for i,c in enumerate(r):
|
||||
if not c.value:continue
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
l.append(t)
|
||||
l = "; ".join(l)
|
||||
if sheetname.lower().find("sheet") <0: l += " ——"+sheetname
|
||||
res.append(l)
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
psr = HuExcelParser()
|
||||
psr(sys.argv[1])
|
||||
1638
rag/parser/pdf_parser.py
Normal file
1638
rag/parser/pdf_parser.py
Normal file
File diff suppressed because it is too large
Load Diff
555629
rag/res/huqie.txt
Normal file
555629
rag/res/huqie.txt
Normal file
File diff suppressed because it is too large
Load Diff
12519
rag/res/ner.json
Normal file
12519
rag/res/ner.json
Normal file
File diff suppressed because it is too large
Load Diff
10539
rag/res/synonym.json
Normal file
10539
rag/res/synonym.json
Normal file
File diff suppressed because it is too large
Load Diff
37
rag/settings.py
Normal file
37
rag/settings.py
Normal file
@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
from web_server.utils import get_base_config,decrypt_database_config
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
from web_server.utils.log_utils import LoggerFactory, getLogger
|
||||
|
||||
|
||||
# Server
|
||||
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
|
||||
SUBPROCESS_STD_LOG_NAME = "std.log"
|
||||
|
||||
ES = get_base_config("es", {})
|
||||
MINIO = decrypt_database_config(name="minio")
|
||||
DOC_MAXIMUM_SIZE = 64 * 1024 * 1024
|
||||
|
||||
# Logger
|
||||
LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag"))
|
||||
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
|
||||
LoggerFactory.LEVEL = 10
|
||||
|
||||
es_logger = getLogger("es")
|
||||
minio_logger = getLogger("minio")
|
||||
cron_logger = getLogger("cron_logger")
|
||||
279
rag/svr/parse_user_docs.py
Normal file
279
rag/svr/parse_user_docs.py
Normal file
@ -0,0 +1,279 @@
|
||||
#
|
||||
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import copy
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from rag.llm import EmbeddingModel, CvModel
|
||||
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
||||
from rag.utils import ELASTICSEARCH, num_tokens_from_string
|
||||
from rag.utils import MINIO
|
||||
from rag.utils import rmSpace, findMaxDt
|
||||
from rag.nlp import huchunk, huqie, search
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
from elasticsearch_dsl import Q
|
||||
from PIL import Image
|
||||
from rag.parser import (
|
||||
PdfParser,
|
||||
DocxParser,
|
||||
ExcelParser
|
||||
)
|
||||
from rag.nlp.huchunk import (
|
||||
PdfChunker,
|
||||
DocxChunker,
|
||||
ExcelChunker,
|
||||
PptChunker,
|
||||
TextChunker
|
||||
)
|
||||
from web_server.db import LLMType
|
||||
from web_server.db.services.document_service import DocumentService
|
||||
from web_server.db.services.llm_service import TenantLLMService
|
||||
from web_server.utils import get_format_time
|
||||
from web_server.utils.file_utils import get_project_base_directory
|
||||
|
||||
BATCH_SIZE = 64
|
||||
|
||||
PDF = PdfChunker(PdfParser())
|
||||
DOC = DocxChunker(DocxParser())
|
||||
EXC = ExcelChunker(ExcelParser())
|
||||
PPT = PptChunker()
|
||||
|
||||
|
||||
def chuck_doc(name, binary, cvmdl=None):
|
||||
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
||||
if suff.find("pdf") >= 0:
|
||||
return PDF(binary)
|
||||
if suff.find("doc") >= 0:
|
||||
return DOC(binary)
|
||||
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff):
|
||||
return EXC(binary)
|
||||
if suff.find("ppt") >= 0:
|
||||
return PPT(binary)
|
||||
if cvmdl and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
|
||||
name.lower()):
|
||||
txt = cvmdl.describe(binary)
|
||||
field = TextChunker.Fields()
|
||||
field.text_chunks = [(txt, binary)]
|
||||
field.table_chunks = []
|
||||
|
||||
return TextChunker()(binary)
|
||||
|
||||
|
||||
def collect(comm, mod, tm):
|
||||
docs = DocumentService.get_newly_uploaded(tm, mod, comm)
|
||||
if len(docs) == 0:
|
||||
return pd.DataFrame()
|
||||
docs = pd.DataFrame(docs)
|
||||
mtm = str(docs["update_time"].max())[:19]
|
||||
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
||||
return docs
|
||||
|
||||
|
||||
def set_progress(docid, prog, msg="Processing...", begin=False):
|
||||
d = {"progress": prog, "progress_msg": msg}
|
||||
if begin:
|
||||
d["process_begin_at"] = get_format_time()
|
||||
try:
|
||||
DocumentService.update_by_id(
|
||||
docid, {"progress": prog, "progress_msg": msg})
|
||||
except Exception as e:
|
||||
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
||||
|
||||
|
||||
def build(row):
|
||||
if row["size"] > DOC_MAXIMUM_SIZE:
|
||||
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
||||
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
||||
return []
|
||||
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
||||
if ELASTICSEARCH.getTotal(res) > 0:
|
||||
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
||||
scripts="""
|
||||
if(!ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.add('%s');
|
||||
""" % (str(row["kb_id"]), str(row["kb_id"])),
|
||||
idxnm=search.index_name(row["tenant_id"])
|
||||
)
|
||||
set_progress(row["id"], 1, "Done")
|
||||
return []
|
||||
|
||||
random.seed(time.time())
|
||||
set_progress(row["id"], random.randint(0, 20) /
|
||||
100., "Finished preparing! Start to slice file!", True)
|
||||
try:
|
||||
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]))
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
set_progress(
|
||||
row["id"], -1, "Can not find file <%s>" %
|
||||
row["doc_name"])
|
||||
else:
|
||||
set_progress(
|
||||
row["id"], -1, f"Internal server error: %s" %
|
||||
str(e).replace(
|
||||
"'", ""))
|
||||
return []
|
||||
|
||||
if not obj.text_chunks and not obj.table_chunks:
|
||||
set_progress(
|
||||
row["id"],
|
||||
1,
|
||||
"Nothing added! Mostly, file type unsupported yet.")
|
||||
return []
|
||||
|
||||
set_progress(row["id"], random.randint(20, 60) / 100.,
|
||||
"Finished slicing files. Start to embedding the content.")
|
||||
|
||||
doc = {
|
||||
"doc_id": row["did"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": os.path.split(row["location"])[-1],
|
||||
"title_tks": huqie.qie(row["name"]),
|
||||
"updated_at": str(row["update_time"]).replace("T", " ")[:19]
|
||||
}
|
||||
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
||||
output_buffer = BytesIO()
|
||||
docs = []
|
||||
md5 = hashlib.md5()
|
||||
for txt, img in obj.text_chunks:
|
||||
d = copy.deepcopy(doc)
|
||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["content_ltks"] = huqie.qie(txt)
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
if not img:
|
||||
docs.append(d)
|
||||
continue
|
||||
|
||||
if isinstance(img, Image):
|
||||
img.save(output_buffer, format='JPEG')
|
||||
else:
|
||||
output_buffer = BytesIO(img)
|
||||
|
||||
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||
docs.append(d)
|
||||
|
||||
for arr, img in obj.table_chunks:
|
||||
for i, txt in enumerate(arr):
|
||||
d = copy.deepcopy(doc)
|
||||
d["content_ltks"] = huqie.qie(txt)
|
||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
if not img:
|
||||
docs.append(d)
|
||||
continue
|
||||
img.save(output_buffer, format='JPEG')
|
||||
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
||||
docs.append(d)
|
||||
set_progress(row["id"], random.randint(60, 70) /
|
||||
100., "Continue embedding the content.")
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def init_kb(row):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
if ELASTICSEARCH.indexExist(idxnm):
|
||||
return
|
||||
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
||||
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
||||
|
||||
|
||||
def embedding(docs, mdl):
|
||||
tts, cnts = [rmSpace(d["title_tks"]) for d in docs], [rmSpace(d["content_ltks"]) for d in docs]
|
||||
tk_count = 0
|
||||
tts, c = mdl.encode(tts)
|
||||
tk_count += c
|
||||
cnts, c = mdl.encode(cnts)
|
||||
tk_count += c
|
||||
vects = 0.1 * tts + 0.9 * cnts
|
||||
assert len(vects) == len(docs)
|
||||
for i, d in enumerate(docs):
|
||||
d["q_vec"] = vects[i].tolist()
|
||||
return tk_count
|
||||
|
||||
|
||||
def model_instance(tenant_id, llm_type):
|
||||
model_config = TenantLLMService.query(tenant_id=tenant_id, model_type=LLMType.EMBEDDING)
|
||||
if not model_config:return
|
||||
model_config = model_config[0]
|
||||
if llm_type == LLMType.EMBEDDING:
|
||||
if model_config.llm_factory not in EmbeddingModel: return
|
||||
return EmbeddingModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
||||
if llm_type == LLMType.IMAGE2TEXT:
|
||||
if model_config.llm_factory not in CvModel: return
|
||||
return CvModel[model_config.llm_factory](model_config.api_key, model_config.llm_name)
|
||||
|
||||
|
||||
def main(comm, mod):
|
||||
global model
|
||||
from rag.llm import HuEmbedding
|
||||
model = HuEmbedding()
|
||||
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
||||
tm = findMaxDt(tm_fnm)
|
||||
rows = collect(comm, mod, tm)
|
||||
if len(rows) == 0:
|
||||
return
|
||||
|
||||
tmf = open(tm_fnm, "a+")
|
||||
for _, r in rows.iterrows():
|
||||
embd_mdl = model_instance(r["tenant_id"], LLMType.EMBEDDING)
|
||||
if not embd_mdl:
|
||||
set_progress(r["id"], -1, "Can't find embedding model!")
|
||||
cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
|
||||
continue
|
||||
cv_mdl = model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
st_tm = timer()
|
||||
cks = build(r, cv_mdl)
|
||||
if not cks:
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
continue
|
||||
# TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
try:
|
||||
tk_count = embedding(cks, embd_mdl)
|
||||
except Exception as e:
|
||||
set_progress(r["id"], -1, "Embedding error:{}".format(str(e)))
|
||||
cron_logger.error(str(e))
|
||||
continue
|
||||
|
||||
|
||||
set_progress(r["id"], random.randint(70, 95) / 100.,
|
||||
"Finished embedding! Start to build index!")
|
||||
init_kb(r)
|
||||
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
|
||||
if es_r:
|
||||
set_progress(r["id"], -1, "Index failure!")
|
||||
cron_logger.error(str(es_r))
|
||||
else:
|
||||
set_progress(r["id"], 1., "Done!")
|
||||
DocumentService.update_by_id(r["id"], {"token_num": tk_count, "chunk_num": len(cks), "process_duation": timer()-st_tm})
|
||||
tmf.write(str(r["update_time"]) + "\n")
|
||||
tmf.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
main(comm.Get_size(), comm.Get_rank())
|
||||
47
rag/utils/__init__.py
Normal file
47
rag/utils/__init__.py
Normal file
@ -0,0 +1,47 @@
|
||||
import os
|
||||
import re
|
||||
import tiktoken
|
||||
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
instances = {}
|
||||
|
||||
def _singleton():
|
||||
key = str(cls) + str(os.getpid())
|
||||
if key not in instances:
|
||||
instances[key] = cls(*args, **kw)
|
||||
return instances[key]
|
||||
|
||||
return _singleton
|
||||
|
||||
|
||||
from .minio_conn import MINIO
|
||||
from .es_conn import ELASTICSEARCH
|
||||
|
||||
def rmSpace(txt):
|
||||
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
||||
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
||||
|
||||
|
||||
def findMaxDt(fnm):
|
||||
m = "1970-01-01 00:00:00"
|
||||
try:
|
||||
with open(fnm, "r") as f:
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:
|
||||
break
|
||||
l = l.strip("\n")
|
||||
if l == 'nan':
|
||||
continue
|
||||
if l > m:
|
||||
m = l
|
||||
except Exception as e:
|
||||
print("WARNING: can't find " + fnm)
|
||||
return m
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
420
rag/utils/es_conn.py
Normal file
420
rag/utils/es_conn.py
Normal file
@ -0,0 +1,420 @@
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import copy
|
||||
import elasticsearch
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch_dsl import UpdateByQuery, Search, Index
|
||||
from rag.settings import es_logger
|
||||
from rag import settings
|
||||
from rag.utils import singleton
|
||||
|
||||
es_logger.info("Elasticsearch version: "+ str(elasticsearch.__version__))
|
||||
|
||||
|
||||
@singleton
|
||||
class HuEs:
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
self.conn()
|
||||
self.idxnm = settings.ES.get("index_name", "")
|
||||
if not self.es.ping():
|
||||
raise Exception("Can't connect to ES cluster")
|
||||
|
||||
def conn(self):
|
||||
for _ in range(10):
|
||||
try:
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
timeout=600
|
||||
)
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
es_logger.info("Connect to es.")
|
||||
break
|
||||
except Exception as e:
|
||||
es_logger.error("Fail to connect to es: " + str(e))
|
||||
time.sleep(1)
|
||||
|
||||
def version(self):
|
||||
v = self.info.get("version", {"number": "5.6"})
|
||||
v = v["number"].split(".")[0]
|
||||
return int(v) >= 7
|
||||
|
||||
def upsert(self, df, idxnm=""):
|
||||
res = []
|
||||
for d in df:
|
||||
id = d["id"]
|
||||
del d["id"]
|
||||
d = {"doc": d, "doc_as_upsert": "true"}
|
||||
T = False
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.update(
|
||||
index=(
|
||||
self.idxnm if not idxnm else idxnm),
|
||||
body=d,
|
||||
id=id,
|
||||
doc_type="doc",
|
||||
refresh=False,
|
||||
retry_on_conflict=100)
|
||||
else:
|
||||
r = self.es.update(
|
||||
index=(
|
||||
self.idxnm if not idxnm else idxnm),
|
||||
body=d,
|
||||
id=id,
|
||||
refresh=False,
|
||||
doc_type="_doc",
|
||||
retry_on_conflict=100)
|
||||
es_logger.info("Successfully upsert: %s" % id)
|
||||
T = True
|
||||
break
|
||||
except Exception as e:
|
||||
es_logger.warning("Fail to index: " +
|
||||
json.dumps(d, ensure_ascii=False) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
T = False
|
||||
|
||||
if not T:
|
||||
res.append(d)
|
||||
es_logger.error(
|
||||
"Fail to index: " +
|
||||
re.sub(
|
||||
"[\r\n]",
|
||||
"",
|
||||
json.dumps(
|
||||
d,
|
||||
ensure_ascii=False)))
|
||||
d["id"] = id
|
||||
d["_index"] = self.idxnm
|
||||
|
||||
if not res:
|
||||
return True
|
||||
return False
|
||||
|
||||
def bulk(self, df, idx_nm=None):
|
||||
ids, acts = {}, []
|
||||
for d in df:
|
||||
id = d["id"] if "id" in d else d["_id"]
|
||||
ids[id] = copy.deepcopy(d)
|
||||
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
|
||||
if "id" in d:
|
||||
del d["id"]
|
||||
if "_id" in d:
|
||||
del d["_id"]
|
||||
acts.append(
|
||||
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
|
||||
acts.append({"doc": d, "doc_as_upsert": "true"})
|
||||
|
||||
res = []
|
||||
for _ in range(100):
|
||||
try:
|
||||
if elasticsearch.__version__[0] < 8:
|
||||
r = self.es.bulk(
|
||||
index=(
|
||||
self.idxnm if not idx_nm else idx_nm),
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s")
|
||||
else:
|
||||
r = self.es.bulk(index=(self.idxnm if not idx_nm else
|
||||
idx_nm), operations=acts,
|
||||
refresh=False, timeout="600s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for it in r["items"]:
|
||||
if "error" in it["update"]:
|
||||
res.append(str(it["update"]["_id"]) +
|
||||
":" + str(it["update"]["error"]))
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
es_logger.warn("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return res
|
||||
|
||||
def bulk4script(self, df):
|
||||
ids, acts = {}, []
|
||||
for d in df:
|
||||
id = d["id"]
|
||||
ids[id] = copy.deepcopy(d["raw"])
|
||||
acts.append({"update": {"_id": id, "_index": self.idxnm}})
|
||||
acts.append(d["script"])
|
||||
es_logger.info("bulk upsert: %s" % id)
|
||||
|
||||
res = []
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.bulk(
|
||||
index=self.idxnm,
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s",
|
||||
doc_type="doc")
|
||||
else:
|
||||
r = self.es.bulk(
|
||||
index=self.idxnm,
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for it in r["items"]:
|
||||
if "error" in it["update"]:
|
||||
res.append(str(it["update"]["_id"]))
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
es_logger.warning("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return res
|
||||
|
||||
def rm(self, d):
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.delete(
|
||||
index=self.idxnm,
|
||||
id=d["id"],
|
||||
doc_type="doc",
|
||||
refresh=True)
|
||||
else:
|
||||
r = self.es.delete(
|
||||
index=self.idxnm,
|
||||
id=d["id"],
|
||||
refresh=True,
|
||||
doc_type="_doc")
|
||||
es_logger.info("Remove %s" % d["id"])
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.warn("Fail to delete: " + str(d) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return True
|
||||
self.conn()
|
||||
|
||||
es_logger.error("Fail to delete: " + str(d))
|
||||
|
||||
return False
|
||||
|
||||
def search(self, q, idxnm=None, src=False, timeout="2s"):
|
||||
if not isinstance(q, dict):
|
||||
q = Search().query(q).to_dict()
|
||||
for i in range(3):
|
||||
try:
|
||||
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
|
||||
body=q,
|
||||
timeout=timeout,
|
||||
# search_type="dfs_query_then_fetch",
|
||||
track_total_hits=True,
|
||||
_source=src)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
return res
|
||||
except Exception as e:
|
||||
es_logger.error(
|
||||
"ES search exception: " +
|
||||
str(e) +
|
||||
"【Q】:" +
|
||||
str(q))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
raise e
|
||||
es_logger.error("ES search timeout for 3 times!")
|
||||
raise Exception("ES search timeout.")
|
||||
|
||||
def updateByQuery(self, q, d):
|
||||
ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
|
||||
scripts = ""
|
||||
for k, v in d.items():
|
||||
scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
|
||||
ubq = ubq.script(source=scripts, params=d)
|
||||
ubq = ubq.params(refresh=False)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for i in range(3):
|
||||
try:
|
||||
r = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return False
|
||||
|
||||
def updateScriptByQuery(self, q, scripts, idxnm=None):
|
||||
ubq = UpdateByQuery(
|
||||
index=self.idxnm if not idxnm else idxnm).using(
|
||||
self.es).query(q)
|
||||
ubq = ubq.script(source=scripts)
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for i in range(3):
|
||||
try:
|
||||
r = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return False
|
||||
|
||||
def deleteByQuery(self, query, idxnm=""):
|
||||
for i in range(3):
|
||||
try:
|
||||
r = self.es.delete_by_query(
|
||||
index=idxnm if idxnm else self.idxnm,
|
||||
body=Search().query(query).to_dict())
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery deleteByQuery: " +
|
||||
str(e) + "【Q】:" + str(query.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def update(self, id, script, routing=None):
|
||||
for i in range(3):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.update(
|
||||
index=self.idxnm,
|
||||
id=id,
|
||||
body=json.dumps(
|
||||
script,
|
||||
ensure_ascii=False),
|
||||
doc_type="doc",
|
||||
routing=routing,
|
||||
refresh=False)
|
||||
else:
|
||||
r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
|
||||
routing=routing, refresh=False) # , doc_type="_doc")
|
||||
return True
|
||||
except Exception as e:
|
||||
es_logger.error(
|
||||
"ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
||||
json.dumps(script, ensure_ascii=False))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def indexExist(self, idxnm):
|
||||
s = Index(idxnm if idxnm else self.idxnm, self.es)
|
||||
for i in range(3):
|
||||
try:
|
||||
return s.exists()
|
||||
except Exception as e:
|
||||
es_logger.error("ES updateByQuery indexExist: " + str(e))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def docExist(self, docid, idxnm=None):
|
||||
for i in range(3):
|
||||
try:
|
||||
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
|
||||
id=docid)
|
||||
except Exception as e:
|
||||
es_logger.error("ES Doc Exist: " + str(e))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
return False
|
||||
|
||||
def createIdx(self, idxnm, mapping):
|
||||
try:
|
||||
if elasticsearch.__version__[0] < 8:
|
||||
return self.es.indices.create(idxnm, body=mapping)
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=idxnm,
|
||||
settings=mapping["settings"],
|
||||
mappings=mapping["mappings"])
|
||||
except Exception as e:
|
||||
es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
||||
|
||||
def deleteIdx(self, idxnm):
|
||||
try:
|
||||
return self.es.indices.delete(idxnm, allow_no_indices=True)
|
||||
except Exception as e:
|
||||
es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
||||
|
||||
def getTotal(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getDocIds(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def getSource(self, res):
|
||||
rr = []
|
||||
for d in res["hits"]["hits"]:
|
||||
d["_source"]["id"] = d["_id"]
|
||||
d["_source"]["_score"] = d["_score"]
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
||||
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
||||
for _ in range(100):
|
||||
try:
|
||||
page = self.es.search(
|
||||
index=self.idxnm,
|
||||
scroll=scroll_time,
|
||||
size=pagesize,
|
||||
body=q,
|
||||
_source=None
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
es_logger.error("ES scrolling fail. " + str(e))
|
||||
time.sleep(3)
|
||||
|
||||
sid = page['_scroll_id']
|
||||
scroll_size = page['hits']['total']["value"]
|
||||
es_logger.info("[TOTAL]%d" % scroll_size)
|
||||
# Start scrolling
|
||||
while scroll_size > 0:
|
||||
yield page["hits"]["hits"]
|
||||
for _ in range(100):
|
||||
try:
|
||||
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
|
||||
break
|
||||
except Exception as e:
|
||||
es_logger.error("ES scrolling fail. " + str(e))
|
||||
time.sleep(3)
|
||||
|
||||
# Update the scroll ID
|
||||
sid = page['_scroll_id']
|
||||
# Get the number of results that we returned in the last scroll
|
||||
scroll_size = len(page['hits']['hits'])
|
||||
|
||||
|
||||
ELASTICSEARCH = HuEs()
|
||||
102
rag/utils/minio_conn.py
Normal file
102
rag/utils/minio_conn.py
Normal file
@ -0,0 +1,102 @@
|
||||
import os
|
||||
import time
|
||||
from minio import Minio
|
||||
from io import BytesIO
|
||||
from rag import settings
|
||||
from rag.settings import minio_logger
|
||||
from rag.utils import singleton
|
||||
|
||||
|
||||
@singleton
|
||||
class HuMinio(object):
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
try:
|
||||
if self.conn:
|
||||
self.__close__()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = Minio(settings.MINIO["host"],
|
||||
access_key=settings.MINIO["user"],
|
||||
secret_key=settings.MINIO["passwd"],
|
||||
secure=False
|
||||
)
|
||||
except Exception as e:
|
||||
minio_logger.error(
|
||||
"Fail to connect %s " % settings.MINIO["host"] + str(e))
|
||||
|
||||
def __close__(self):
|
||||
del self.conn
|
||||
self.conn = None
|
||||
|
||||
def put(self, bucket, fnm, binary):
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
self.conn.make_bucket(bucket)
|
||||
|
||||
r = self.conn.put_object(bucket, fnm,
|
||||
BytesIO(binary),
|
||||
len(binary)
|
||||
)
|
||||
return r
|
||||
except Exception as e:
|
||||
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
try:
|
||||
self.conn.remove_object(bucket, fnm)
|
||||
except Exception as e:
|
||||
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
|
||||
|
||||
|
||||
def get(self, bucket, fnm):
|
||||
for _ in range(10):
|
||||
try:
|
||||
r = self.conn.get_object(bucket, fnm)
|
||||
return r.read()
|
||||
except Exception as e:
|
||||
minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
def obj_exist(self, bucket, fnm):
|
||||
try:
|
||||
if self.conn.stat_object(bucket, fnm):return True
|
||||
return False
|
||||
except Exception as e:
|
||||
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||
return False
|
||||
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||
except Exception as e:
|
||||
minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
MINIO = HuMinio()
|
||||
|
||||
if __name__ == "__main__":
|
||||
conn = HuMinio()
|
||||
fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg"
|
||||
from PIL import Image
|
||||
img = Image.open(fnm)
|
||||
buff = BytesIO()
|
||||
img.save(buff, format='JPEG')
|
||||
print(conn.put("test", "11-408.jpg", buff.getvalue()))
|
||||
bts = conn.get("test", "11-408.jpg")
|
||||
img = Image.open(BytesIO(bts))
|
||||
img.save("test.jpg")
|
||||
Reference in New Issue
Block a user