mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
support snapshot download from local (#153)
* support snapshot download from local * let snapshot download from local
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import random
|
||||
|
||||
import fitz
|
||||
@ -12,10 +13,12 @@ from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
|
||||
from PyPDF2 import PdfReader as pdf2_read
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
||||
from rag.nlp import huqie
|
||||
from copy import deepcopy
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
logging.getLogger("pdfminer").setLevel(logging.WARNING)
|
||||
|
||||
@ -32,8 +35,17 @@ class HuParser:
|
||||
self.updown_cnt_mdl = xgb.Booster()
|
||||
if torch.cuda.is_available():
|
||||
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
||||
self.updown_cnt_mdl.load_model(hf_hub_download(repo_id="InfiniFlow/text_concat_xgb_v1.0",
|
||||
filename="updown_concat_xgb.model"))
|
||||
try:
|
||||
model_dir = snapshot_download(
|
||||
repo_id="InfiniFlow/text_concat_xgb_v1.0",
|
||||
local_dir=os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc"),
|
||||
local_files_only=True)
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0")
|
||||
|
||||
self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model"))
|
||||
self.page_from = 0
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
@ -37,7 +37,16 @@ class LayoutRecognizer(Recognizer):
|
||||
"Equation",
|
||||
]
|
||||
def __init__(self, domain):
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
try:
|
||||
model_dir = snapshot_download(
|
||||
repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc"),
|
||||
local_files_only=True)
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
|
||||
super().__init__(self.labels, domain, model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||
self.garbage_layouts = ["footer", "header", "reference"]
|
||||
|
||||
|
||||
@ -14,6 +14,10 @@
|
||||
import copy
|
||||
import time
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .operators import *
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
@ -21,6 +25,7 @@ import onnxruntime as ort
|
||||
from .postprocess import build_post_process
|
||||
from rag.settings import cron_logger
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
if ops is None:
|
||||
@ -66,9 +71,15 @@ def load_model(model_dir, nm):
|
||||
options.intra_op_num_threads = 2
|
||||
options.inter_op_num_threads = 2
|
||||
if False and ort.get_device() == "GPU":
|
||||
sess = ort.InferenceSession(model_file_path, options=options, providers=['CUDAExecutionProvider'])
|
||||
sess = ort.InferenceSession(
|
||||
model_file_path,
|
||||
options=options,
|
||||
providers=['CUDAExecutionProvider'])
|
||||
else:
|
||||
sess = ort.InferenceSession(model_file_path, options=options, providers=['CPUExecutionProvider'])
|
||||
sess = ort.InferenceSession(
|
||||
model_file_path,
|
||||
options=options,
|
||||
providers=['CPUExecutionProvider'])
|
||||
return sess, sess.get_inputs()[0]
|
||||
|
||||
|
||||
@ -331,7 +342,8 @@ class TextRecognizer(object):
|
||||
outputs = self.predictor.run(None, input_dict)
|
||||
break
|
||||
except Exception as e:
|
||||
if i >= 3: raise e
|
||||
if i >= 3:
|
||||
raise e
|
||||
time.sleep(5)
|
||||
preds = outputs[0]
|
||||
rec_result = self.postprocess_op(preds)
|
||||
@ -442,7 +454,8 @@ class TextDetector(object):
|
||||
outputs = self.predictor.run(None, input_dict)
|
||||
break
|
||||
except Exception as e:
|
||||
if i >= 3: raise e
|
||||
if i >= 3:
|
||||
raise e
|
||||
time.sleep(5)
|
||||
|
||||
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
|
||||
@ -466,7 +479,15 @@ class OCR(object):
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
try:
|
||||
model_dir = snapshot_download(
|
||||
repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc"),
|
||||
local_files_only=True)
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
|
||||
self.text_detector = TextDetector(model_dir)
|
||||
self.text_recognizer = TextRecognizer(model_dir)
|
||||
@ -548,14 +569,16 @@ class OCR(object):
|
||||
cron_logger.debug("dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
|
||||
return zip(self.sorted_boxes(dt_boxes), [("",0) for _ in range(len(dt_boxes))])
|
||||
return zip(self.sorted_boxes(dt_boxes), [
|
||||
("", 0) for _ in range(len(dt_boxes))])
|
||||
|
||||
def recognize(self, ori_im, box):
|
||||
img_crop = self.get_rotate_crop_image(ori_im, box)
|
||||
|
||||
rec_res, elapse = self.text_recognizer([img_crop])
|
||||
text, score = rec_res[0]
|
||||
if score < self.drop_score:return ""
|
||||
if score < self.drop_score:
|
||||
return ""
|
||||
return text
|
||||
|
||||
def __call__(self, img, cls=True):
|
||||
@ -600,8 +623,7 @@ class OCR(object):
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
|
||||
|
||||
#for bno in range(len(img_crop_list)):
|
||||
# for bno in range(len(img_crop_list)):
|
||||
# print(f"{bno}, {rec_res[bno]}")
|
||||
|
||||
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))
|
||||
|
||||
@ -17,6 +17,7 @@ from copy import deepcopy
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .operators import *
|
||||
from rag.settings import cron_logger
|
||||
|
||||
@ -35,7 +36,15 @@ class Recognizer(object):
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
try:
|
||||
model_dir = snapshot_download(
|
||||
repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc"),
|
||||
local_files_only=True)
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
|
||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||
if not os.path.exists(model_file_path):
|
||||
|
||||
@ -34,7 +34,16 @@ class TableStructureRecognizer(Recognizer):
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
try:
|
||||
model_dir = snapshot_download(
|
||||
repo_id="InfiniFlow/deepdoc",
|
||||
local_dir=os.path.join(
|
||||
get_project_base_directory(),
|
||||
"rag/res/deepdoc"),
|
||||
local_files_only=True)
|
||||
except Exception as e:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
|
||||
super().__init__(self.labels, "tsr", model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||
|
||||
def __call__(self, images, thr=0.2):
|
||||
|
||||
Reference in New Issue
Block a user