mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Update embedding_model.py (#9083)
### What problem does this PR solve? Reduce the logic scope for DefaultEmbedding ### Type of change - [x] Refactoring
This commit is contained in:
@ -60,7 +60,6 @@ class Base(ABC):
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
_model = None
|
||||
_model_name = ""
|
||||
_model_lock = threading.Lock()
|
||||
@ -78,9 +77,13 @@ class DefaultEmbedding(Base):
|
||||
|
||||
"""
|
||||
if not settings.LIGHTEN:
|
||||
input_cuda_visible_devices = None
|
||||
with DefaultEmbedding._model_lock:
|
||||
import torch
|
||||
from FlagEmbedding import FlagModel
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
@ -95,6 +98,10 @@ class DefaultEmbedding(Base):
|
||||
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
||||
finally:
|
||||
if input_cuda_visible_devices:
|
||||
# restore CUDA_VISIBLE_DEVICES
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = DefaultEmbedding._model_name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user