From 0ce720a2479b591829b288bcbd0a613f482f53e5 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Thu, 27 Jun 2024 14:57:24 +0800 Subject: [PATCH] fix mem leak for local reranker (#1295) ### What problem does this PR solve? #1288 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/rerank_model.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index e449cb1fa..c56a3ccea 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -39,6 +39,7 @@ class Base(ABC): class DefaultRerank(Base): _model = None _model_lock = threading.Lock() + def __init__(self, key, model_name, **kwargs): """ If you have trouble downloading HuggingFace models, -_^ this might help!! @@ -102,19 +103,24 @@ class JinaRerank(Base): class YoudaoRerank(DefaultRerank): _model = None + _model_lock = threading.Lock() def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): from BCEmbedding import RerankerModel if not YoudaoRerank._model: - try: - print("LOADING BCE...") - YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( - get_home_cache_dir(), - re.sub(r"^[a-zA-Z]+/", "", model_name))) - except Exception as e: - YoudaoRerank._model = RerankerModel( - model_name_or_path=model_name.replace( - "maidalun1020", "InfiniFlow")) + with YoudaoRerank._model_lock: + if not YoudaoRerank._model: + try: + print("LOADING BCE...") + YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( + get_home_cache_dir(), + re.sub(r"^[a-zA-Z]+/", "", model_name))) + except Exception as e: + YoudaoRerank._model = RerankerModel( + model_name_or_path=model_name.replace( + "maidalun1020", "InfiniFlow")) + + self._model = YoudaoRerank._model def similarity(self, query: str, texts: list): pairs = [(query, truncate(t, self._model.max_length)) for t in texts]