mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-25 04:26:39 +08:00
### What problem does this PR solve? This PR addresses critical memory and CPU resource management issues in high-concurrency environments (multi-worker setups): GPU Memory Exhaustion (OOM): Currently, onnxruntime-gpu uses an aggressive memory arena that does not effectively release VRAM back to the system after a task completes. In multi-process worker setups ($WS > 4), this leads to BFCArena allocation failures and OOM errors as workers "hoard" VRAM even when idle. This PR introduces an optional GPU Memory Arena Shrinkage toggle to mitigate this issue. CPU Oversubscription: ONNX intra_op and inter_op thread counts are currently hardcoded to 2. When running many workers, this causes significant CPU context-switching overhead and degrades performance. This PR makes these values configurable to match the host's actual CPU core density. Multi-GPU Support: The memory management logic has been improved to dynamically target the correct device_id, ensuring stability on systems with multiple GPUs. Transparency: Added detailed initialization logs to help administrators verify and troubleshoot their ONNX session configurations. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: shakeel <shakeel@lollylaw.com>
758 lines
28 KiB
Python
758 lines
28 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import gc
|
|
import logging
|
|
import copy
|
|
import time
|
|
import os
|
|
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from common.file_utils import get_project_base_directory
|
|
from common.misc_utils import pip_install_torch
|
|
from common import settings
|
|
from .operators import * # noqa: F403
|
|
from . import operators
|
|
import math
|
|
import numpy as np
|
|
import cv2
|
|
import onnxruntime as ort
|
|
|
|
from .postprocess import build_post_process
|
|
|
|
loaded_models = {}
|
|
|
|
def transform(data, ops=None):
|
|
""" transform """
|
|
if ops is None:
|
|
ops = []
|
|
for op in ops:
|
|
data = op(data)
|
|
if data is None:
|
|
return None
|
|
return data
|
|
|
|
|
|
def create_operators(op_param_list, global_config=None):
|
|
"""
|
|
create operators based on the config
|
|
|
|
Args:
|
|
params(list): a dict list, used to create some operators
|
|
"""
|
|
assert isinstance(
|
|
op_param_list, list), ('operator config should be a list')
|
|
ops = []
|
|
for operator in op_param_list:
|
|
assert isinstance(operator,
|
|
dict) and len(operator) == 1, "yaml format error"
|
|
op_name = list(operator)[0]
|
|
param = {} if operator[op_name] is None else operator[op_name]
|
|
if global_config is not None:
|
|
param.update(global_config)
|
|
op = getattr(operators, op_name)(**param)
|
|
ops.append(op)
|
|
return ops
|
|
|
|
|
|
def load_model(model_dir, nm, device_id: int | None = None):
|
|
model_file_path = os.path.join(model_dir, nm + ".onnx")
|
|
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
|
|
|
|
global loaded_models
|
|
loaded_model = loaded_models.get(model_cached_tag)
|
|
if loaded_model:
|
|
logging.info(f"load_model {model_file_path} reuses cached model")
|
|
return loaded_model
|
|
|
|
if not os.path.exists(model_file_path):
|
|
raise ValueError("not find model file path {}".format(
|
|
model_file_path))
|
|
|
|
def cuda_is_available():
|
|
try:
|
|
pip_install_torch()
|
|
import torch
|
|
target_id = 0 if device_id is None else device_id
|
|
if torch.cuda.is_available() and torch.cuda.device_count() > target_id:
|
|
return True
|
|
except Exception:
|
|
return False
|
|
return False
|
|
|
|
options = ort.SessionOptions()
|
|
options.enable_cpu_mem_arena = False
|
|
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
# Prevent CPU oversubscription by allowing explicit thread control in multi-worker environments
|
|
options.intra_op_num_threads = int(os.environ.get("OCR_INTRA_OP_NUM_THREADS", "2"))
|
|
options.inter_op_num_threads = int(os.environ.get("OCR_INTER_OP_NUM_THREADS", "2"))
|
|
|
|
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
|
|
# Shrink GPU memory after execution
|
|
run_options = ort.RunOptions()
|
|
if cuda_is_available():
|
|
gpu_mem_limit_mb = int(os.environ.get("OCR_GPU_MEM_LIMIT_MB", "2048"))
|
|
arena_strategy = os.environ.get("OCR_ARENA_EXTEND_STRATEGY", "kNextPowerOfTwo")
|
|
provider_device_id = 0 if device_id is None else device_id
|
|
cuda_provider_options = {
|
|
"device_id": provider_device_id, # Use specific GPU
|
|
"gpu_mem_limit": max(gpu_mem_limit_mb, 0) * 1024 * 1024,
|
|
"arena_extend_strategy": arena_strategy, # gpu memory allocation strategy
|
|
}
|
|
sess = ort.InferenceSession(
|
|
model_file_path,
|
|
options=options,
|
|
providers=['CUDAExecutionProvider'],
|
|
provider_options=[cuda_provider_options]
|
|
)
|
|
# Explicit arena shrinkage for GPU to release VRAM back to the system after each run
|
|
if os.environ.get("OCR_GPUMEM_ARENA_SHRINKAGE") == "1":
|
|
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", f"gpu:{provider_device_id}")
|
|
logging.info(
|
|
f"load_model {model_file_path} enabled GPU memory arena shrinkage on device {provider_device_id}")
|
|
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
|
|
else:
|
|
sess = ort.InferenceSession(
|
|
model_file_path,
|
|
options=options,
|
|
providers=['CPUExecutionProvider'])
|
|
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
|
|
logging.info(f"load_model {model_file_path} uses CPU")
|
|
loaded_model = (sess, run_options)
|
|
loaded_models[model_cached_tag] = loaded_model
|
|
return loaded_model
|
|
|
|
|
|
class TextRecognizer:
|
|
def __init__(self, model_dir, device_id: int | None = None):
|
|
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
|
|
self.rec_batch_num = 16
|
|
postprocess_params = {
|
|
'name': 'CTCLabelDecode',
|
|
"character_dict_path": os.path.join(model_dir, "ocr.res"),
|
|
"use_space_char": True
|
|
}
|
|
self.postprocess_op = build_post_process(postprocess_params)
|
|
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
|
|
self.input_tensor = self.predictor.get_inputs()[0]
|
|
|
|
def resize_norm_img(self, img, max_wh_ratio):
|
|
imgC, imgH, imgW = self.rec_image_shape
|
|
|
|
assert imgC == img.shape[2]
|
|
imgW = int((imgH * max_wh_ratio))
|
|
w = self.input_tensor.shape[3:][0]
|
|
if isinstance(w, str):
|
|
pass
|
|
elif w is not None and w > 0:
|
|
imgW = w
|
|
h, w = img.shape[:2]
|
|
ratio = w / float(h)
|
|
if math.ceil(imgH * ratio) > imgW:
|
|
resized_w = imgW
|
|
else:
|
|
resized_w = int(math.ceil(imgH * ratio))
|
|
|
|
resized_image = cv2.resize(img, (resized_w, imgH))
|
|
resized_image = resized_image.astype('float32')
|
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
resized_image -= 0.5
|
|
resized_image /= 0.5
|
|
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
|
padding_im[:, :, 0:resized_w] = resized_image
|
|
return padding_im
|
|
|
|
def resize_norm_img_vl(self, img, image_shape):
|
|
|
|
imgC, imgH, imgW = image_shape
|
|
img = img[:, :, ::-1] # bgr2rgb
|
|
resized_image = cv2.resize(
|
|
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
|
resized_image = resized_image.astype('float32')
|
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
return resized_image
|
|
|
|
def resize_norm_img_srn(self, img, image_shape):
|
|
imgC, imgH, imgW = image_shape
|
|
|
|
img_black = np.zeros((imgH, imgW))
|
|
im_hei = img.shape[0]
|
|
im_wid = img.shape[1]
|
|
|
|
if im_wid <= im_hei * 1:
|
|
img_new = cv2.resize(img, (imgH * 1, imgH))
|
|
elif im_wid <= im_hei * 2:
|
|
img_new = cv2.resize(img, (imgH * 2, imgH))
|
|
elif im_wid <= im_hei * 3:
|
|
img_new = cv2.resize(img, (imgH * 3, imgH))
|
|
else:
|
|
img_new = cv2.resize(img, (imgW, imgH))
|
|
|
|
img_np = np.asarray(img_new)
|
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
|
img_black[:, 0:img_np.shape[1]] = img_np
|
|
img_black = img_black[:, :, np.newaxis]
|
|
|
|
row, col, c = img_black.shape
|
|
c = 1
|
|
|
|
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
|
|
|
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
|
|
|
imgC, imgH, imgW = image_shape
|
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
|
|
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
|
(feature_dim, 1)).astype('int64')
|
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
|
(max_text_length, 1)).astype('int64')
|
|
|
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
|
[-1, 1, max_text_length, max_text_length])
|
|
gsrm_slf_attn_bias1 = np.tile(
|
|
gsrm_slf_attn_bias1,
|
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
|
|
|
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
|
[-1, 1, max_text_length, max_text_length])
|
|
gsrm_slf_attn_bias2 = np.tile(
|
|
gsrm_slf_attn_bias2,
|
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
|
|
|
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
|
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
|
|
|
return [
|
|
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
|
gsrm_slf_attn_bias2
|
|
]
|
|
|
|
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
|
norm_img = self.resize_norm_img_srn(img, image_shape)
|
|
norm_img = norm_img[np.newaxis, :]
|
|
|
|
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
|
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
|
|
|
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
|
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
|
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
|
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
|
|
|
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
|
gsrm_slf_attn_bias2)
|
|
|
|
def resize_norm_img_sar(self, img, image_shape,
|
|
width_downsample_ratio=0.25):
|
|
imgC, imgH, imgW_min, imgW_max = image_shape
|
|
h = img.shape[0]
|
|
w = img.shape[1]
|
|
valid_ratio = 1.0
|
|
# make sure new_width is an integral multiple of width_divisor.
|
|
width_divisor = int(1 / width_downsample_ratio)
|
|
# resize
|
|
ratio = w / float(h)
|
|
resize_w = math.ceil(imgH * ratio)
|
|
if resize_w % width_divisor != 0:
|
|
resize_w = round(resize_w / width_divisor) * width_divisor
|
|
if imgW_min is not None:
|
|
resize_w = max(imgW_min, resize_w)
|
|
if imgW_max is not None:
|
|
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
|
resize_w = min(imgW_max, resize_w)
|
|
resized_image = cv2.resize(img, (resize_w, imgH))
|
|
resized_image = resized_image.astype('float32')
|
|
# norm
|
|
if image_shape[0] == 1:
|
|
resized_image = resized_image / 255
|
|
resized_image = resized_image[np.newaxis, :]
|
|
else:
|
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
resized_image -= 0.5
|
|
resized_image /= 0.5
|
|
resize_shape = resized_image.shape
|
|
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
|
padding_im[:, :, 0:resize_w] = resized_image
|
|
pad_shape = padding_im.shape
|
|
|
|
return padding_im, resize_shape, pad_shape, valid_ratio
|
|
|
|
def resize_norm_img_spin(self, img):
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
# return padding_im
|
|
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
|
img = np.array(img, np.float32)
|
|
img = np.expand_dims(img, -1)
|
|
img = img.transpose((2, 0, 1))
|
|
mean = [127.5]
|
|
std = [127.5]
|
|
mean = np.array(mean, dtype=np.float32)
|
|
std = np.array(std, dtype=np.float32)
|
|
mean = np.float32(mean.reshape(1, -1))
|
|
stdinv = 1 / np.float32(std.reshape(1, -1))
|
|
img -= mean
|
|
img *= stdinv
|
|
return img
|
|
|
|
def resize_norm_img_svtr(self, img, image_shape):
|
|
|
|
imgC, imgH, imgW = image_shape
|
|
resized_image = cv2.resize(
|
|
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
|
resized_image = resized_image.astype('float32')
|
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
|
resized_image -= 0.5
|
|
resized_image /= 0.5
|
|
return resized_image
|
|
|
|
def resize_norm_img_abinet(self, img, image_shape):
|
|
|
|
imgC, imgH, imgW = image_shape
|
|
|
|
resized_image = cv2.resize(
|
|
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
|
resized_image = resized_image.astype('float32')
|
|
resized_image = resized_image / 255.
|
|
|
|
mean = np.array([0.485, 0.456, 0.406])
|
|
std = np.array([0.229, 0.224, 0.225])
|
|
resized_image = (
|
|
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
|
resized_image = resized_image.transpose((2, 0, 1))
|
|
resized_image = resized_image.astype('float32')
|
|
|
|
return resized_image
|
|
|
|
def norm_img_can(self, img, image_shape):
|
|
|
|
img = cv2.cvtColor(
|
|
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
|
|
|
if self.rec_image_shape[0] == 1:
|
|
h, w = img.shape
|
|
_, imgH, imgW = self.rec_image_shape
|
|
if h < imgH or w < imgW:
|
|
padding_h = max(imgH - h, 0)
|
|
padding_w = max(imgW - w, 0)
|
|
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
|
|
'constant',
|
|
constant_values=(255))
|
|
img = img_padded
|
|
|
|
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
|
img = img.astype('float32')
|
|
|
|
return img
|
|
|
|
def close(self):
|
|
# close session and release manually
|
|
logging.info('Close text recognizer.')
|
|
if hasattr(self, "predictor"):
|
|
del self.predictor
|
|
gc.collect()
|
|
|
|
def __call__(self, img_list):
|
|
img_num = len(img_list)
|
|
# Calculate the aspect ratio of all text bars
|
|
width_list = []
|
|
for img in img_list:
|
|
width_list.append(img.shape[1] / float(img.shape[0]))
|
|
# Sorting can speed up the recognition process
|
|
indices = np.argsort(np.array(width_list))
|
|
rec_res = [['', 0.0]] * img_num
|
|
batch_num = self.rec_batch_num
|
|
st = time.time()
|
|
|
|
for beg_img_no in range(0, img_num, batch_num):
|
|
end_img_no = min(img_num, beg_img_no + batch_num)
|
|
norm_img_batch = []
|
|
imgC, imgH, imgW = self.rec_image_shape[:3]
|
|
max_wh_ratio = imgW / imgH
|
|
# max_wh_ratio = 0
|
|
for ino in range(beg_img_no, end_img_no):
|
|
h, w = img_list[indices[ino]].shape[0:2]
|
|
wh_ratio = w * 1.0 / h
|
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
|
for ino in range(beg_img_no, end_img_no):
|
|
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
|
max_wh_ratio)
|
|
norm_img = norm_img[np.newaxis, :]
|
|
norm_img_batch.append(norm_img)
|
|
norm_img_batch = np.concatenate(norm_img_batch)
|
|
norm_img_batch = norm_img_batch.copy()
|
|
|
|
input_dict = {}
|
|
input_dict[self.input_tensor.name] = norm_img_batch
|
|
for i in range(100000):
|
|
try:
|
|
outputs = self.predictor.run(None, input_dict, self.run_options)
|
|
break
|
|
except Exception as e:
|
|
if i >= 3:
|
|
raise e
|
|
time.sleep(5)
|
|
preds = outputs[0]
|
|
rec_result = self.postprocess_op(preds)
|
|
for rno in range(len(rec_result)):
|
|
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
|
|
|
return rec_res, time.time() - st
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
|
|
class TextDetector:
|
|
def __init__(self, model_dir, device_id: int | None = None):
|
|
pre_process_list = [{
|
|
'DetResizeForTest': {
|
|
'limit_side_len': 960,
|
|
'limit_type': "max",
|
|
}
|
|
}, {
|
|
'NormalizeImage': {
|
|
'std': [0.229, 0.224, 0.225],
|
|
'mean': [0.485, 0.456, 0.406],
|
|
'scale': '1./255.',
|
|
'order': 'hwc'
|
|
}
|
|
}, {
|
|
'ToCHWImage': None
|
|
}, {
|
|
'KeepKeys': {
|
|
'keep_keys': ['image', 'shape']
|
|
}
|
|
}]
|
|
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
|
|
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
|
|
|
|
self.postprocess_op = build_post_process(postprocess_params)
|
|
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
|
|
self.input_tensor = self.predictor.get_inputs()[0]
|
|
|
|
img_h, img_w = self.input_tensor.shape[2:]
|
|
if isinstance(img_h, str) or isinstance(img_w, str):
|
|
pass
|
|
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
|
|
pre_process_list[0] = {
|
|
'DetResizeForTest': {
|
|
'image_shape': [img_h, img_w]
|
|
}
|
|
}
|
|
self.preprocess_op = create_operators(pre_process_list)
|
|
|
|
def order_points_clockwise(self, pts):
|
|
rect = np.zeros((4, 2), dtype="float32")
|
|
s = pts.sum(axis=1)
|
|
rect[0] = pts[np.argmin(s)]
|
|
rect[2] = pts[np.argmax(s)]
|
|
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
|
diff = np.diff(np.array(tmp), axis=1)
|
|
rect[1] = tmp[np.argmin(diff)]
|
|
rect[3] = tmp[np.argmax(diff)]
|
|
return rect
|
|
|
|
def clip_det_res(self, points, img_height, img_width):
|
|
for pno in range(points.shape[0]):
|
|
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
|
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
|
return points
|
|
|
|
def filter_tag_det_res(self, dt_boxes, image_shape):
|
|
img_height, img_width = image_shape[0:2]
|
|
dt_boxes_new = []
|
|
for box in dt_boxes:
|
|
if isinstance(box, list):
|
|
box = np.array(box)
|
|
box = self.order_points_clockwise(box)
|
|
box = self.clip_det_res(box, img_height, img_width)
|
|
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
|
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
|
if rect_width <= 3 or rect_height <= 3:
|
|
continue
|
|
dt_boxes_new.append(box)
|
|
dt_boxes = np.array(dt_boxes_new)
|
|
return dt_boxes
|
|
|
|
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
|
img_height, img_width = image_shape[0:2]
|
|
dt_boxes_new = []
|
|
for box in dt_boxes:
|
|
if isinstance(box, list):
|
|
box = np.array(box)
|
|
box = self.clip_det_res(box, img_height, img_width)
|
|
dt_boxes_new.append(box)
|
|
dt_boxes = np.array(dt_boxes_new)
|
|
return dt_boxes
|
|
|
|
def close(self):
|
|
logging.info("Close text detector.")
|
|
if hasattr(self, "predictor"):
|
|
del self.predictor
|
|
gc.collect()
|
|
|
|
def __call__(self, img):
|
|
ori_im = img.copy()
|
|
data = {'image': img}
|
|
|
|
st = time.time()
|
|
data = transform(data, self.preprocess_op)
|
|
img, shape_list = data
|
|
if img is None:
|
|
return None, 0
|
|
img = np.expand_dims(img, axis=0)
|
|
shape_list = np.expand_dims(shape_list, axis=0)
|
|
img = img.copy()
|
|
input_dict = {}
|
|
input_dict[self.input_tensor.name] = img
|
|
for i in range(100000):
|
|
try:
|
|
outputs = self.predictor.run(None, input_dict, self.run_options)
|
|
break
|
|
except Exception as e:
|
|
if i >= 3:
|
|
raise e
|
|
time.sleep(5)
|
|
|
|
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
|
|
dt_boxes = post_result[0]['points']
|
|
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
|
|
|
return dt_boxes, time.time() - st
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
|
|
class OCR:
|
|
def __init__(self, model_dir=None):
|
|
"""
|
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
|
|
For Linux:
|
|
export HF_ENDPOINT=https://hf-mirror.com
|
|
|
|
For Windows:
|
|
Good luck
|
|
^_-
|
|
|
|
"""
|
|
if not model_dir:
|
|
try:
|
|
model_dir = os.path.join(
|
|
get_project_base_directory(),
|
|
"rag/res/deepdoc")
|
|
|
|
# Append muti-gpus task to the list
|
|
if settings.PARALLEL_DEVICES > 0:
|
|
self.text_detector = []
|
|
self.text_recognizer = []
|
|
for device_id in range(settings.PARALLEL_DEVICES):
|
|
self.text_detector.append(TextDetector(model_dir, device_id))
|
|
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
|
else:
|
|
self.text_detector = [TextDetector(model_dir)]
|
|
self.text_recognizer = [TextRecognizer(model_dir)]
|
|
|
|
except Exception:
|
|
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
|
local_dir_use_symlinks=False)
|
|
|
|
if settings.PARALLEL_DEVICES > 0:
|
|
self.text_detector = []
|
|
self.text_recognizer = []
|
|
for device_id in range(settings.PARALLEL_DEVICES):
|
|
self.text_detector.append(TextDetector(model_dir, device_id))
|
|
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
|
else:
|
|
self.text_detector = [TextDetector(model_dir)]
|
|
self.text_recognizer = [TextRecognizer(model_dir)]
|
|
|
|
self.drop_score = 0.5
|
|
self.crop_image_res_index = 0
|
|
|
|
def get_rotate_crop_image(self, img, points):
|
|
"""
|
|
img_height, img_width = img.shape[0:2]
|
|
left = int(np.min(points[:, 0]))
|
|
right = int(np.max(points[:, 0]))
|
|
top = int(np.min(points[:, 1]))
|
|
bottom = int(np.max(points[:, 1]))
|
|
img_crop = img[top:bottom, left:right, :].copy()
|
|
points[:, 0] = points[:, 0] - left
|
|
points[:, 1] = points[:, 1] - top
|
|
"""
|
|
assert len(points) == 4, "shape of points must be 4*2"
|
|
img_crop_width = int(
|
|
max(
|
|
np.linalg.norm(points[0] - points[1]),
|
|
np.linalg.norm(points[2] - points[3])))
|
|
img_crop_height = int(
|
|
max(
|
|
np.linalg.norm(points[0] - points[3]),
|
|
np.linalg.norm(points[1] - points[2])))
|
|
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
|
[img_crop_width, img_crop_height],
|
|
[0, img_crop_height]])
|
|
M = cv2.getPerspectiveTransform(points, pts_std)
|
|
dst_img = cv2.warpPerspective(
|
|
img,
|
|
M, (img_crop_width, img_crop_height),
|
|
borderMode=cv2.BORDER_REPLICATE,
|
|
flags=cv2.INTER_CUBIC)
|
|
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
|
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
|
# Try original orientation
|
|
rec_result = self.text_recognizer[0]([dst_img])
|
|
text, score = rec_result[0][0]
|
|
best_score = score
|
|
best_img = dst_img
|
|
|
|
# Try clockwise 90° rotation
|
|
rotated_cw = np.rot90(dst_img, k=3)
|
|
rec_result = self.text_recognizer[0]([rotated_cw])
|
|
rotated_cw_text, rotated_cw_score = rec_result[0][0]
|
|
if rotated_cw_score > best_score:
|
|
best_score = rotated_cw_score
|
|
best_img = rotated_cw
|
|
|
|
# Try counter-clockwise 90° rotation
|
|
rotated_ccw = np.rot90(dst_img, k=1)
|
|
rec_result = self.text_recognizer[0]([rotated_ccw])
|
|
rotated_ccw_text, rotated_ccw_score = rec_result[0][0]
|
|
if rotated_ccw_score > best_score:
|
|
best_img = rotated_ccw
|
|
|
|
# Use the best image
|
|
dst_img = best_img
|
|
return dst_img
|
|
|
|
def sorted_boxes(self, dt_boxes):
|
|
"""
|
|
Sort text boxes in order from top to bottom, left to right
|
|
args:
|
|
dt_boxes(array):detected text boxes with shape [4, 2]
|
|
return:
|
|
sorted boxes(array) with shape [4, 2]
|
|
"""
|
|
num_boxes = dt_boxes.shape[0]
|
|
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
|
_boxes = list(sorted_boxes)
|
|
|
|
for i in range(num_boxes - 1):
|
|
for j in range(i, -1, -1):
|
|
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
|
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
|
tmp = _boxes[j]
|
|
_boxes[j] = _boxes[j + 1]
|
|
_boxes[j + 1] = tmp
|
|
else:
|
|
break
|
|
return _boxes
|
|
|
|
def detect(self, img, device_id: int | None = None):
|
|
if device_id is None:
|
|
device_id = 0
|
|
|
|
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
|
|
|
if img is None:
|
|
return None, None, time_dict
|
|
|
|
start = time.time()
|
|
dt_boxes, elapse = self.text_detector[device_id](img)
|
|
time_dict['det'] = elapse
|
|
|
|
if dt_boxes is None:
|
|
end = time.time()
|
|
time_dict['all'] = end - start
|
|
return None, None, time_dict
|
|
|
|
return zip(self.sorted_boxes(dt_boxes), [
|
|
("", 0) for _ in range(len(dt_boxes))])
|
|
|
|
def recognize(self, ori_im, box, device_id: int | None = None):
|
|
if device_id is None:
|
|
device_id = 0
|
|
|
|
img_crop = self.get_rotate_crop_image(ori_im, box)
|
|
|
|
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
|
|
text, score = rec_res[0]
|
|
if score < self.drop_score:
|
|
return ""
|
|
return text
|
|
|
|
def recognize_batch(self, img_list, device_id: int | None = None):
|
|
if device_id is None:
|
|
device_id = 0
|
|
rec_res, elapse = self.text_recognizer[device_id](img_list)
|
|
texts = []
|
|
for i in range(len(rec_res)):
|
|
text, score = rec_res[i]
|
|
if score < self.drop_score:
|
|
text = ""
|
|
texts.append(text)
|
|
return texts
|
|
|
|
def __call__(self, img, device_id = 0, cls=True):
|
|
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
|
if device_id is None:
|
|
device_id = 0
|
|
|
|
if img is None:
|
|
return None, None, time_dict
|
|
|
|
start = time.time()
|
|
ori_im = img.copy()
|
|
dt_boxes, elapse = self.text_detector[device_id](img)
|
|
time_dict['det'] = elapse
|
|
|
|
if dt_boxes is None:
|
|
end = time.time()
|
|
time_dict['all'] = end - start
|
|
return None, None, time_dict
|
|
|
|
img_crop_list = []
|
|
|
|
dt_boxes = self.sorted_boxes(dt_boxes)
|
|
|
|
for bno in range(len(dt_boxes)):
|
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
|
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
|
img_crop_list.append(img_crop)
|
|
|
|
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
|
|
|
|
time_dict['rec'] = elapse
|
|
|
|
filter_boxes, filter_rec_res = [], []
|
|
for box, rec_result in zip(dt_boxes, rec_res):
|
|
text, score = rec_result
|
|
if score >= self.drop_score:
|
|
filter_boxes.append(box)
|
|
filter_rec_res.append(rec_result)
|
|
end = time.time()
|
|
time_dict['all'] = end - start
|
|
|
|
# for bno in range(len(img_crop_list)):
|
|
# print(f"{bno}, {rec_res[bno]}")
|
|
|
|
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))
|