From 47926a95aec756c445e69ec800c666a5c77175c6 Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Thu, 27 Jun 2024 14:48:49 +0800 Subject: [PATCH] Fix ragflow may encounter an OOM (Out Of Memory) when there are a lot of conversations (#1292) ### What problem does this PR solve? Fix ragflow may encounter an OOM (Out Of Memory) when there are a lot of conversations. #1288 ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) Co-authored-by: zhuhao --- rag/llm/embedding_model.py | 28 ++++++++++++++++------------ rag/llm/rerank_model.py | 23 ++++++++++++----------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 6b45b6035..596c1e353 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -15,6 +15,7 @@ # import re from typing import Optional +import threading import requests from huggingface_hub import snapshot_download from zhipuai import ZhipuAI @@ -44,7 +45,7 @@ class Base(ABC): class DefaultEmbedding(Base): _model = None - + _model_lock = threading.Lock() def __init__(self, key, model_name, **kwargs): """ If you have trouble downloading HuggingFace models, -_^ this might help!! @@ -58,17 +59,20 @@ class DefaultEmbedding(Base): """ if not DefaultEmbedding._model: - try: - self._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), - query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", - use_fp16=torch.cuda.is_available()) - except Exception as e: - model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", - local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), - local_dir_use_symlinks=False) - self._model = FlagModel(model_dir, - query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", - use_fp16=torch.cuda.is_available()) + with DefaultEmbedding._model_lock: + if not DefaultEmbedding._model: + try: + DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", + use_fp16=torch.cuda.is_available()) + except Exception as e: + model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", + local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + local_dir_use_symlinks=False) + DefaultEmbedding._model = FlagModel(model_dir, + query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", + use_fp16=torch.cuda.is_available()) + self._model = DefaultEmbedding._model def encode(self, texts: list, batch_size=32): texts = [truncate(t, 2048) for t in texts] diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index ba81273ce..e449cb1fa 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -14,6 +14,7 @@ # limitations under the License. # import re +import threading import requests import torch from FlagEmbedding import FlagReranker @@ -37,7 +38,7 @@ class Base(ABC): class DefaultRerank(Base): _model = None - + _model_lock = threading.Lock() def __init__(self, key, model_name, **kwargs): """ If you have trouble downloading HuggingFace models, -_^ this might help!! @@ -51,16 +52,16 @@ class DefaultRerank(Base): """ if not DefaultRerank._model: - try: - self._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), - use_fp16=torch.cuda.is_available()) - except Exception as e: - self._model = snapshot_download(repo_id=model_name, - local_dir=os.path.join(get_home_cache_dir(), - re.sub(r"^[a-zA-Z]+/", "", model_name)), - local_dir_use_symlinks=False) - self._model = FlagReranker(os.path.join(get_home_cache_dir(), model_name), - use_fp16=torch.cuda.is_available()) + with DefaultRerank._model_lock: + if not DefaultRerank._model: + try: + DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), use_fp16=torch.cuda.is_available()) + except Exception as e: + model_dir = snapshot_download(repo_id= model_name, + local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), + local_dir_use_symlinks=False) + DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) + self._model = DefaultRerank._model def similarity(self, query: str, texts: list): pairs = [(query,truncate(t, 2048)) for t in texts]