Fix: embedding model failure in CometAPI (#10137)

### What problem does this PR solve?

Related PR:
Feat: add CometAPI to LLMFactory and update related mappings #10119 

Change:
Fixes the issue where the embedding model in CometAPI was not being
called correctly

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: TensorNull <tensor.null@gmail.com>
This commit is contained in:
buua436
2025-09-18 14:49:47 +08:00
committed by GitHub
parent c353840244
commit 91b609447d

View File

@ -86,9 +86,10 @@ class DefaultEmbedding(Base):
with DefaultEmbedding._model_lock:
import torch
from FlagEmbedding import FlagModel
if "CUDA_VISIBLE_DEVICES" in os.environ:
input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
@ -472,6 +473,7 @@ class MistralEmbed(Base):
def encode(self, texts: list):
import time
import random
texts = [truncate(t, 8196) for t in texts]
batch_size = 16
ress = []
@ -495,6 +497,7 @@ class MistralEmbed(Base):
def encode_queries(self, text):
import time
import random
retry_max = 5
while retry_max > 0:
try:
@ -942,6 +945,7 @@ class GiteeEmbed(SILICONFLOWEmbed):
base_url = "https://ai.gitee.com/v1/embeddings"
super().__init__(key, model_name, base_url)
class DeepInfraEmbed(OpenAIEmbed):
_FACTORY_NAME = "DeepInfra"
@ -960,10 +964,10 @@ class Ai302Embed(Base):
super().__init__(key, model_name, base_url)
class CometEmbed(Base):
class CometEmbed(OpenAIEmbed):
_FACTORY_NAME = "CometAPI"
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1/embeddings"):
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"):
if not base_url:
base_url = "https://api.cometapi.com/v1/embeddings"
base_url = "https://api.cometapi.com/v1"
super().__init__(key, model_name, base_url)