let's load model from local (#163)

This commit is contained in:
KevinHuSh
2024-03-28 16:10:47 +08:00
committed by GitHub
parent f3477202fe
commit a5384446e3
7 changed files with 31 additions and 48 deletions

View File

@ -18,7 +18,7 @@ from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
from rag.nlp import huqie from rag.nlp import huqie
from copy import deepcopy from copy import deepcopy
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import snapshot_download
logging.getLogger("pdfminer").setLevel(logging.WARNING) logging.getLogger("pdfminer").setLevel(logging.WARNING)
@ -36,18 +36,18 @@ class HuParser:
if torch.cuda.is_available(): if torch.cuda.is_available():
self.updown_cnt_mdl.set_param({"device": "cuda"}) self.updown_cnt_mdl.set_param({"device": "cuda"})
try: try:
model_dir = snapshot_download( model_dir = os.path.join(
repo_id="InfiniFlow/text_concat_xgb_v1.0",
local_dir=os.path.join(
get_project_base_directory(), get_project_base_directory(),
"rag/res/deepdoc"), "rag/res/deepdoc")
local_files_only=True) self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model"))
except Exception as e: except Exception as e:
model_dir = snapshot_download( model_dir = snapshot_download(
repo_id="InfiniFlow/text_concat_xgb_v1.0") repo_id="InfiniFlow/text_concat_xgb_v1.0")
self.updown_cnt_mdl.load_model(os.path.join( self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model")) model_dir, "updown_concat_xgb.model"))
self.page_from = 0 self.page_from = 0
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!

View File

@ -17,7 +17,6 @@ from copy import deepcopy
import numpy as np import numpy as np
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from api.db import ParserType
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import Recognizer from deepdoc.vision import Recognizer
@ -39,17 +38,14 @@ class LayoutRecognizer(Recognizer):
def __init__(self, domain): def __init__(self, domain):
try: try:
model_dir = snapshot_download( model_dir = os.path.join(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(
get_project_base_directory(), get_project_base_directory(),
"rag/res/deepdoc"), "rag/res/deepdoc")
local_files_only=True) super().__init__(self.labels, domain, model_dir)
except Exception as e: except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
# os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
super().__init__(self.labels, domain, model_dir) super().__init__(self.labels, domain, model_dir)
self.garbage_layouts = ["footer", "header", "reference"] self.garbage_layouts = ["footer", "header", "reference"]
def __call__(self, image_list, ocr_res, scale_factor=3, def __call__(self, image_list, ocr_res, scale_factor=3,

View File

@ -480,17 +480,16 @@ class OCR(object):
""" """
if not model_dir: if not model_dir:
try: try:
model_dir = snapshot_download( model_dir = os.path.join(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(
get_project_base_directory(), get_project_base_directory(),
"rag/res/deepdoc"), "rag/res/deepdoc")
local_files_only=True)
except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
self.text_detector = TextDetector(model_dir) self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir) self.text_recognizer = TextRecognizer(model_dir)
except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
self.drop_score = 0.5 self.drop_score = 0.5
self.crop_image_res_index = 0 self.crop_image_res_index = 0

View File

@ -36,17 +36,14 @@ class Recognizer(object):
""" """
if not model_dir: if not model_dir:
try: model_dir = os.path.join(
model_dir = snapshot_download(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(
get_project_base_directory(), get_project_base_directory(),
"rag/res/deepdoc"), "rag/res/deepdoc")
local_files_only=True)
except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
model_file_path = os.path.join(model_dir, task_name + ".onnx") model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path):
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format( raise ValueError("not find model file path {}".format(
model_file_path)) model_file_path))

View File

@ -35,17 +35,11 @@ class TableStructureRecognizer(Recognizer):
def __init__(self): def __init__(self):
try: try:
model_dir = snapshot_download( super().__init__(self.labels, "tsr", os.path.join(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(
get_project_base_directory(), get_project_base_directory(),
"rag/res/deepdoc"), "rag/res/deepdoc"))
local_files_only=True)
except Exception as e: except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc"))
# os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
super().__init__(self.labels, "tsr", model_dir)
def __call__(self, images, thr=0.2): def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr) tbls = super().__call__(images, thr)

View File

@ -28,16 +28,13 @@ from api.utils.file_utils import get_project_base_directory
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
try: try:
model_dir = snapshot_download( flag_model = FlagModel(os.path.join(
repo_id="BAAI/bge-large-zh-v1.5",
local_dir=os.path.join(
get_project_base_directory(), get_project_base_directory(),
"rag/res/bge-large-zh-v1.5"), "rag/res/bge-large-zh-v1.5"),
local_files_only=True) query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
except Exception as e: except Exception as e:
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5") flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
flag_model = FlagModel(model_dir,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available()) use_fp16=torch.cuda.is_available())

View File

@ -247,7 +247,7 @@ class Dealer:
for ck in chunks] for ck in chunks]
cites = {} cites = {}
thr = 0.63 thr = 0.63
while len(cites.keys()) == 0 and pieces_ and chunks_tks: while thr>0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
for i, a in enumerate(pieces_): for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
chunk_v, chunk_v,