mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: automatic LLMs registration (#8651)
### What problem does this PR solve? Support automatic LLMs registration. ### Type of change - [x] Refactoring
This commit is contained in:
@ -13,28 +13,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from zhipuai import ZhipuAI
|
||||
import os
|
||||
from abc import ABC
|
||||
from ollama import Client
|
||||
import dashscope
|
||||
from openai import OpenAI
|
||||
import numpy as np
|
||||
import asyncio
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import google.generativeai as genai
|
||||
import json
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -60,7 +59,8 @@ class Base(ABC):
|
||||
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
_FACTORY_NAME = "BAAI"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
_model = None
|
||||
_model_name = ""
|
||||
_model_lock = threading.Lock()
|
||||
@ -79,21 +79,22 @@ class DefaultEmbedding(Base):
|
||||
"""
|
||||
if not settings.LIGHTEN:
|
||||
with DefaultEmbedding._model_lock:
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
from FlagEmbedding import FlagModel
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
DefaultEmbedding._model = FlagModel(
|
||||
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available(),
|
||||
)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
DefaultEmbedding._model = FlagModel(model_dir,
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
model_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = DefaultEmbedding._model_name
|
||||
|
||||
@ -105,7 +106,7 @@ class DefaultEmbedding(Base):
|
||||
token_count += num_tokens_from_string(t)
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
||||
ress.extend(self._model.encode(texts[i : i + batch_size]).tolist())
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
@ -114,8 +115,9 @@ class DefaultEmbedding(Base):
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
def __init__(self, key, model_name="text-embedding-ada-002",
|
||||
base_url="https://api.openai.com/v1"):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
@ -128,8 +130,7 @@ class OpenAIEmbed(Base):
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i:i + batch_size],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
@ -138,12 +139,13 @@ class OpenAIEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
|
||||
|
||||
class LocalAIEmbed(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local embedding model url cannot be None")
|
||||
@ -155,7 +157,7 @@ class LocalAIEmbed(Base):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
except Exception as _e:
|
||||
@ -169,41 +171,42 @@ class LocalAIEmbed(Base):
|
||||
|
||||
|
||||
class AzureEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
api_key = json.loads(key).get('api_key', '')
|
||||
api_version = json.loads(key).get('api_version', '2024-02-01')
|
||||
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class BaiChuanEmbed(OpenAIEmbed):
|
||||
def __init__(self, key,
|
||||
model_name='Baichuan-Text-Embedding',
|
||||
base_url='https://api.baichuan-ai.com/v1'):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class QWenEmbed(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
import dashscope
|
||||
|
||||
batch_size = 4
|
||||
res = []
|
||||
token_count = 0
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
for i in range(0, len(texts), batch_size):
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=self.model_name,
|
||||
input=texts[i:i + batch_size],
|
||||
api_key=self.key,
|
||||
text_type="document"
|
||||
)
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
try:
|
||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||
for e in resp["output"]["embeddings"]:
|
||||
@ -216,20 +219,16 @@ class QWenEmbed(Base):
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=self.model_name,
|
||||
input=text[:2048],
|
||||
api_key=self.key,
|
||||
text_type="query"
|
||||
)
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
||||
try:
|
||||
return np.array(resp["output"]["embeddings"][0]
|
||||
["embedding"]), self.total_token_count(resp)
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
|
||||
|
||||
class ZhipuEmbed(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="embedding-2", **kwargs):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
@ -246,8 +245,7 @@ class ZhipuEmbed(Base):
|
||||
texts = [truncate(t, MAX_LEN) for t in texts]
|
||||
|
||||
for txt in texts:
|
||||
res = self.client.embeddings.create(input=txt,
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
||||
try:
|
||||
arr.append(res.data[0].embedding)
|
||||
tks_num += self.total_token_count(res)
|
||||
@ -256,8 +254,7 @@ class ZhipuEmbed(Base):
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=text,
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
@ -265,18 +262,17 @@ class ZhipuEmbed(Base):
|
||||
|
||||
|
||||
class OllamaEmbed(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
|
||||
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
res = self.client.embeddings(prompt=txt,
|
||||
model=self.model_name,
|
||||
options={"use_mmap": True})
|
||||
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
|
||||
try:
|
||||
arr.append(res["embedding"])
|
||||
except Exception as _e:
|
||||
@ -285,9 +281,7 @@ class OllamaEmbed(Base):
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings(prompt=text,
|
||||
model=self.model_name,
|
||||
options={"use_mmap": True})
|
||||
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
|
||||
try:
|
||||
return np.array(res["embedding"]), 128
|
||||
except Exception as _e:
|
||||
@ -295,27 +289,28 @@ class OllamaEmbed(Base):
|
||||
|
||||
|
||||
class FastEmbed(DefaultEmbedding):
|
||||
|
||||
_FACTORY_NAME = "FastEmbed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str | None = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
**kwargs,
|
||||
self,
|
||||
key: str | None = None,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not settings.LIGHTEN:
|
||||
with FastEmbed._model_lock:
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
|
||||
local_dir=os.path.join(get_home_cache_dir(),
|
||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
cache_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = model_name
|
||||
@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding):
|
||||
|
||||
|
||||
class XinferenceEmbed(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", base_url=""):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
@ -350,7 +347,7 @@ class XinferenceEmbed(Base):
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
@ -359,8 +356,7 @@ class XinferenceEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[text],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
@ -368,20 +364,18 @@ class XinferenceEmbed(Base):
|
||||
|
||||
|
||||
class YoudaoEmbed(Base):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_client = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||
from BCEmbedding import EmbeddingModel as qanthing
|
||||
|
||||
try:
|
||||
logging.info("LOADING BCE...")
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
|
||||
get_home_cache_dir(),
|
||||
"bce-embedding-base_v1"))
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
|
||||
except Exception:
|
||||
YoudaoEmbed._client = qanthing(
|
||||
model_name_or_path=model_name.replace(
|
||||
"maidalun1020", "InfiniFlow"))
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 10
|
||||
@ -390,7 +384,7 @@ class YoudaoEmbed(Base):
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
for i in range(0, len(texts), batch_size):
|
||||
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
|
||||
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
|
||||
res.extend(embds)
|
||||
return np.array(res), token_count
|
||||
|
||||
@ -400,14 +394,11 @@ class YoudaoEmbed(Base):
|
||||
|
||||
|
||||
class JinaEmbed(Base):
|
||||
def __init__(self, key, model_name="jina-embeddings-v3",
|
||||
base_url="https://api.jina.ai/v1/embeddings"):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
|
||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
@ -416,11 +407,7 @@ class JinaEmbed(Base):
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"input": texts[i:i + batch_size],
|
||||
'encoding_type': 'float'
|
||||
}
|
||||
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
@ -435,50 +422,12 @@ class JinaEmbed(Base):
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class InfinityEmbed(Base):
|
||||
_model = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
|
||||
engine_kwargs: dict = {},
|
||||
key = None,
|
||||
):
|
||||
|
||||
from infinity_emb import EngineArgs
|
||||
from infinity_emb.engine import AsyncEngineArray
|
||||
|
||||
self._default_model = model_names[0]
|
||||
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
|
||||
|
||||
async def _embed(self, sentences: list[str], model_name: str = ""):
|
||||
if not model_name:
|
||||
model_name = self._default_model
|
||||
engine = self.engine_array[model_name]
|
||||
was_already_running = engine.is_running
|
||||
if not was_already_running:
|
||||
await engine.astart()
|
||||
embeddings, usage = await engine.embed(sentences=sentences)
|
||||
if not was_already_running:
|
||||
await engine.astop()
|
||||
return embeddings, usage
|
||||
|
||||
def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
embeddings, usage = asyncio.run(self._embed(texts, model_name))
|
||||
return np.array(embeddings), usage
|
||||
|
||||
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
return self.encode([text])
|
||||
|
||||
|
||||
class MistralEmbed(Base):
|
||||
def __init__(self, key, model_name="mistral-embed",
|
||||
base_url=None):
|
||||
_FACTORY_NAME = "Mistral"
|
||||
|
||||
def __init__(self, key, model_name="mistral-embed", base_url=None):
|
||||
from mistralai.client import MistralClient
|
||||
|
||||
self.client = MistralClient(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
@ -488,8 +437,7 @@ class MistralEmbed(Base):
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings(input=texts[i:i + batch_size],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
token_count += self.total_token_count(res)
|
||||
@ -498,8 +446,7 @@ class MistralEmbed(Base):
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings(input=[truncate(text, 8196)],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
@ -507,30 +454,31 @@ class MistralEmbed(Base):
|
||||
|
||||
|
||||
class BedrockEmbed(Base):
|
||||
def __init__(self, key, model_name,
|
||||
**kwargs):
|
||||
_FACTORY_NAME = "Bedrock"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
import boto3
|
||||
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
|
||||
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
|
||||
self.bedrock_region = json.loads(key).get('bedrock_region', '')
|
||||
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
self.model_name = model_name
|
||||
|
||||
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
|
||||
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||
self.client = boto3.client('bedrock-runtime')
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
else:
|
||||
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
||||
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
embeddings = []
|
||||
token_count = 0
|
||||
for text in texts:
|
||||
if self.model_name.split('.')[0] == 'amazon':
|
||||
if self.model_name.split(".")[0] == "amazon":
|
||||
body = {"inputText": text}
|
||||
elif self.model_name.split('.')[0] == 'cohere':
|
||||
body = {"texts": [text], "input_type": 'search_document'}
|
||||
elif self.model_name.split(".")[0] == "cohere":
|
||||
body = {"texts": [text], "input_type": "search_document"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
@ -545,10 +493,10 @@ class BedrockEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
embeddings = []
|
||||
token_count = num_tokens_from_string(text)
|
||||
if self.model_name.split('.')[0] == 'amazon':
|
||||
if self.model_name.split(".")[0] == "amazon":
|
||||
body = {"inputText": truncate(text, 8196)}
|
||||
elif self.model_name.split('.')[0] == 'cohere':
|
||||
body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
|
||||
elif self.model_name.split(".")[0] == "cohere":
|
||||
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
@ -561,11 +509,12 @@ class BedrockEmbed(Base):
|
||||
|
||||
|
||||
class GeminiEmbed(Base):
|
||||
def __init__(self, key, model_name='models/text-embedding-004',
|
||||
**kwargs):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = 'models/' + model_name
|
||||
|
||||
self.model_name = "models/" + model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
token_count = sum(num_tokens_from_string(text) for text in texts)
|
||||
@ -573,35 +522,27 @@ class GeminiEmbed(Base):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
result = genai.embed_content(
|
||||
model=self.model_name,
|
||||
content=texts[i: i + batch_size],
|
||||
task_type="retrieval_document",
|
||||
title="Embedding of single string")
|
||||
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
|
||||
try:
|
||||
ress.extend(result['embedding'])
|
||||
ress.extend(result["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
return np.array(ress),token_count
|
||||
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
genai.configure(api_key=self.key)
|
||||
result = genai.embed_content(
|
||||
model=self.model_name,
|
||||
content=truncate(text,2048),
|
||||
task_type="retrieval_document",
|
||||
title="Embedding of single string")
|
||||
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
|
||||
token_count = num_tokens_from_string(text)
|
||||
try:
|
||||
return np.array(result['embedding']), token_count
|
||||
return np.array(result["embedding"]), token_count
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
|
||||
|
||||
class NvidiaEmbed(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
|
||||
):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
||||
self.api_key = key
|
||||
@ -645,6 +586,8 @@ class NvidiaEmbed(Base):
|
||||
|
||||
|
||||
class LmStudioEmbed(LocalAIEmbed):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed):
|
||||
|
||||
|
||||
class OpenAI_APIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class CoHereEmbed(Base):
|
||||
_FACTORY_NAME = "Cohere"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from cohere import Client
|
||||
|
||||
@ -701,6 +648,8 @@ class CoHereEmbed(Base):
|
||||
|
||||
|
||||
class TogetherAIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class PerfXCloudEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "PerfXCloud"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://cloud.perfxlab.cn/v1"
|
||||
@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class UpstageEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Upstage"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
|
||||
if not base_url:
|
||||
base_url = "https://api.upstage.ai/v1/solar"
|
||||
@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class SILICONFLOWEmbed(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base):
|
||||
|
||||
|
||||
class ReplicateEmbed(Base):
|
||||
_FACTORY_NAME = "Replicate"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from replicate.client import Client
|
||||
|
||||
@ -790,6 +747,8 @@ class ReplicateEmbed(Base):
|
||||
|
||||
|
||||
class BaiduYiyanEmbed(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import qianfan
|
||||
|
||||
@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base):
|
||||
|
||||
|
||||
class VoyageEmbed(Base):
|
||||
_FACTORY_NAME = "Voyage AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import voyageai
|
||||
|
||||
@ -832,9 +793,7 @@ class VoyageEmbed(Base):
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embed(
|
||||
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
|
||||
)
|
||||
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
|
||||
try:
|
||||
ress.extend(res.embeddings)
|
||||
token_count += res.total_tokens
|
||||
@ -843,9 +802,7 @@ class VoyageEmbed(Base):
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(
|
||||
texts=text, model=self.model_name, input_type="query"
|
||||
)
|
||||
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
||||
try:
|
||||
return np.array(res.embeddings)[0], res.total_tokens
|
||||
except Exception as _e:
|
||||
@ -853,6 +810,8 @@ class VoyageEmbed(Base):
|
||||
|
||||
|
||||
class HuggingFaceEmbed(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
if not model_name:
|
||||
raise ValueError("Model name cannot be None")
|
||||
@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base):
|
||||
def encode(self, texts: list):
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/embed",
|
||||
json={"inputs": text},
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
embeddings.append(embedding[0])
|
||||
@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base):
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text):
|
||||
response = requests.post(
|
||||
f"{self.base_url}/embed",
|
||||
json={"inputs": text},
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
return np.array(embedding[0]), num_tokens_from_string(text)
|
||||
@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base):
|
||||
|
||||
|
||||
class VolcEngineEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||
if not base_url:
|
||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get('ark_api_key', '')
|
||||
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
|
||||
super().__init__(ark_api_key,model_name,base_url)
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url)
|
||||
|
||||
|
||||
class GPUStackEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class NovitaEmbed(SILICONFLOWEmbed):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai/embeddings"
|
||||
@ -915,7 +872,9 @@ class NovitaEmbed(SILICONFLOWEmbed):
|
||||
|
||||
|
||||
class GiteeEmbed(SILICONFLOWEmbed):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
Reference in New Issue
Block a user