mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Fix: embedding model failure in CometAPI (#10137)
### What problem does this PR solve? Related PR: Feat: add CometAPI to LLMFactory and update related mappings #10119 Change: Fixes the issue where the embedding model in CometAPI was not being called correctly ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: TensorNull <tensor.null@gmail.com>
This commit is contained in:
@ -86,9 +86,10 @@ class DefaultEmbedding(Base):
|
|||||||
with DefaultEmbedding._model_lock:
|
with DefaultEmbedding._model_lock:
|
||||||
import torch
|
import torch
|
||||||
from FlagEmbedding import FlagModel
|
from FlagEmbedding import FlagModel
|
||||||
|
|
||||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
|
||||||
|
|
||||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||||
try:
|
try:
|
||||||
@ -472,6 +473,7 @@ class MistralEmbed(Base):
|
|||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
|
||||||
texts = [truncate(t, 8196) for t in texts]
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
ress = []
|
ress = []
|
||||||
@ -495,6 +497,7 @@ class MistralEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
|
||||||
retry_max = 5
|
retry_max = 5
|
||||||
while retry_max > 0:
|
while retry_max > 0:
|
||||||
try:
|
try:
|
||||||
@ -942,6 +945,7 @@ class GiteeEmbed(SILICONFLOWEmbed):
|
|||||||
base_url = "https://ai.gitee.com/v1/embeddings"
|
base_url = "https://ai.gitee.com/v1/embeddings"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
class DeepInfraEmbed(OpenAIEmbed):
|
class DeepInfraEmbed(OpenAIEmbed):
|
||||||
_FACTORY_NAME = "DeepInfra"
|
_FACTORY_NAME = "DeepInfra"
|
||||||
|
|
||||||
@ -960,10 +964,10 @@ class Ai302Embed(Base):
|
|||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
class CometEmbed(Base):
|
class CometEmbed(OpenAIEmbed):
|
||||||
_FACTORY_NAME = "CometAPI"
|
_FACTORY_NAME = "CometAPI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1/embeddings"):
|
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.cometapi.com/v1/embeddings"
|
base_url = "https://api.cometapi.com/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|||||||
Reference in New Issue
Block a user