Refa: automatic LLMs registration (#8651)

### What problem does this PR solve?

Support automatic LLMs registration.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-07-03 19:05:31 +08:00
committed by GitHub
parent 3234a15aae
commit f8a6987f1e
7 changed files with 619 additions and 876 deletions

View File

@ -13,28 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
import re
import threading
from abc import ABC
from urllib.parse import urljoin
import dashscope
import google.generativeai as genai
import numpy as np
import requests
from huggingface_hub import snapshot_download
from zhipuai import ZhipuAI
import os
from abc import ABC
from ollama import Client
import dashscope
from openai import OpenAI
import numpy as np
import asyncio
from zhipuai import ZhipuAI
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
import json
class Base(ABC):
@ -60,7 +59,8 @@ class Base(ABC):
class DefaultEmbedding(Base):
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
_FACTORY_NAME = "BAAI"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
_model = None
_model_name = ""
_model_lock = threading.Lock()
@ -79,21 +79,22 @@ class DefaultEmbedding(Base):
"""
if not settings.LIGHTEN:
with DefaultEmbedding._model_lock:
from FlagEmbedding import FlagModel
import torch
from FlagEmbedding import FlagModel
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
DefaultEmbedding._model = FlagModel(
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available(),
)
DefaultEmbedding._model_name = model_name
except Exception:
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
local_dir_use_symlinks=False)
DefaultEmbedding._model = FlagModel(model_dir,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
model_dir = snapshot_download(
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
self._model = DefaultEmbedding._model
self._model_name = DefaultEmbedding._model_name
@ -105,7 +106,7 @@ class DefaultEmbedding(Base):
token_count += num_tokens_from_string(t)
ress = []
for i in range(0, len(texts), batch_size):
ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
ress.extend(self._model.encode(texts[i : i + batch_size]).tolist())
return np.array(ress), token_count
def encode_queries(self, text: str):
@ -114,8 +115,9 @@ class DefaultEmbedding(Base):
class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002",
base_url="https://api.openai.com/v1"):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
@ -128,8 +130,7 @@ class OpenAIEmbed(Base):
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size],
model=self.model_name)
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
@ -138,12 +139,13 @@ class OpenAIEmbed(Base):
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)],
model=self.model_name)
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)
class LocalAIEmbed(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local embedding model url cannot be None")
@ -155,7 +157,7 @@ class LocalAIEmbed(Base):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
except Exception as _e:
@ -169,41 +171,42 @@ class LocalAIEmbed(Base):
class AzureEmbed(OpenAIEmbed):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01')
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
class BaiChuanEmbed(OpenAIEmbed):
def __init__(self, key,
model_name='Baichuan-Text-Embedding',
base_url='https://api.baichuan-ai.com/v1'):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
class QWenEmbed(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
self.key = key
self.model_name = model_name
def encode(self, texts: list):
import dashscope
batch_size = 4
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=texts[i:i + batch_size],
api_key=self.key,
text_type="document"
)
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
@ -216,20 +219,16 @@ class QWenEmbed(Base):
return np.array(res), token_count
def encode_queries(self, text):
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=text[:2048],
api_key=self.key,
text_type="query"
)
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
try:
return np.array(resp["output"]["embeddings"][0]
["embedding"]), self.total_token_count(resp)
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)
class ZhipuEmbed(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
@ -246,8 +245,7 @@ class ZhipuEmbed(Base):
texts = [truncate(t, MAX_LEN) for t in texts]
for txt in texts:
res = self.client.embeddings.create(input=txt,
model=self.model_name)
res = self.client.embeddings.create(input=txt, model=self.model_name)
try:
arr.append(res.data[0].embedding)
tks_num += self.total_token_count(res)
@ -256,8 +254,7 @@ class ZhipuEmbed(Base):
return np.array(arr), tks_num
def encode_queries(self, text):
res = self.client.embeddings.create(input=text,
model=self.model_name)
res = self.client.embeddings.create(input=text, model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@ -265,18 +262,17 @@ class ZhipuEmbed(Base):
class OllamaEmbed(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.model_name = model_name
def encode(self, texts: list):
arr = []
tks_num = 0
for txt in texts:
res = self.client.embeddings(prompt=txt,
model=self.model_name,
options={"use_mmap": True})
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
try:
arr.append(res["embedding"])
except Exception as _e:
@ -285,9 +281,7 @@ class OllamaEmbed(Base):
return np.array(arr), tks_num
def encode_queries(self, text):
res = self.client.embeddings(prompt=text,
model=self.model_name,
options={"use_mmap": True})
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
try:
return np.array(res["embedding"]), 128
except Exception as _e:
@ -295,27 +289,28 @@ class OllamaEmbed(Base):
class FastEmbed(DefaultEmbedding):
_FACTORY_NAME = "FastEmbed"
def __init__(
self,
key: str | None = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str | None = None,
threads: int | None = None,
**kwargs,
self,
key: str | None = None,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: str | None = None,
threads: int | None = None,
**kwargs,
):
if not settings.LIGHTEN:
with FastEmbed._model_lock:
from fastembed import TextEmbedding
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
DefaultEmbedding._model_name = model_name
except Exception:
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
local_dir=os.path.join(get_home_cache_dir(),
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
local_dir_use_symlinks=False)
cache_dir = snapshot_download(
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
)
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
self._model = DefaultEmbedding._model
self._model_name = model_name
@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding):
class XinferenceEmbed(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
@ -350,7 +347,7 @@ class XinferenceEmbed(Base):
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
@ -359,8 +356,7 @@ class XinferenceEmbed(Base):
return np.array(ress), total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[text],
model=self.model_name)
res = self.client.embeddings.create(input=[text], model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@ -368,20 +364,18 @@ class XinferenceEmbed(Base):
class YoudaoEmbed(Base):
_FACTORY_NAME = "Youdao"
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing
try:
logging.info("LOADING BCE...")
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
get_home_cache_dir(),
"bce-embedding-base_v1"))
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
except Exception:
YoudaoEmbed._client = qanthing(
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
def encode(self, texts: list):
batch_size = 10
@ -390,7 +384,7 @@ class YoudaoEmbed(Base):
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
res.extend(embds)
return np.array(res), token_count
@ -400,14 +394,11 @@ class YoudaoEmbed(Base):
class JinaEmbed(Base):
def __init__(self, key, model_name="jina-embeddings-v3",
base_url="https://api.jina.ai/v1/embeddings"):
_FACTORY_NAME = "Jina"
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings"
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def encode(self, texts: list):
@ -416,11 +407,7 @@ class JinaEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
data = {
"model": self.model_name,
"input": texts[i:i + batch_size],
'encoding_type': 'float'
}
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
response = requests.post(self.base_url, headers=self.headers, json=data)
try:
res = response.json()
@ -435,50 +422,12 @@ class JinaEmbed(Base):
return np.array(embds[0]), cnt
class InfinityEmbed(Base):
_model = None
def __init__(
self,
model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
engine_kwargs: dict = {},
key = None,
):
from infinity_emb import EngineArgs
from infinity_emb.engine import AsyncEngineArray
self._default_model = model_names[0]
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
async def _embed(self, sentences: list[str], model_name: str = ""):
if not model_name:
model_name = self._default_model
engine = self.engine_array[model_name]
was_already_running = engine.is_running
if not was_already_running:
await engine.astart()
embeddings, usage = await engine.embed(sentences=sentences)
if not was_already_running:
await engine.astop()
return embeddings, usage
def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
embeddings, usage = asyncio.run(self._embed(texts, model_name))
return np.array(embeddings), usage
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
return self.encode([text])
class MistralEmbed(Base):
def __init__(self, key, model_name="mistral-embed",
base_url=None):
_FACTORY_NAME = "Mistral"
def __init__(self, key, model_name="mistral-embed", base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
@ -488,8 +437,7 @@ class MistralEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
res = self.client.embeddings(input=texts[i:i + batch_size],
model=self.model_name)
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
token_count += self.total_token_count(res)
@ -498,8 +446,7 @@ class MistralEmbed(Base):
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embeddings(input=[truncate(text, 8196)],
model=self.model_name)
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@ -507,30 +454,31 @@ class MistralEmbed(Base):
class BedrockEmbed(Base):
def __init__(self, key, model_name,
**kwargs):
_FACTORY_NAME = "Bedrock"
def __init__(self, key, model_name, **kwargs):
import boto3
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
self.bedrock_region = json.loads(key).get('bedrock_region', '')
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client('bedrock-runtime')
self.client = boto3.client("bedrock-runtime")
else:
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]
embeddings = []
token_count = 0
for text in texts:
if self.model_name.split('.')[0] == 'amazon':
if self.model_name.split(".")[0] == "amazon":
body = {"inputText": text}
elif self.model_name.split('.')[0] == 'cohere':
body = {"texts": [text], "input_type": 'search_document'}
elif self.model_name.split(".")[0] == "cohere":
body = {"texts": [text], "input_type": "search_document"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
@ -545,10 +493,10 @@ class BedrockEmbed(Base):
def encode_queries(self, text):
embeddings = []
token_count = num_tokens_from_string(text)
if self.model_name.split('.')[0] == 'amazon':
if self.model_name.split(".")[0] == "amazon":
body = {"inputText": truncate(text, 8196)}
elif self.model_name.split('.')[0] == 'cohere':
body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
elif self.model_name.split(".")[0] == "cohere":
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
@ -561,11 +509,12 @@ class BedrockEmbed(Base):
class GeminiEmbed(Base):
def __init__(self, key, model_name='models/text-embedding-004',
**kwargs):
_FACTORY_NAME = "Gemini"
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
self.key = key
self.model_name = 'models/' + model_name
self.model_name = "models/" + model_name
def encode(self, texts: list):
texts = [truncate(t, 2048) for t in texts]
token_count = sum(num_tokens_from_string(text) for text in texts)
@ -573,35 +522,27 @@ class GeminiEmbed(Base):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
result = genai.embed_content(
model=self.model_name,
content=texts[i: i + batch_size],
task_type="retrieval_document",
title="Embedding of single string")
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
try:
ress.extend(result['embedding'])
ress.extend(result["embedding"])
except Exception as _e:
log_exception(_e, result)
return np.array(ress),token_count
return np.array(ress), token_count
def encode_queries(self, text):
genai.configure(api_key=self.key)
result = genai.embed_content(
model=self.model_name,
content=truncate(text,2048),
task_type="retrieval_document",
title="Embedding of single string")
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
token_count = num_tokens_from_string(text)
try:
return np.array(result['embedding']), token_count
return np.array(result["embedding"]), token_count
except Exception as _e:
log_exception(_e, result)
class NvidiaEmbed(Base):
def __init__(
self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key
@ -645,6 +586,8 @@ class NvidiaEmbed(Base):
class LmStudioEmbed(LocalAIEmbed):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local llm url cannot be None")
@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed):
class OpenAI_APIEmbed(OpenAIEmbed):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed):
class CoHereEmbed(Base):
_FACTORY_NAME = "Cohere"
def __init__(self, key, model_name, base_url=None):
from cohere import Client
@ -701,6 +648,8 @@ class CoHereEmbed(Base):
class TogetherAIEmbed(OpenAIEmbed):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed):
class PerfXCloudEmbed(OpenAIEmbed):
_FACTORY_NAME = "PerfXCloud"
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
base_url = "https://cloud.perfxlab.cn/v1"
@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed):
class UpstageEmbed(OpenAIEmbed):
_FACTORY_NAME = "Upstage"
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed):
class SILICONFLOWEmbed(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/embeddings"
@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base):
class ReplicateEmbed(Base):
_FACTORY_NAME = "Replicate"
def __init__(self, key, model_name, base_url=None):
from replicate.client import Client
@ -790,6 +747,8 @@ class ReplicateEmbed(Base):
class BaiduYiyanEmbed(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None):
import qianfan
@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base):
class VoyageEmbed(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None):
import voyageai
@ -832,9 +793,7 @@ class VoyageEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
res = self.client.embed(
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
)
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
try:
ress.extend(res.embeddings)
token_count += res.total_tokens
@ -843,9 +802,7 @@ class VoyageEmbed(Base):
return np.array(ress), token_count
def encode_queries(self, text):
res = self.client.embed(
texts=text, model=self.model_name, input_type="query"
)
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
try:
return np.array(res.embeddings)[0], res.total_tokens
except Exception as _e:
@ -853,6 +810,8 @@ class VoyageEmbed(Base):
class HuggingFaceEmbed(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key, model_name, base_url=None):
if not model_name:
raise ValueError("Model name cannot be None")
@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base):
def encode(self, texts: list):
embeddings = []
for text in texts:
response = requests.post(
f"{self.base_url}/embed",
json={"inputs": text},
headers={'Content-Type': 'application/json'}
)
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
embeddings.append(embedding[0])
@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base):
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text):
response = requests.post(
f"{self.base_url}/embed",
json={"inputs": text},
headers={'Content-Type': 'application/json'}
)
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
return np.array(embedding[0]), num_tokens_from_string(text)
@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base):
class VolcEngineEmbed(OpenAIEmbed):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get('ark_api_key', '')
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
super().__init__(ark_api_key,model_name,base_url)
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url)
class GPUStackEmbed(OpenAIEmbed):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed):
class NovitaEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/embeddings"
@ -915,7 +872,9 @@ class NovitaEmbed(SILICONFLOWEmbed):
class GiteeEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
if not base_url:
base_url = "https://ai.gitee.com/v1/embeddings"
super().__init__(key, model_name, base_url)
super().__init__(key, model_name, base_url)