remove unused codes, seperate layout detection out as a new api. Add new rag methed 'table' (#55)

This commit is contained in:
KevinHuSh
2024-02-05 18:08:17 +08:00
committed by GitHub
parent f305776217
commit 407b2523b6
33 changed files with 306 additions and 505 deletions

View File

@ -3,7 +3,7 @@ import random
import re
import numpy as np
from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge
hierarchical_merge, make_colon_as_title, naive_merge, random_choices
from rag.nlp import huqie
from rag.parser.docx_parser import HuDocxParser
from rag.parser.pdf_parser import HuParser
@ -51,7 +51,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
doc_parser = HuDocxParser()
# TODO: table of contents need to be removed
sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(random.choices([t for t,_ in sections], k=200)))
remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200)))
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
@ -67,20 +67,20 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
l = f.readline()
if not l:break
txt += l
sections = txt.split("\n")
sections = txt.split("\n")
sections = [(l,"") for l in sections if l]
remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200)))
remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200)))
callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
make_colon_as_title(sections)
bull = bullets_category([t for t in random.choices([t for t,_ in sections], k=100)])
bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)])
if bull >= 0: cks = hierarchical_merge(bull, sections, 3)
else: cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?"))
sections = [t for t, _ in sections]
# is it English
eng = is_english(random.choices(sections, k=218))
eng = is_english(random_choices(sections, k=218))
res = []
# add tables

View File

@ -86,7 +86,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
l = f.readline()
if not l:break
txt += l
sections = txt.split("\n")
sections = txt.split("\n")
sections = txt.split("\n")
sections = [l for l in sections if l]
callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")

View File

@ -52,7 +52,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
l = f.readline()
if not l:break
txt += l
sections = txt.split("\n")
sections = txt.split("\n")
sections = [(l,"") for l in sections if l]
callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")

View File

@ -1,6 +1,9 @@
import copy
import re
from collections import Counter
from api.db import ParserType
from rag.cv.ppdetection import PPDet
from rag.parser import tokenize
from rag.nlp import huqie
from rag.parser.pdf_parser import HuParser
@ -9,6 +12,10 @@ from rag.utils import num_tokens_from_string
class Pdf(HuParser):
def __init__(self):
self.model_speciess = ParserType.PAPER.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
self.__images__(
@ -63,6 +70,15 @@ class Pdf(HuParser):
"[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)",
txt.lower().strip())
if from_page > 0:
return {
"title":"",
"authors": "",
"abstract": "",
"lines": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
re.match(r"(text|title)", b.get("layoutno", "text"))],
"tables": tbls
}
# get title and authors
title = ""
authors = []
@ -115,18 +131,13 @@ class Pdf(HuParser):
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
pdf_parser = None
paper = {}
if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
paper = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
else: raise NotImplementedError("file type not supported yet(pdf supported)")
doc = {
"docnm_kwd": paper["title"] if paper["title"] else filename,
"authors_tks": paper["authors"]
}
doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
doc = {"docnm_kwd": filename, "authors_tks": paper["authors"],
"title_tks": huqie.qie(paper["title"] if paper["title"] else filename)}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
# is it English

View File

@ -3,7 +3,7 @@ import re
from io import BytesIO
from nltk import word_tokenize
from openpyxl import load_workbook
from rag.parser import is_english
from rag.parser import is_english, random_choices
from rag.nlp import huqie, stemmer
@ -33,9 +33,9 @@ class Excel(object):
if len(res) % 999 == 0:
callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
self.is_english = is_english([rmPrefix(q) for q, _ in random.choices(res, k=30) if len(q)>1])
self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1])
return res

170
rag/app/table.py Normal file
View File

@ -0,0 +1,170 @@
import copy
import random
import re
from io import BytesIO
from xpinyin import Pinyin
import numpy as np
import pandas as pd
from nltk import word_tokenize
from openpyxl import load_workbook
from dateutil.parser import parse as datetime_parse
from rag.parser import is_english, tokenize
from rag.nlp import huqie, stemmer
class Excel(object):
def __call__(self, fnm, binary=None, callback=None):
if not binary:
wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
total += len(list(wb[sheetname].rows))
res, fails, done = [], [], 0
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
headers = [cell.value for cell in rows[0]]
missed = set([i for i,h in enumerate(headers) if h is None])
headers = [cell.value for i,cell in enumerate(rows[0]) if i not in missed]
data = []
for i, r in enumerate(rows[1:]):
row = [cell.value for ii,cell in enumerate(r) if ii not in missed]
if len(row) != len(headers):
fails.append(str(i))
continue
data.append(row)
done += 1
if done % 999 == 0:
callback(done * 0.6/total, ("Extract records: {}".format(len(res)) + (f"{len(fails)} failure({sheetname}), line: %s..."%(",".join(fails[:3])) if fails else "")))
res.append(pd.DataFrame(np.array(data), columns=headers))
callback(0.6, ("Extract records: {}. ".format(done) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
def trans_datatime(s):
try:
return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S")
except Exception as e:
pass
def trans_bool(s):
if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", ""]
if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", ""]
def column_data_type(arr):
uni = len(set([a for a in arr if a is not None]))
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t:f for f,t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
for a in arr:
if a is None:continue
if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")):
counts["int"] += 1
elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")):
counts["float"] += 1
elif re.match(r"(true|false|yes|no|是|否)$", str(a), flags=re.IGNORECASE):
counts["bool"] += 1
elif trans_datatime(str(a)):
counts["datetime"] += 1
else: counts["text"] += 1
counts = sorted(counts.items(), key=lambda x: x[1]*-1)
ty = counts[0][0]
for i in range(len(arr)):
if arr[i] is None:continue
try:
arr[i] = trans[ty](str(arr[i]))
except Exception as e:
arr[i] = None
if ty == "text":
if len(arr) > 128 and uni/len(arr) < 0.1:
ty = "keyword"
return arr, ty
def chunk(filename, binary=None, callback=None, **kwargs):
dfs = []
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
dfs = excel_parser(filename, binary, callback)
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = ""
if binary:
txt = binary.decode("utf-8")
else:
with open(filename, "r") as f:
while True:
l = f.readline()
if not l: break
txt += l
lines = txt.split("\n")
fails = []
headers = lines[0].split(kwargs.get("delimiter", "\t"))
rows = []
for i, line in enumerate(lines[1:]):
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers):
fails.append(str(i))
continue
rows.append(row)
if len(rows) % 999 == 0:
callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract records: {}".format(len(rows)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
else: raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
res = []
PY = Pinyin()
fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
for df in dfs:
for n in ["id", "_id", "index", "idx"]:
if n in df.columns:del df[n]
clmns = df.columns.values
txts = list(copy.deepcopy(clmns))
py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
clmn_tys = []
for j in range(len(clmns)):
cln,ty = column_data_type(df[clmns[j]])
clmn_tys.append(ty)
df[clmns[j]] = cln
if ty == "text": txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))]
# TODO: set this column map to KB parser configuration
eng = is_english(txts)
for ii,row in df.iterrows():
d = {}
row_txt = []
for j in range(len(clmns)):
if row[clmns[j]] is None:continue
fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]])
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
if not row_txt:continue
tokenize(d, "; ".join(row_txt), eng)
print(d)
res.append(d)
callback(0.6, "")
return res
if __name__== "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@ -67,7 +67,7 @@ class Dealer:
ps = int(req.get("size", 1000))
src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id",
"image_id", "doc_id", "q_512_vec", "q_768_vec",
"q_1024_vec", "q_1536_vec", "available_int"])
"q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks")
@ -234,7 +234,7 @@ class Dealer:
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
if not ins_embd:
return [], [], []
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ")
ins_tw = [sres.field[i][cfield].split(" ")
for i in sres.ids]
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
@ -281,6 +281,7 @@ class Dealer:
d = {
"chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"],
"content_with_weight": sres.field[id]["content_with_weight"],
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": dnm,
"kb_id": sres.field[id]["kb_id"],

View File

@ -1,4 +1,5 @@
import copy
import random
from .pdf_parser import HuParser as PdfParser
from .docx_parser import HuDocxParser as DocxParser
@ -38,6 +39,9 @@ BULLET_PATTERN = [[
]
]
def random_choices(arr, k):
k = min(len(arr), k)
return random.choices(arr, k=k)
def bullets_category(sections):
global BULLET_PATTERN

View File

@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-
import os
import random
from functools import partial
import fitz
import requests
import xgboost as xgb
from io import BytesIO
import torch
@ -10,13 +13,14 @@ import pdfplumber
import logging
from PIL import Image
import numpy as np
from api.db import ParserType
from rag.nlp import huqie
from collections import Counter
from copy import deepcopy
from rag.cv.table_recognize import TableTransformer
from rag.cv.ppdetection import PPDet
from huggingface_hub import hf_hub_download
logging.getLogger("pdfminer").setLevel(logging.WARNING)
@ -25,8 +29,10 @@ class HuParser:
from paddleocr import PaddleOCR
logging.getLogger("ppocr").setLevel(logging.ERROR)
self.ocr = PaddleOCR(use_angle_cls=False, lang="ch")
self.layouter = PPDet("/data/newpeak/medical-gpt/res/ppdet")
self.tbl_det = PPDet("/data/newpeak/medical-gpt/res/ppdet.tbl")
if not hasattr(self, "model_speciess"):
self.model_speciess = ParserType.GENERAL.value
self.layouter = partial(self.__remote_call, self.model_speciess)
self.tbl_det = partial(self.__remote_call, "table_component")
self.updown_cnt_mdl = xgb.Booster()
if torch.cuda.is_available():
@ -45,6 +51,38 @@ class HuParser:
"""
def __remote_call(self, species, images, thr=0.7):
url = os.environ.get("INFINIFLOW_SERVER")
if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
token = os.environ.get("INFINIFLOW_TOKEN")
if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
def convert_image_to_bytes(PILimage):
image = BytesIO()
PILimage.save(image, format='png')
image.seek(0)
return image.getvalue()
images = [convert_image_to_bytes(img) for img in images]
def remote_call():
nonlocal images, thr
res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr},
headers={"Authorization": token}, timeout=len(images) * 10)
res = res.json()
if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
return res["data"]
for _ in range(3):
try:
return remote_call()
except RuntimeError as e:
raise e
except Exception as e:
logging.error("layout_predict:"+str(e))
return remote_call()
def __char_width(self, c):
return (c["x1"] - c["x0"]) // len(c["text"])
@ -344,7 +382,7 @@ class HuParser:
return layouts
def __table_paddle(self, images):
tbls = self.tbl_det([np.array(img) for img in images], thr=0.5)
tbls = self.tbl_det(images, thr=0.5)
res = []
# align left&right for rows, align top&bottom for columns
for tbl in tbls:
@ -522,7 +560,7 @@ class HuParser:
assert len(self.page_images) == len(self.boxes)
# Tag layout type
boxes = []
layouts = self.layouter([np.array(img) for img in self.page_images])
layouts = self.layouter(self.page_images)
assert len(self.page_images) == len(layouts)
for pn, lts in enumerate(layouts):
bxs = self.boxes[pn]
@ -1705,7 +1743,8 @@ class HuParser:
self.__ocr_paddle(i + 1, img, chars, zoomin)
if not self.is_english and not any([c for c in self.page_chars]) and self.boxes:
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices([b for bxs in self.boxes for b in bxs], k=30)]))
bxes = [b for bxs in self.boxes for b in bxs]
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
logging.info("Is it English:", self.is_english)

View File

@ -134,5 +134,5 @@ if __name__ == "__main__":
while True:
dispatch()
time.sleep(3)
time.sleep(1)
update_progress()

View File

@ -36,7 +36,7 @@ from rag.nlp import search
from io import BytesIO
import pandas as pd
from rag.app import laws, paper, presentation, manual, qa
from rag.app import laws, paper, presentation, manual, qa, table,book
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
@ -49,10 +49,12 @@ BATCH_SIZE = 64
FACTORY = {
ParserType.GENERAL.value: laws,
ParserType.PAPER.value: paper,
ParserType.BOOK.value: book,
ParserType.PRESENTATION.value: presentation,
ParserType.MANUAL.value: manual,
ParserType.LAWS.value: laws,
ParserType.QA.value: qa,
ParserType.TABLE.value: table,
}
@ -66,7 +68,7 @@ def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."):
d = {"progress_msg": msg}
if prog is not None: d["progress"] = prog
try:
TaskService.update_by_id(task_id, d)
TaskService.update_progress(task_id, d)
except Exception as e:
cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
@ -113,7 +115,7 @@ def build(row, cvmdl):
return []
callback = partial(set_progress, row["id"], row["from_page"], row["to_page"])
chunker = FACTORY[row["parser_id"]]
chunker = FACTORY[row["parser_id"].lower()]
try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
@ -154,6 +156,7 @@ def build(row, cvmdl):
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
del d["image"]
docs.append(d)
return docs
@ -168,7 +171,7 @@ def init_kb(row):
def embedding(docs, mdl):
tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs]
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
tk_count = 0
if len(tts) == len(cnts):
tts, c = mdl.encode(tts)
@ -207,6 +210,7 @@ def main(comm, mod):
cks = build(r, cv_mdl)
if not cks:
tmf.write(str(r["update_time"]) + "\n")
callback(1., "No chunk! Done!")
continue
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
@ -215,7 +219,6 @@ def main(comm, mod):
except Exception as e:
callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))
continue
callback(msg="Finished embedding! Start to build index!")
init_kb(r)
@ -227,6 +230,7 @@ def main(comm, mod):
else:
if TaskService.do_cancel(r["id"]):
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
continue
callback(1., "Done!")
DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))