From a5384446e399bbab31a56a3124117fd8a0652ae7 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Thu, 28 Mar 2024 16:10:47 +0800 Subject: [PATCH] let's load model from local (#163) --- deepdoc/parser/pdf_parser.py | 16 ++++++++-------- deepdoc/vision/layout_recognizer.py | 12 ++++-------- deepdoc/vision/ocr.py | 13 ++++++------- deepdoc/vision/recognizer.py | 13 +++++-------- deepdoc/vision/table_structure_recognizer.py | 12 +++--------- rag/llm/embedding_model.py | 11 ++++------- rag/nlp/search.py | 2 +- 7 files changed, 31 insertions(+), 48 deletions(-) diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 10257ec66..dfd3756d2 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -18,7 +18,7 @@ from api.utils.file_utils import get_project_base_directory from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer from rag.nlp import huqie from copy import deepcopy -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import snapshot_download logging.getLogger("pdfminer").setLevel(logging.WARNING) @@ -36,18 +36,18 @@ class HuParser: if torch.cuda.is_available(): self.updown_cnt_mdl.set_param({"device": "cuda"}) try: - model_dir = snapshot_download( - repo_id="InfiniFlow/text_concat_xgb_v1.0", - local_dir=os.path.join( + model_dir = os.path.join( get_project_base_directory(), - "rag/res/deepdoc"), - local_files_only=True) + "rag/res/deepdoc") + self.updown_cnt_mdl.load_model(os.path.join( + model_dir, "updown_concat_xgb.model")) except Exception as e: model_dir = snapshot_download( repo_id="InfiniFlow/text_concat_xgb_v1.0") + self.updown_cnt_mdl.load_model(os.path.join( + model_dir, "updown_concat_xgb.model")) + - self.updown_cnt_mdl.load_model(os.path.join( - model_dir, "updown_concat_xgb.model")) self.page_from = 0 """ If you have trouble downloading HuggingFace models, -_^ this might help!! diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py index 7b87622e5..917ee6ed8 100644 --- a/deepdoc/vision/layout_recognizer.py +++ b/deepdoc/vision/layout_recognizer.py @@ -17,7 +17,6 @@ from copy import deepcopy import numpy as np from huggingface_hub import snapshot_download -from api.db import ParserType from api.utils.file_utils import get_project_base_directory from deepdoc.vision import Recognizer @@ -39,17 +38,14 @@ class LayoutRecognizer(Recognizer): def __init__(self, domain): try: - model_dir = snapshot_download( - repo_id="InfiniFlow/deepdoc", - local_dir=os.path.join( + model_dir = os.path.join( get_project_base_directory(), - "rag/res/deepdoc"), - local_files_only=True) + "rag/res/deepdoc") + super().__init__(self.labels, domain, model_dir) except Exception as e: model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") + super().__init__(self.labels, domain, model_dir) - # os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) - super().__init__(self.labels, domain, model_dir) self.garbage_layouts = ["footer", "header", "reference"] def __call__(self, image_list, ocr_res, scale_factor=3, diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index dd34e4b26..b55024ed4 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -480,17 +480,16 @@ class OCR(object): """ if not model_dir: try: - model_dir = snapshot_download( - repo_id="InfiniFlow/deepdoc", - local_dir=os.path.join( + model_dir = os.path.join( get_project_base_directory(), - "rag/res/deepdoc"), - local_files_only=True) + "rag/res/deepdoc") + self.text_detector = TextDetector(model_dir) + self.text_recognizer = TextRecognizer(model_dir) except Exception as e: model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") + self.text_detector = TextDetector(model_dir) + self.text_recognizer = TextRecognizer(model_dir) - self.text_detector = TextDetector(model_dir) - self.text_recognizer = TextRecognizer(model_dir) self.drop_score = 0.5 self.crop_image_res_index = 0 diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index 4f619660a..ad8b9ba24 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -36,17 +36,14 @@ class Recognizer(object): """ if not model_dir: - try: - model_dir = snapshot_download( - repo_id="InfiniFlow/deepdoc", - local_dir=os.path.join( + model_dir = os.path.join( get_project_base_directory(), - "rag/res/deepdoc"), - local_files_only=True) - except Exception as e: + "rag/res/deepdoc") + model_file_path = os.path.join(model_dir, task_name + ".onnx") + if not os.path.exists(model_file_path): model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") + model_file_path = os.path.join(model_dir, task_name + ".onnx") - model_file_path = os.path.join(model_dir, task_name + ".onnx") if not os.path.exists(model_file_path): raise ValueError("not find model file path {}".format( model_file_path)) diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py index ebd57a6c8..6779137d8 100644 --- a/deepdoc/vision/table_structure_recognizer.py +++ b/deepdoc/vision/table_structure_recognizer.py @@ -35,17 +35,11 @@ class TableStructureRecognizer(Recognizer): def __init__(self): try: - model_dir = snapshot_download( - repo_id="InfiniFlow/deepdoc", - local_dir=os.path.join( + super().__init__(self.labels, "tsr", os.path.join( get_project_base_directory(), - "rag/res/deepdoc"), - local_files_only=True) + "rag/res/deepdoc")) except Exception as e: - model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") - - # os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) - super().__init__(self.labels, "tsr", model_dir) + super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc")) def __call__(self, images, thr=0.2): tbls = super().__call__(images, thr) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 68a6e0aa4..169a4df83 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -28,16 +28,13 @@ from api.utils.file_utils import get_project_base_directory from rag.utils import num_tokens_from_string try: - model_dir = snapshot_download( - repo_id="BAAI/bge-large-zh-v1.5", - local_dir=os.path.join( + flag_model = FlagModel(os.path.join( get_project_base_directory(), "rag/res/bge-large-zh-v1.5"), - local_files_only=True) + query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", + use_fp16=torch.cuda.is_available()) except Exception as e: - model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5") - -flag_model = FlagModel(model_dir, + flag_model = FlagModel("BAAI/bge-large-zh-v1.5", query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index ac92853dd..01564bba9 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -247,7 +247,7 @@ class Dealer: for ck in chunks] cites = {} thr = 0.63 - while len(cites.keys()) == 0 and pieces_ and chunks_tks: + while thr>0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks: for i, a in enumerate(pieces_): sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], chunk_v,