mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 12:06:42 +08:00
build python version rag-flow (#21)
* clean rust version project * clean rust version project * build python version rag-flow
This commit is contained in:
29
python/Dockerfile
Normal file
29
python/Dockerfile
Normal file
@ -0,0 +1,29 @@
|
||||
FROM ubuntu:22.04 as base
|
||||
|
||||
RUN apt-get update
|
||||
|
||||
ENV TZ="Asia/Taipei"
|
||||
RUN apt-get install -yq \
|
||||
build-essential \
|
||||
curl \
|
||||
libncursesw5-dev \
|
||||
libssl-dev \
|
||||
libsqlite3-dev \
|
||||
libgdbm-dev \
|
||||
libc6-dev \
|
||||
libbz2-dev \
|
||||
software-properties-common \
|
||||
python3.11 python3.11-dev python3-pip
|
||||
|
||||
RUN apt-get install -yq git
|
||||
RUN pip3 config set global.index-url https://mirror.baidu.com/pypi/simple
|
||||
RUN pip3 config set global.trusted-host mirror.baidu.com
|
||||
RUN pip3 install --upgrade pip
|
||||
RUN pip3 install torch==2.0.1
|
||||
RUN pip3 install torch-model-archiver==0.8.2
|
||||
RUN pip3 install torchvision==0.15.2
|
||||
COPY requirements.txt .
|
||||
|
||||
WORKDIR /docgpt
|
||||
ENV PYTHONPATH=/docgpt/
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
|
||||
```shell
|
||||
|
||||
docker pull postgres
|
||||
|
||||
LOCAL_POSTGRES_DATA=./postgres-data
|
||||
|
||||
docker run
|
||||
--name docass-postgres
|
||||
-p 5455:5432
|
||||
-v $LOCAL_POSTGRES_DATA:/var/lib/postgresql/data
|
||||
-e POSTGRES_USER=root
|
||||
-e POSTGRES_PASSWORD=infiniflow_docass
|
||||
-e POSTGRES_DB=docass
|
||||
-d
|
||||
postgres
|
||||
|
||||
docker network create elastic
|
||||
docker pull elasticsearch:8.11.3;
|
||||
docker pull docker.elastic.co/kibana/kibana:8.11.3
|
||||
|
||||
```
|
||||
63
python/]
Normal file
63
python/]
Normal file
@ -0,0 +1,63 @@
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
class Base(ABC):
|
||||
def describe(self, image, max_tokens=300):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
def __init__(self):
|
||||
import openapi
|
||||
openapi.api_key = os.environ["OPENAPI_KEY"]
|
||||
self.client = OpenAI()
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
buffered = BytesIO()
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception as e:
|
||||
image.save(buffered, format="PNG")
|
||||
b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model="gpt-4-vision-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return res.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class QWen(Base):
|
||||
def chat(self, system, history, gen_conf):
|
||||
from http import HTTPStatus
|
||||
from dashscope import Generation
|
||||
from dashscope.api_entities.dashscope_response import Role
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
response = Generation.call(
|
||||
Generation.Models.qwen_turbo,
|
||||
messages=messages,
|
||||
result_format='message'
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
return response.message
|
||||
@ -1,41 +0,0 @@
|
||||
{
|
||||
"version":1,
|
||||
"disable_existing_loggers":false,
|
||||
"formatters":{
|
||||
"simple":{
|
||||
"format":"%(asctime)s - %(name)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s"
|
||||
}
|
||||
},
|
||||
"handlers":{
|
||||
"console":{
|
||||
"class":"logging.StreamHandler",
|
||||
"level":"DEBUG",
|
||||
"formatter":"simple",
|
||||
"stream":"ext://sys.stdout"
|
||||
},
|
||||
"info_file_handler":{
|
||||
"class":"logging.handlers.TimedRotatingFileHandler",
|
||||
"level":"INFO",
|
||||
"formatter":"simple",
|
||||
"filename":"log/info.log",
|
||||
"when": "MIDNIGHT",
|
||||
"interval":1,
|
||||
"backupCount":30,
|
||||
"encoding":"utf8"
|
||||
},
|
||||
"error_file_handler":{
|
||||
"class":"logging.handlers.TimedRotatingFileHandler",
|
||||
"level":"ERROR",
|
||||
"formatter":"simple",
|
||||
"filename":"log/errors.log",
|
||||
"when": "MIDNIGHT",
|
||||
"interval":1,
|
||||
"backupCount":30,
|
||||
"encoding":"utf8"
|
||||
}
|
||||
},
|
||||
"root":{
|
||||
"level":"DEBUG",
|
||||
"handlers":["console","info_file_handler","error_file_handler"]
|
||||
}
|
||||
}
|
||||
@ -1,139 +0,0 @@
|
||||
{
|
||||
"settings": {
|
||||
"index": {
|
||||
"number_of_shards": 4,
|
||||
"number_of_replicas": 0,
|
||||
"refresh_interval" : "1000ms"
|
||||
},
|
||||
"similarity": {
|
||||
"scripted_sim": {
|
||||
"type": "scripted",
|
||||
"script": {
|
||||
"source": "double idf = Math.log(1+(field.docCount-term.docFreq+0.5)/(term.docFreq + 0.5))/Math.log(1+((field.docCount-0.5)/1.5)); return query.boost * idf * Math.min(doc.freq, 1);"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"lat_lon": {"type": "geo_point", "store":"true"}
|
||||
},
|
||||
"date_detection": "true",
|
||||
"dynamic_templates": [
|
||||
{
|
||||
"int": {
|
||||
"match": "*_int",
|
||||
"mapping": {
|
||||
"type": "integer",
|
||||
"store": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"numeric": {
|
||||
"match": "*_flt",
|
||||
"mapping": {
|
||||
"type": "float",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"tks": {
|
||||
"match": "*_tks",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"similarity": "scripted_sim",
|
||||
"analyzer": "whitespace",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"ltks":{
|
||||
"match": "*_ltks",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"analyzer": "whitespace",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"kwd": {
|
||||
"match_pattern": "regex",
|
||||
"match": "^(.*_(kwd|id|ids|uid|uids)|uid)$",
|
||||
"mapping": {
|
||||
"type": "keyword",
|
||||
"similarity": "boolean",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dt": {
|
||||
"match_pattern": "regex",
|
||||
"match": "^.*(_dt|_time|_at)$",
|
||||
"mapping": {
|
||||
"type": "date",
|
||||
"format": "yyyy-MM-dd HH:mm:ss||yyyy-MM-dd||yyyy-MM-dd_HH:mm:ss",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"nested": {
|
||||
"match": "*_nst",
|
||||
"mapping": {
|
||||
"type": "nested"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"object": {
|
||||
"match": "*_obj",
|
||||
"mapping": {
|
||||
"type": "object",
|
||||
"dynamic": "true"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"string": {
|
||||
"match": "*_with_weight",
|
||||
"mapping": {
|
||||
"type": "text",
|
||||
"index": "false",
|
||||
"store": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"string": {
|
||||
"match": "*_fea",
|
||||
"mapping": {
|
||||
"type": "rank_feature"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"dense_vector": {
|
||||
"match": "*_vec",
|
||||
"mapping": {
|
||||
"type": "dense_vector",
|
||||
"index": true,
|
||||
"similarity": "cosine"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"binary": {
|
||||
"match": "*_bin",
|
||||
"mapping": {
|
||||
"type": "binary"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@ -1,9 +0,0 @@
|
||||
[infiniflow]
|
||||
es=http://es01:9200
|
||||
postgres_user=root
|
||||
postgres_password=infiniflow_docgpt
|
||||
postgres_host=postgres
|
||||
postgres_port=5432
|
||||
minio_host=minio:9000
|
||||
minio_user=infiniflow
|
||||
minio_password=infiniflow_docgpt
|
||||
@ -1,21 +0,0 @@
|
||||
import os
|
||||
from .embedding_model import *
|
||||
from .chat_model import *
|
||||
from .cv_model import *
|
||||
|
||||
EmbeddingModel = None
|
||||
ChatModel = None
|
||||
CvModel = None
|
||||
|
||||
|
||||
if os.environ.get("OPENAI_API_KEY"):
|
||||
EmbeddingModel = GptEmbed()
|
||||
ChatModel = GptTurbo()
|
||||
CvModel = GptV4()
|
||||
|
||||
elif os.environ.get("DASHSCOPE_API_KEY"):
|
||||
EmbeddingModel = QWenEmbd()
|
||||
ChatModel = QWenChat()
|
||||
CvModel = QWenCV()
|
||||
else:
|
||||
EmbeddingModel = HuEmbedding()
|
||||
@ -1,37 +0,0 @@
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import os
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def chat(self, system, history, gen_conf):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
def __init__(self):
|
||||
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
res = self.client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=history,
|
||||
**gen_conf)
|
||||
return res.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class QWenChat(Base):
|
||||
def chat(self, system, history, gen_conf):
|
||||
from http import HTTPStatus
|
||||
from dashscope import Generation
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
response = Generation.call(
|
||||
Generation.Models.qwen_turbo,
|
||||
messages=history,
|
||||
result_format='message'
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
return response.message
|
||||
@ -1,66 +0,0 @@
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def describe(self, image, max_tokens=300):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
def image2base64(self, image):
|
||||
if isinstance(image, BytesIO):
|
||||
return base64.b64encode(image.getvalue()).decode("utf-8")
|
||||
buffered = BytesIO()
|
||||
try:
|
||||
image.save(buffered, format="JPEG")
|
||||
except Exception as e:
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
def __init__(self):
|
||||
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||
|
||||
def describe(self, image, max_tokens=300):
|
||||
b64 = self.image2base64(image)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model="gpt-4-vision-preview",
|
||||
messages=self.prompt(b64),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return res.choices[0].message.content.strip()
|
||||
|
||||
|
||||
class QWenCV(Base):
|
||||
def describe(self, image, max_tokens=300):
|
||||
from http import HTTPStatus
|
||||
from dashscope import MultiModalConversation
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
|
||||
messages=self.prompt(self.image2base64(image)))
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
return response.message
|
||||
@ -1,61 +0,0 @@
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
|
||||
class HuEmbedding(Base):
|
||||
def __init__(self):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
res = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
||||
return np.array(res)
|
||||
|
||||
|
||||
class GptEmbed(Base):
|
||||
def __init__(self):
|
||||
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
res = self.client.embeddings.create(input=texts,
|
||||
model="text-embedding-ada-002")
|
||||
return [d["embedding"] for d in res["data"]]
|
||||
|
||||
|
||||
class QWenEmbd(Base):
|
||||
def encode(self, texts: list, batch_size=32, text_type="document"):
|
||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
res = []
|
||||
for txt in texts:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v2,
|
||||
input=txt[:2048],
|
||||
text_type=text_type
|
||||
)
|
||||
res.append(resp["output"]["embeddings"][0]["embedding"])
|
||||
return res
|
||||
@ -1,435 +0,0 @@
|
||||
import re
|
||||
import os
|
||||
import copy
|
||||
import base64
|
||||
import magic
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuChunker:
|
||||
|
||||
def __init__(self):
|
||||
self.MAX_LVL = 12
|
||||
self.proj_patt = [
|
||||
(r"第[零一二三四五六七八九十百]+章", 1),
|
||||
(r"第[零一二三四五六七八九十百]+[条节]", 2),
|
||||
(r"[零一二三四五六七八九十百]+[、 ]", 3),
|
||||
(r"[\((][零一二三四五六七八九十百]+[)\)]", 4),
|
||||
(r"[0-9]+(、|\.[ ]|\.[^0-9])", 5),
|
||||
(r"[0-9]+\.[0-9]+(、|[ ]|[^0-9])", 6),
|
||||
(r"[0-9]+\.[0-9]+\.[0-9]+(、|[ ]|[^0-9])", 7),
|
||||
(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+(、|[ ]|[^0-9])", 8),
|
||||
(r".{,48}[::??]@", 9),
|
||||
(r"[0-9]+)", 10),
|
||||
(r"[\((][0-9]+[)\)]", 11),
|
||||
(r"[零一二三四五六七八九十百]+是", 12),
|
||||
(r"[⚫•➢✓ ]", 12)
|
||||
]
|
||||
self.lines = []
|
||||
|
||||
def _garbage(self, txt):
|
||||
patt = [
|
||||
r"(在此保证|不得以任何形式翻版|请勿传阅|仅供内部使用|未经事先书面授权)",
|
||||
r"(版权(归本公司)*所有|免责声明|保留一切权力|承担全部责任|特别声明|报告中涉及)",
|
||||
r"(不承担任何责任|投资者的通知事项:|任何机构和个人|本报告仅为|不构成投资)",
|
||||
r"(不构成对任何个人或机构投资建议|联系其所在国家|本报告由从事证券交易)",
|
||||
r"(本研究报告由|「认可投资者」|所有研究报告均以|请发邮件至)",
|
||||
r"(本报告仅供|市场有风险,投资需谨慎|本报告中提及的)",
|
||||
r"(本报告反映|此信息仅供|证券分析师承诺|具备证券投资咨询业务资格)",
|
||||
r"^(时间|签字|签章)[::]",
|
||||
r"(参考文献|目录索引|图表索引)",
|
||||
r"[ ]*年[ ]+月[ ]+日",
|
||||
r"^(中国证券业协会|[0-9]+年[0-9]+月[0-9]+日)$",
|
||||
r"\.{10,}",
|
||||
r"(———————END|帮我转发|欢迎收藏|快来关注我吧)"
|
||||
]
|
||||
return any([re.search(p, txt) for p in patt])
|
||||
|
||||
def _proj_match(self, line):
|
||||
for p, j in self.proj_patt:
|
||||
if re.match(p, line):
|
||||
return j
|
||||
return
|
||||
|
||||
def _does_proj_match(self):
|
||||
mat = [None for _ in range(len(self.lines))]
|
||||
for i in range(len(self.lines)):
|
||||
mat[i] = self._proj_match(self.lines[i])
|
||||
return mat
|
||||
|
||||
def naive_text_chunk(self, text, ti="", MAX_LEN=612):
|
||||
if text:
|
||||
self.lines = [l.strip().replace(u'\u3000', u' ')
|
||||
.replace(u'\xa0', u'')
|
||||
for l in text.split("\n\n")]
|
||||
self.lines = [l for l in self.lines if not self._garbage(l)]
|
||||
self.lines = [re.sub(r"([ ]+| )", " ", l)
|
||||
for l in self.lines if l]
|
||||
if not self.lines:
|
||||
return []
|
||||
arr = self.lines
|
||||
|
||||
res = [""]
|
||||
i = 0
|
||||
while i < len(arr):
|
||||
a = arr[i]
|
||||
if not a:
|
||||
i += 1
|
||||
continue
|
||||
if len(a) > MAX_LEN:
|
||||
a_ = a.split("\n")
|
||||
if len(a_) >= 2:
|
||||
arr.pop(i)
|
||||
for j in range(2, len(a_) + 1):
|
||||
if len("\n".join(a_[:j])) >= MAX_LEN:
|
||||
arr.insert(i, "\n".join(a_[:j - 1]))
|
||||
arr.insert(i + 1, "\n".join(a_[j - 1:]))
|
||||
break
|
||||
else:
|
||||
assert False, f"Can't split: {a}"
|
||||
continue
|
||||
|
||||
if len(res[-1]) < MAX_LEN / 3:
|
||||
res[-1] += "\n" + a
|
||||
else:
|
||||
res.append(a)
|
||||
i += 1
|
||||
|
||||
if ti:
|
||||
for i in range(len(res)):
|
||||
if res[i].find("——来自") >= 0:
|
||||
continue
|
||||
res[i] += f"\t——来自“{ti}”"
|
||||
|
||||
return res
|
||||
|
||||
def _merge(self):
|
||||
# merge continuous same level text
|
||||
lines = [self.lines[0]] if self.lines else []
|
||||
for i in range(1, len(self.lines)):
|
||||
if self.mat[i] == self.mat[i - 1] \
|
||||
and len(lines[-1]) < 256 \
|
||||
and len(self.lines[i]) < 256:
|
||||
lines[-1] += "\n" + self.lines[i]
|
||||
continue
|
||||
lines.append(self.lines[i])
|
||||
self.lines = lines
|
||||
self.mat = self._does_proj_match()
|
||||
return self.mat
|
||||
|
||||
def text_chunks(self, text):
|
||||
if text:
|
||||
self.lines = [l.strip().replace(u'\u3000', u' ')
|
||||
.replace(u'\xa0', u'')
|
||||
for l in re.split(r"[\r\n]", text)]
|
||||
self.lines = [l for l in self.lines if not self._garbage(l)]
|
||||
self.lines = [l for l in self.lines if l]
|
||||
self.mat = self._does_proj_match()
|
||||
mat = self._merge()
|
||||
|
||||
tree = []
|
||||
for i in range(len(self.lines)):
|
||||
tree.append({"proj": mat[i],
|
||||
"children": [],
|
||||
"read": False})
|
||||
# find all children
|
||||
for i in range(len(self.lines) - 1):
|
||||
if tree[i]["proj"] is None:
|
||||
continue
|
||||
ed = i + 1
|
||||
while ed < len(tree) and (tree[ed]["proj"] is None or
|
||||
tree[ed]["proj"] > tree[i]["proj"]):
|
||||
ed += 1
|
||||
|
||||
nxt = tree[i]["proj"] + 1
|
||||
st = set([p["proj"] for p in tree[i + 1: ed] if p["proj"]])
|
||||
while nxt not in st:
|
||||
nxt += 1
|
||||
if nxt > self.MAX_LVL:
|
||||
break
|
||||
if nxt <= self.MAX_LVL:
|
||||
for j in range(i + 1, ed):
|
||||
if tree[j]["proj"] is not None:
|
||||
break
|
||||
tree[i]["children"].append(j)
|
||||
for j in range(i + 1, ed):
|
||||
if tree[j]["proj"] != nxt:
|
||||
continue
|
||||
tree[i]["children"].append(j)
|
||||
else:
|
||||
for j in range(i + 1, ed):
|
||||
tree[i]["children"].append(j)
|
||||
|
||||
# get DFS combinations, find all the paths to leaf
|
||||
paths = []
|
||||
|
||||
def dfs(i, path):
|
||||
nonlocal tree, paths
|
||||
path.append(i)
|
||||
tree[i]["read"] = True
|
||||
if len(self.lines[i]) > 256:
|
||||
paths.append(path)
|
||||
return
|
||||
if not tree[i]["children"]:
|
||||
if len(path) > 1 or len(self.lines[i]) >= 32:
|
||||
paths.append(path)
|
||||
return
|
||||
for j in tree[i]["children"]:
|
||||
dfs(j, copy.deepcopy(path))
|
||||
|
||||
for i, t in enumerate(tree):
|
||||
if t["read"]:
|
||||
continue
|
||||
dfs(i, [])
|
||||
|
||||
# concat txt on the path for all paths
|
||||
res = []
|
||||
lines = np.array(self.lines)
|
||||
for p in paths:
|
||||
if len(p) < 2:
|
||||
tree[p[0]]["read"] = False
|
||||
continue
|
||||
txt = "\n".join(lines[p[:-1]]) + "\n" + lines[p[-1]]
|
||||
res.append(txt)
|
||||
# concat continuous orphans
|
||||
assert len(tree) == len(lines)
|
||||
ii = 0
|
||||
while ii < len(tree):
|
||||
if tree[ii]["read"]:
|
||||
ii += 1
|
||||
continue
|
||||
txt = lines[ii]
|
||||
e = ii + 1
|
||||
while e < len(tree) and not tree[e]["read"] and len(txt) < 256:
|
||||
txt += "\n" + lines[e]
|
||||
e += 1
|
||||
res.append(txt)
|
||||
ii = e
|
||||
|
||||
# if the node has not been read, find its daddy
|
||||
def find_daddy(st):
|
||||
nonlocal lines, tree
|
||||
proj = tree[st]["proj"]
|
||||
if len(self.lines[st]) > 512:
|
||||
return [st]
|
||||
if proj is None:
|
||||
proj = self.MAX_LVL + 1
|
||||
for i in range(st - 1, -1, -1):
|
||||
if tree[i]["proj"] and tree[i]["proj"] < proj:
|
||||
a = [st] + find_daddy(i)
|
||||
return a
|
||||
return []
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class PdfChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self, pdf_parser):
|
||||
self.pdf = pdf_parser
|
||||
super().__init__()
|
||||
|
||||
def tableHtmls(self, pdfnm):
|
||||
_, tbls = self.pdf(pdfnm, return_html=True)
|
||||
res = []
|
||||
for img, arr in tbls:
|
||||
if arr[0].find("<table>") < 0:
|
||||
continue
|
||||
buffered = BytesIO()
|
||||
if img:
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(
|
||||
buffered.getvalue()).decode('utf-8') if img else ""
|
||||
res.append({"table": arr[0], "image": img_str})
|
||||
return res
|
||||
|
||||
def html(self, pdfnm):
|
||||
txts, tbls = self.pdf(pdfnm, return_html=True)
|
||||
res = []
|
||||
txt_cks = self.text_chunks(txts)
|
||||
for txt, img in [(self.pdf.remove_tag(c), self.pdf.crop(c))
|
||||
for c in txt_cks]:
|
||||
buffered = BytesIO()
|
||||
if img:
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(
|
||||
buffered.getvalue()).decode('utf-8') if img else ""
|
||||
res.append({"table": "<p>%s</p>" % txt.replace("\n", "<br/>"),
|
||||
"image": img_str})
|
||||
|
||||
for img, arr in tbls:
|
||||
if not arr:
|
||||
continue
|
||||
buffered = BytesIO()
|
||||
if img:
|
||||
img.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(
|
||||
buffered.getvalue()).decode('utf-8') if img else ""
|
||||
res.append({"table": arr[0], "image": img_str})
|
||||
|
||||
return res
|
||||
|
||||
def __call__(self, pdfnm, return_image=True, naive_chunk=False):
|
||||
flds = self.Fields()
|
||||
text, tbls = self.pdf(pdfnm)
|
||||
fnm = pdfnm
|
||||
txt_cks = self.text_chunks(text) if not naive_chunk else \
|
||||
self.naive_text_chunk(text, ti=fnm if isinstance(fnm, str) else "")
|
||||
flds.text_chunks = [(self.pdf.remove_tag(c),
|
||||
self.pdf.crop(c) if return_image else None) for c in txt_cks]
|
||||
|
||||
flds.table_chunks = [(arr, img if return_image else None)
|
||||
for img, arr in tbls]
|
||||
return flds
|
||||
|
||||
|
||||
class DocxChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self, doc_parser):
|
||||
self.doc = doc_parser
|
||||
super().__init__()
|
||||
|
||||
def _does_proj_match(self):
|
||||
mat = []
|
||||
for s in self.styles:
|
||||
s = s.split(" ")[-1]
|
||||
try:
|
||||
mat.append(int(s))
|
||||
except Exception as e:
|
||||
mat.append(None)
|
||||
return mat
|
||||
|
||||
def _merge(self):
|
||||
i = 1
|
||||
while i < len(self.lines):
|
||||
if self.mat[i] == self.mat[i - 1] \
|
||||
and len(self.lines[i - 1]) < 256 \
|
||||
and len(self.lines[i]) < 256:
|
||||
self.lines[i - 1] += "\n" + self.lines[i]
|
||||
self.styles.pop(i)
|
||||
self.lines.pop(i)
|
||||
self.mat.pop(i)
|
||||
continue
|
||||
i += 1
|
||||
self.mat = self._does_proj_match()
|
||||
return self.mat
|
||||
|
||||
def __call__(self, fnm):
|
||||
flds = self.Fields()
|
||||
flds.title = os.path.splitext(
|
||||
os.path.basename(fnm))[0] if isinstance(
|
||||
fnm, type("")) else ""
|
||||
secs, tbls = self.doc(fnm)
|
||||
self.lines = [l for l, s in secs]
|
||||
self.styles = [s for l, s in secs]
|
||||
|
||||
txt_cks = self.text_chunks("")
|
||||
flds.text_chunks = [(t, None) for t in txt_cks if not self._garbage(t)]
|
||||
flds.table_chunks = [(tb, None) for tb in tbls for t in tb if t]
|
||||
return flds
|
||||
|
||||
|
||||
class ExcelChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self, excel_parser):
|
||||
self.excel = excel_parser
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, fnm):
|
||||
flds = self.Fields()
|
||||
flds.text_chunks = [(t, None) for t in self.excel(fnm)]
|
||||
flds.table_chunks = []
|
||||
return flds
|
||||
|
||||
|
||||
class PptChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, fnm):
|
||||
from pptx import Presentation
|
||||
ppt = Presentation(fnm) if isinstance(
|
||||
fnm, str) else Presentation(
|
||||
BytesIO(fnm))
|
||||
flds = self.Fields()
|
||||
flds.text_chunks = []
|
||||
for slide in ppt.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
flds.text_chunks.append((shape.text, None))
|
||||
flds.table_chunks = []
|
||||
return flds
|
||||
|
||||
|
||||
class TextChunker(HuChunker):
|
||||
|
||||
@dataclass
|
||||
class Fields:
|
||||
text_chunks: List = None
|
||||
table_chunks: List = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def is_binary_file(file_path):
|
||||
mime = magic.Magic(mime=True)
|
||||
if isinstance(file_path, str):
|
||||
file_type = mime.from_file(file_path)
|
||||
else:
|
||||
file_type = mime.from_buffer(file_path)
|
||||
if 'text' in file_type:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def __call__(self, fnm):
|
||||
flds = self.Fields()
|
||||
if self.is_binary_file(fnm):
|
||||
return flds
|
||||
with open(fnm, "r") as f:
|
||||
txt = f.read()
|
||||
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
||||
flds.table_chunks = []
|
||||
return flds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(__file__) + "/../")
|
||||
if sys.argv[1].split(".")[-1].lower() == "pdf":
|
||||
from parser import PdfParser
|
||||
ckr = PdfChunker(PdfParser())
|
||||
if sys.argv[1].split(".")[-1].lower().find("doc") >= 0:
|
||||
from parser import DocxParser
|
||||
ckr = DocxChunker(DocxParser())
|
||||
if sys.argv[1].split(".")[-1].lower().find("xlsx") >= 0:
|
||||
from parser import ExcelParser
|
||||
ckr = ExcelChunker(ExcelParser())
|
||||
|
||||
# ckr.html(sys.argv[1])
|
||||
print(ckr(sys.argv[1]))
|
||||
@ -1,411 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import datrie
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from hanziconv import HanziConv
|
||||
|
||||
|
||||
class Huqie:
|
||||
def key_(self, line):
|
||||
return str(line.lower().encode("utf-8"))[2:-1]
|
||||
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
print("[HUQIE]:Build trie", fnm, file=sys.stderr)
|
||||
try:
|
||||
of = open(fnm, "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
line = re.sub(r"[\r\n]+", "", line)
|
||||
line = re.split(r"[ \t]", line)
|
||||
k = self.key_(line[0])
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
|
||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||
self.trie_[self.rkey_(line[0])] = 1
|
||||
self.trie_.save(fnm + ".trie")
|
||||
of.close()
|
||||
except Exception as e:
|
||||
print("[HUQIE]:Faild to build trie, ", fnm, e, file=sys.stderr)
|
||||
|
||||
def __init__(self, debug=False):
|
||||
self.DEBUG = debug
|
||||
self.DENOMINATOR = 1000000
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.DIR_ = ""
|
||||
if os.path.exists("../res/huqie.txt"):
|
||||
self.DIR_ = "../res/huqie"
|
||||
if os.path.exists("./res/huqie.txt"):
|
||||
self.DIR_ = "./res/huqie"
|
||||
if os.path.exists("./huqie.txt"):
|
||||
self.DIR_ = "./huqie"
|
||||
assert self.DIR_, f"【Can't find huqie】"
|
||||
|
||||
self.SPLIT_CHAR = r"([ ,\.<>/?;'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-z\.-]+|[0-9,\.-]+)"
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
|
||||
return
|
||||
except Exception as e:
|
||||
print("[HUQIE]:Build default trie", file=sys.stderr)
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
self.loadDict_(self.DIR_ + ".txt")
|
||||
|
||||
def loadUserDict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception as e:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def addUserDict(self, fnm):
|
||||
self.loadDict_(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""把字符串全角转半角"""
|
||||
rstring = ""
|
||||
for uchar in ustring:
|
||||
inside_code = ord(uchar)
|
||||
if inside_code == 0x3000:
|
||||
inside_code = 0x0020
|
||||
else:
|
||||
inside_code -= 0xfee0
|
||||
if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
|
||||
rstring += uchar
|
||||
else:
|
||||
rstring += chr(inside_code)
|
||||
return rstring
|
||||
|
||||
def _tradi2simp(self, line):
|
||||
return HanziConv.toSimplified(line)
|
||||
|
||||
def dfs_(self, chars, s, preTks, tkslist):
|
||||
MAX_L = 10
|
||||
res = s
|
||||
# if s > MAX_L or s>= len(chars):
|
||||
if s >= len(chars):
|
||||
tkslist.append(preTks)
|
||||
return res
|
||||
|
||||
# pruning
|
||||
S = s + 1
|
||||
if s + 2 <= len(chars):
|
||||
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(
|
||||
self.key_(t2)):
|
||||
S = s + 2
|
||||
if len(preTks) > 2 and len(
|
||||
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||
S = s + 2
|
||||
|
||||
################
|
||||
for e in range(S, len(chars) + 1):
|
||||
t = "".join(chars[s:e])
|
||||
k = self.key_(t)
|
||||
|
||||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||
break
|
||||
|
||||
if k in self.trie_:
|
||||
pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
pretks.append((t, (-12, '')))
|
||||
res = max(res, self.dfs_(chars, e, pretks, tkslist))
|
||||
|
||||
if res > s:
|
||||
return res
|
||||
|
||||
t = "".join(chars[s:s + 1])
|
||||
k = self.key_(t)
|
||||
if k in self.trie_:
|
||||
preTks.append((t, self.trie_[k]))
|
||||
else:
|
||||
preTks.append((t, (-12, '')))
|
||||
|
||||
return self.dfs_(chars, s + 1, preTks, tkslist)
|
||||
|
||||
def freq(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return 0
|
||||
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||||
|
||||
def tag(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return ""
|
||||
return self.trie_[k][1]
|
||||
|
||||
def score_(self, tfts):
|
||||
B = 30
|
||||
F, L, tks = 0, 0, []
|
||||
for tk, (freq, tag) in tfts:
|
||||
F += freq
|
||||
L += 0 if len(tk) < 2 else 1
|
||||
tks.append(tk)
|
||||
F /= len(tks)
|
||||
L /= len(tks)
|
||||
if self.DEBUG:
|
||||
print("[SC]", tks, len(tks), L, F, B / len(tks) + L + F)
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
res.append((tks, s))
|
||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||
|
||||
def merge_(self, tks):
|
||||
patts = [
|
||||
(r"[ ]+", " "),
|
||||
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
|
||||
]
|
||||
# for p,s in patts: tks = re.sub(p, s, tks)
|
||||
|
||||
# if split chars is part of token
|
||||
res = []
|
||||
tks = re.sub(r"[ ]+", " ", tks).split(" ")
|
||||
s = 0
|
||||
while True:
|
||||
if s >= len(tks):
|
||||
break
|
||||
E = s + 1
|
||||
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||||
tk = "".join(tks[s:e])
|
||||
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||||
E = e
|
||||
res.append("".join(tks[s:E]))
|
||||
s = E
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(
|
||||
self.key_(t)):
|
||||
e += 1
|
||||
t = line[s:e]
|
||||
|
||||
while e - 1 > s and self.key_(t) not in self.trie_:
|
||||
e -= 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s = e
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||||
s -= 1
|
||||
t = line[s:e]
|
||||
|
||||
while s + 1 < e and self.key_(t) not in self.trie_:
|
||||
s += 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, '')))
|
||||
|
||||
s -= 1
|
||||
|
||||
return self.score_(res[::-1])
|
||||
|
||||
def qie(self, line):
|
||||
line = self._strQ2B(line).lower()
|
||||
line = self._tradi2simp(line)
|
||||
arr = re.split(self.SPLIT_CHAR, line)
|
||||
res = []
|
||||
for L in arr:
|
||||
if len(L) < 2 or re.match(
|
||||
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
res.append(L)
|
||||
continue
|
||||
# print(L)
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self.maxForward_(L)
|
||||
tks1, s1 = self.maxBackward_(L)
|
||||
if self.DEBUG:
|
||||
print("[FW]", tks, s)
|
||||
print("[BW]", tks1, s1)
|
||||
|
||||
diff = [0 for _ in range(max(len(tks1), len(tks)))]
|
||||
for i in range(min(len(tks1), len(tks))):
|
||||
if tks[i] != tks1[i]:
|
||||
diff[i] = 1
|
||||
|
||||
if s1 > s:
|
||||
tks = tks1
|
||||
|
||||
i = 0
|
||||
while i < len(tks):
|
||||
s = i
|
||||
while s < len(tks) and diff[s] == 0:
|
||||
s += 1
|
||||
if s == len(tks):
|
||||
res.append(" ".join(tks[i:]))
|
||||
break
|
||||
if s > i:
|
||||
res.append(" ".join(tks[i:s]))
|
||||
|
||||
e = s
|
||||
while e < len(tks) and e - s < 5 and diff[e] == 1:
|
||||
e += 1
|
||||
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
|
||||
i = e + 1
|
||||
|
||||
res = " ".join(res)
|
||||
if self.DEBUG:
|
||||
print("[TKS]", self.merge_(res))
|
||||
return self.merge_(res)
|
||||
|
||||
def qieqie(self, tks):
|
||||
res = []
|
||||
for tk in tks.split(" "):
|
||||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||||
res.append(tk)
|
||||
continue
|
||||
tkslist = []
|
||||
if len(tk) > 10:
|
||||
tkslist.append(tk)
|
||||
else:
|
||||
self.dfs_(tk, 0, [], tkslist)
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
if re.match(r"[a-z\.-]+$", tk):
|
||||
for t in stk:
|
||||
if len(t) < 3:
|
||||
stk = tk
|
||||
break
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
|
||||
res.append(stk)
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
if s >= u'\u4e00' and s <= u'\u9fa5':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_number(s):
|
||||
if s >= u'\u0030' and s <= u'\u0039':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naiveQie(txt):
|
||||
tks = []
|
||||
for t in txt.split(" "):
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
) and re.match(r".*[a-zA-Z]$", t):
|
||||
tks.append(" ")
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
|
||||
hq = Huqie()
|
||||
qie = hq.qie
|
||||
qieqie = hq.qieqie
|
||||
tag = hq.tag
|
||||
freq = hq.freq
|
||||
loadUserDict = hq.loadUserDict
|
||||
addUserDict = hq.addUserDict
|
||||
tradi2simp = hq._tradi2simp
|
||||
strQ2B = hq._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
huqie = Huqie(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
tks = huqie.qie(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie("虽然我不怎么玩")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie("这周日你去吗?这周日你有空吗?")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
|
||||
print(huqie.qieqie(tks))
|
||||
tks = huqie.qie(
|
||||
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
|
||||
print(huqie.qieqie(tks))
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
huqie.DEBUG = False
|
||||
huqie.loadUserDict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
print(huqie.qie(line))
|
||||
of.close()
|
||||
@ -1,167 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import math
|
||||
from elasticsearch_dsl import Q, Search
|
||||
from nlp import huqie, term_weight, synonym
|
||||
|
||||
|
||||
class EsQueryer:
|
||||
def __init__(self, es):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.es = es
|
||||
self.syn = synonym.Dealer(None)
|
||||
self.flds = ["ask_tks^10", "ask_small_tks"]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def isChinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1. / len(arr) >= 0.8
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
txt = re.sub(
|
||||
r"是*(什么样的|哪家|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*",
|
||||
"",
|
||||
txt)
|
||||
return re.sub(
|
||||
r"(what|who|how|which|where|why|(is|are|were|was) there) (is|are|were|was)*", "", txt, re.IGNORECASE)
|
||||
|
||||
def question(self, txt, tbl="qa", min_match="60%"):
|
||||
txt = re.sub(
|
||||
r"[ \t,,。??/`!!&]+",
|
||||
" ",
|
||||
huqie.tradi2simp(
|
||||
huqie.strQ2B(
|
||||
txt.lower()))).strip()
|
||||
txt = EsQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
tks = txt.split(" ")
|
||||
q = []
|
||||
for i in range(1, len(tks)):
|
||||
q.append("\"%s %s\"~2" % (tks[i - 1], tks[i]))
|
||||
if not q:
|
||||
q.append(txt)
|
||||
return Q("bool",
|
||||
must=Q("query_string", fields=self.flds,
|
||||
type="best_fields", query=" OR ".join(q),
|
||||
boost=1, minimum_should_match="60%")
|
||||
), txt.split(" ")
|
||||
|
||||
def needQieqie(tk):
|
||||
if len(tk) < 4:
|
||||
return False
|
||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||
return False
|
||||
return True
|
||||
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt): # .split(" "):
|
||||
if not tt:
|
||||
continue
|
||||
twts = self.tw.weights([tt])
|
||||
syns = self.syn.lookup(tt)
|
||||
logging.info(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
sm = huqie.qieqie(tk).split(" ") if needQieqie(tk) else []
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
"",
|
||||
m) for m in sm]
|
||||
sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
if len(sm) < 2:
|
||||
sm = []
|
||||
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk = EsQueryer.subSpecialChar(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = "\"%s\"" % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} %s)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
|
||||
" ".join(sm), " ".join(sm))
|
||||
tms.append((tk, w))
|
||||
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
if len(twts) > 1:
|
||||
tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts]))
|
||||
if re.match(r"[0-9a-z ]+$", tt):
|
||||
tms = f"(\"{tt}\" OR \"%s\")" % huqie.qie(tt)
|
||||
|
||||
syns = " OR ".join(
|
||||
["\"%s\"^0.7" % EsQueryer.subSpecialChar(huqie.qie(s)) for s in syns])
|
||||
if syns:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
|
||||
qs.append(tms)
|
||||
|
||||
flds = copy.deepcopy(self.flds)
|
||||
mst = []
|
||||
if qs:
|
||||
mst.append(
|
||||
Q("query_string", fields=flds, type="best_fields",
|
||||
query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match)
|
||||
)
|
||||
|
||||
return Q("bool",
|
||||
must=mst,
|
||||
), keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3,
|
||||
vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
import numpy as np
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
|
||||
def toDict(tks):
|
||||
d = {}
|
||||
if isinstance(tks, type("")):
|
||||
tks = tks.split(" ")
|
||||
for t, c in self.tw.weights(tks):
|
||||
if t not in d:
|
||||
d[t] = 0
|
||||
d[t] += c
|
||||
return d
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
tksim = [self.similarity(atks, btks) for btks in btkss]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
if isinstance(dtwt, type("")):
|
||||
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt))}
|
||||
if isinstance(qtwt, type("")):
|
||||
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt))}
|
||||
s = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
if k in dtwt:
|
||||
s += v * dtwt[k]
|
||||
q = 1e-9
|
||||
for k, v in qtwt.items():
|
||||
q += v * v
|
||||
d = 1e-9
|
||||
for k, v in dtwt.items():
|
||||
d += v * v
|
||||
return s / math.sqrt(q) / math.sqrt(d)
|
||||
@ -1,252 +0,0 @@
|
||||
import re
|
||||
from elasticsearch_dsl import Q, Search, A
|
||||
from typing import List, Optional, Tuple, Dict, Union
|
||||
from dataclasses import dataclass
|
||||
from util import setup_logging, rmSpace
|
||||
from nlp import huqie, query
|
||||
from datetime import datetime
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def index_name(uid): return f"docgpt_{uid}"
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, es, emb_mdl):
|
||||
self.qryr = query.EsQueryer(es)
|
||||
self.qryr.flds = [
|
||||
"title_tks^10",
|
||||
"title_sm_tks^5",
|
||||
"content_ltks^2",
|
||||
"content_sm_ltks"]
|
||||
self.es = es
|
||||
self.emb_mdl = emb_mdl
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
total: int
|
||||
ids: List[str]
|
||||
query_vector: List[float] = None
|
||||
field: Optional[Dict] = None
|
||||
highlight: Optional[Dict] = None
|
||||
aggregation: Union[List, Dict, None] = None
|
||||
keywords: Optional[List[str]] = None
|
||||
group_docs: List[List] = None
|
||||
|
||||
def _vector(self, txt, sim=0.8, topk=10):
|
||||
return {
|
||||
"field": "q_vec",
|
||||
"k": topk,
|
||||
"similarity": sim,
|
||||
"num_candidates": 1000,
|
||||
"query_vector": self.emb_mdl.encode_queries(txt)
|
||||
}
|
||||
|
||||
def search(self, req, idxnm, tks_num=3):
|
||||
keywords = []
|
||||
qst = req.get("question", "")
|
||||
|
||||
bqry, keywords = self.qryr.question(qst)
|
||||
if req.get("kb_ids"):
|
||||
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
||||
bqry.filter.append(Q("exists", field="q_tks"))
|
||||
bqry.boost = 0.05
|
||||
print(bqry)
|
||||
|
||||
s = Search()
|
||||
pg = int(req.get("page", 1)) - 1
|
||||
ps = int(req.get("size", 1000))
|
||||
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
|
||||
"image_id", "doc_id", "q_vec"])
|
||||
|
||||
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
|
||||
s = s.highlight("content_ltks")
|
||||
s = s.highlight("title_ltks")
|
||||
if not qst:
|
||||
s = s.sort(
|
||||
{"create_time": {"order": "desc", "unmapped_type": "date"}})
|
||||
|
||||
s = s.highlight_options(
|
||||
fragment_size=120,
|
||||
number_of_fragments=5,
|
||||
boundary_scanner_locale="zh-CN",
|
||||
boundary_scanner="SENTENCE",
|
||||
boundary_chars=",./;:\\!(),。?:!……()——、"
|
||||
)
|
||||
s = s.to_dict()
|
||||
q_vec = []
|
||||
if req.get("vector"):
|
||||
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
del s["highlight"]
|
||||
q_vec = s["knn"]["query_vector"]
|
||||
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
||||
print("TOTAL: ", self.es.getTotal(res))
|
||||
if self.es.getTotal(res) == 0 and "knn" in s:
|
||||
bqry, _ = self.qryr.question(qst, min_match="10%")
|
||||
if req.get("kb_ids"):
|
||||
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
||||
s["query"] = bqry.to_dict()
|
||||
s["knn"]["filter"] = bqry.to_dict()
|
||||
s["knn"]["similarity"] = 0.7
|
||||
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
||||
|
||||
kwds = set([])
|
||||
for k in keywords:
|
||||
kwds.add(k)
|
||||
for kk in huqie.qieqie(k).split(" "):
|
||||
if len(kk) < 2:
|
||||
continue
|
||||
if kk in kwds:
|
||||
continue
|
||||
kwds.add(kk)
|
||||
|
||||
aggs = self.getAggregation(res, "docnm_kwd")
|
||||
|
||||
return self.SearchResult(
|
||||
total=self.es.getTotal(res),
|
||||
ids=self.es.getDocIds(res),
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=self.getHighlight(res),
|
||||
field=self.getFields(res, ["docnm_kwd", "content_ltks",
|
||||
"kb_id", "image_id", "doc_id", "q_vec"]),
|
||||
keywords=list(kwds)
|
||||
)
|
||||
|
||||
def getAggregation(self, res, g):
|
||||
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
|
||||
return
|
||||
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
def getHighlight(self, res):
|
||||
def rmspace(line):
|
||||
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
||||
r = []
|
||||
for t in line.split(" "):
|
||||
if not t:
|
||||
continue
|
||||
if len(r) > 0 and len(
|
||||
t) > 0 and r[-1][-1] in eng and t[0] in eng:
|
||||
r.append(" ")
|
||||
r.append(t)
|
||||
r = "".join(r)
|
||||
return r
|
||||
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
if not hlts:
|
||||
continue
|
||||
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
|
||||
return ans
|
||||
|
||||
def getFields(self, sres, flds):
|
||||
res = {}
|
||||
if not flds:
|
||||
return {}
|
||||
for d in self.es.getSource(sres):
|
||||
m = {n: d.get(n) for n in flds if d.get(n) is not None}
|
||||
for n, v in m.items():
|
||||
if isinstance(v, type([])):
|
||||
m[n] = "\t".join([str(vv) for vv in v])
|
||||
continue
|
||||
if not isinstance(v, type("")):
|
||||
m[n] = str(m[n])
|
||||
m[n] = rmSpace(m[n])
|
||||
|
||||
if m:
|
||||
res[d["id"]] = m
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def trans2floats(txt):
|
||||
return [float(t) for t in txt.split("\t")]
|
||||
|
||||
def insert_citations(self, ans, top_idx, sres,
|
||||
vfield="q_vec", cfield="content_ltks"):
|
||||
|
||||
ins_embd = [Dealer.trans2floats(
|
||||
sres.field[sres.ids[i]][vfield]) for i in top_idx]
|
||||
ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
|
||||
s = 0
|
||||
e = 0
|
||||
res = ""
|
||||
|
||||
def citeit():
|
||||
nonlocal s, e, ans, res
|
||||
if not ins_embd:
|
||||
return
|
||||
embd = self.emb_mdl.encode(ans[s: e])
|
||||
sim = self.qryr.hybrid_similarity(embd,
|
||||
ins_embd,
|
||||
huqie.qie(ans[s:e]).split(" "),
|
||||
ins_tw)
|
||||
print(ans[s: e], sim)
|
||||
mx = np.max(sim) * 0.99
|
||||
if mx < 0.55:
|
||||
return
|
||||
cita = list(set([top_idx[i]
|
||||
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
|
||||
for i in cita:
|
||||
res += f"@?{i}?@"
|
||||
|
||||
return cita
|
||||
|
||||
punct = set(";。?!!")
|
||||
if not self.qryr.isChinese(ans):
|
||||
punct.add("?")
|
||||
punct.add(".")
|
||||
while e < len(ans):
|
||||
if e - s < 12 or ans[e] not in punct:
|
||||
e += 1
|
||||
continue
|
||||
if ans[e] == "." and e + \
|
||||
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
|
||||
e += 1
|
||||
continue
|
||||
if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
|
||||
e += 1
|
||||
continue
|
||||
res += ans[s: e]
|
||||
citeit()
|
||||
res += ans[e]
|
||||
e += 1
|
||||
s = e
|
||||
|
||||
if s < len(ans):
|
||||
res += ans[s:]
|
||||
citeit()
|
||||
|
||||
return res
|
||||
|
||||
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
|
||||
vfield="q_vec", cfield="content_ltks"):
|
||||
ins_embd = [
|
||||
Dealer.trans2floats(
|
||||
sres.field[i]["q_vec"]) for i in sres.ids]
|
||||
if not ins_embd:
|
||||
return []
|
||||
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
|
||||
# return CosineSimilarity([sres.query_vector], ins_embd)[0]
|
||||
sim = self.qryr.hybrid_similarity(sres.query_vector,
|
||||
ins_embd,
|
||||
huqie.qie(query).split(" "),
|
||||
ins_tw, tkweight, vtweight)
|
||||
return sim
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from util import es_conn
|
||||
SE = Dealer(es_conn.HuEs("infiniflow"))
|
||||
qs = [
|
||||
"胡凯",
|
||||
""
|
||||
]
|
||||
for q in qs:
|
||||
print(">>>>>>>>>>>>>>>>>>>>", q)
|
||||
print(SE.search(
|
||||
{"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
|
||||
@ -1,67 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self, redis=None):
|
||||
|
||||
self.lookup_num = 100000000
|
||||
self.load_tm = time.time() - 1000000
|
||||
self.dictionary = None
|
||||
try:
|
||||
self.dictionary = json.load(open("./synonym.json", 'r'))
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
self.dictionary = json.load(open("./res/synonym.json", 'r'))
|
||||
except Exception as e:
|
||||
try:
|
||||
self.dictionary = json.load(open("../res/synonym.json", 'r'))
|
||||
except Exception as e:
|
||||
logging.warn("Miss synonym.json")
|
||||
self.dictionary = {}
|
||||
|
||||
if not redis:
|
||||
logging.warning(
|
||||
"Realtime synonym is disabled, since no redis connection.")
|
||||
if not len(self.dictionary.keys()):
|
||||
logging.warning(f"Fail to load synonym")
|
||||
|
||||
self.redis = redis
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
if not self.redis:
|
||||
return
|
||||
|
||||
if self.lookup_num < 100:
|
||||
return
|
||||
tm = time.time()
|
||||
if tm - self.load_tm < 3600:
|
||||
return
|
||||
|
||||
self.load_tm = time.time()
|
||||
self.lookup_num = 0
|
||||
d = self.redis.get("kevin_synonyms")
|
||||
if not d:
|
||||
return
|
||||
try:
|
||||
d = json.loads(d)
|
||||
self.dictionary = d
|
||||
except Exception as e:
|
||||
logging.error("Fail to load synonym!" + str(e))
|
||||
|
||||
def lookup(self, tk):
|
||||
self.lookup_num += 1
|
||||
self.load()
|
||||
res = self.dictionary.get(re.sub(r"[ \t]+", " ", tk.lower()), [])
|
||||
if isinstance(res, str):
|
||||
res = [res]
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dl = Dealer()
|
||||
print(dl.dictionary)
|
||||
@ -1,216 +0,0 @@
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
from nlp import huqie
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self):
|
||||
self.stop_words = set(["请问",
|
||||
"您",
|
||||
"你",
|
||||
"我",
|
||||
"他",
|
||||
"是",
|
||||
"的",
|
||||
"就",
|
||||
"有",
|
||||
"于",
|
||||
"及",
|
||||
"即",
|
||||
"在",
|
||||
"为",
|
||||
"最",
|
||||
"有",
|
||||
"从",
|
||||
"以",
|
||||
"了",
|
||||
"将",
|
||||
"与",
|
||||
"吗",
|
||||
"吧",
|
||||
"中",
|
||||
"#",
|
||||
"什么",
|
||||
"怎么",
|
||||
"哪个",
|
||||
"哪些",
|
||||
"啥",
|
||||
"相关"])
|
||||
|
||||
def load_dict(fnm):
|
||||
res = {}
|
||||
f = open(fnm, "r")
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:
|
||||
break
|
||||
arr = l.replace("\n", "").split("\t")
|
||||
if len(arr) < 2:
|
||||
res[arr[0]] = 0
|
||||
else:
|
||||
res[arr[0]] = int(arr[1])
|
||||
|
||||
c = 0
|
||||
for _, v in res.items():
|
||||
c += v
|
||||
if c == 0:
|
||||
return set(res.keys())
|
||||
return res
|
||||
|
||||
fnm = os.path.join(os.path.dirname(__file__), '../res/')
|
||||
if not os.path.exists(fnm):
|
||||
fnm = os.path.join(os.path.dirname(__file__), '../../res/')
|
||||
self.ne, self.df = {}, {}
|
||||
try:
|
||||
self.ne = json.load(open(fnm + "ner.json", "r"))
|
||||
except Exception as e:
|
||||
print("[WARNING] Load ner.json FAIL!")
|
||||
try:
|
||||
self.df = load_dict(fnm + "term.freq")
|
||||
except Exception as e:
|
||||
print("[WARNING] Load term.freq FAIL!")
|
||||
|
||||
def pretoken(self, txt, num=False, stpwd=True):
|
||||
patt = [
|
||||
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
|
||||
]
|
||||
rewt = [
|
||||
]
|
||||
for p, r in rewt:
|
||||
txt = re.sub(p, r, txt)
|
||||
|
||||
res = []
|
||||
for t in huqie.qie(txt).split(" "):
|
||||
tk = t
|
||||
if (stpwd and tk in self.stop_words) or (
|
||||
re.match(r"[0-9]$", tk) and not num):
|
||||
continue
|
||||
for p in patt:
|
||||
if re.match(p, t):
|
||||
tk = "#"
|
||||
break
|
||||
tk = re.sub(r"([\+\\-])", r"\\\1", tk)
|
||||
if tk != "#" and tk:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
tks) > 1 and len(tks[i + 1]) > 1: # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
res.append(" ".join(tks[i:j]))
|
||||
i = j
|
||||
else:
|
||||
res.append(" ".join(tks[i:i + 2]))
|
||||
i = i + 2
|
||||
else:
|
||||
if len(tks[i]) > 0:
|
||||
res.append(tks[i])
|
||||
i += 1
|
||||
return [t for t in res if t]
|
||||
|
||||
def ner(self, t):
|
||||
if not self.ne:
|
||||
return ""
|
||||
res = self.ne.get(t, "")
|
||||
if res:
|
||||
return res
|
||||
|
||||
def split(self, txt):
|
||||
tks = []
|
||||
for t in re.sub(r"[ \t]+", " ", txt).split(" "):
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
|
||||
re.match(r".*[a-zA-Z]$", t) and tks and \
|
||||
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
|
||||
tks[-1] = tks[-1] + " " + t
|
||||
else:
|
||||
tks.append(t)
|
||||
return tks
|
||||
|
||||
def weights(self, tks):
|
||||
def skill(t):
|
||||
if t not in self.sk:
|
||||
return 1
|
||||
return 6
|
||||
|
||||
def ner(t):
|
||||
if not self.ne or t not in self.ne:
|
||||
return 1
|
||||
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
|
||||
"firstnm": 1}
|
||||
return m[self.ne[t]]
|
||||
|
||||
def postag(t):
|
||||
t = huqie.tag(t)
|
||||
if t in set(["r", "c", "d"]):
|
||||
return 0.3
|
||||
if t in set(["ns", "nt"]):
|
||||
return 3
|
||||
if t in set(["n"]):
|
||||
return 2
|
||||
if re.match(r"[0-9-]+", t):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def freq(t):
|
||||
if re.match(r"[0-9\. -]+$", t):
|
||||
return 10000
|
||||
s = huqie.freq(t)
|
||||
if not s and re.match(r"[a-z\. -]+$", t):
|
||||
return 10
|
||||
if not s:
|
||||
s = 0
|
||||
|
||||
if not s and len(t) >= 4:
|
||||
s = [tt for tt in huqie.qieqie(t).split(" ") if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
s = np.min([freq(tt) for tt in s]) / 6.
|
||||
else:
|
||||
s = 0
|
||||
|
||||
return max(s, 10)
|
||||
|
||||
def df(t):
|
||||
if re.match(r"[0-9\. -]+$", t):
|
||||
return 100000
|
||||
if t in self.df:
|
||||
return self.df[t] + 3
|
||||
elif re.match(r"[a-z\. -]+$", t):
|
||||
return 3
|
||||
elif len(t) >= 4:
|
||||
s = [tt for tt in huqie.qieqie(t).split(" ") if len(tt) > 1]
|
||||
if len(s) > 1:
|
||||
return max(3, np.min([df(tt) for tt in s]) / 6.)
|
||||
|
||||
return 3
|
||||
|
||||
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
|
||||
|
||||
tw = []
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
np.array([ner(t) * postag(t) for t in tt])
|
||||
|
||||
tw.extend(zip(tt, wts))
|
||||
|
||||
S = np.sum([s for _, s in tw])
|
||||
return [(t, s / S) for t, s in tw]
|
||||
0
python/output/ToPDF.pdf
Normal file
0
python/output/ToPDF.pdf
Normal file
@ -1,3 +0,0 @@
|
||||
from .pdf_parser import HuParser as PdfParser
|
||||
from .docx_parser import HuDocxParser as DocxParser
|
||||
from .excel_parser import HuExcelParser as ExcelParser
|
||||
@ -1,104 +0,0 @@
|
||||
from docx import Document
|
||||
import re
|
||||
import pandas as pd
|
||||
from collections import Counter
|
||||
from nlp import huqie
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuDocxParser:
|
||||
|
||||
def __extract_table_content(self, tb):
|
||||
df = []
|
||||
for row in tb.rows:
|
||||
df.append([c.text for c in row.cells])
|
||||
return self.__compose_table_content(pd.DataFrame(df))
|
||||
|
||||
def __compose_table_content(self, df):
|
||||
|
||||
def blockType(b):
|
||||
patt = [
|
||||
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[年/-][0-9]{1,2}月*$", "Dt"),
|
||||
("^[0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^第*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[ABCDE]$", "DT"),
|
||||
("^[0-9.,+%/ -]+$", "Nu"),
|
||||
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
||||
(r"^[A-Z]*[a-z' -]+$", "En"),
|
||||
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
||||
(r"^.{1}$", "Sg")
|
||||
]
|
||||
for p, n in patt:
|
||||
if re.search(p, b):
|
||||
return n
|
||||
tks = [t for t in huqie.qie(b).split(" ") if len(t) > 1]
|
||||
if len(tks) > 3:
|
||||
if len(tks) < 12:
|
||||
return "Tx"
|
||||
else:
|
||||
return "Lx"
|
||||
|
||||
if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
|
||||
return "Nr"
|
||||
|
||||
return "Ot"
|
||||
|
||||
if len(df) < 2:
|
||||
return []
|
||||
max_type = Counter([blockType(str(df.iloc[i, j])) for i in range(
|
||||
1, len(df)) for j in range(len(df.iloc[i, :]))])
|
||||
max_type = max(max_type.items(), key=lambda x: x[1])[0]
|
||||
|
||||
colnm = len(df.iloc[0, :])
|
||||
hdrows = [0] # header is not nessesarily appear in the first line
|
||||
if max_type == "Nu":
|
||||
for r in range(1, len(df)):
|
||||
tys = Counter([blockType(str(df.iloc[r, j]))
|
||||
for j in range(len(df.iloc[r, :]))])
|
||||
tys = max(tys.items(), key=lambda x: x[1])[0]
|
||||
if tys != max_type:
|
||||
hdrows.append(r)
|
||||
|
||||
lines = []
|
||||
for i in range(1, len(df)):
|
||||
if i in hdrows:
|
||||
continue
|
||||
hr = [r - i for r in hdrows]
|
||||
hr = [r for r in hr if r < 0]
|
||||
t = len(hr) - 1
|
||||
while t > 0:
|
||||
if hr[t] - hr[t - 1] > 1:
|
||||
hr = hr[t:]
|
||||
break
|
||||
t -= 1
|
||||
headers = []
|
||||
for j in range(len(df.iloc[i, :])):
|
||||
t = []
|
||||
for h in hr:
|
||||
x = str(df.iloc[i + h, j]).strip()
|
||||
if x in t:
|
||||
continue
|
||||
t.append(x)
|
||||
t = ",".join(t)
|
||||
if t:
|
||||
t += ": "
|
||||
headers.append(t)
|
||||
cells = []
|
||||
for j in range(len(df.iloc[i, :])):
|
||||
if not str(df.iloc[i, j]):
|
||||
continue
|
||||
cells.append(headers[j] + str(df.iloc[i, j]))
|
||||
lines.append(";".join(cells))
|
||||
|
||||
if colnm > 3:
|
||||
return lines
|
||||
return ["\n".join(lines)]
|
||||
|
||||
def __call__(self, fnm):
|
||||
self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm))
|
||||
secs = [(p.text, p.style.name) for p in self.doc.paragraphs]
|
||||
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
|
||||
return secs, tbls
|
||||
@ -1,25 +0,0 @@
|
||||
from openpyxl import load_workbook
|
||||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuExcelParser:
|
||||
def __call__(self, fnm):
|
||||
if isinstance(fnm, str):
|
||||
wb = load_workbook(fnm)
|
||||
else:
|
||||
wb = load_workbook(BytesIO(fnm))
|
||||
res = []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
lines = []
|
||||
for r in ws.rows:
|
||||
lines.append(
|
||||
"\t".join([str(c.value) if c.value is not None else "" for c in r]))
|
||||
res.append(f"《{sheetname}》\n" + "\n".join(lines))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
psr = HuExcelParser()
|
||||
psr(sys.argv[1])
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,194 +0,0 @@
|
||||
accelerate==0.24.1
|
||||
addict==2.4.0
|
||||
aiobotocore==2.7.0
|
||||
aiofiles==23.2.1
|
||||
aiohttp==3.8.6
|
||||
aioitertools==0.11.0
|
||||
aiosignal==1.3.1
|
||||
aliyun-python-sdk-core==2.14.0
|
||||
aliyun-python-sdk-kms==2.16.2
|
||||
altair==5.1.2
|
||||
anyio==3.7.1
|
||||
astor==0.8.1
|
||||
async-timeout==4.0.3
|
||||
attrdict==2.0.1
|
||||
attrs==23.1.0
|
||||
Babel==2.13.1
|
||||
bce-python-sdk==0.8.92
|
||||
beautifulsoup4==4.12.2
|
||||
bitsandbytes==0.41.1
|
||||
blinker==1.7.0
|
||||
botocore==1.31.64
|
||||
cachetools==5.3.2
|
||||
certifi==2023.7.22
|
||||
cffi==1.16.0
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
cloudpickle==3.0.0
|
||||
contourpy==1.2.0
|
||||
crcmod==1.7
|
||||
cryptography==41.0.5
|
||||
cssselect==1.2.0
|
||||
cssutils==2.9.0
|
||||
cycler==0.12.1
|
||||
Cython==3.0.5
|
||||
datasets==2.13.0
|
||||
datrie==0.8.2
|
||||
decorator==5.1.1
|
||||
defusedxml==0.7.1
|
||||
dill==0.3.6
|
||||
einops==0.7.0
|
||||
elastic-transport==8.10.0
|
||||
elasticsearch==8.10.1
|
||||
elasticsearch-dsl==8.9.0
|
||||
et-xmlfile==1.1.0
|
||||
fastapi==0.104.1
|
||||
ffmpy==0.3.1
|
||||
filelock==3.13.1
|
||||
fire==0.5.0
|
||||
FlagEmbedding==1.1.5
|
||||
Flask==3.0.0
|
||||
flask-babel==4.0.0
|
||||
fonttools==4.44.0
|
||||
frozenlist==1.4.0
|
||||
fsspec==2023.10.0
|
||||
future==0.18.3
|
||||
gast==0.5.4
|
||||
-e
|
||||
git+https://github.com/ggerganov/llama.cpp.git@5f6e0c0dff1e7a89331e6b25eca9a9fd71324069#egg=gguf&subdirectory=gguf-py
|
||||
gradio==3.50.2
|
||||
gradio_client==0.6.1
|
||||
greenlet==3.0.1
|
||||
h11==0.14.0
|
||||
hanziconv==0.3.2
|
||||
httpcore==1.0.1
|
||||
httpx==0.25.1
|
||||
huggingface-hub==0.17.3
|
||||
idna==3.4
|
||||
imageio==2.31.6
|
||||
imgaug==0.4.0
|
||||
importlib-metadata==6.8.0
|
||||
importlib-resources==6.1.0
|
||||
install==1.3.5
|
||||
itsdangerous==2.1.2
|
||||
Jinja2==3.1.2
|
||||
jmespath==0.10.0
|
||||
joblib==1.3.2
|
||||
jsonschema==4.19.2
|
||||
jsonschema-specifications==2023.7.1
|
||||
kiwisolver==1.4.5
|
||||
lazy_loader==0.3
|
||||
lmdb==1.4.1
|
||||
lxml==4.9.3
|
||||
MarkupSafe==2.1.3
|
||||
matplotlib==3.8.1
|
||||
modelscope==1.9.4
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.14
|
||||
networkx==3.2.1
|
||||
nltk==3.8.1
|
||||
numpy==1.24.4
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
nvidia-nccl-cu12==2.18.1
|
||||
nvidia-nvjitlink-cu12==12.3.52
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
opencv-contrib-python==4.6.0.66
|
||||
opencv-python==4.6.0.66
|
||||
openpyxl==3.1.2
|
||||
opt-einsum==3.3.0
|
||||
orjson==3.9.10
|
||||
oss2==2.18.3
|
||||
packaging==23.2
|
||||
paddleocr==2.7.0.3
|
||||
paddlepaddle-gpu==2.5.2.post120
|
||||
pandas==2.1.2
|
||||
pdf2docx==0.5.5
|
||||
pdfminer.six==20221105
|
||||
pdfplumber==0.10.3
|
||||
Pillow==10.0.1
|
||||
platformdirs==3.11.0
|
||||
premailer==3.10.0
|
||||
protobuf==4.25.0
|
||||
psutil==5.9.6
|
||||
pyarrow==14.0.0
|
||||
pyclipper==1.3.0.post5
|
||||
pycocotools==2.0.7
|
||||
pycparser==2.21
|
||||
pycryptodome==3.19.0
|
||||
pydantic==1.10.13
|
||||
pydub==0.25.1
|
||||
PyMuPDF==1.20.2
|
||||
pyparsing==3.1.1
|
||||
pypdfium2==4.23.1
|
||||
python-dateutil==2.8.2
|
||||
python-docx==1.1.0
|
||||
python-multipart==0.0.6
|
||||
pytz==2023.3.post1
|
||||
PyYAML==6.0.1
|
||||
rapidfuzz==3.5.2
|
||||
rarfile==4.1
|
||||
referencing==0.30.2
|
||||
regex==2023.10.3
|
||||
requests==2.31.0
|
||||
rpds-py==0.12.0
|
||||
s3fs==2023.10.0
|
||||
safetensors==0.4.0
|
||||
scikit-image==0.22.0
|
||||
scikit-learn==1.3.2
|
||||
scipy==1.11.3
|
||||
semantic-version==2.10.0
|
||||
sentence-transformers==2.2.2
|
||||
sentencepiece==0.1.98
|
||||
shapely==2.0.2
|
||||
simplejson==3.19.2
|
||||
six==1.16.0
|
||||
sniffio==1.3.0
|
||||
sortedcontainers==2.4.0
|
||||
soupsieve==2.5
|
||||
SQLAlchemy==2.0.23
|
||||
starlette==0.27.0
|
||||
sympy==1.12
|
||||
tabulate==0.9.0
|
||||
tblib==3.0.0
|
||||
termcolor==2.3.0
|
||||
threadpoolctl==3.2.0
|
||||
tifffile==2023.9.26
|
||||
tiktoken==0.5.1
|
||||
timm==0.9.10
|
||||
tokenizers==0.13.3
|
||||
tomli==2.0.1
|
||||
toolz==0.12.0
|
||||
torch==2.1.0
|
||||
torchaudio==2.1.0
|
||||
torchvision==0.16.0
|
||||
tornado==6.3.3
|
||||
tqdm==4.66.1
|
||||
transformers==4.33.0
|
||||
transformers-stream-generator==0.0.4
|
||||
triton==2.1.0
|
||||
typing_extensions==4.8.0
|
||||
tzdata==2023.3
|
||||
urllib3==2.0.7
|
||||
uvicorn==0.24.0
|
||||
uvloop==0.19.0
|
||||
visualdl==2.5.3
|
||||
websockets==11.0.3
|
||||
Werkzeug==3.0.1
|
||||
wrapt==1.15.0
|
||||
xgboost==2.0.1
|
||||
xinference==0.6.0
|
||||
xorbits==0.7.0
|
||||
xoscar==0.1.3
|
||||
xxhash==3.4.1
|
||||
yapf==0.40.2
|
||||
yarl==1.9.2
|
||||
zipp==3.17.0
|
||||
8
python/res/1-0.tm
Normal file
8
python/res/1-0.tm
Normal file
@ -0,0 +1,8 @@
|
||||
2023-12-20 11:44:08.791336+00:00
|
||||
2023-12-20 11:44:08.853249+00:00
|
||||
2023-12-20 11:44:08.909933+00:00
|
||||
2023-12-21 00:47:09.996757+00:00
|
||||
2023-12-20 11:44:08.965855+00:00
|
||||
2023-12-20 11:44:09.011682+00:00
|
||||
2023-12-21 00:47:10.063326+00:00
|
||||
2023-12-20 11:44:09.069486+00:00
|
||||
555629
python/res/huqie.txt
555629
python/res/huqie.txt
File diff suppressed because it is too large
Load Diff
12519
python/res/ner.json
12519
python/res/ner.json
File diff suppressed because it is too large
Load Diff
10539
python/res/synonym.json
10539
python/res/synonym.json
File diff suppressed because it is too large
Load Diff
3
python/res/thumbnail-1-0.tm
Normal file
3
python/res/thumbnail-1-0.tm
Normal file
@ -0,0 +1,3 @@
|
||||
2023-12-27 08:21:49.309802+00:00
|
||||
2023-12-27 08:37:22.407772+00:00
|
||||
2023-12-27 08:59:18.845627+00:00
|
||||
@ -1,118 +0,0 @@
|
||||
import sys, datetime, random, re, cv2
|
||||
from os.path import dirname, realpath
|
||||
sys.path.append(dirname(realpath(__file__)) + "/../")
|
||||
from util.db_conn import Postgres
|
||||
from util.minio_conn import HuMinio
|
||||
from util import findMaxDt
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import pdfplumber
|
||||
|
||||
|
||||
PG = Postgres("infiniflow", "docgpt")
|
||||
MINIO = HuMinio("infiniflow")
|
||||
def set_thumbnail(did, base64):
|
||||
sql = f"""
|
||||
update doc_info set thumbnail_base64='{base64}'
|
||||
where
|
||||
did={did}
|
||||
"""
|
||||
PG.update(sql)
|
||||
|
||||
|
||||
def collect(comm, mod, tm):
|
||||
sql = f"""
|
||||
select
|
||||
did, uid, doc_name, location, updated_at
|
||||
from doc_info
|
||||
where
|
||||
updated_at >= '{tm}'
|
||||
and MOD(did, {comm}) = {mod}
|
||||
and is_deleted=false
|
||||
and type <> 'folder'
|
||||
and thumbnail_base64=''
|
||||
order by updated_at asc
|
||||
limit 10
|
||||
"""
|
||||
docs = PG.select(sql)
|
||||
if len(docs) == 0:return pd.DataFrame()
|
||||
|
||||
mtm = str(docs["updated_at"].max())[:19]
|
||||
print("TOTAL:", len(docs), "To: ", mtm)
|
||||
return docs
|
||||
|
||||
|
||||
def build(row):
|
||||
if not re.search(r"\.(pdf|jpg|jpeg|png|gif|svg|apng|icon|ico|webp|mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$",
|
||||
row["doc_name"].lower().strip()):
|
||||
set_thumbnail(row["did"], "_")
|
||||
return
|
||||
|
||||
def thumbnail(img, SIZE=128):
|
||||
w,h = img.size
|
||||
p = SIZE/max(w, h)
|
||||
w, h = int(w*p), int(h*p)
|
||||
img.thumbnail((w, h))
|
||||
buffered = BytesIO()
|
||||
try:
|
||||
img.save(buffered, format="JPEG")
|
||||
except Exception as e:
|
||||
try:
|
||||
img.save(buffered, format="PNG")
|
||||
except Exception as ee:
|
||||
pass
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
iobytes = BytesIO(MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
|
||||
if re.search(r"\.pdf$", row["doc_name"].lower().strip()):
|
||||
pdf = pdfplumber.open(iobytes)
|
||||
img = pdf.pages[0].to_image().annotated
|
||||
set_thumbnail(row["did"], thumbnail(img))
|
||||
|
||||
if re.search(r"\.(jpg|jpeg|png|gif|svg|apng|webp|icon|ico)$", row["doc_name"].lower().strip()):
|
||||
img = Image.open(iobytes)
|
||||
set_thumbnail(row["did"], thumbnail(img))
|
||||
|
||||
if re.search(r"\.(mpg|mpeg|avi|rm|rmvb|mov|wmv|mp4)$", row["doc_name"].lower().strip()):
|
||||
url = MINIO.get_presigned_url("%s-upload"%str(row["uid"]),
|
||||
row["location"],
|
||||
expires=datetime.timedelta(seconds=60)
|
||||
)
|
||||
cap = cv2.VideoCapture(url)
|
||||
succ = cap.isOpened()
|
||||
i = random.randint(1, 11)
|
||||
while succ:
|
||||
ret, frame = cap.read()
|
||||
if not ret: break
|
||||
if i > 0:
|
||||
i -= 1
|
||||
continue
|
||||
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
print(img.size)
|
||||
set_thumbnail(row["did"], thumbnail(img))
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def main(comm, mod):
|
||||
global model
|
||||
tm_fnm = f"res/thumbnail-{comm}-{mod}.tm"
|
||||
tm = findMaxDt(tm_fnm)
|
||||
rows = collect(comm, mod, tm)
|
||||
if len(rows) == 0:return
|
||||
|
||||
tmf = open(tm_fnm, "a+")
|
||||
for _, r in rows.iterrows():
|
||||
build(r)
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
tmf.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
main(comm.Get_size(), comm.Get_rank())
|
||||
|
||||
@ -1,165 +0,0 @@
|
||||
#-*- coding:utf-8 -*-
|
||||
import sys, os, re,inspect,json,traceback,logging,argparse, copy
|
||||
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
|
||||
from tornado.web import RequestHandler,Application
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.options import define,options
|
||||
from util import es_conn, setup_logging
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
from nlp import huqie
|
||||
from nlp import query as Query
|
||||
from nlp import search
|
||||
from llm import HuEmbedding, GptTurbo
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from util import config
|
||||
from timeit import default_timer as timer
|
||||
from collections import OrderedDict
|
||||
from llm import ChatModel, EmbeddingModel
|
||||
|
||||
SE = None
|
||||
CFIELD="content_ltks"
|
||||
EMBEDDING = EmbeddingModel
|
||||
LLM = ChatModel
|
||||
|
||||
def get_QA_pairs(hists):
|
||||
pa = []
|
||||
for h in hists:
|
||||
for k in ["user", "assistant"]:
|
||||
if h.get(k):
|
||||
pa.append({
|
||||
"content": h[k],
|
||||
"role": k,
|
||||
})
|
||||
|
||||
for p in pa[:-1]: assert len(p) == 2, p
|
||||
return pa
|
||||
|
||||
|
||||
|
||||
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
|
||||
max_len //= len(top_i)
|
||||
# add instruction to prompt
|
||||
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
|
||||
if len(instructions)>2:
|
||||
# Said that LLM is sensitive to the first and the last one, so
|
||||
# rearrange the order of references
|
||||
instructions.append(copy.deepcopy(instructions[1]))
|
||||
instructions.pop(1)
|
||||
|
||||
def token_num(txt):
|
||||
c = 0
|
||||
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
|
||||
if re.match(r"[a-zA-Z-]+$", tk):
|
||||
c += 1
|
||||
continue
|
||||
c += len(tk)
|
||||
return c
|
||||
|
||||
_inst = ""
|
||||
for ins in instructions:
|
||||
if token_num(_inst) > 4096:
|
||||
_inst += "\n知识库:" + instructions[-1][:max_len]
|
||||
break
|
||||
_inst += "\n知识库:" + ins[:max_len]
|
||||
return _inst
|
||||
|
||||
|
||||
def prompt_and_answer(history, inst):
|
||||
hist = get_QA_pairs(history)
|
||||
chks = []
|
||||
for s in re.split(r"[::;;。\n\r]+", inst):
|
||||
if s: chks.append(s)
|
||||
chks = len(set(chks))/(0.1+len(chks))
|
||||
print("Duplication portion:", chks)
|
||||
|
||||
system = """
|
||||
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
|
||||
以下是知识库:
|
||||
%s
|
||||
以上是知识库。
|
||||
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
|
||||
|
||||
print("【PROMPT】:", system)
|
||||
start = timer()
|
||||
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
|
||||
print("GENERATE: ", timer()-start)
|
||||
print("===>>", response)
|
||||
return response
|
||||
|
||||
|
||||
class Handler(RequestHandler):
|
||||
def post(self):
|
||||
global SE,MUST_TK_NUM
|
||||
param = json.loads(self.request.body.decode('utf-8'))
|
||||
try:
|
||||
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
|
||||
res = SE.search({
|
||||
"question": question,
|
||||
"kb_ids": param.get("kb_ids", []),
|
||||
"size": param.get("topn", 15)},
|
||||
search.index_name(param["uid"])
|
||||
)
|
||||
|
||||
sim = SE.rerank(res, question)
|
||||
rk_idx = np.argsort(sim*-1)
|
||||
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
|
||||
inst = get_instruction(res, topidx)
|
||||
|
||||
ans, topidx = prompt_and_answer(param["history"], inst)
|
||||
ans = SE.insert_citations(ans, topidx, res)
|
||||
|
||||
refer = OrderedDict()
|
||||
docnms = {}
|
||||
for i in rk_idx:
|
||||
did = res.field[res.ids[i]]["doc_id"]
|
||||
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
|
||||
if did not in refer: refer[did] = []
|
||||
refer[did].append({
|
||||
"chunk_id": res.ids[i],
|
||||
"content": res.field[res.ids[i]]["content_ltks"],
|
||||
"image": ""
|
||||
})
|
||||
|
||||
print("::::::::::::::", ans)
|
||||
self.write(json.dumps({
|
||||
"code":0,
|
||||
"msg":"success",
|
||||
"data":{
|
||||
"uid": param["uid"],
|
||||
"dialog_id": param["dialog_id"],
|
||||
"assistant": ans,
|
||||
"refer": [{
|
||||
"did": did,
|
||||
"doc_name": docnms[did],
|
||||
"chunks": chunks
|
||||
} for did, chunks in refer.items()]
|
||||
}
|
||||
}))
|
||||
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Request 500: "+str(e))
|
||||
self.write(json.dumps({
|
||||
"code":500,
|
||||
"msg":str(e),
|
||||
"data":{}
|
||||
}))
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
|
||||
ARGS = parser.parse_args()
|
||||
|
||||
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
|
||||
|
||||
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
|
||||
http_server = HTTPServer(app)
|
||||
http_server.bind(ARGS.port)
|
||||
http_server.start(3)
|
||||
|
||||
IOLoop.current().start()
|
||||
|
||||
@ -1,258 +0,0 @@
|
||||
import json, os, sys, hashlib, copy, time, random, re
|
||||
from os.path import dirname, realpath
|
||||
sys.path.append(dirname(realpath(__file__)) + "/../")
|
||||
from util.es_conn import HuEs
|
||||
from util.db_conn import Postgres
|
||||
from util.minio_conn import HuMinio
|
||||
from util import rmSpace, findMaxDt
|
||||
from FlagEmbedding import FlagModel
|
||||
from nlp import huchunk, huqie, search
|
||||
from io import BytesIO
|
||||
import pandas as pd
|
||||
from elasticsearch_dsl import Q
|
||||
from PIL import Image
|
||||
from parser import (
|
||||
PdfParser,
|
||||
DocxParser,
|
||||
ExcelParser
|
||||
)
|
||||
from nlp.huchunk import (
|
||||
PdfChunker,
|
||||
DocxChunker,
|
||||
ExcelChunker,
|
||||
PptChunker,
|
||||
TextChunker
|
||||
)
|
||||
|
||||
ES = HuEs("infiniflow")
|
||||
BATCH_SIZE = 64
|
||||
PG = Postgres("infiniflow", "docgpt")
|
||||
MINIO = HuMinio("infiniflow")
|
||||
|
||||
PDF = PdfChunker(PdfParser())
|
||||
DOC = DocxChunker(DocxParser())
|
||||
EXC = ExcelChunker(ExcelParser())
|
||||
PPT = PptChunker()
|
||||
|
||||
def chuck_doc(name, binary):
|
||||
suff = os.path.split(name)[-1].lower().split(".")[-1]
|
||||
if suff.find("pdf") >= 0: return PDF(binary)
|
||||
if suff.find("doc") >= 0: return DOC(binary)
|
||||
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
|
||||
if suff.find("ppt") >= 0: return PPT(binary)
|
||||
if os.envirement.get("PARSE_IMAGE") \
|
||||
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
|
||||
name.lower()):
|
||||
from llm import CvModel
|
||||
txt = CvModel.describe(binary)
|
||||
field = TextChunker.Fields()
|
||||
field.text_chunks = [(txt, binary)]
|
||||
field.table_chunks = []
|
||||
|
||||
|
||||
return TextChunker()(binary)
|
||||
|
||||
|
||||
def collect(comm, mod, tm):
|
||||
sql = f"""
|
||||
select
|
||||
id as kb2doc_id,
|
||||
kb_id,
|
||||
did,
|
||||
updated_at,
|
||||
is_deleted
|
||||
from kb2_doc
|
||||
where
|
||||
updated_at >= '{tm}'
|
||||
and kb_progress = 0
|
||||
and MOD(did, {comm}) = {mod}
|
||||
order by updated_at asc
|
||||
limit 1000
|
||||
"""
|
||||
kb2doc = PG.select(sql)
|
||||
if len(kb2doc) == 0:return pd.DataFrame()
|
||||
|
||||
sql = """
|
||||
select
|
||||
did,
|
||||
uid,
|
||||
doc_name,
|
||||
location,
|
||||
size
|
||||
from doc_info
|
||||
where
|
||||
did in (%s)
|
||||
"""%",".join([str(i) for i in kb2doc["did"].unique()])
|
||||
docs = PG.select(sql)
|
||||
docs = docs.fillna("")
|
||||
docs = docs.join(kb2doc.set_index("did"), on="did", how="left")
|
||||
|
||||
mtm = str(docs["updated_at"].max())[:19]
|
||||
print("TOTAL:", len(docs), "To: ", mtm)
|
||||
return docs
|
||||
|
||||
|
||||
def set_progress(kb2doc_id, prog, msg="Processing..."):
|
||||
sql = f"""
|
||||
update kb2_doc set kb_progress={prog}, kb_progress_msg='{msg}'
|
||||
where
|
||||
id={kb2doc_id}
|
||||
"""
|
||||
PG.update(sql)
|
||||
|
||||
|
||||
def build(row):
|
||||
if row["size"] > 256000000:
|
||||
set_progress(row["kb2doc_id"], -1, "File size exceeds( <= 256Mb )")
|
||||
return []
|
||||
res = ES.search(Q("term", doc_id=row["did"]))
|
||||
if ES.getTotal(res) > 0:
|
||||
ES.updateScriptByQuery(Q("term", doc_id=row["did"]),
|
||||
scripts="""
|
||||
if(!ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.add('%s');
|
||||
"""%(str(row["kb_id"]), str(row["kb_id"])),
|
||||
idxnm = search.index_name(row["uid"])
|
||||
)
|
||||
set_progress(row["kb2doc_id"], 1, "Done")
|
||||
return []
|
||||
|
||||
random.seed(time.time())
|
||||
set_progress(row["kb2doc_id"], random.randint(0, 20)/100., "Finished preparing! Start to slice file!")
|
||||
try:
|
||||
obj = chuck_doc(row["doc_name"], MINIO.get("%s-upload"%str(row["uid"]), row["location"]))
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
set_progress(row["kb2doc_id"], -1, "Can not find file <%s>"%row["doc_name"])
|
||||
else:
|
||||
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
|
||||
return []
|
||||
|
||||
if not obj.text_chunks and not obj.table_chunks:
|
||||
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
|
||||
return []
|
||||
|
||||
set_progress(row["kb2doc_id"], random.randint(20, 60)/100., "Finished slicing files. Start to embedding the content.")
|
||||
|
||||
doc = {
|
||||
"doc_id": row["did"],
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": os.path.split(row["location"])[-1],
|
||||
"title_tks": huqie.qie(os.path.split(row["location"])[-1]),
|
||||
"updated_at": str(row["updated_at"]).replace("T", " ")[:19]
|
||||
}
|
||||
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
||||
output_buffer = BytesIO()
|
||||
docs = []
|
||||
md5 = hashlib.md5()
|
||||
for txt, img in obj.text_chunks:
|
||||
d = copy.deepcopy(doc)
|
||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
d["content_ltks"] = huqie.qie(txt)
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
if not img:
|
||||
docs.append(d)
|
||||
continue
|
||||
|
||||
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
|
||||
else: output_buffer = BytesIO(img)
|
||||
|
||||
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
||||
output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
||||
docs.append(d)
|
||||
|
||||
for arr, img in obj.table_chunks:
|
||||
for i, txt in enumerate(arr):
|
||||
d = copy.deepcopy(doc)
|
||||
d["content_ltks"] = huqie.qie(txt)
|
||||
md5.update((txt + str(d["doc_id"])).encode("utf-8"))
|
||||
d["_id"] = md5.hexdigest()
|
||||
if not img:
|
||||
docs.append(d)
|
||||
continue
|
||||
img.save(output_buffer, format='JPEG')
|
||||
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
||||
output_buffer.getvalue())
|
||||
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
||||
docs.append(d)
|
||||
set_progress(row["kb2doc_id"], random.randint(60, 70)/100., "Continue embedding the content.")
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def init_kb(row):
|
||||
idxnm = search.index_name(row["uid"])
|
||||
if ES.indexExist(idxnm): return
|
||||
return ES.createIdx(idxnm, json.load(open("conf/mapping.json", "r")))
|
||||
|
||||
|
||||
model = None
|
||||
def embedding(docs):
|
||||
global model
|
||||
tts = model.encode([rmSpace(d["title_tks"]) for d in docs])
|
||||
cnts = model.encode([rmSpace(d["content_ltks"]) for d in docs])
|
||||
vects = 0.1 * tts + 0.9 * cnts
|
||||
assert len(vects) == len(docs)
|
||||
for i,d in enumerate(docs):d["q_vec"] = vects[i].tolist()
|
||||
|
||||
|
||||
def rm_doc_from_kb(df):
|
||||
if len(df) == 0:return
|
||||
for _,r in df.iterrows():
|
||||
ES.updateScriptByQuery(Q("term", doc_id=r["did"]),
|
||||
scripts="""
|
||||
if(ctx._source.kb_id.contains('%s'))
|
||||
ctx._source.kb_id.remove(
|
||||
ctx._source.kb_id.indexOf('%s')
|
||||
);
|
||||
"""%(str(r["kb_id"]),str(r["kb_id"])),
|
||||
idxnm = search.index_name(r["uid"])
|
||||
)
|
||||
if len(df) == 0:return
|
||||
sql = """
|
||||
delete from kb2_doc where id in (%s)
|
||||
"""%",".join([str(i) for i in df["kb2doc_id"]])
|
||||
PG.update(sql)
|
||||
|
||||
|
||||
def main(comm, mod):
|
||||
global model
|
||||
from llm import HuEmbedding
|
||||
model = HuEmbedding()
|
||||
tm_fnm = f"res/{comm}-{mod}.tm"
|
||||
tm = findMaxDt(tm_fnm)
|
||||
rows = collect(comm, mod, tm)
|
||||
if len(rows) == 0:return
|
||||
|
||||
rm_doc_from_kb(rows.loc[rows.is_deleted == True])
|
||||
rows = rows.loc[rows.is_deleted == False].reset_index(drop=True)
|
||||
if len(rows) == 0:return
|
||||
tmf = open(tm_fnm, "a+")
|
||||
for _, r in rows.iterrows():
|
||||
cks = build(r)
|
||||
if not cks:
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
continue
|
||||
## TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
embedding(cks)
|
||||
|
||||
set_progress(r["kb2doc_id"], random.randint(70, 95)/100.,
|
||||
"Finished embedding! Start to build index!")
|
||||
init_kb(r)
|
||||
es_r = ES.bulk(cks, search.index_name(r["uid"]))
|
||||
if es_r:
|
||||
set_progress(r["kb2doc_id"], -1, "Index failure!")
|
||||
print(es_r)
|
||||
else: set_progress(r["kb2doc_id"], 1., "Done!")
|
||||
tmf.write(str(r["updated_at"]) + "\n")
|
||||
tmf.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
main(comm.Get_size(), comm.Get_rank())
|
||||
|
||||
15
python/tmp.log
Normal file
15
python/tmp.log
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s]
|
||||
Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 106184.91it/s]
|
||||
----------- Model Configuration -----------
|
||||
Model Arch: GFL
|
||||
Transform Order:
|
||||
--transform op: Resize
|
||||
--transform op: NormalizeImage
|
||||
--transform op: Permute
|
||||
--transform op: PadStride
|
||||
--------------------------------------------
|
||||
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
|
||||
The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.
|
||||
Some weights of the model checkpoint at microsoft/table-transformer-structure-recognition were not used when initializing TableTransformerForObjectDetection: ['model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
|
||||
- This IS expected if you are initializing TableTransformerForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
|
||||
@ -1,24 +0,0 @@
|
||||
import re
|
||||
|
||||
|
||||
def rmSpace(txt):
|
||||
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
||||
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
||||
|
||||
|
||||
def findMaxDt(fnm):
|
||||
m = "1970-01-01 00:00:00"
|
||||
try:
|
||||
with open(fnm, "r") as f:
|
||||
while True:
|
||||
l = f.readline()
|
||||
if not l:
|
||||
break
|
||||
l = l.strip("\n")
|
||||
if l == 'nan':
|
||||
continue
|
||||
if l > m:
|
||||
m = l
|
||||
except Exception as e:
|
||||
print("WARNING: can't find " + fnm)
|
||||
return m
|
||||
@ -1,31 +0,0 @@
|
||||
from configparser import ConfigParser
|
||||
import os
|
||||
import inspect
|
||||
|
||||
CF = ConfigParser()
|
||||
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
|
||||
if not os.path.exists(__fnm):
|
||||
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
|
||||
assert os.path.exists(
|
||||
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
|
||||
if not os.path.exists(__fnm):
|
||||
__fnm = "./sys.cnf"
|
||||
|
||||
CF.read(__fnm)
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
if env == "spark":
|
||||
CF.read("./cv.cnf")
|
||||
|
||||
def get(self, key, default=None):
|
||||
global CF
|
||||
return os.environ.get(key.upper(),
|
||||
CF[self.env].get(key, default)
|
||||
)
|
||||
|
||||
|
||||
def init(env):
|
||||
return Config(env)
|
||||
@ -1,70 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from util import config
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class Postgres(object):
|
||||
def __init__(self, env, dbnm):
|
||||
self.config = config.init(env)
|
||||
self.conn = None
|
||||
self.dbnm = dbnm
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
import psycopg2
|
||||
try:
|
||||
if self.conn:
|
||||
self.__close__()
|
||||
del self.conn
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
|
||||
user={self.config.get('postgres_user')}
|
||||
password={self.config.get('postgres_password')}
|
||||
host={self.config.get('postgres_host')}
|
||||
port={self.config.get('postgres_port')}""")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Fail to connect %s " %
|
||||
self.config.get("pgdb_host") + str(e))
|
||||
|
||||
def __close__(self):
|
||||
try:
|
||||
self.conn.close()
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Fail to close %s " %
|
||||
self.config.get("pgdb_host") + str(e))
|
||||
|
||||
def select(self, sql):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return pd.read_sql(sql, self.conn)
|
||||
except Exception as e:
|
||||
logging.error(f"Fail to exec {sql} " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
def update(self, sql):
|
||||
for _ in range(10):
|
||||
try:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute(sql)
|
||||
updated_rows = cur.rowcount
|
||||
self.conn.commit()
|
||||
cur.close()
|
||||
return updated_rows
|
||||
except Exception as e:
|
||||
logging.error(f"Fail to exec {sql} " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Postgres("infiniflow", "docgpt")
|
||||
@ -1,428 +0,0 @@
|
||||
import re
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
import copy
|
||||
import elasticsearch
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch_dsl import UpdateByQuery, Search, Index, Q
|
||||
from util import config
|
||||
|
||||
logging.info("Elasticsearch version: ", elasticsearch.__version__)
|
||||
|
||||
|
||||
def instance(env):
|
||||
CF = config.init(env)
|
||||
ES_DRESS = CF.get("es").split(",")
|
||||
|
||||
ES = Elasticsearch(
|
||||
ES_DRESS,
|
||||
timeout=600
|
||||
)
|
||||
|
||||
logging.info("ES: ", ES_DRESS, ES.info())
|
||||
|
||||
return ES
|
||||
|
||||
|
||||
class HuEs:
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
self.info = {}
|
||||
self.config = config.init(env)
|
||||
self.conn()
|
||||
self.idxnm = self.config.get("idx_nm", "")
|
||||
if not self.es.ping():
|
||||
raise Exception("Can't connect to ES cluster")
|
||||
|
||||
def conn(self):
|
||||
for _ in range(10):
|
||||
try:
|
||||
c = instance(self.env)
|
||||
if c:
|
||||
self.es = c
|
||||
self.info = c.info()
|
||||
logging.info("Connect to es.")
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error("Fail to connect to es: " + str(e))
|
||||
time.sleep(1)
|
||||
|
||||
def version(self):
|
||||
v = self.info.get("version", {"number": "5.6"})
|
||||
v = v["number"].split(".")[0]
|
||||
return int(v) >= 7
|
||||
|
||||
def upsert(self, df, idxnm=""):
|
||||
res = []
|
||||
for d in df:
|
||||
id = d["id"]
|
||||
del d["id"]
|
||||
d = {"doc": d, "doc_as_upsert": "true"}
|
||||
T = False
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.update(
|
||||
index=(
|
||||
self.idxnm if not idxnm else idxnm),
|
||||
body=d,
|
||||
id=id,
|
||||
doc_type="doc",
|
||||
refresh=False,
|
||||
retry_on_conflict=100)
|
||||
else:
|
||||
r = self.es.update(
|
||||
index=(
|
||||
self.idxnm if not idxnm else idxnm),
|
||||
body=d,
|
||||
id=id,
|
||||
refresh=False,
|
||||
doc_type="_doc",
|
||||
retry_on_conflict=100)
|
||||
logging.info("Successfully upsert: %s" % id)
|
||||
T = True
|
||||
break
|
||||
except Exception as e:
|
||||
logging.warning("Fail to index: " +
|
||||
json.dumps(d, ensure_ascii=False) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
T = False
|
||||
|
||||
if not T:
|
||||
res.append(d)
|
||||
logging.error(
|
||||
"Fail to index: " +
|
||||
re.sub(
|
||||
"[\r\n]",
|
||||
"",
|
||||
json.dumps(
|
||||
d,
|
||||
ensure_ascii=False)))
|
||||
d["id"] = id
|
||||
d["_index"] = self.idxnm
|
||||
|
||||
if not res:
|
||||
return True
|
||||
return False
|
||||
|
||||
def bulk(self, df, idx_nm=None):
|
||||
ids, acts = {}, []
|
||||
for d in df:
|
||||
id = d["id"] if "id" in d else d["_id"]
|
||||
ids[id] = copy.deepcopy(d)
|
||||
ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
|
||||
if "id" in d:
|
||||
del d["id"]
|
||||
if "_id" in d:
|
||||
del d["_id"]
|
||||
acts.append(
|
||||
{"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
|
||||
acts.append({"doc": d, "doc_as_upsert": "true"})
|
||||
|
||||
res = []
|
||||
for _ in range(100):
|
||||
try:
|
||||
if elasticsearch.__version__[0] < 8:
|
||||
r = self.es.bulk(
|
||||
index=(
|
||||
self.idxnm if not idx_nm else idx_nm),
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s")
|
||||
else:
|
||||
r = self.es.bulk(index=(self.idxnm if not idx_nm else
|
||||
idx_nm), operations=acts,
|
||||
refresh=False, timeout="600s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for it in r["items"]:
|
||||
if "error" in it["update"]:
|
||||
res.append(str(it["update"]["_id"]) +
|
||||
":" + str(it["update"]["error"]))
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
logging.warn("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return res
|
||||
|
||||
def bulk4script(self, df):
|
||||
ids, acts = {}, []
|
||||
for d in df:
|
||||
id = d["id"]
|
||||
ids[id] = copy.deepcopy(d["raw"])
|
||||
acts.append({"update": {"_id": id, "_index": self.idxnm}})
|
||||
acts.append(d["script"])
|
||||
logging.info("bulk upsert: %s" % id)
|
||||
|
||||
res = []
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.bulk(
|
||||
index=self.idxnm,
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s",
|
||||
doc_type="doc")
|
||||
else:
|
||||
r = self.es.bulk(
|
||||
index=self.idxnm,
|
||||
body=acts,
|
||||
refresh=False,
|
||||
timeout="600s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for it in r["items"]:
|
||||
if "error" in it["update"]:
|
||||
res.append(str(it["update"]["_id"]))
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
logging.warning("Fail to bulk: " + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return res
|
||||
|
||||
def rm(self, d):
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.delete(
|
||||
index=self.idxnm,
|
||||
id=d["id"],
|
||||
doc_type="doc",
|
||||
refresh=True)
|
||||
else:
|
||||
r = self.es.delete(
|
||||
index=self.idxnm,
|
||||
id=d["id"],
|
||||
refresh=True,
|
||||
doc_type="_doc")
|
||||
logging.info("Remove %s" % d["id"])
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warn("Fail to delete: " + str(d) + str(e))
|
||||
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
|
||||
time.sleep(3)
|
||||
continue
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return True
|
||||
self.conn()
|
||||
|
||||
logging.error("Fail to delete: " + str(d))
|
||||
|
||||
return False
|
||||
|
||||
def search(self, q, idxnm=None, src=False, timeout="2s"):
|
||||
if not isinstance(q, dict):
|
||||
q = Search().query(q).to_dict()
|
||||
for i in range(3):
|
||||
try:
|
||||
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
|
||||
body=q,
|
||||
timeout=timeout,
|
||||
# search_type="dfs_query_then_fetch",
|
||||
track_total_hits=True,
|
||||
_source=src)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
return res
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"ES search exception: " +
|
||||
str(e) +
|
||||
"【Q】:" +
|
||||
str(q))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
raise e
|
||||
logging.error("ES search timeout for 3 times!")
|
||||
raise Exception("ES search timeout.")
|
||||
|
||||
def updateByQuery(self, q, d):
|
||||
ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
|
||||
scripts = ""
|
||||
for k, v in d.items():
|
||||
scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
|
||||
ubq = ubq.script(source=scripts, params=d)
|
||||
ubq = ubq.params(refresh=False)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for i in range(3):
|
||||
try:
|
||||
r = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return False
|
||||
|
||||
def updateScriptByQuery(self, q, scripts, idxnm=None):
|
||||
ubq = UpdateByQuery(
|
||||
index=self.idxnm if not idxnm else idxnm).using(
|
||||
self.es).query(q)
|
||||
ubq = ubq.script(source=scripts)
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for i in range(3):
|
||||
try:
|
||||
r = ubq.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("ES updateByQuery exception: " +
|
||||
str(e) + "【Q】:" + str(q.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
self.conn()
|
||||
|
||||
return False
|
||||
|
||||
def deleteByQuery(self, query, idxnm=""):
|
||||
for i in range(3):
|
||||
try:
|
||||
r = self.es.delete_by_query(
|
||||
index=idxnm if idxnm else self.idxnm,
|
||||
body=Search().query(query).to_dict())
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("ES updateByQuery deleteByQuery: " +
|
||||
str(e) + "【Q】:" + str(query.to_dict()))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def update(self, id, script, routing=None):
|
||||
for i in range(3):
|
||||
try:
|
||||
if not self.version():
|
||||
r = self.es.update(
|
||||
index=self.idxnm,
|
||||
id=id,
|
||||
body=json.dumps(
|
||||
script,
|
||||
ensure_ascii=False),
|
||||
doc_type="doc",
|
||||
routing=routing,
|
||||
refresh=False)
|
||||
else:
|
||||
r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
|
||||
routing=routing, refresh=False) # , doc_type="_doc")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
|
||||
json.dumps(script, ensure_ascii=False))
|
||||
if str(e).find("Timeout") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def indexExist(self, idxnm):
|
||||
s = Index(idxnm if idxnm else self.idxnm, self.es)
|
||||
for i in range(3):
|
||||
try:
|
||||
return s.exists()
|
||||
except Exception as e:
|
||||
logging.error("ES updateByQuery indexExist: " + str(e))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def docExist(self, docid, idxnm=None):
|
||||
for i in range(3):
|
||||
try:
|
||||
return self.es.exists(index=(idxnm if idxnm else self.idxnm),
|
||||
id=docid)
|
||||
except Exception as e:
|
||||
logging.error("ES Doc Exist: " + str(e))
|
||||
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
||||
continue
|
||||
return False
|
||||
|
||||
def createIdx(self, idxnm, mapping):
|
||||
try:
|
||||
if elasticsearch.__version__[0] < 8:
|
||||
return self.es.indices.create(idxnm, body=mapping)
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=idxnm,
|
||||
settings=mapping["settings"],
|
||||
mappings=mapping["mappings"])
|
||||
except Exception as e:
|
||||
logging.error("ES create index error %s ----%s" % (idxnm, str(e)))
|
||||
|
||||
def deleteIdx(self, idxnm):
|
||||
try:
|
||||
return self.es.indices.delete(idxnm, allow_no_indices=True)
|
||||
except Exception as e:
|
||||
logging.error("ES delete index error %s ----%s" % (idxnm, str(e)))
|
||||
|
||||
def getTotal(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getDocIds(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def getSource(self, res):
|
||||
rr = []
|
||||
for d in res["hits"]["hits"]:
|
||||
d["_source"]["id"] = d["_id"]
|
||||
d["_source"]["_score"] = d["_score"]
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
||||
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
||||
for _ in range(100):
|
||||
try:
|
||||
page = self.es.search(
|
||||
index=self.idxnm,
|
||||
scroll=scroll_time,
|
||||
size=pagesize,
|
||||
body=q,
|
||||
_source=None
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error("ES scrolling fail. " + str(e))
|
||||
time.sleep(3)
|
||||
|
||||
sid = page['_scroll_id']
|
||||
scroll_size = page['hits']['total']["value"]
|
||||
logging.info("[TOTAL]%d" % scroll_size)
|
||||
# Start scrolling
|
||||
while scroll_size > 0:
|
||||
yield page["hits"]["hits"]
|
||||
for _ in range(100):
|
||||
try:
|
||||
page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error("ES scrolling fail. " + str(e))
|
||||
time.sleep(3)
|
||||
|
||||
# Update the scroll ID
|
||||
sid = page['_scroll_id']
|
||||
# Get the number of results that we returned in the last scroll
|
||||
scroll_size = len(page['hits']['hits'])
|
||||
@ -1,84 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from util import config
|
||||
from minio import Minio
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class HuMinio(object):
|
||||
def __init__(self, env):
|
||||
self.config = config.init(env)
|
||||
self.conn = None
|
||||
self.__open__()
|
||||
|
||||
def __open__(self):
|
||||
try:
|
||||
if self.conn:
|
||||
self.__close__()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.conn = Minio(self.config.get("minio_host"),
|
||||
access_key=self.config.get("minio_user"),
|
||||
secret_key=self.config.get("minio_password"),
|
||||
secure=False
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Fail to connect %s " %
|
||||
self.config.get("minio_host") + str(e))
|
||||
|
||||
def __close__(self):
|
||||
del self.conn
|
||||
self.conn = None
|
||||
|
||||
def put(self, bucket, fnm, binary):
|
||||
for _ in range(10):
|
||||
try:
|
||||
if not self.conn.bucket_exists(bucket):
|
||||
self.conn.make_bucket(bucket)
|
||||
|
||||
r = self.conn.put_object(bucket, fnm,
|
||||
BytesIO(binary),
|
||||
len(binary)
|
||||
)
|
||||
return r
|
||||
except Exception as e:
|
||||
logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
|
||||
def get(self, bucket, fnm):
|
||||
for _ in range(10):
|
||||
try:
|
||||
r = self.conn.get_object(bucket, fnm)
|
||||
return r.read()
|
||||
except Exception as e:
|
||||
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
def get_presigned_url(self, bucket, fnm, expires):
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||
except Exception as e:
|
||||
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
conn = HuMinio("infiniflow")
|
||||
fnm = "/opt/home/kevinhu/docgpt/upload/13/11-408.jpg"
|
||||
from PIL import Image
|
||||
img = Image.open(fnm)
|
||||
buff = BytesIO()
|
||||
img.save(buff, format='JPEG')
|
||||
print(conn.put("test", "11-408.jpg", buff.getvalue()))
|
||||
bts = conn.get("test", "11-408.jpg")
|
||||
img = Image.open(BytesIO(bts))
|
||||
img.save("test.jpg")
|
||||
@ -1,36 +0,0 @@
|
||||
import json
|
||||
import logging.config
|
||||
import os
|
||||
|
||||
|
||||
def log_dir():
|
||||
fnm = os.path.join(os.path.dirname(__file__), '../log/')
|
||||
if not os.path.exists(fnm):
|
||||
fnm = os.path.join(os.path.dirname(__file__), '../../log/')
|
||||
assert os.path.exists(fnm), f"Can't locate log dir: {fnm}"
|
||||
return fnm
|
||||
|
||||
|
||||
def setup_logging(default_path="conf/logging.json",
|
||||
default_level=logging.INFO,
|
||||
env_key="LOG_CFG"):
|
||||
path = default_path
|
||||
value = os.getenv(env_key, None)
|
||||
if value:
|
||||
path = value
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
config = json.load(f)
|
||||
fnm = log_dir()
|
||||
|
||||
config["handlers"]["info_file_handler"]["filename"] = fnm + "info.log"
|
||||
config["handlers"]["error_file_handler"]["filename"] = fnm + "error.log"
|
||||
logging.config.dictConfig(config)
|
||||
else:
|
||||
logging.basicConfig(level=default_level)
|
||||
|
||||
|
||||
__fnm = os.path.join(os.path.dirname(__file__), 'conf/logging.json')
|
||||
if not os.path.exists(__fnm):
|
||||
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/logging.json')
|
||||
setup_logging(__fnm)
|
||||
Reference in New Issue
Block a user