add ocr and recognizer demo, update README (#74)

This commit is contained in:
KevinHuSh
2024-02-26 19:51:35 +08:00
committed by GitHub
parent d1417102b6
commit d1c600d5d3
9 changed files with 525 additions and 73 deletions

View File

@ -1,4 +1,49 @@
from .ocr import OCR
from .recognizer import Recognizer
from .layout_recognizer import LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer
def init_in_out(args):
from PIL import Image
import fitz
import os
import traceback
from api.utils.file_utils import traversal_files
images = []
outputs = []
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
def pdf_pages(fnm, zoomin=3):
nonlocal outputs, images
pdf = fitz.open(fnm)
mat = fitz.Matrix(zoomin, zoomin)
for i, page in enumerate(pdf):
pix = page.get_pixmap(matrix=mat)
img = Image.frombytes("RGB", [pix.width, pix.height],
pix.samples)
images.append(img)
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
def images_and_outputs(fnm):
nonlocal outputs, images
if fnm.split(".")[-1].lower() == "pdf":
pdf_pages(fnm)
return
try:
images.append(Image.open(fnm))
outputs.append(os.path.split(fnm)[-1])
except Exception as e:
traceback.print_exc()
if os.path.isdir(args.inputs):
for fnm in traversal_files(args.inputs):
images_and_outputs(fnm)
else:
images_and_outputs(args.inputs)
for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i])
return images, outputs

View File

@ -1,17 +1,26 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
from collections import Counter
from copy import deepcopy
import numpy as np
from api.utils.file_utils import get_project_base_directory
from .recognizer import Recognizer
from deepdoc.vision import Recognizer
class LayoutRecognizer(Recognizer):
def __init__(self, domain):
self.layout_labels = [
labels = [
"_background_",
"Text",
"Title",
@ -24,7 +33,8 @@ class LayoutRecognizer(Recognizer):
"Reference",
"Equation",
]
super().__init__(self.layout_labels, domain,
def __init__(self, domain):
super().__init__(self.labels, domain,
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
@ -37,7 +47,7 @@ class LayoutRecognizer(Recognizer):
return any([re.search(p, b["text"]) for p in patt])
layouts = super().__call__(image_list, thr, batch_size)
# save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
# save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
assert len(image_list) == len(ocr_res)
# Tag layout type
boxes = []
@ -117,3 +127,5 @@ class LayoutRecognizer(Recognizer):
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
return ocr_res, page_layout

View File

@ -17,7 +17,6 @@ from copy import deepcopy
import onnxruntime as ort
from huggingface_hub import snapshot_download
from . import seeit
from .operators import *
from rag.settings import cron_logger
@ -36,7 +35,7 @@ class Recognizer(object):
"""
if not model_dir:
model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path):
@ -46,6 +45,9 @@ class Recognizer(object):
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
else:
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
self.input_names = [node.name for node in self.ort_sess.get_inputs()]
self.output_names = [node.name for node in self.ort_sess.get_outputs()]
self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
self.label_list = label_list
@staticmethod
@ -275,23 +277,131 @@ class Recognizer(object):
return max_overlaped_i
def preprocess(self, image_list):
preprocess_ops = []
for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
inputs = []
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
if "scale_factor" in self.input_names:
preprocess_ops = []
for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
inputs.append({"image": np.array((im,)).astype('float32'),
"scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
else:
hh, ww = self.input_shape
for img in image_list:
h, w = img.shape[:2]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
# Scale input pixel values to 0 to 1
img /= 255.0
img = img.transpose(2, 0, 1)
img = img[np.newaxis, :, :, :].astype(np.float32)
inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
return inputs
def postprocess(self, boxes, inputs, thr):
if "scale_factor" in self.input_names:
bb = []
for b in boxes:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
cron_logger.warning(f"bad category id")
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
return bb
def xywh2xyxy(x):
# [x, y, w, h] to [x1, y1, x2, y2]
y = np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2
y[:, 1] = x[:, 1] - x[:, 3] / 2
y[:, 2] = x[:, 0] + x[:, 2] / 2
y[:, 3] = x[:, 1] + x[:, 3] / 2
return y
def compute_iou(box, boxes):
# Compute xmin, ymin, xmax, ymax for both boxes
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])
# Compute intersection area
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
# Compute union area
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area
# Compute IoU
iou = intersection_area / union_area
return iou
def iou_filter(boxes, scores, iou_threshold):
sorted_indices = np.argsort(scores)[::-1]
keep_boxes = []
while sorted_indices.size > 0:
# Pick the last box
box_id = sorted_indices[0]
keep_boxes.append(box_id)
# Compute IoU of the picked box with the rest
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
# Remove boxes with IoU over the threshold
keep_indices = np.where(ious < iou_threshold)[0]
# print(keep_indices.shape, sorted_indices.shape)
sorted_indices = sorted_indices[keep_indices + 1]
return keep_boxes
boxes = np.squeeze(boxes).T
# Filter out object confidence scores below threshold
scores = np.max(boxes[:, 4:], axis=1)
boxes = boxes[scores > thr, :]
scores = scores[scores > thr]
if len(boxes) == 0: return []
# Get the class with the highest confidence
class_ids = np.argmax(boxes[:, 4:], axis=1)
boxes = boxes[:, :4]
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
boxes = xywh2xyxy(boxes)
unique_class_ids = np.unique(class_ids)
indices = []
for class_id in unique_class_ids:
class_indices = np.where(class_ids == class_id)[0]
class_boxes = boxes[class_indices, :]
class_scores = scores[class_indices]
class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
indices.extend(class_indices[class_keep_boxes])
return [{
"type": self.label_list[class_ids[i]].lower(),
"bbox": [float(t) for t in boxes[i].tolist()],
"score": float(scores[i])
} for i in indices]
def __call__(self, image_list, thr=0.7, batch_size=16):
res = []
imgs = []
@ -306,22 +416,14 @@ class Recognizer(object):
end_index = min((i + 1) * batch_size, len(imgs))
batch_image_list = imgs[start_index:end_index]
inputs = self.preprocess(batch_image_list)
print("preprocess")
for ins in inputs:
bb = []
for b in self.ort_sess.run(None, ins)[0]:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
cron_logger.warning(f"bad category id")
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names})[0], ins, thr)
res.append(bb)
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
return res

47
deepdoc/vision/t_ocr.py Normal file
View File

@ -0,0 +1,47 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')))
import numpy as np
import argparse
from deepdoc.vision import OCR, init_in_out
from deepdoc.vision.seeit import draw_box
def main(args):
ocr = OCR()
images, outputs = init_in_out(args)
for i, img in enumerate(images):
bxs = ocr(np.array(img))
bxs = [(line[0], line[1][0]) for line in bxs]
bxs = [{
"text": t,
"bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
"type": "ocr",
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
img = draw_box(images[i], bxs, ["ocr"], 1.)
img.save(outputs[i], quality=95)
with open(outputs[i] + ".txt", "w+") as f: f.write("\n".join([o["text"] for o in bxs]))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inputs',
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
default="./ocr_outputs")
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,173 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os, sys
import re
import numpy as np
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')))
import argparse
from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
from deepdoc.vision.seeit import draw_box
def main(args):
images, outputs = init_in_out(args)
if args.mode.lower() == "layout":
labels = LayoutRecognizer.labels
detr = Recognizer(labels, "layout.paper", os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
if args.mode.lower() == "tsr":
labels = TableStructureRecognizer.labels
detr = TableStructureRecognizer()
ocr = OCR()
layouts = detr(images, float(args.threshold))
for i, lyt in enumerate(layouts):
if args.mode.lower() == "tsr":
#lyt = [t for t in lyt if t["type"] == "table column"]
html = get_table_html(images[i], lyt, ocr)
with open(outputs[i]+".html", "w+") as f: f.write(html)
lyt = [{
"type": t["label"],
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
"score": t["score"]
} for t in lyt]
img = draw_box(images[i], lyt, labels, float(args.threshold))
img.save(outputs[i], quality=95)
print("save result to: " + outputs[i])
def get_table_html(img, tb_cpns, ocr):
boxes = ocr(np.array(img))
boxes = Recognizer.sort_Y_firstly(
[{"x0": b[0][0], "x1": b[1][0],
"top": b[0][1], "text": t[0],
"bottom": b[-1][1],
"layout_type": "table",
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3
)
def gather(kwd, fzy=10, ption=0.6):
nonlocal boxes
eles = Recognizer.sort_Y_firstly(
[r for r in tb_cpns if re.match(kwd, r["label"])], fzy)
eles = Recognizer.layouts_cleanup(boxes, eles, 5, ption)
return Recognizer.sort_Y_firstly(eles, 0)
headers = gather(r".*header$")
rows = gather(r".* (row|header)")
spans = gather(r".*spanning")
clmns = sorted([r for r in tb_cpns if re.match(
r"table column$", r["label"])], key=lambda x: x["x0"])
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
for b in boxes:
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
if ii is not None:
b["R"] = ii
b["R_top"] = rows[ii]["top"]
b["R_bott"] = rows[ii]["bottom"]
ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3)
if ii is not None:
b["H_top"] = headers[ii]["top"]
b["H_bott"] = headers[ii]["bottom"]
b["H_left"] = headers[ii]["x0"]
b["H_right"] = headers[ii]["x1"]
b["H"] = ii
ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
if ii is not None:
b["C"] = ii
b["C_left"] = clmns[ii]["x0"]
b["C_right"] = clmns[ii]["x1"]
ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3)
if ii is not None:
b["H_top"] = spans[ii]["top"]
b["H_bott"] = spans[ii]["bottom"]
b["H_left"] = spans[ii]["x0"]
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii
html = """
<html>
<head>
<style>
._table_1nkzy_11 {
margin: auto;
width: 70%%;
padding: 10px;
}
._table_1nkzy_11 p {
margin-bottom: 50px;
border: 1px solid #e1e1e1;
}
caption {
color: #6ac1ca;
font-size: 20px;
height: 50px;
line-height: 50px;
font-weight: 600;
margin-bottom: 10px;
}
._table_1nkzy_11 table {
width: 100%%;
border-collapse: collapse;
}
th {
color: #fff;
background-color: #6ac1ca;
}
td:hover {
background: #c1e8e8;
}
tr:nth-child(even) {
background-color: #f2f2f2;
}
._table_1nkzy_11 th,
._table_1nkzy_11 td {
text-align: center;
border: 1px solid #ddd;
padding: 8px;
}
</style>
</head>
<body>
%s
</body>
</html>
"""% TableStructureRecognizer.construct_table(boxes, html=True)
return html
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inputs',
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
default="./layouts_outputs")
parser.add_argument('--threshold', help="A threshold to filter out detections. Default: 0.5", default=0.5)
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
default="layout")
args = parser.parse_args()
main(args)

View File

@ -1,3 +1,15 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import re
@ -12,15 +24,16 @@ from .recognizer import Recognizer
class TableStructureRecognizer(Recognizer):
labels = [
"table",
"table column",
"table row",
"table column header",
"table projected row header",
"table spanning cell",
]
def __init__(self):
self.labels = [
"table",
"table column",
"table row",
"table column header",
"table projected row header",
"table spanning cell",
]
super().__init__(self.labels, "tsr",
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
@ -79,7 +92,8 @@ class TableStructureRecognizer(Recognizer):
return True
return False
def __blockType(self, b):
@staticmethod
def blockType(b):
patt = [
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^(20|19)[0-9]{2}年$", "Dt"),
@ -109,11 +123,12 @@ class TableStructureRecognizer(Recognizer):
return "Ot"
def construct_table(self, boxes, is_english=False, html=False):
@staticmethod
def construct_table(boxes, is_english=False, html=False):
cap = ""
i = 0
while i < len(boxes):
if self.is_caption(boxes[i]):
if TableStructureRecognizer.is_caption(boxes[i]):
cap += boxes[i]["text"]
boxes.pop(i)
i -= 1
@ -122,14 +137,15 @@ class TableStructureRecognizer(Recognizer):
if not boxes:
return []
for b in boxes:
b["btype"] = self.__blockType(b)
b["btype"] = TableStructureRecognizer.blockType(b)
max_type = Counter([b["btype"] for b in boxes]).items()
max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
logging.debug("MAXTYPE: " + max_type)
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
rowh = np.min(rowh) if rowh else 0
boxes = self.sort_R_firstly(boxes, rowh / 2)
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
#for b in boxes:print(b)
boxes[0]["rn"] = 0
rows = [[boxes[0]]]
btm = boxes[0]["bottom"]
@ -150,9 +166,9 @@ class TableStructureRecognizer(Recognizer):
colwm = np.min(colwm) if colwm else 0
crosspage = len(set([b["page_number"] for b in boxes])) > 1
if crosspage:
boxes = self.sort_X_firstly(boxes, colwm / 2, False)
boxes = Recognizer.sort_X_firstly(boxes, colwm / 2, False)
else:
boxes = self.sort_C_firstly(boxes, colwm / 2)
boxes = Recognizer.sort_C_firstly(boxes, colwm / 2)
boxes[0]["cn"] = 0
cols = [[boxes[0]]]
right = boxes[0]["x1"]
@ -313,16 +329,18 @@ class TableStructureRecognizer(Recognizer):
hdset.add(i)
if html:
return [self.__html_table(cap, hdset,
self.__cal_spans(boxes, rows,
cols, tbl, True)
)]
return TableStructureRecognizer.__html_table(cap, hdset,
TableStructureRecognizer.__cal_spans(boxes, rows,
cols, tbl, True)
)
return self.__desc_table(cap, hdset,
self.__cal_spans(boxes, rows, cols, tbl, False),
is_english)
return TableStructureRecognizer.__desc_table(cap, hdset,
TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl,
False),
is_english)
def __html_table(self, cap, hdset, tbl):
@staticmethod
def __html_table(cap, hdset, tbl):
# constrcut HTML
html = "<table>"
if cap:
@ -339,8 +357,8 @@ class TableStructureRecognizer(Recognizer):
txt = ""
if arr:
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
txt = "".join([c["text"]
for c in self.sort_Y_firstly(arr, h)])
txt = " ".join([c["text"]
for c in Recognizer.sort_Y_firstly(arr, h)])
txts.append(txt)
sp = ""
if arr[0].get("colspan"):
@ -366,7 +384,8 @@ class TableStructureRecognizer(Recognizer):
html += "\n</table>"
return html
def __desc_table(self, cap, hdr_rowno, tbl, is_english):
@staticmethod
def __desc_table(cap, hdr_rowno, tbl, is_english):
# get text of every colomn in header row to become header text
clmno = len(tbl[0])
rowno = len(tbl)
@ -469,7 +488,8 @@ class TableStructureRecognizer(Recognizer):
row_txt = [t + f"\t——{from_}{cap}" for t in row_txt]
return row_txt
def __cal_spans(self, boxes, rows, cols, tbl, html=True):
@staticmethod
def __cal_spans(boxes, rows, cols, tbl, html=True):
# caculate span
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
for cln in cols]
@ -553,4 +573,3 @@ class TableStructureRecognizer(Recognizer):
tbl[rowspan[0]][colspan[0]] = arr
return tbl