From 91b609447da15e56a23ff2add95395df5ace244c Mon Sep 17 00:00:00 2001 From: buua436 <66937541+buua436@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:49:47 +0800 Subject: [PATCH] Fix: embedding model failure in CometAPI (#10137) ### What problem does this PR solve? Related PR: Feat: add CometAPI to LLMFactory and update related mappings #10119 Change: Fixes the issue where the embedding model in CometAPI was not being called correctly ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: TensorNull --- rag/llm/embedding_model.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 904a45d9d..2e78b3dbc 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -86,9 +86,10 @@ class DefaultEmbedding(Base): with DefaultEmbedding._model_lock: import torch from FlagEmbedding import FlagModel + if "CUDA_VISIBLE_DEVICES" in os.environ: input_cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] - os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model + os.environ["CUDA_VISIBLE_DEVICES"] = "0" # handle some issues with multiple GPUs when initializing the model if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: try: @@ -472,6 +473,7 @@ class MistralEmbed(Base): def encode(self, texts: list): import time import random + texts = [truncate(t, 8196) for t in texts] batch_size = 16 ress = [] @@ -495,6 +497,7 @@ class MistralEmbed(Base): def encode_queries(self, text): import time import random + retry_max = 5 while retry_max > 0: try: @@ -942,6 +945,7 @@ class GiteeEmbed(SILICONFLOWEmbed): base_url = "https://ai.gitee.com/v1/embeddings" super().__init__(key, model_name, base_url) + class DeepInfraEmbed(OpenAIEmbed): _FACTORY_NAME = "DeepInfra" @@ -960,10 +964,10 @@ class Ai302Embed(Base): super().__init__(key, model_name, base_url) -class CometEmbed(Base): +class CometEmbed(OpenAIEmbed): _FACTORY_NAME = "CometAPI" - def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1/embeddings"): + def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1"): if not base_url: - base_url = "https://api.cometapi.com/v1/embeddings" + base_url = "https://api.cometapi.com/v1" super().__init__(key, model_name, base_url)