mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add OCR's muti-gpus and parallel processing support (#5972)
### What problem does this PR solve? Add OCR's muti-gpus and parallel processing support ### Type of change - [x] New Feature (non-breaking change which adds functionality) @yuzhichang I've tried to resolve the comments in #5697. OCR jobs can now be done on both CPU and GPU. ( By the way, I've encountered a “Generate embedding error” issue #5954 that might be due to my outdated GPUs? idk. ) Please review it and give me suggestions. GPU:   CPU: 
This commit is contained in:
@ -66,10 +66,12 @@ def create_operators(op_param_list, global_config=None):
|
||||
return ops
|
||||
|
||||
|
||||
def load_model(model_dir, nm):
|
||||
def load_model(model_dir, nm, device_id: int | None = None):
|
||||
model_file_path = os.path.join(model_dir, nm + ".onnx")
|
||||
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
|
||||
|
||||
global loaded_models
|
||||
loaded_model = loaded_models.get(model_file_path)
|
||||
loaded_model = loaded_models.get(model_cached_tag)
|
||||
if loaded_model:
|
||||
logging.info(f"load_model {model_file_path} reuses cached model")
|
||||
return loaded_model
|
||||
@ -81,7 +83,7 @@ def load_model(model_dir, nm):
|
||||
def cuda_is_available():
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > device_id:
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@ -98,7 +100,7 @@ def load_model(model_dir, nm):
|
||||
run_options = ort.RunOptions()
|
||||
if cuda_is_available():
|
||||
cuda_provider_options = {
|
||||
"device_id": 0, # Use specific GPU
|
||||
"device_id": device_id, # Use specific GPU
|
||||
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory
|
||||
"arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy
|
||||
}
|
||||
@ -108,7 +110,7 @@ def load_model(model_dir, nm):
|
||||
providers=['CUDAExecutionProvider'],
|
||||
provider_options=[cuda_provider_options]
|
||||
)
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:0")
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id))
|
||||
logging.info(f"load_model {model_file_path} uses GPU")
|
||||
else:
|
||||
sess = ort.InferenceSession(
|
||||
@ -118,12 +120,12 @@ def load_model(model_dir, nm):
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
|
||||
logging.info(f"load_model {model_file_path} uses CPU")
|
||||
loaded_model = (sess, run_options)
|
||||
loaded_models[model_file_path] = loaded_model
|
||||
loaded_models[model_cached_tag] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
class TextRecognizer:
|
||||
def __init__(self, model_dir):
|
||||
def __init__(self, model_dir, device_id: int | None = None):
|
||||
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
|
||||
self.rec_batch_num = 16
|
||||
postprocess_params = {
|
||||
@ -132,7 +134,7 @@ class TextRecognizer:
|
||||
"use_space_char": True
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.run_options = load_model(model_dir, 'rec')
|
||||
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
|
||||
self.input_tensor = self.predictor.get_inputs()[0]
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
@ -394,7 +396,7 @@ class TextRecognizer:
|
||||
|
||||
|
||||
class TextDetector:
|
||||
def __init__(self, model_dir):
|
||||
def __init__(self, model_dir, device_id: int | None = None):
|
||||
pre_process_list = [{
|
||||
'DetResizeForTest': {
|
||||
'limit_side_len': 960,
|
||||
@ -418,7 +420,7 @@ class TextDetector:
|
||||
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
|
||||
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.run_options = load_model(model_dir, 'det')
|
||||
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
|
||||
self.input_tensor = self.predictor.get_inputs()[0]
|
||||
|
||||
img_h, img_w = self.input_tensor.shape[2:]
|
||||
@ -507,7 +509,7 @@ class TextDetector:
|
||||
|
||||
|
||||
class OCR:
|
||||
def __init__(self, model_dir=None):
|
||||
def __init__(self, model_dir=None, parallel_devices: int | None = None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
@ -524,14 +526,33 @@ class OCR:
|
||||
model_dir = os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc")
|
||||
self.text_detector = TextDetector(model_dir)
|
||||
self.text_recognizer = TextRecognizer(model_dir)
|
||||
|
||||
# Append muti-gpus task to the list
|
||||
if parallel_devices is not None and parallel_devices > 0:
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(parallel_devices):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
self.text_detector = [TextDetector(model_dir, 0)]
|
||||
self.text_recognizer = [TextRecognizer(model_dir, 0)]
|
||||
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||
local_dir_use_symlinks=False)
|
||||
self.text_detector = TextDetector(model_dir)
|
||||
self.text_recognizer = TextRecognizer(model_dir)
|
||||
|
||||
if parallel_devices is not None:
|
||||
assert parallel_devices > 0 , "Number of devices must be >= 1"
|
||||
self.text_detector = []
|
||||
self.text_recognizer = []
|
||||
for device_id in range(parallel_devices):
|
||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||
else:
|
||||
self.text_detector = [TextDetector(model_dir, 0)]
|
||||
self.text_recognizer = [TextRecognizer(model_dir, 0)]
|
||||
|
||||
self.drop_score = 0.5
|
||||
self.crop_image_res_index = 0
|
||||
@ -593,14 +614,17 @@ class OCR:
|
||||
break
|
||||
return _boxes
|
||||
|
||||
def detect(self, img):
|
||||
def detect(self, img, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
|
||||
if img is None:
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
dt_boxes, elapse = self.text_detector[device_id](img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
@ -611,17 +635,22 @@ class OCR:
|
||||
return zip(self.sorted_boxes(dt_boxes), [
|
||||
("", 0) for _ in range(len(dt_boxes))])
|
||||
|
||||
def recognize(self, ori_im, box):
|
||||
def recognize(self, ori_im, box, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
img_crop = self.get_rotate_crop_image(ori_im, box)
|
||||
|
||||
rec_res, elapse = self.text_recognizer([img_crop])
|
||||
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
|
||||
text, score = rec_res[0]
|
||||
if score < self.drop_score:
|
||||
return ""
|
||||
return text
|
||||
|
||||
def recognize_batch(self, img_list):
|
||||
rec_res, elapse = self.text_recognizer(img_list)
|
||||
def recognize_batch(self, img_list, device_id: int | None = None):
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
rec_res, elapse = self.text_recognizer[device_id](img_list)
|
||||
texts = []
|
||||
for i in range(len(rec_res)):
|
||||
text, score = rec_res[i]
|
||||
@ -630,15 +659,17 @@ class OCR:
|
||||
texts.append(text)
|
||||
return texts
|
||||
|
||||
def __call__(self, img, cls=True):
|
||||
def __call__(self, img, device_id = 0, cls=True):
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
if device_id is None:
|
||||
device_id = 0
|
||||
|
||||
if img is None:
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
dt_boxes, elapse = self.text_detector[device_id](img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
@ -655,7 +686,7 @@ class OCR:
|
||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
|
||||
|
||||
time_dict['rec'] = elapse
|
||||
|
||||
|
||||
Reference in New Issue
Block a user