diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 7e8494b96..d8de3e0de 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -60,7 +60,6 @@ class Base(ABC): class DefaultEmbedding(Base): _FACTORY_NAME = "BAAI" - os.environ["CUDA_VISIBLE_DEVICES"] = "0" _model = None _model_name = "" _model_lock = threading.Lock() @@ -78,9 +77,13 @@ class DefaultEmbedding(Base): """ if not settings.LIGHTEN: + input_cuda_visible_devices = None with DefaultEmbedding._model_lock: import torch from FlagEmbedding import FlagModel + if "CUDA_VISIBLE_DEVICES" in os.environ: + input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] + os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: try: @@ -95,6 +98,10 @@ class DefaultEmbedding(Base): repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False ) DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) + finally: + if input_cuda_visible_devices: + # restore CUDA_VISIBLE_DEVICES + os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices self._model = DefaultEmbedding._model self._model_name = DefaultEmbedding._model_name