support snapshot download from local (#153)

* support snapshot download from local

* let snapshot download from local
This commit is contained in:
KevinHuSh
2024-03-27 09:53:42 +08:00
committed by GitHub
parent da21320b88
commit 979b3a5b4b
12 changed files with 109 additions and 24 deletions

View File

@ -14,6 +14,10 @@
import copy
import time
import os
from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory
from .operators import *
import numpy as np
import onnxruntime as ort
@ -21,6 +25,7 @@ import onnxruntime as ort
from .postprocess import build_post_process
from rag.settings import cron_logger
def transform(data, ops=None):
""" transform """
if ops is None:
@ -66,9 +71,15 @@ def load_model(model_dir, nm):
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
if False and ort.get_device() == "GPU":
sess = ort.InferenceSession(model_file_path, options=options, providers=['CUDAExecutionProvider'])
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CUDAExecutionProvider'])
else:
sess = ort.InferenceSession(model_file_path, options=options, providers=['CPUExecutionProvider'])
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CPUExecutionProvider'])
return sess, sess.get_inputs()[0]
@ -331,7 +342,8 @@ class TextRecognizer(object):
outputs = self.predictor.run(None, input_dict)
break
except Exception as e:
if i >= 3: raise e
if i >= 3:
raise e
time.sleep(5)
preds = outputs[0]
rec_result = self.postprocess_op(preds)
@ -442,7 +454,8 @@ class TextDetector(object):
outputs = self.predictor.run(None, input_dict)
break
except Exception as e:
if i >= 3: raise e
if i >= 3:
raise e
time.sleep(5)
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
@ -466,7 +479,15 @@ class OCR(object):
"""
if not model_dir:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
try:
model_dir = snapshot_download(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(
get_project_base_directory(),
"rag/res/deepdoc"),
local_files_only=True)
except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
@ -548,14 +569,16 @@ class OCR(object):
cron_logger.debug("dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), elapse))
return zip(self.sorted_boxes(dt_boxes), [("",0) for _ in range(len(dt_boxes))])
return zip(self.sorted_boxes(dt_boxes), [
("", 0) for _ in range(len(dt_boxes))])
def recognize(self, ori_im, box):
img_crop = self.get_rotate_crop_image(ori_im, box)
rec_res, elapse = self.text_recognizer([img_crop])
text, score = rec_res[0]
if score < self.drop_score:return ""
if score < self.drop_score:
return ""
return text
def __call__(self, img, cls=True):
@ -600,8 +623,7 @@ class OCR(object):
end = time.time()
time_dict['all'] = end - start
#for bno in range(len(img_crop_list)):
# for bno in range(len(img_crop_list)):
# print(f"{bno}, {rec_res[bno]}")
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))