mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-23 06:46:40 +08:00
Update embedding_model.py (#9083)
### What problem does this PR solve? Reduce the logic scope for DefaultEmbedding ### Type of change - [x] Refactoring
This commit is contained in:
@ -60,7 +60,6 @@ class Base(ABC):
|
|||||||
|
|
||||||
class DefaultEmbedding(Base):
|
class DefaultEmbedding(Base):
|
||||||
_FACTORY_NAME = "BAAI"
|
_FACTORY_NAME = "BAAI"
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
_model = None
|
_model = None
|
||||||
_model_name = ""
|
_model_name = ""
|
||||||
_model_lock = threading.Lock()
|
_model_lock = threading.Lock()
|
||||||
@ -78,9 +77,13 @@ class DefaultEmbedding(Base):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if not settings.LIGHTEN:
|
if not settings.LIGHTEN:
|
||||||
|
input_cuda_visible_devices = None
|
||||||
with DefaultEmbedding._model_lock:
|
with DefaultEmbedding._model_lock:
|
||||||
import torch
|
import torch
|
||||||
from FlagEmbedding import FlagModel
|
from FlagEmbedding import FlagModel
|
||||||
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
|
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
|
||||||
|
|
||||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||||
try:
|
try:
|
||||||
@ -95,6 +98,10 @@ class DefaultEmbedding(Base):
|
|||||||
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||||
)
|
)
|
||||||
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
||||||
|
finally:
|
||||||
|
if input_cuda_visible_devices:
|
||||||
|
# restore CUDA_VISIBLE_DEVICES
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = input_cuda_visible_devices
|
||||||
self._model = DefaultEmbedding._model
|
self._model = DefaultEmbedding._model
|
||||||
self._model_name = DefaultEmbedding._model_name
|
self._model_name = DefaultEmbedding._model_name
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user