build python version rag-flow (#21)

* clean rust version project

* clean rust version project

* build python version rag-flow
This commit is contained in:
KevinHuSh
2024-01-15 08:46:22 +08:00
committed by GitHub
parent db8cae3f1e
commit 30791976d5
123 changed files with 4985 additions and 4239 deletions

29
python/Dockerfile Normal file
View File

@ -0,0 +1,29 @@
FROM ubuntu:22.04 as base
RUN apt-get update
ENV TZ="Asia/Taipei"
RUN apt-get install -yq \
build-essential \
curl \
libncursesw5-dev \
libssl-dev \
libsqlite3-dev \
libgdbm-dev \
libc6-dev \
libbz2-dev \
software-properties-common \
python3.11 python3.11-dev python3-pip
RUN apt-get install -yq git
RUN pip3 config set global.index-url https://mirror.baidu.com/pypi/simple
RUN pip3 config set global.trusted-host mirror.baidu.com
RUN pip3 install --upgrade pip
RUN pip3 install torch==2.0.1
RUN pip3 install torch-model-archiver==0.8.2
RUN pip3 install torchvision==0.15.2
COPY requirements.txt .
WORKDIR /docgpt
ENV PYTHONPATH=/docgpt/

View File

@ -1,22 +0,0 @@
```shell
docker pull postgres
LOCAL_POSTGRES_DATA=./postgres-data
docker run
--name docass-postgres
-p 5455:5432
-v $LOCAL_POSTGRES_DATA:/var/lib/postgresql/data
-e POSTGRES_USER=root
-e POSTGRES_PASSWORD=infiniflow_docass
-e POSTGRES_DB=docass
-d
postgres
docker network create elastic
docker pull elasticsearch:8.11.3;
docker pull docker.elastic.co/kibana/kibana:8.11.3
```

63
python/] Normal file
View File

@ -0,0 +1,63 @@
from abc import ABC
from openai import OpenAI
import os
import base64
from io import BytesIO
class Base(ABC):
def describe(self, image, max_tokens=300):
raise NotImplementedError("Please implement encode method!")
class GptV4(Base):
def __init__(self):
import openapi
openapi.api_key = os.environ["OPENAPI_KEY"]
self.client = OpenAI()
def describe(self, image, max_tokens=300):
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
image.save(buffered, format="PNG")
b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
res = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
],
}
],
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip()
class QWen(Base):
def chat(self, system, history, gen_conf):
from http import HTTPStatus
from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
response = Generation.call(
Generation.Models.qwen_turbo,
messages=messages,
result_format='message'
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message

View File

@ -1,41 +0,0 @@
{
"version":1,
"disable_existing_loggers":false,
"formatters":{
"simple":{
"format":"%(asctime)s - %(name)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s"
}
},
"handlers":{
"console":{
"class":"logging.StreamHandler",
"level":"DEBUG",
"formatter":"simple",
"stream":"ext://sys.stdout"
},
"info_file_handler":{
"class":"logging.handlers.TimedRotatingFileHandler",
"level":"INFO",
"formatter":"simple",
"filename":"log/info.log",
"when": "MIDNIGHT",
"interval":1,
"backupCount":30,
"encoding":"utf8"
},
"error_file_handler":{
"class":"logging.handlers.TimedRotatingFileHandler",
"level":"ERROR",
"formatter":"simple",
"filename":"log/errors.log",
"when": "MIDNIGHT",
"interval":1,
"backupCount":30,
"encoding":"utf8"
}
},
"root":{
"level":"DEBUG",
"handlers":["console","info_file_handler","error_file_handler"]
}
}

View File

@ -1,139 +0,0 @@
{
"settings": {
"index": {
"number_of_shards": 4,
"number_of_replicas": 0,
"refresh_interval" : "1000ms"
},
"similarity": {
"scripted_sim": {
"type": "scripted",
"script": {
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
}
}
}
},
"mappings": {
"properties": {
"lat_lon": {"type": "geo_point", "store":"true"}
},
"date_detection": "true",
"dynamic_templates": [
{
"int": {
"match": "*_int",
"mapping": {
"type": "integer",
"store": "true"
}
}
},
{
"numeric": {
"match": "*_flt",
"mapping": {
"type": "float",
"store": true
}
}
},
{
"tks": {
"match": "*_tks",
"mapping": {
"type": "text",
"similarity": "scripted_sim",
"analyzer": "whitespace",
"store": true
}
}
},
{
"ltks":{
"match": "*_ltks",
"mapping": {
"type": "text",
"analyzer": "whitespace",
"store": true
}
}
},
{
"kwd": {
"match_pattern": "regex",
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
"mapping": {
"type": "keyword",
"similarity": "boolean",
"store": true
}
}
},
{
"dt": {
"match_pattern": "regex",
"match": "^.*(_dt|_time|_at)$",
"mapping": {
"type": "date",
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
"store": true
}
}
},
{
"nested": {
"match": "*_nst",
"mapping": {
"type": "nested"
}
}
},
{
"object": {
"match": "*_obj",
"mapping": {
"type": "object",
"dynamic": "true"
}
}
},
{
"string": {
"match": "*_with_weight",
"mapping": {
"type": "text",
"index": "false",
"store": true
}
}
},
{
"string": {
"match": "*_fea",
"mapping": {
"type": "rank_feature"
}
}
},
{
"dense_vector": {
"match": "*_vec",
"mapping": {
"type": "dense_vector",
"index": true,
"similarity": "cosine"
}
}
},
{
"binary": {
"match": "*_bin",
"mapping": {
"type": "binary"
}
}
}
]
}
}

View File

@ -1,9 +0,0 @@
[infiniflow]
es=http://es01:9200
postgres_user=root
postgres_password=infiniflow_docgpt
postgres_host=postgres
postgres_port=5432
minio_host=minio:9000
minio_user=infiniflow
minio_password=infiniflow_docgpt

View File

@ -1,21 +0,0 @@
import os
from .embedding_model import *
from .chat_model import *
from .cv_model import *
EmbeddingModel = None
ChatModel = None
CvModel = None
if os.environ.get("OPENAI_API_KEY"):
EmbeddingModel = GptEmbed()
ChatModel = GptTurbo()
CvModel = GptV4()
elif os.environ.get("DASHSCOPE_API_KEY"):
EmbeddingModel = QWenEmbd()
ChatModel = QWenChat()
CvModel = QWenCV()
else:
EmbeddingModel = HuEmbedding()

View File

@ -1,37 +0,0 @@
from abc import ABC
from openai import OpenAI
import os
class Base(ABC):
def chat(self, system, history, gen_conf):
raise NotImplementedError("Please implement encode method!")
class GptTurbo(Base):
def __init__(self):
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
def chat(self, system, history, gen_conf):
history.insert(0, {"role": "system", "content": system})
res = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=history,
**gen_conf)
return res.choices[0].message.content.strip()
class QWenChat(Base):
def chat(self, system, history, gen_conf):
from http import HTTPStatus
from dashscope import Generation
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
history.insert(0, {"role": "system", "content": system})
response = Generation.call(
Generation.Models.qwen_turbo,
messages=history,
result_format='message'
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message

View File

@ -1,66 +0,0 @@
from abc import ABC
from openai import OpenAI
import os
import base64
from io import BytesIO
class Base(ABC):
def describe(self, image, max_tokens=300):
raise NotImplementedError("Please implement encode method!")
def image2base64(self, image):
if isinstance(image, BytesIO):
return base64.b64encode(image.getvalue()).decode("utf-8")
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def prompt(self, b64):
return [
{
"role": "user",
"content": [
{
"type": "text",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
],
}
]
class GptV4(Base):
def __init__(self):
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
def describe(self, image, max_tokens=300):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=self.prompt(b64),
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip()
class QWenCV(Base):
def describe(self, image, max_tokens=300):
from http import HTTPStatus
from dashscope import MultiModalConversation
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
messages=self.prompt(self.image2base64(image)))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message

View File

@ -1,61 +0,0 @@
from abc import ABC
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import os
import numpy as np
class Base(ABC):
def encode(self, texts: list, batch_size=32):
raise NotImplementedError("Please implement encode method!")
class HuEmbedding(Base):
def __init__(self):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
def encode(self, texts: list, batch_size=32):
res = []
for i in range(0, len(texts), batch_size):
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
return np.array(res)
class GptEmbed(Base):
def __init__(self):
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts,
model="text-embedding-ada-002")
return [d["embedding"] for d in res["data"]]
class QWenEmbd(Base):
def encode(self, texts: list, batch_size=32, text_type="document"):
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
import dashscope
from http import HTTPStatus
res = []
for txt in texts:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v2,
input=txt[:2048],
text_type=text_type
)
res.append(resp["output"]["embeddings"][0]["embedding"])
return res

View File

@ -1,435 +0,0 @@
import re
import os
import copy
import base64
import magic
from dataclasses import dataclass
from typing import List
import numpy as np
from io import BytesIO
class HuChunker:
def __init__(self):
self.MAX_LVL = 12
self.proj_patt = [
(r"第[零一二三四五六七八九十百]+章", 1),
(r"第[零一二三四五六七八九十百]+[条节]", 2),
(r"[零一二三四五六七八九十百]+[、  ]", 3),
(r"[\(][零一二三四五六七八九十百]+[\)]", 4),
(r"[0-9]+(、|\.[  ]|\.[^0-9])", 5),
(r"[0-9]+\.[0-9]+(、|[  ]|[^0-9])", 6),
(r"[0-9]+\.[0-9]+\.[0-9]+(、|[  ]|[^0-9])", 7),
(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+(、|[  ]|[^0-9])", 8),
(r".{,48}[:?]@", 9),
(r"[0-9]+", 10),
(r"[\(][0-9]+[\)]", 11),
(r"[零一二三四五六七八九十百]+是", 12),
(r"[⚫•➢✓ ]", 12)
]
self.lines = []
def _garbage(self, txt):
patt = [
r"(在此保证|不得以任何形式翻版|请勿传阅|仅供内部使用|未经事先书面授权)",
r"(版权(归本公司)*所有|免责声明|保留一切权力|承担全部责任|特别声明|报告中涉及)",
r"(不承担任何责任|投资者的通知事项:|任何机构和个人|本报告仅为|不构成投资)",
r"(不构成对任何个人或机构投资建议|联系其所在国家|本报告由从事证券交易)",
r"(本研究报告由|「认可投资者」|所有研究报告均以|请发邮件至)",
r"(本报告仅供|市场有风险,投资需谨慎|本报告中提及的)",
r"(本报告反映|此信息仅供|证券分析师承诺|具备证券投资咨询业务资格)",
r"^(时间|签字|签章)[:]",
r"(参考文献|目录索引|图表索引)",
r"[ ]*年[ ]+月[ ]+日",
r"^(中国证券业协会|[0-9]+年[0-9]+月[0-9]+日)$",
r"\.{10,}",
r"(———————END|帮我转发|欢迎收藏|快来关注我吧)"
]
return any([re.search(p, txt) for p in patt])
def _proj_match(self, line):
for p, j in self.proj_patt:
if re.match(p, line):
return j
return
def _does_proj_match(self):
mat = [None for _ in range(len(self.lines))]
for i in range(len(self.lines)):
mat[i] = self._proj_match(self.lines[i])
return mat
def naive_text_chunk(self, text, ti="", MAX_LEN=612):
if text:
self.lines = [l.strip().replace(u'\u3000', u' ')
.replace(u'\xa0', u'')
for l in text.split("\n\n")]
self.lines = [l for l in self.lines if not self._garbage(l)]
self.lines = [re.sub(r"([ ]+| )", " ", l)
for l in self.lines if l]
if not self.lines:
return []
arr = self.lines
res = [""]
i = 0
while i < len(arr):
a = arr[i]
if not a:
i += 1
continue
if len(a) > MAX_LEN:
a_ = a.split("\n")
if len(a_) >= 2:
arr.pop(i)
for j in range(2, len(a_) + 1):
if len("\n".join(a_[:j])) >= MAX_LEN:
arr.insert(i, "\n".join(a_[:j - 1]))
arr.insert(i + 1, "\n".join(a_[j - 1:]))
break
else:
assert False, f"Can't split: {a}"
continue
if len(res[-1]) < MAX_LEN / 3:
res[-1] += "\n" + a
else:
res.append(a)
i += 1
if ti:
for i in range(len(res)):
if res[i].find("——来自") >= 0:
continue
res[i] += f"\t——来自“{ti}"
return res
def _merge(self):
# merge continuous same level text
lines = [self.lines[0]] if self.lines else []
for i in range(1, len(self.lines)):
if self.mat[i] == self.mat[i - 1] \
and len(lines[-1]) < 256 \
and len(self.lines[i]) < 256:
lines[-1] += "\n" + self.lines[i]
continue
lines.append(self.lines[i])
self.lines = lines
self.mat = self._does_proj_match()
return self.mat
def text_chunks(self, text):
if text:
self.lines = [l.strip().replace(u'\u3000', u' ')
.replace(u'\xa0', u'')
for l in re.split(r"[\r\n]", text)]
self.lines = [l for l in self.lines if not self._garbage(l)]
self.lines = [l for l in self.lines if l]
self.mat = self._does_proj_match()
mat = self._merge()
tree = []
for i in range(len(self.lines)):
tree.append({"proj": mat[i],
"children": [],
"read": False})
# find all children
for i in range(len(self.lines) - 1):
if tree[i]["proj"] is None:
continue
ed = i + 1
while ed < len(tree) and (tree[ed]["proj"] is None or
tree[ed]["proj"] > tree[i]["proj"]):
ed += 1
nxt = tree[i]["proj"] + 1
st = set([p["proj"] for p in tree[i + 1: ed] if p["proj"]])
while nxt not in st:
nxt += 1
if nxt > self.MAX_LVL:
break
if nxt <= self.MAX_LVL:
for j in range(i + 1, ed):
if tree[j]["proj"] is not None:
break
tree[i]["children"].append(j)
for j in range(i + 1, ed):
if tree[j]["proj"] != nxt:
continue
tree[i]["children"].append(j)
else:
for j in range(i + 1, ed):
tree[i]["children"].append(j)
# get DFS combinations, find all the paths to leaf
paths = []
def dfs(i, path):
nonlocal tree, paths
path.append(i)
tree[i]["read"] = True
if len(self.lines[i]) > 256:
paths.append(path)
return
if not tree[i]["children"]:
if len(path) > 1 or len(self.lines[i]) >= 32:
paths.append(path)
return
for j in tree[i]["children"]:
dfs(j, copy.deepcopy(path))
for i, t in enumerate(tree):
if t["read"]:
continue
dfs(i, [])
# concat txt on the path for all paths
res = []
lines = np.array(self.lines)
for p in paths:
if len(p) < 2:
tree[p[0]]["read"] = False
continue
txt = "\n".join(lines[p[:-1]]) + "\n" + lines[p[-1]]
res.append(txt)
# concat continuous orphans
assert len(tree) == len(lines)
ii = 0
while ii < len(tree):
if tree[ii]["read"]:
ii += 1
continue
txt = lines[ii]
e = ii + 1
while e < len(tree) and not tree[e]["read"] and len(txt) < 256:
txt += "\n" + lines[e]
e += 1
res.append(txt)
ii = e
# if the node has not been read, find its daddy
def find_daddy(st):
nonlocal lines, tree
proj = tree[st]["proj"]
if len(self.lines[st]) > 512:
return [st]
if proj is None:
proj = self.MAX_LVL + 1
for i in range(st - 1, -1, -1):
if tree[i]["proj"] and tree[i]["proj"] < proj:
a = [st] + find_daddy(i)
return a
return []
return res
class PdfChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, pdf_parser):
self.pdf = pdf_parser
super().__init__()
def tableHtmls(self, pdfnm):
_, tbls = self.pdf(pdfnm, return_html=True)
res = []
for img, arr in tbls:
if arr[0].find("<table>") < 0:
continue
buffered = BytesIO()
if img:
img.save(buffered, format="JPEG")
img_str = base64.b64encode(
buffered.getvalue()).decode('utf-8') if img else ""
res.append({"table": arr[0], "image": img_str})
return res
def html(self, pdfnm):
txts, tbls = self.pdf(pdfnm, return_html=True)
res = []
txt_cks = self.text_chunks(txts)
for txt, img in [(self.pdf.remove_tag(c), self.pdf.crop(c))
for c in txt_cks]:
buffered = BytesIO()
if img:
img.save(buffered, format="JPEG")
img_str = base64.b64encode(
buffered.getvalue()).decode('utf-8') if img else ""
res.append({"table": "<p>%s</p>" % txt.replace("\n", "<br/>"),
"image": img_str})
for img, arr in tbls:
if not arr:
continue
buffered = BytesIO()
if img:
img.save(buffered, format="JPEG")
img_str = base64.b64encode(
buffered.getvalue()).decode('utf-8') if img else ""
res.append({"table": arr[0], "image": img_str})
return res
def __call__(self, pdfnm, return_image=True, naive_chunk=False):
flds = self.Fields()
text, tbls = self.pdf(pdfnm)
fnm = pdfnm
txt_cks = self.text_chunks(text) if not naive_chunk else \
self.naive_text_chunk(text, ti=fnm if isinstance(fnm, str) else "")
flds.text_chunks = [(self.pdf.remove_tag(c),
self.pdf.crop(c) if return_image else None) for c in txt_cks]
flds.table_chunks = [(arr, img if return_image else None)
for img, arr in tbls]
return flds
class DocxChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, doc_parser):
self.doc = doc_parser
super().__init__()
def _does_proj_match(self):
mat = []
for s in self.styles:
s = s.split(" ")[-1]
try:
mat.append(int(s))
except Exception as e:
mat.append(None)
return mat
def _merge(self):
i = 1
while i < len(self.lines):
if self.mat[i] == self.mat[i - 1] \
and len(self.lines[i - 1]) < 256 \
and len(self.lines[i]) < 256:
self.lines[i - 1] += "\n" + self.lines[i]
self.styles.pop(i)
self.lines.pop(i)
self.mat.pop(i)
continue
i += 1
self.mat = self._does_proj_match()
return self.mat
def __call__(self, fnm):
flds = self.Fields()
flds.title = os.path.splitext(
os.path.basename(fnm))[0] if isinstance(
fnm, type("")) else ""
secs, tbls = self.doc(fnm)
self.lines = [l for l, s in secs]
self.styles = [s for l, s in secs]
txt_cks = self.text_chunks("")
flds.text_chunks = [(t, None) for t in txt_cks if not self._garbage(t)]
flds.table_chunks = [(tb, None) for tb in tbls for t in tb if t]
return flds
class ExcelChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, excel_parser):
self.excel = excel_parser
super().__init__()
def __call__(self, fnm):
flds = self.Fields()
flds.text_chunks = [(t, None) for t in self.excel(fnm)]
flds.table_chunks = []
return flds
class PptChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self):
super().__init__()
def __call__(self, fnm):
from pptx import Presentation
ppt = Presentation(fnm) if isinstance(
fnm, str) else Presentation(
BytesIO(fnm))
flds = self.Fields()
flds.text_chunks = []
for slide in ppt.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
flds.text_chunks.append((shape.text, None))
flds.table_chunks = []
return flds
class TextChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self):
super().__init__()
@staticmethod
def is_binary_file(file_path):
mime = magic.Magic(mime=True)
if isinstance(file_path, str):
file_type = mime.from_file(file_path)
else:
file_type = mime.from_buffer(file_path)
if 'text' in file_type:
return False
else:
return True
def __call__(self, fnm):
flds = self.Fields()
if self.is_binary_file(fnm):
return flds
with open(fnm, "r") as f:
txt = f.read()
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
flds.table_chunks = []
return flds
if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(__file__) + "/../")
if sys.argv[1].split(".")[-1].lower() == "pdf":
from parser import PdfParser
ckr = PdfChunker(PdfParser())
if sys.argv[1].split(".")[-1].lower().find("doc") >= 0:
from parser import DocxParser
ckr = DocxChunker(DocxParser())
if sys.argv[1].split(".")[-1].lower().find("xlsx") >= 0:
from parser import ExcelParser
ckr = ExcelChunker(ExcelParser())
# ckr.html(sys.argv[1])
print(ckr(sys.argv[1]))

View File

@ -1,411 +0,0 @@
# -*- coding: utf-8 -*-
import copy
import datrie
import math
import os
import re
import string
import sys
from hanziconv import HanziConv
class Huqie:
def key_(self, line):
return str(line.lower().encode("utf-8"))[2:-1]
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
print("[HUQIE]:Build trie", fnm, file=sys.stderr)
try:
of = open(fnm, "r")
while True:
line = of.readline()
if not line:
break
line = re.sub(r"[\r\n]+", "", line)
line = re.split(r"[ \t]", line)
k = self.key_(line[0])
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
if k not in self.trie_ or self.trie_[k][0] < F:
self.trie_[self.key_(line[0])] = (F, line[2])
self.trie_[self.rkey_(line[0])] = 1
self.trie_.save(fnm + ".trie")
of.close()
except Exception as e:
print("[HUQIE]:Faild to build trie, ", fnm, e, file=sys.stderr)
def __init__(self, debug=False):
self.DEBUG = debug
self.DENOMINATOR = 1000000
self.trie_ = datrie.Trie(string.printable)
self.DIR_ = ""
if os.path.exists("../res/huqie.txt"):
self.DIR_ = "../res/huqie"
if os.path.exists("./res/huqie.txt"):
self.DIR_ = "./res/huqie"
if os.path.exists("./huqie.txt"):
self.DIR_ = "./huqie"
assert self.DIR_, f"【Can't find huqie】"
self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
try:
self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
return
except Exception as e:
print("[HUQIE]:Build default trie", file=sys.stderr)
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(self.DIR_ + ".txt")
def loadUserDict(self, fnm):
try:
self.trie_ = datrie.Trie.load(fnm + ".trie")
return
except Exception as e:
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm)
def addUserDict(self, fnm):
self.loadDict_(fnm)
def _strQ2B(self, ustring):
"""把字符串全角转半角"""
rstring = ""
for uchar in ustring:
inside_code = ord(uchar)
if inside_code == 0x3000:
inside_code = 0x0020
else:
inside_code -= 0xfee0
if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
rstring += uchar
else:
rstring += chr(inside_code)
return rstring
def _tradi2simp(self, line):
return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist):
MAX_L = 10
res = s
# if s > MAX_L or s>= len(chars):
if s >= len(chars):
tkslist.append(preTks)
return res
# pruning
S = s + 1
if s + 2 <= len(chars):
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
self.key_(t2)):
S = s + 2
if len(preTks) > 2 and len(
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2
################
for e in range(S, len(chars) + 1):
t = "".join(chars[s:e])
k = self.key_(t)
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
break
if k in self.trie_:
pretks = copy.deepcopy(preTks)
if k in self.trie_:
pretks.append((t, self.trie_[k]))
else:
pretks.append((t, (-12, '')))
res = max(res, self.dfs_(chars, e, pretks, tkslist))
if res > s:
return res
t = "".join(chars[s:s + 1])
k = self.key_(t)
if k in self.trie_:
preTks.append((t, self.trie_[k]))
else:
preTks.append((t, (-12, '')))
return self.dfs_(chars, s + 1, preTks, tkslist)
def freq(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return 0
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
def tag(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return ""
return self.trie_[k][1]
def score_(self, tfts):
B = 30
F, L, tks = 0, 0, []
for tk, (freq, tag) in tfts:
F += freq
L += 0 if len(tk) < 2 else 1
tks.append(tk)
F /= len(tks)
L /= len(tks)
if self.DEBUG:
print("[SC]", tks, len(tks), L, F, B / len(tks) + L + F)
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
res.append((tks, s))
return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks):
patts = [
(r"[ ]+", " "),
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
]
# for p,s in patts: tks = re.sub(p, s, tks)
# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split(" ")
s = 0
while True:
if s >= len(tks):
break
E = s + 1
for e in range(s + 2, min(len(tks) + 2, s + 6)):
tk = "".join(tks[s:e])
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
E = e
res.append("".join(tks[s:E]))
s = E
return " ".join(res)
def maxForward_(self, line):
res = []
s = 0
while s < len(line):
e = s + 1
t = line[s:e]
while e < len(line) and self.trie_.has_keys_with_prefix(
self.key_(t)):
e += 1
t = line[s:e]
while e - 1 > s and self.key_(t) not in self.trie_:
e -= 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s = e
return self.score_(res)
def maxBackward_(self, line):
res = []
s = len(line) - 1
while s >= 0:
e = s + 1
t = line[s:e]
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
s -= 1
t = line[s:e]
while s + 1 < e and self.key_(t) not in self.trie_:
s += 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s -= 1
return self.score_(res[::-1])
def qie(self, line):
line = self._strQ2B(line).lower()
line = self._tradi2simp(line)
arr = re.split(self.SPLIT_CHAR, line)
res = []
for L in arr:
if len(L) < 2 or re.match(
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
res.append(L)
continue
# print(L)
# use maxforward for the first time
tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L)
if self.DEBUG:
print("[FW]", tks, s)
print("[BW]", tks1, s1)
diff = [0 for _ in range(max(len(tks1), len(tks)))]
for i in range(min(len(tks1), len(tks))):
if tks[i] != tks1[i]:
diff[i] = 1
if s1 > s:
tks = tks1
i = 0
while i < len(tks):
s = i
while s < len(tks) and diff[s] == 0:
s += 1
if s == len(tks):
res.append(" ".join(tks[i:]))
break
if s > i:
res.append(" ".join(tks[i:s]))
e = s
while e < len(tks) and e - s < 5 and diff[e] == 1:
e += 1
tkslist = []
self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
i = e + 1
res = " ".join(res)
if self.DEBUG:
print("[TKS]", self.merge_(res))
return self.merge_(res)
def qieqie(self, tks):
res = []
for tk in tks.split(" "):
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
res.append(tk)
continue
tkslist = []
if len(tk) > 10:
tkslist.append(tk)
else:
self.dfs_(tk, 0, [], tkslist)
if len(tkslist) < 2:
res.append(tk)
continue
stk = self.sortTks_(tkslist)[1][0]
if len(stk) == len(tk):
stk = tk
else:
if re.match(r"[a-z\.-]+$", tk):
for t in stk:
if len(t) < 3:
stk = tk
break
else:
stk = " ".join(stk)
else:
stk = " ".join(stk)
res.append(stk)
return " ".join(res)
def is_chinese(s):
if s >= u'\u4e00' and s <= u'\u9fa5':
return True
else:
return False
def is_number(s):
if s >= u'\u0030' and s <= u'\u0039':
return True
else:
return False
def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or (
s >= u'\u0061' and s <= u'\u007a'):
return True
else:
return False
def naiveQie(txt):
tks = []
for t in txt.split(" "):
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
) and re.match(r".*[a-zA-Z]$", t):
tks.append(" ")
tks.append(t)
return tks
hq = Huqie()
qie = hq.qie
qieqie = hq.qieqie
tag = hq.tag
freq = hq.freq
loadUserDict = hq.loadUserDict
addUserDict = hq.addUserDict
tradi2simp = hq._tradi2simp
strQ2B = hq._strQ2B
if __name__ == '__main__':
huqie = Huqie(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
tks = huqie.qie(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
print(huqie.qieqie(tks))
tks = huqie.qie(
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
print(huqie.qieqie(tks))
tks = huqie.qie(
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
print(huqie.qieqie(tks))
tks = huqie.qie(
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
print(huqie.qieqie(tks))
tks = huqie.qie("虽然我不怎么玩")
print(huqie.qieqie(tks))
tks = huqie.qie("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
print(huqie.qieqie(tks))
tks = huqie.qie(
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
print(huqie.qieqie(tks))
tks = huqie.qie("这周日你去吗?这周日你有空吗?")
print(huqie.qieqie(tks))
tks = huqie.qie("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
print(huqie.qieqie(tks))
tks = huqie.qie(
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
print(huqie.qieqie(tks))
if len(sys.argv) < 2:
sys.exit()
huqie.DEBUG = False
huqie.loadUserDict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()
if not line:
break
print(huqie.qie(line))
of.close()

View File

@ -1,167 +0,0 @@
import json
import re
import sys
import os
import logging
import copy
import math
from elasticsearch_dsl import Q, Search
from nlp import huqie, term_weight, synonym
class EsQueryer:
def __init__(self, es):
self.tw = term_weight.Dealer()
self.es = es
self.syn = synonym.Dealer(None)
self.flds = ["ask_tks^10", "ask_small_tks"]
@staticmethod
def subSpecialChar(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1. / len(arr) >= 0.8
@staticmethod
def rmWWW(txt):
txt = re.sub(
r"是*(什么样的|哪家|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
"",
txt)
return re.sub(
r"(what|who|how|which|where|why|(is|are|were|was) there) (is|are|were|was)*", "", txt, re.IGNORECASE)
def question(self, txt, tbl="qa", min_match="60%"):
txt = re.sub(
r"[ \t,,。??/`!&]+",
" ",
huqie.tradi2simp(
huqie.strQ2B(
txt.lower()))).strip()
txt = EsQueryer.rmWWW(txt)
if not self.isChinese(txt):
tks = txt.split(" ")
q = []
for i in range(1, len(tks)):
q.append("\"%s %s\"~2" % (tks[i - 1], tks[i]))
if not q:
q.append(txt)
return Q("bool",
must=Q("query_string", fields=self.flds,
type="best_fields", query=" OR ".join(q),
boost=1, minimum_should_match="60%")
), txt.split(" ")
def needQieqie(tk):
if len(tk) < 4:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
qs, keywords = [], []
for tt in self.tw.split(txt): # .split(" "):
if not tt:
continue
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
logging.info(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = huqie.qieqie(tk).split(" ") if needQieqie(tk) else []
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m) for m in sm]
sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(sm) < 2:
sm = []
keywords.append(re.sub(r"[ \\\"']+", "", tk))
tk_syns = self.syn.lookup(tk)
tk = EsQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = "\"%s\"" % tk
if tk_syns:
tk = f"({tk} %s)" % " ".join(tk_syns)
if sm:
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
" ".join(sm), " ".join(sm))
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
if re.match(r"[0-9a-z ]+$", tt):
tms = f"(\"{tt}\" OR \"%s\")" % huqie.qie(tt)
syns = " OR ".join(
["\"%s\"^0.7" % EsQueryer.subSpecialChar(huqie.qie(s)) for s in syns])
if syns:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
flds = copy.deepcopy(self.flds)
mst = []
if qs:
mst.append(
Q("query_string", fields=flds, type="best_fields",
query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
)
return Q("bool",
must=mst,
), keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
def toDict(tks):
d = {}
if isinstance(tks, type("")):
tks = tks.split(" ")
for t, c in self.tw.weights(tks):
if t not in d:
d[t] = 0
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
tksim = [self.similarity(atks, btks) for btks in btkss]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight
def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")):
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt))}
if isinstance(qtwt, type("")):
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt))}
s = 1e-9
for k, v in qtwt.items():
if k in dtwt:
s += v * dtwt[k]
q = 1e-9
for k, v in qtwt.items():
q += v * v
d = 1e-9
for k, v in dtwt.items():
d += v * v
return s / math.sqrt(q) / math.sqrt(d)

View File

@ -1,252 +0,0 @@
import re
from elasticsearch_dsl import Q, Search, A
from typing import List, Optional, Tuple, Dict, Union
from dataclasses import dataclass
from util import setup_logging, rmSpace
from nlp import huqie, query
from datetime import datetime
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
from copy import deepcopy
def index_name(uid): return f"docgpt_{uid}"
class Dealer:
def __init__(self, es, emb_mdl):
self.qryr = query.EsQueryer(es)
self.qryr.flds = [
"title_tks^10",
"title_sm_tks^5",
"content_ltks^2",
"content_sm_ltks"]
self.es = es
self.emb_mdl = emb_mdl
@dataclass
class SearchResult:
total: int
ids: List[str]
query_vector: List[float] = None
field: Optional[Dict] = None
highlight: Optional[Dict] = None
aggregation: Union[List, Dict, None] = None
keywords: Optional[List[str]] = None
group_docs: List[List] = None
def _vector(self, txt, sim=0.8, topk=10):
return {
"field": "q_vec",
"k": topk,
"similarity": sim,
"num_candidates": 1000,
"query_vector": self.emb_mdl.encode_queries(txt)
}
def search(self, req, idxnm, tks_num=3):
keywords = []
qst = req.get("question", "")
bqry, keywords = self.qryr.question(qst)
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
bqry.filter.append(Q("exists", field="q_tks"))
bqry.boost = 0.05
print(bqry)
s = Search()
pg = int(req.get("page", 1)) - 1
ps = int(req.get("size", 1000))
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
"image_id", "doc_id", "q_vec"])
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks")
s = s.highlight("title_ltks")
if not qst:
s = s.sort(
{"create_time": {"order": "desc", "unmapped_type": "date"}})
s = s.highlight_options(
fragment_size=120,
number_of_fragments=5,
boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、"
)
s = s.to_dict()
q_vec = []
if req.get("vector"):
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
s["knn"]["filter"] = bqry.to_dict()
del s["highlight"]
q_vec = s["knn"]["query_vector"]
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
print("TOTAL: ", self.es.getTotal(res))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.7
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
kwds = set([])
for k in keywords:
kwds.add(k)
for kk in huqie.qieqie(k).split(" "):
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
aggs = self.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total=self.es.getTotal(res),
ids=self.es.getDocIds(res),
query_vector=q_vec,
aggregation=aggs,
highlight=self.getHighlight(res),
field=self.getFields(res, ["docnm_kwd", "content_ltks",
"kb_id", "image_id", "doc_id", "q_vec"]),
keywords=list(kwds)
)
def getAggregation(self, res, g):
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
return
bkts = res["aggregations"]["aggs_" + g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
def getHighlight(self, res):
def rmspace(line):
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
r = []
for t in line.split(" "):
if not t:
continue
if len(r) > 0 and len(
t) > 0 and r[-1][-1] in eng and t[0] in eng:
r.append(" ")
r.append(t)
r = "".join(r)
return r
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:
continue
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
return ans
def getFields(self, sres, flds):
res = {}
if not flds:
return {}
for d in self.es.getSource(sres):
m = {n: d.get(n) for n in flds if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, type([])):
m[n] = "\t".join([str(vv) for vv in v])
continue
if not isinstance(v, type("")):
m[n] = str(m[n])
m[n] = rmSpace(m[n])
if m:
res[d["id"]] = m
return res
@staticmethod
def trans2floats(txt):
return [float(t) for t in txt.split("\t")]
def insert_citations(self, ans, top_idx, sres,
vfield="q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(
sres.field[sres.ids[i]][vfield]) for i in top_idx]
ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
s = 0
e = 0
res = ""
def citeit():
nonlocal s, e, ans, res
if not ins_embd:
return
embd = self.emb_mdl.encode(ans[s: e])
sim = self.qryr.hybrid_similarity(embd,
ins_embd,
huqie.qie(ans[s:e]).split(" "),
ins_tw)
print(ans[s: e], sim)
mx = np.max(sim) * 0.99
if mx < 0.55:
return
cita = list(set([top_idx[i]
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
for i in cita:
res += f"@?{i}?@"
return cita
punct = set(";。?!")
if not self.qryr.isChinese(ans):
punct.add("?")
punct.add(".")
while e < len(ans):
if e - s < 12 or ans[e] not in punct:
e += 1
continue
if ans[e] == "." and e + \
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
e += 1
continue
if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
e += 1
continue
res += ans[s: e]
citeit()
res += ans[e]
e += 1
s = e
if s < len(ans):
res += ans[s:]
citeit()
return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
vfield="q_vec", cfield="content_ltks"):
ins_embd = [
Dealer.trans2floats(
sres.field[i]["q_vec"]) for i in sres.ids]
if not ins_embd:
return []
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
# return CosineSimilarity([sres.query_vector], ins_embd)[0]
sim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
huqie.qie(query).split(" "),
ins_tw, tkweight, vtweight)
return sim
if __name__ == "__main__":
from util import es_conn
SE = Dealer(es_conn.HuEs("infiniflow"))
qs = [
"胡凯",
""
]
for q in qs:
print(">>>>>>>>>>>>>>>>>>>>", q)
print(SE.search(
{"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))

View File

@ -1,67 +0,0 @@
import json
import time
import logging
import re
class Dealer:
def __init__(self, redis=None):
self.lookup_num = 100000000
self.load_tm = time.time() - 1000000
self.dictionary = None
try:
self.dictionary = json.load(open("./synonym.json", 'r'))
except Exception as e:
pass
try:
self.dictionary = json.load(open("./res/synonym.json", 'r'))
except Exception as e:
try:
self.dictionary = json.load(open("../res/synonym.json", 'r'))
except Exception as e:
logging.warn("Miss synonym.json")
self.dictionary = {}
if not redis:
logging.warning(
"Realtime synonym is disabled, since no redis connection.")
if not len(self.dictionary.keys()):
logging.warning(f"Fail to load synonym")
self.redis = redis
self.load()
def load(self):
if not self.redis:
return
if self.lookup_num < 100:
return
tm = time.time()
if tm - self.load_tm < 3600:
return
self.load_tm = time.time()
self.lookup_num = 0
d = self.redis.get("kevin_synonyms")
if not d:
return
try:
d = json.loads(d)
self.dictionary = d
except Exception as e:
logging.error("Fail to load synonym!" + str(e))
def lookup(self, tk):
self.lookup_num += 1
self.load()
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
if isinstance(res, str):
res = [res]
return res
if __name__ == '__main__':
dl = Dealer()
print(dl.dictionary)

View File

@ -1,216 +0,0 @@
import math
import json
import re
import os
import numpy as np
from nlp import huqie
class Dealer:
def __init__(self):
self.stop_words = set(["请问",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"#",
"什么",
"怎么",
"哪个",
"哪些",
"",
"相关"])
def load_dict(fnm):
res = {}
f = open(fnm, "r")
while True:
l = f.readline()
if not l:
break
arr = l.replace("\n", "").split("\t")
if len(arr) < 2:
res[arr[0]] = 0
else:
res[arr[0]] = int(arr[1])
c = 0
for _, v in res.items():
c += v
if c == 0:
return set(res.keys())
return res
fnm = os.path.join(os.path.dirname(__file__), '../res/')
if not os.path.exists(fnm):
fnm = os.path.join(os.path.dirname(__file__), '../../res/')
self.ne, self.df = {}, {}
try:
self.ne = json.load(open(fnm + "ner.json", "r"))
except Exception as e:
print("[WARNING] Load ner.json FAIL!")
try:
self.df = load_dict(fnm + "term.freq")
except Exception as e:
print("[WARNING] Load term.freq FAIL!")
def pretoken(self, txt, num=False, stpwd=True):
patt = [
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
]
rewt = [
]
for p, r in rewt:
txt = re.sub(p, r, txt)
res = []
for t in huqie.qie(txt).split(" "):
tk = t
if (stpwd and tk in self.stop_words) or (
re.match(r"[0-9]$", tk) and not num):
continue
for p in patt:
if re.match(p, t):
tk = "#"
break
tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk:
res.append(tk)
return res
def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
j = i
if i == 0 and oneTerm(tks[i]) and len(
tks) > 1 and len(tks[i + 1]) > 1: # 多 工位
res.append(" ".join(tks[0:2]))
i = 2
continue
while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
j += 1
if j - i > 1:
if j - i < 5:
res.append(" ".join(tks[i:j]))
i = j
else:
res.append(" ".join(tks[i:i + 2]))
i = i + 2
else:
if len(tks[i]) > 0:
res.append(tks[i])
i += 1
return [t for t in res if t]
def ner(self, t):
if not self.ne:
return ""
res = self.ne.get(t, "")
if res:
return res
def split(self, txt):
tks = []
for t in re.sub(r"[ \t]+", " ", txt).split(" "):
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t
else:
tks.append(t)
return tks
def weights(self, tks):
def skill(t):
if t not in self.sk:
return 1
return 6
def ner(t):
if not self.ne or t not in self.ne:
return 1
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
"firstnm": 1}
return m[self.ne[t]]
def postag(t):
t = huqie.tag(t)
if t in set(["r", "c", "d"]):
return 0.3
if t in set(["ns", "nt"]):
return 3
if t in set(["n"]):
return 2
if re.match(r"[0-9-]+", t):
return 2
return 1
def freq(t):
if re.match(r"[0-9\. -]+$", t):
return 10000
s = huqie.freq(t)
if not s and re.match(r"[a-z\. -]+$", t):
return 10
if not s:
s = 0
if not s and len(t) >= 4:
s = [tt for tt in huqie.qieqie(t).split(" ") if len(tt) > 1]
if len(s) > 1:
s = np.min([freq(tt) for tt in s]) / 6.
else:
s = 0
return max(s, 10)
def df(t):
if re.match(r"[0-9\. -]+$", t):
return 100000
if t in self.df:
return self.df[t] + 3
elif re.match(r"[a-z\. -]+$", t):
return 3
elif len(t) >= 4:
s = [tt for tt in huqie.qieqie(t).split(" ") if len(tt) > 1]
if len(s) > 1:
return max(3, np.min([df(tt) for tt in s]) / 6.)
return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = []
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tt])
tw.extend(zip(tt, wts))
S = np.sum([s for _, s in tw])
return [(t, s / S) for t, s in tw]

0
python/output/ToPDF.pdf Normal file
View File

View File

@ -1,3 +0,0 @@
from .pdf_parser import HuParser as PdfParser
from .docx_parser import HuDocxParser as DocxParser
from .excel_parser import HuExcelParser as ExcelParser

View File

@ -1,104 +0,0 @@
from docx import Document
import re
import pandas as pd
from collections import Counter
from nlp import huqie
from io import BytesIO
class HuDocxParser:
def __extract_table_content(self, tb):
df = []
for row in tb.rows:
df.append([c.text for c in row.cells])
return self.__compose_table_content(pd.DataFrame(df))
def __compose_table_content(self, df):
def blockType(b):
patt = [
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^(20|19)[0-9]{2}年$", "Dt"),
(r"^(20|19)[0-9]{2}[年/-][0-9]{1,2}月*$", "Dt"),
("^[0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^第*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}[ABCDE]$", "DT"),
("^[0-9.,+%/ -]+$", "Nu"),
(r"^[0-9A-Z/\._~-]+$", "Ca"),
(r"^[A-Z]*[a-z' -]+$", "En"),
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()' -]+$", "NE"),
(r"^.{1}$", "Sg")
]
for p, n in patt:
if re.search(p, b):
return n
tks = [t for t in huqie.qie(b).split(" ") if len(t) > 1]
if len(tks) > 3:
if len(tks) < 12:
return "Tx"
else:
return "Lx"
if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
return "Nr"
return "Ot"
if len(df) < 2:
return []
max_type = Counter([blockType(str(df.iloc[i, j])) for i in range(
1, len(df)) for j in range(len(df.iloc[i, :]))])
max_type = max(max_type.items(), key=lambda x: x[1])[0]
colnm = len(df.iloc[0, :])
hdrows = [0] # header is not nessesarily appear in the first line
if max_type == "Nu":
for r in range(1, len(df)):
tys = Counter([blockType(str(df.iloc[r, j]))
for j in range(len(df.iloc[r, :]))])
tys = max(tys.items(), key=lambda x: x[1])[0]
if tys != max_type:
hdrows.append(r)
lines = []
for i in range(1, len(df)):
if i in hdrows:
continue
hr = [r - i for r in hdrows]
hr = [r for r in hr if r < 0]
t = len(hr) - 1
while t > 0:
if hr[t] - hr[t - 1] > 1:
hr = hr[t:]
break
t -= 1
headers = []
for j in range(len(df.iloc[i, :])):
t = []
for h in hr:
x = str(df.iloc[i + h, j]).strip()
if x in t:
continue
t.append(x)
t = ",".join(t)
if t:
t += ": "
headers.append(t)
cells = []
for j in range(len(df.iloc[i, :])):
if not str(df.iloc[i, j]):
continue
cells.append(headers[j] + str(df.iloc[i, j]))
lines.append(";".join(cells))
if colnm > 3:
return lines
return ["\n".join(lines)]
def __call__(self, fnm):
self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm))
secs = [(p.text, p.style.name) for p in self.doc.paragraphs]
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
return secs, tbls

View File

@ -1,25 +0,0 @@
from openpyxl import load_workbook
import sys
from io import BytesIO
class HuExcelParser:
def __call__(self, fnm):
if isinstance(fnm, str):
wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(fnm))
res = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
lines = []
for r in ws.rows:
lines.append(
"\t".join([str(c.value) if c.value is not None else "" for c in r]))
res.append(f"{sheetname}\n" + "\n".join(lines))
return res
if __name__ == "__main__":
psr = HuExcelParser()
psr(sys.argv[1])

File diff suppressed because it is too large Load Diff

View File

@ -1,194 +0,0 @@
accelerate==0.24.1
addict==2.4.0
aiobotocore==2.7.0
aiofiles==23.2.1
aiohttp==3.8.6
aioitertools==0.11.0
aiosignal==1.3.1
aliyun-python-sdk-core==2.14.0
aliyun-python-sdk-kms==2.16.2
altair==5.1.2
anyio==3.7.1
astor==0.8.1
async-timeout==4.0.3
attrdict==2.0.1
attrs==23.1.0
Babel==2.13.1
bce-python-sdk==0.8.92
beautifulsoup4==4.12.2
bitsandbytes==0.41.1
blinker==1.7.0
botocore==1.31.64
cachetools==5.3.2
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
contourpy==1.2.0
crcmod==1.7
cryptography==41.0.5
cssselect==1.2.0
cssutils==2.9.0
cycler==0.12.1
Cython==3.0.5
datasets==2.13.0
datrie==0.8.2
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
einops==0.7.0
elastic-transport==8.10.0
elasticsearch==8.10.1
elasticsearch-dsl==8.9.0
et-xmlfile==1.1.0
fastapi==0.104.1
ffmpy==0.3.1
filelock==3.13.1
fire==0.5.0
FlagEmbedding==1.1.5
Flask==3.0.0
flask-babel==4.0.0
fonttools==4.44.0
frozenlist==1.4.0
fsspec==2023.10.0
future==0.18.3
gast==0.5.4
-e
git+https://github.com/ggerganov/llama.cpp.git@5f6e0c0dff1e7a89331e6b25eca9a9fd71324069#egg=gguf&subdirectory=gguf-py
gradio==3.50.2
gradio_client==0.6.1
greenlet==3.0.1
h11==0.14.0
hanziconv==0.3.2
httpcore==1.0.1
httpx==0.25.1
huggingface-hub==0.17.3
idna==3.4
imageio==2.31.6
imgaug==0.4.0
importlib-metadata==6.8.0
importlib-resources==6.1.0
install==1.3.5
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==0.10.0
joblib==1.3.2
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
kiwisolver==1.4.5
lazy_loader==0.3
lmdb==1.4.1
lxml==4.9.3
MarkupSafe==2.1.3
matplotlib==3.8.1
modelscope==1.9.4
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
networkx==3.2.1
nltk==3.8.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.52
nvidia-nvtx-cu12==12.1.105
opencv-contrib-python==4.6.0.66
opencv-python==4.6.0.66
openpyxl==3.1.2
opt-einsum==3.3.0
orjson==3.9.10
oss2==2.18.3
packaging==23.2
paddleocr==2.7.0.3
paddlepaddle-gpu==2.5.2.post120
pandas==2.1.2
pdf2docx==0.5.5
pdfminer.six==20221105
pdfplumber==0.10.3
Pillow==10.0.1
platformdirs==3.11.0
premailer==3.10.0
protobuf==4.25.0
psutil==5.9.6
pyarrow==14.0.0
pyclipper==1.3.0.post5
pycocotools==2.0.7
pycparser==2.21
pycryptodome==3.19.0
pydantic==1.10.13
pydub==0.25.1
PyMuPDF==1.20.2
pyparsing==3.1.1
pypdfium2==4.23.1
python-dateutil==2.8.2
python-docx==1.1.0
python-multipart==0.0.6
pytz==2023.3.post1
PyYAML==6.0.1
rapidfuzz==3.5.2
rarfile==4.1
referencing==0.30.2
regex==2023.10.3
requests==2.31.0
rpds-py==0.12.0
s3fs==2023.10.0
safetensors==0.4.0
scikit-image==0.22.0
scikit-learn==1.3.2
scipy==1.11.3
semantic-version==2.10.0
sentence-transformers==2.2.2
sentencepiece==0.1.98
shapely==2.0.2
simplejson==3.19.2
six==1.16.0
sniffio==1.3.0
sortedcontainers==2.4.0
soupsieve==2.5
SQLAlchemy==2.0.23
starlette==0.27.0
sympy==1.12
tabulate==0.9.0
tblib==3.0.0
termcolor==2.3.0
threadpoolctl==3.2.0
tifffile==2023.9.26
tiktoken==0.5.1
timm==0.9.10
tokenizers==0.13.3
tomli==2.0.1
toolz==0.12.0
torch==2.1.0
torchaudio==2.1.0
torchvision==0.16.0
tornado==6.3.3
tqdm==4.66.1
transformers==4.33.0
transformers-stream-generator==0.0.4
triton==2.1.0
typing_extensions==4.8.0
tzdata==2023.3
urllib3==2.0.7
uvicorn==0.24.0
uvloop==0.19.0
visualdl==2.5.3
websockets==11.0.3
Werkzeug==3.0.1
wrapt==1.15.0
xgboost==2.0.1
xinference==0.6.0
xorbits==0.7.0
xoscar==0.1.3
xxhash==3.4.1
yapf==0.40.2
yarl==1.9.2
zipp==3.17.0

8
python/res/1-0.tm Normal file
View File

@ -0,0 +1,8 @@
2023-12-20 11:44:08.791336+00:00
2023-12-20 11:44:08.853249+00:00
2023-12-20 11:44:08.909933+00:00
2023-12-21 00:47:09.996757+00:00
2023-12-20 11:44:08.965855+00:00
2023-12-20 11:44:09.011682+00:00
2023-12-21 00:47:10.063326+00:00
2023-12-20 11:44:09.069486+00:00

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,3 @@
2023-12-27 08:21:49.309802+00:00
2023-12-27 08:37:22.407772+00:00
2023-12-27 08:59:18.845627+00:00

View File

@ -1,118 +0,0 @@
import sys, datetime, random, re, cv2
from os.path import dirname, realpath
sys.path.append(dirname(realpath(__file__)) + "/../")
from util.db_conn import Postgres
from util.minio_conn import HuMinio
from util import findMaxDt
import base64
from io import BytesIO
import pandas as pd
from PIL import Image
import pdfplumber
PG = Postgres("infiniflow", "docgpt")
MINIO = HuMinio("infiniflow")
def set_thumbnail(did, base64):
sql = f"""
update doc_info set thumbnail_base64='{base64}'
where
did={did}
"""
PG.update(sql)
def collect(comm, mod, tm):
sql = f"""
select
did, uid, doc_name, location, updated_at
from doc_info
where
updated_at >= '{tm}'
and MOD(did, {comm}) = {mod}
and is_deleted=false
and type <> 'folder'
and thumbnail_base64=''
order by updated_at asc
limit 10
"""
docs = PG.select(sql)
if len(docs) == 0:return pd.DataFrame()
mtm = str(docs["updated_at"].max())[:19]
print("TOTAL:", len(docs), "To: ", mtm)
return docs
def build(row):
if not re.search(r"\.(pdf|jpg|jpeg|png|gif|svg|apng|icon|ico|webp|mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$",
row["doc_name"].lower().strip()):
set_thumbnail(row["did"], "_")
return
def thumbnail(img, SIZE=128):
w,h = img.size
p = SIZE/max(w, h)
w, h = int(w*p), int(h*p)
img.thumbnail((w, h))
buffered = BytesIO()
try:
img.save(buffered, format="JPEG")
except Exception as e:
try:
img.save(buffered, format="PNG")
except Exception as ee:
pass
return base64.b64encode(buffered.getvalue()).decode("utf-8")
iobytes = BytesIO(MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
if re.search(r"\.pdf$", row["doc_name"].lower().strip()):
pdf = pdfplumber.open(iobytes)
img = pdf.pages[0].to_image().annotated
set_thumbnail(row["did"], thumbnail(img))
if re.search(r"\.(jpg|jpeg|png|gif|svg|apng|webp|icon|ico)$", row["doc_name"].lower().strip()):
img = Image.open(iobytes)
set_thumbnail(row["did"], thumbnail(img))
if re.search(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$", row["doc_name"].lower().strip()):
url = MINIO.get_presigned_url("%s-upload"%str(row["uid"]),
row["location"],
expires=datetime.timedelta(seconds=60)
)
cap = cv2.VideoCapture(url)
succ = cap.isOpened()
i = random.randint(1, 11)
while succ:
ret, frame = cap.read()
if not ret: break
if i > 0:
i -= 1
continue
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
print(img.size)
set_thumbnail(row["did"], thumbnail(img))
cap.release()
cv2.destroyAllWindows()
def main(comm, mod):
global model
tm_fnm = f"res/thumbnail-{comm}-{mod}.tm"
tm = findMaxDt(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:return
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
build(r)
tmf.write(str(r["updated_at"]) + "\n")
tmf.close()
if __name__ == "__main__":
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())

View File

@ -1,165 +0,0 @@
#-*- coding:utf-8 -*-
import sys, os, re,inspect,json,traceback,logging,argparse, copy
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
from tornado.web import RequestHandler,Application
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer
from tornado.options import define,options
from util import es_conn, setup_logging
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from nlp import huqie
from nlp import query as Query
from nlp import search
from llm import HuEmbedding, GptTurbo
import numpy as np
from io import BytesIO
from util import config
from timeit import default_timer as timer
from collections import OrderedDict
from llm import ChatModel, EmbeddingModel
SE = None
CFIELD="content_ltks"
EMBEDDING = EmbeddingModel
LLM = ChatModel
def get_QA_pairs(hists):
pa = []
for h in hists:
for k in ["user", "assistant"]:
if h.get(k):
pa.append({
"content": h[k],
"role": k,
})
for p in pa[:-1]: assert len(p) == 2, p
return pa
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
max_len //= len(top_i)
# add instruction to prompt
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
if len(instructions)>2:
# Said that LLM is sensitive to the first and the last one, so
# rearrange the order of references
instructions.append(copy.deepcopy(instructions[1]))
instructions.pop(1)
def token_num(txt):
c = 0
for tk in re.split(r"[,。/?‘’”“:;:;!]", txt):
if re.match(r"[a-zA-Z-]+$", tk):
c += 1
continue
c += len(tk)
return c
_inst = ""
for ins in instructions:
if token_num(_inst) > 4096:
_inst += "\n知识库:" + instructions[-1][:max_len]
break
_inst += "\n知识库:" + ins[:max_len]
return _inst
def prompt_and_answer(history, inst):
hist = get_QA_pairs(history)
chks = []
for s in re.split(r"[:;;。\n\r]+", inst):
if s: chks.append(s)
chks = len(set(chks))/(0.1+len(chks))
print("Duplication portion:", chks)
system = """
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
以下是知识库:
%s
以上是知识库。
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
print("【PROMPT】:", system)
start = timer()
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
print("GENERATE: ", timer()-start)
print("===>>", response)
return response
class Handler(RequestHandler):
def post(self):
global SE,MUST_TK_NUM
param = json.loads(self.request.body.decode('utf-8'))
try:
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
res = SE.search({
"question": question,
"kb_ids": param.get("kb_ids", []),
"size": param.get("topn", 15)},
search.index_name(param["uid"])
)
sim = SE.rerank(res, question)
rk_idx = np.argsort(sim*-1)
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
inst = get_instruction(res, topidx)
ans, topidx = prompt_and_answer(param["history"], inst)
ans = SE.insert_citations(ans, topidx, res)
refer = OrderedDict()
docnms = {}
for i in rk_idx:
did = res.field[res.ids[i]]["doc_id"]
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
if did not in refer: refer[did] = []
refer[did].append({
"chunk_id": res.ids[i],
"content": res.field[res.ids[i]]["content_ltks"],
"image": ""
})
print("::::::::::::::", ans)
self.write(json.dumps({
"code":0,
"msg":"success",
"data":{
"uid": param["uid"],
"dialog_id": param["dialog_id"],
"assistant": ans,
"refer": [{
"did": did,
"doc_name": docnms[did],
"chunks": chunks
} for did, chunks in refer.items()]
}
}))
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
except Exception as e:
logging.error("Request 500: "+str(e))
self.write(json.dumps({
"code":500,
"msg":str(e),
"data":{}
}))
print(traceback.format_exc())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
ARGS = parser.parse_args()
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
http_server = HTTPServer(app)
http_server.bind(ARGS.port)
http_server.start(3)
IOLoop.current().start()

View File

@ -1,258 +0,0 @@
import json, os, sys, hashlib, copy, time, random, re
from os.path import dirname, realpath
sys.path.append(dirname(realpath(__file__)) + "/../")
from util.es_conn import HuEs
from util.db_conn import Postgres
from util.minio_conn import HuMinio
from util import rmSpace, findMaxDt
from FlagEmbedding import FlagModel
from nlp import huchunk, huqie, search
from io import BytesIO
import pandas as pd
from elasticsearch_dsl import Q
from PIL import Image
from parser import (
PdfParser,
DocxParser,
ExcelParser
)
from nlp.huchunk import (
PdfChunker,
DocxChunker,
ExcelChunker,
PptChunker,
TextChunker
)
ES = HuEs("infiniflow")
BATCH_SIZE = 64
PG = Postgres("infiniflow", "docgpt")
MINIO = HuMinio("infiniflow")
PDF = PdfChunker(PdfParser())
DOC = DocxChunker(DocxParser())
EXC = ExcelChunker(ExcelParser())
PPT = PptChunker()
def chuck_doc(name, binary):
suff = os.path.split(name)[-1].lower().split(".")[-1]
if suff.find("pdf") >= 0: return PDF(binary)
if suff.find("doc") >= 0: return DOC(binary)
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
if suff.find("ppt") >= 0: return PPT(binary)
if os.envirement.get("PARSE_IMAGE") \
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
name.lower()):
from llm import CvModel
txt = CvModel.describe(binary)
field = TextChunker.Fields()
field.text_chunks = [(txt, binary)]
field.table_chunks = []
return TextChunker()(binary)
def collect(comm, mod, tm):
sql = f"""
select
id as kb2doc_id,
kb_id,
did,
updated_at,
is_deleted
from kb2_doc
where
updated_at >= '{tm}'
and kb_progress = 0
and MOD(did, {comm}) = {mod}
order by updated_at asc
limit 1000
"""
kb2doc = PG.select(sql)
if len(kb2doc) == 0:return pd.DataFrame()
sql = """
select
did,
uid,
doc_name,
location,
size
from doc_info
where
did in (%s)
"""%",".join([str(i) for i in kb2doc["did"].unique()])
docs = PG.select(sql)
docs = docs.fillna("")
docs = docs.join(kb2doc.set_index("did"), on="did", how="left")
mtm = str(docs["updated_at"].max())[:19]
print("TOTAL:", len(docs), "To: ", mtm)
return docs
def set_progress(kb2doc_id, prog, msg="Processing..."):
sql = f"""
update kb2_doc set kb_progress={prog}, kb_progress_msg='{msg}'
where
id={kb2doc_id}
"""
PG.update(sql)
def build(row):
if row["size"] > 256000000:
set_progress(row["kb2doc_id"], -1, "File size exceeds( <= 256Mb )")
return []
res = ES.search(Q("term", doc_id=row["did"]))
if ES.getTotal(res) > 0:
ES.updateScriptByQuery(Q("term", doc_id=row["did"]),
scripts="""
if(!ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.add('%s');
"""%(str(row["kb_id"]), str(row["kb_id"])),
idxnm = search.index_name(row["uid"])
)
set_progress(row["kb2doc_id"], 1, "Done")
return []
random.seed(time.time())
set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
try:
obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
else:
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
return []
if not obj.text_chunks and not obj.table_chunks:
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
return []
set_progress(row["kb2doc_id"], random.randint(20, 60)/100., "Finished slicing files. Start to embedding the content.")
doc = {
"doc_id": row["did"],
"kb_id": [str(row["kb_id"])],
"docnm_kwd": os.path.split(row["location"])[-1],
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
"updated_at": str(row["updated_at"]).replace("T", " ")[:19]
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
output_buffer = BytesIO()
docs = []
md5 = hashlib.md5()
for txt, img in obj.text_chunks:
d = copy.deepcopy(doc)
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["content_ltks"] = huqie.qie(txt)
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
if not img:
docs.append(d)
continue
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
else: output_buffer = BytesIO(img)
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
docs.append(d)
for arr, img in obj.table_chunks:
for i, txt in enumerate(arr):
d = copy.deepcopy(doc)
d["content_ltks"] = huqie.qie(txt)
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
if not img:
docs.append(d)
continue
img.save(output_buffer, format='JPEG')
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
docs.append(d)
set_progress(row["kb2doc_id"], random.randint(60, 70)/100., "Continue embedding the content.")
return docs
def init_kb(row):
idxnm = search.index_name(row["uid"])
if ES.indexExist(idxnm): return
return ES.createIdx(idxnm, json.load(open("conf/mapping.json", "r")))
model = None
def embedding(docs):
global model
tts = model.encode([rmSpace(d["title_tks"]) for d in docs])
cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs])
vects = 0.1 * tts + 0.9 * cnts
assert len(vects) == len(docs)
for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist()
def rm_doc_from_kb(df):
if len(df) == 0:return
for _,r in df.iterrows():
ES.updateScriptByQuery(Q("term", doc_id=r["did"]),
scripts="""
if(ctx._source.kb_id.contains('%s'))
ctx._source.kb_id.remove(
ctx._source.kb_id.indexOf('%s')
);
"""%(str(r["kb_id"]),str(r["kb_id"])),
idxnm = search.index_name(r["uid"])
)
if len(df) == 0:return
sql = """
delete from kb2_doc where id in (%s)
"""%",".join([str(i) for i in df["kb2doc_id"]])
PG.update(sql)
def main(comm, mod):
global model
from llm import HuEmbedding
model = HuEmbedding()
tm_fnm = f"res/{comm}-{mod}.tm"
tm = findMaxDt(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:return
rm_doc_from_kb(rows.loc[rows.is_deleted == True])
rows = rows.loc[rows.is_deleted == False].reset_index(drop=True)
if len(rows) == 0:return
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
cks = build(r)
if not cks:
tmf.write(str(r["updated_at"]) + "\n")
continue
## TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
embedding(cks)
set_progress(r["kb2doc_id"], random.randint(70, 95)/100.,
"Finished embedding! Start to build index!")
init_kb(r)
es_r = ES.bulk(cks, search.index_name(r["uid"]))
if es_r:
set_progress(r["kb2doc_id"], -1, "Index failure!")
print(es_r)
else: set_progress(r["kb2doc_id"], 1., "Done!")
tmf.write(str(r["updated_at"]) + "\n")
tmf.close()
if __name__ == "__main__":
from mpi4py import MPI
comm = MPI.COMM_WORLD
main(comm.Get_size(), comm.Get_rank())

15
python/tmp.log Normal file
View File

@ -0,0 +1,15 @@
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s]
Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 106184.91it/s]
----------- Model Configuration -----------
Model Arch: GFL
Transform Order:
--transform op: Resize
--transform op: NormalizeImage
--transform op: Permute
--transform op: PadStride
--------------------------------------------
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.
Some weights of the model checkpoint at microsoft/table-transformer-structure-recognition were not used when initializing TableTransformerForObjectDetection: ['model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).

View File

@ -1,24 +0,0 @@
import re
def rmSpace(txt):
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
def findMaxDt(fnm):
m = "1970-01-01 00:00:00"
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
break
l = l.strip("\n")
if l == 'nan':
continue
if l > m:
m = l
except Exception as e:
print("WARNING: can't find " + fnm)
return m

View File

@ -1,31 +0,0 @@
from configparser import ConfigParser
import os
import inspect
CF = ConfigParser()
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
if not os.path.exists(__fnm):
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
assert os.path.exists(
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
if not os.path.exists(__fnm):
__fnm = "./sys.cnf"
CF.read(__fnm)
class Config:
def __init__(self, env):
self.env = env
if env == "spark":
CF.read("./cv.cnf")
def get(self, key, default=None):
global CF
return os.environ.get(key.upper(),
CF[self.env].get(key, default)
)
def init(env):
return Config(env)

View File

@ -1,70 +0,0 @@
import logging
import time
from util import config
import pandas as pd
class Postgres(object):
def __init__(self, env, dbnm):
self.config = config.init(env)
self.conn = None
self.dbnm = dbnm
self.__open__()
def __open__(self):
import psycopg2
try:
if self.conn:
self.__close__()
del self.conn
except Exception as e:
pass
try:
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
user={self.config.get('postgres_user')}
password={self.config.get('postgres_password')}
host={self.config.get('postgres_host')}
port={self.config.get('postgres_port')}""")
except Exception as e:
logging.error(
"Fail to connect %s " %
self.config.get("pgdb_host") + str(e))
def __close__(self):
try:
self.conn.close()
except Exception as e:
logging.error(
"Fail to close %s " %
self.config.get("pgdb_host") + str(e))
def select(self, sql):
for _ in range(10):
try:
return pd.read_sql(sql, self.conn)
except Exception as e:
logging.error(f"Fail to exec {sql} " + str(e))
self.__open__()
time.sleep(1)
return pd.DataFrame()
def update(self, sql):
for _ in range(10):
try:
cur = self.conn.cursor()
cur.execute(sql)
updated_rows = cur.rowcount
self.conn.commit()
cur.close()
return updated_rows
except Exception as e:
logging.error(f"Fail to exec {sql} " + str(e))
self.__open__()
time.sleep(1)
return 0
if __name__ == "__main__":
Postgres("infiniflow", "docgpt")

View File

@ -1,428 +0,0 @@
import re
import logging
import json
import time
import copy
import elasticsearch
from elasticsearch import Elasticsearch
from elasticsearch_dsl import UpdateByQuery, Search, Index, Q
from util import config
logging.info("Elasticsearch version: ", elasticsearch.__version__)
def instance(env):
CF = config.init(env)
ES_DRESS = CF.get("es").split(",")
ES = Elasticsearch(
ES_DRESS,
timeout=600
)
logging.info("ES: ", ES_DRESS, ES.info())
return ES
class HuEs:
def __init__(self, env):
self.env = env
self.info = {}
self.config = config.init(env)
self.conn()
self.idxnm = self.config.get("idx_nm", "")
if not self.es.ping():
raise Exception("Can't connect to ES cluster")
def conn(self):
for _ in range(10):
try:
c = instance(self.env)
if c:
self.es = c
self.info = c.info()
logging.info("Connect to es.")
break
except Exception as e:
logging.error("Fail to connect to es: " + str(e))
time.sleep(1)
def version(self):
v = self.info.get("version", {"number": "5.6"})
v = v["number"].split(".")[0]
return int(v) >= 7
def upsert(self, df, idxnm=""):
res = []
for d in df:
id = d["id"]
del d["id"]
d = {"doc": d, "doc_as_upsert": "true"}
T = False
for _ in range(10):
try:
if not self.version():
r = self.es.update(
index=(
self.idxnm if not idxnm else idxnm),
body=d,
id=id,
doc_type="doc",
refresh=False,
retry_on_conflict=100)
else:
r = self.es.update(
index=(
self.idxnm if not idxnm else idxnm),
body=d,
id=id,
refresh=False,
doc_type="_doc",
retry_on_conflict=100)
logging.info("Successfully upsert: %s" % id)
T = True
break
except Exception as e:
logging.warning("Fail to index: " +
json.dumps(d, ensure_ascii=False) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
T = False
if not T:
res.append(d)
logging.error(
"Fail to index: " +
re.sub(
"[\r\n]",
"",
json.dumps(
d,
ensure_ascii=False)))
d["id"] = id
d["_index"] = self.idxnm
if not res:
return True
return False
def bulk(self, df, idx_nm=None):
ids, acts = {}, []
for d in df:
id = d["id"] if "id" in d else d["_id"]
ids[id] = copy.deepcopy(d)
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
if "id" in d:
del d["id"]
if "_id" in d:
del d["_id"]
acts.append(
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
acts.append({"doc": d, "doc_as_upsert": "true"})
res = []
for _ in range(100):
try:
if elasticsearch.__version__[0] < 8:
r = self.es.bulk(
index=(
self.idxnm if not idx_nm else idx_nm),
body=acts,
refresh=False,
timeout="600s")
else:
r = self.es.bulk(index=(self.idxnm if not idx_nm else
idx_nm), operations=acts,
refresh=False, timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for it in r["items"]:
if "error" in it["update"]:
res.append(str(it["update"]["_id"]) +
":" + str(it["update"]["error"]))
return res
except Exception as e:
logging.warn("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
return res
def bulk4script(self, df):
ids, acts = {}, []
for d in df:
id = d["id"]
ids[id] = copy.deepcopy(d["raw"])
acts.append({"update": {"_id": id, "_index": self.idxnm}})
acts.append(d["script"])
logging.info("bulk upsert: %s" % id)
res = []
for _ in range(10):
try:
if not self.version():
r = self.es.bulk(
index=self.idxnm,
body=acts,
refresh=False,
timeout="600s",
doc_type="doc")
else:
r = self.es.bulk(
index=self.idxnm,
body=acts,
refresh=False,
timeout="600s")
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
return res
for it in r["items"]:
if "error" in it["update"]:
res.append(str(it["update"]["_id"]))
return res
except Exception as e:
logging.warning("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
self.conn()
return res
def rm(self, d):
for _ in range(10):
try:
if not self.version():
r = self.es.delete(
index=self.idxnm,
id=d["id"],
doc_type="doc",
refresh=True)
else:
r = self.es.delete(
index=self.idxnm,
id=d["id"],
refresh=True,
doc_type="_doc")
logging.info("Remove %s" % d["id"])
return True
except Exception as e:
logging.warn("Fail to delete: " + str(d) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3)
continue
if re.search(r"(not_found)", str(e), re.IGNORECASE):
return True
self.conn()
logging.error("Fail to delete: " + str(d))
return False
def search(self, q, idxnm=None, src=False, timeout="2s"):
if not isinstance(q, dict):
q = Search().query(q).to_dict()
for i in range(3):
try:
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
body=q,
timeout=timeout,
# search_type="dfs_query_then_fetch",
track_total_hits=True,
_source=src)
if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.")
return res
except Exception as e:
logging.error(
"ES search exception: " +
str(e) +
"【Q】" +
str(q))
if str(e).find("Timeout") > 0:
continue
raise e
logging.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.")
def updateByQuery(self, q, d):
ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
scripts = ""
for k, v in d.items():
scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
ubq = ubq.script(source=scripts, params=d)
ubq = ubq.params(refresh=False)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
r = ubq.execute()
return True
except Exception as e:
logging.error("ES updateByQuery exception: " +
str(e) + "【Q】" + str(q.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
self.conn()
return False
def updateScriptByQuery(self, q, scripts, idxnm=None):
ubq = UpdateByQuery(
index=self.idxnm if not idxnm else idxnm).using(
self.es).query(q)
ubq = ubq.script(source=scripts)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
ubq = ubq.params(conflicts="proceed")
for i in range(3):
try:
r = ubq.execute()
return True
except Exception as e:
logging.error("ES updateByQuery exception: " +
str(e) + "【Q】" + str(q.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
self.conn()
return False
def deleteByQuery(self, query, idxnm=""):
for i in range(3):
try:
r = self.es.delete_by_query(
index=idxnm if idxnm else self.idxnm,
body=Search().query(query).to_dict())
return True
except Exception as e:
logging.error("ES updateByQuery deleteByQuery: " +
str(e) + "【Q】" + str(query.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False
def update(self, id, script, routing=None):
for i in range(3):
try:
if not self.version():
r = self.es.update(
index=self.idxnm,
id=id,
body=json.dumps(
script,
ensure_ascii=False),
doc_type="doc",
routing=routing,
refresh=False)
else:
r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
routing=routing, refresh=False) # , doc_type="_doc")
return True
except Exception as e:
logging.error("ES update exception: " + str(e) + " id" + str(id) + ", version:" + str(self.version()) +
json.dumps(script, ensure_ascii=False))
if str(e).find("Timeout") > 0:
continue
return False
def indexExist(self, idxnm):
s = Index(idxnm if idxnm else self.idxnm, self.es)
for i in range(3):
try:
return s.exists()
except Exception as e:
logging.error("ES updateByQuery indexExist: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False
def docExist(self, docid, idxnm=None):
for i in range(3):
try:
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
id=docid)
except Exception as e:
logging.error("ES Doc Exist: " + str(e))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue
return False
def createIdx(self, idxnm, mapping):
try:
if elasticsearch.__version__[0] < 8:
return self.es.indices.create(idxnm, body=mapping)
from elasticsearch.client import IndicesClient
return IndicesClient(self.es).create(index=idxnm,
settings=mapping["settings"],
mappings=mapping["mappings"])
except Exception as e:
logging.error("ES create index error %s ----%s" % (idxnm, str(e)))
def deleteIdx(self, idxnm):
try:
return self.es.indices.delete(idxnm, allow_no_indices=True)
except Exception as e:
logging.error("ES delete index error %s ----%s" % (idxnm, str(e)))
def getTotal(self, res):
if isinstance(res["hits"]["total"], type({})):
return res["hits"]["total"]["value"]
return res["hits"]["total"]
def getDocIds(self, res):
return [d["_id"] for d in res["hits"]["hits"]]
def getSource(self, res):
rr = []
for d in res["hits"]["hits"]:
d["_source"]["id"] = d["_id"]
d["_source"]["_score"] = d["_score"]
rr.append(d["_source"])
return rr
def scrollIter(self, pagesize=100, scroll_time='2m', q={
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
for _ in range(100):
try:
page = self.es.search(
index=self.idxnm,
scroll=scroll_time,
size=pagesize,
body=q,
_source=None
)
break
except Exception as e:
logging.error("ES scrolling fail. " + str(e))
time.sleep(3)
sid = page['_scroll_id']
scroll_size = page['hits']['total']["value"]
logging.info("[TOTAL]%d" % scroll_size)
# Start scrolling
while scroll_size > 0:
yield page["hits"]["hits"]
for _ in range(100):
try:
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
break
except Exception as e:
logging.error("ES scrolling fail. " + str(e))
time.sleep(3)
# Update the scroll ID
sid = page['_scroll_id']
# Get the number of results that we returned in the last scroll
scroll_size = len(page['hits']['hits'])

View File

@ -1,84 +0,0 @@
import logging
import time
from util import config
from minio import Minio
from io import BytesIO
class HuMinio(object):
def __init__(self, env):
self.config = config.init(env)
self.conn = None
self.__open__()
def __open__(self):
try:
if self.conn:
self.__close__()
except Exception as e:
pass
try:
self.conn = Minio(self.config.get("minio_host"),
access_key=self.config.get("minio_user"),
secret_key=self.config.get("minio_password"),
secure=False
)
except Exception as e:
logging.error(
"Fail to connect %s " %
self.config.get("minio_host") + str(e))
def __close__(self):
del self.conn
self.conn = None
def put(self, bucket, fnm, binary):
for _ in range(10):
try:
if not self.conn.bucket_exists(bucket):
self.conn.make_bucket(bucket)
r = self.conn.put_object(bucket, fnm,
BytesIO(binary),
len(binary)
)
return r
except Exception as e:
logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
def get(self, bucket, fnm):
for _ in range(10):
try:
r = self.conn.get_object(bucket, fnm)
return r.read()
except Exception as e:
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return
def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e:
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return
if __name__ == "__main__":
conn = HuMinio("infiniflow")
fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg"
from PIL import Image
img = Image.open(fnm)
buff = BytesIO()
img.save(buff, format='JPEG')
print(conn.put("test", "11-408.jpg", buff.getvalue()))
bts = conn.get("test", "11-408.jpg")
img = Image.open(BytesIO(bts))
img.save("test.jpg")

View File

@ -1,36 +0,0 @@
import json
import logging.config
import os
def log_dir():
fnm = os.path.join(os.path.dirname(__file__), '../log/')
if not os.path.exists(fnm):
fnm = os.path.join(os.path.dirname(__file__), '../../log/')
assert os.path.exists(fnm), f"Can't locate log dir: {fnm}"
return fnm
def setup_logging(default_path="conf/logging.json",
default_level=logging.INFO,
env_key="LOG_CFG"):
path = default_path
value = os.getenv(env_key, None)
if value:
path = value
if os.path.exists(path):
with open(path, "r") as f:
config = json.load(f)
fnm = log_dir()
config["handlers"]["info_file_handler"]["filename"] = fnm + "info.log"
config["handlers"]["error_file_handler"]["filename"] = fnm + "error.log"
logging.config.dictConfig(config)
else:
logging.basicConfig(level=default_level)
__fnm = os.path.join(os.path.dirname(__file__), 'conf/logging.json')
if not os.path.exists(__fnm):
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/logging.json')
setup_logging(__fnm)