Refa: automatic LLMs registration (#8651)

### What problem does this PR solve?

Support automatic LLMs registration.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-07-03 19:05:31 +08:00
committed by GitHub
parent 3234a15aae
commit f8a6987f1e
7 changed files with 619 additions and 876 deletions

View File

@ -15,289 +15,53 @@
# #
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency! # AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
# #
from .embedding_model import (
OllamaEmbed,
LocalAIEmbed,
OpenAIEmbed,
AzureEmbed,
XinferenceEmbed,
QWenEmbed,
ZhipuEmbed,
FastEmbed,
YoudaoEmbed,
BaiChuanEmbed,
JinaEmbed,
DefaultEmbedding,
MistralEmbed,
BedrockEmbed,
GeminiEmbed,
NvidiaEmbed,
LmStudioEmbed,
OpenAI_APIEmbed,
CoHereEmbed,
TogetherAIEmbed,
PerfXCloudEmbed,
UpstageEmbed,
SILICONFLOWEmbed,
ReplicateEmbed,
BaiduYiyanEmbed,
VoyageEmbed,
HuggingFaceEmbed,
VolcEngineEmbed,
GPUStackEmbed,
NovitaEmbed,
GiteeEmbed
)
from .chat_model import (
GptTurbo,
AzureChat,
ZhipuChat,
QWenChat,
OllamaChat,
LocalAIChat,
XinferenceChat,
MoonshotChat,
DeepSeekChat,
VolcEngineChat,
BaiChuanChat,
MiniMaxChat,
MistralChat,
GeminiChat,
BedrockChat,
GroqChat,
OpenRouterChat,
StepFunChat,
NvidiaChat,
LmStudioChat,
OpenAI_APIChat,
CoHereChat,
LeptonAIChat,
TogetherAIChat,
PerfXCloudChat,
UpstageChat,
NovitaAIChat,
SILICONFLOWChat,
PPIOChat,
YiChat,
ReplicateChat,
HunyuanChat,
SparkChat,
BaiduYiyanChat,
AnthropicChat,
GoogleChat,
HuggingFaceChat,
GPUStackChat,
ModelScopeChat,
GiteeChat
)
from .cv_model import ( import importlib
GptV4, import inspect
AzureGptV4,
OllamaCV,
XinferenceCV,
QWenCV,
Zhipu4V,
LocalCV,
GeminiCV,
OpenRouterCV,
LocalAICV,
NvidiaCV,
LmStudioCV,
StepFunCV,
OpenAI_APICV,
TogetherAICV,
YiCV,
HunyuanCV,
AnthropicCV,
SILICONFLOWCV,
GPUStackCV,
GoogleCV,
)
from .rerank_model import ( ChatModel = globals().get("ChatModel", {})
LocalAIRerank, CvModel = globals().get("CvModel", {})
DefaultRerank, EmbeddingModel = globals().get("EmbeddingModel", {})
JinaRerank, RerankModel = globals().get("RerankModel", {})
YoudaoRerank, Seq2txtModel = globals().get("Seq2txtModel", {})
XInferenceRerank, TTSModel = globals().get("TTSModel", {})
NvidiaRerank,
LmStudioRerank,
OpenAI_APIRerank,
CoHereRerank,
TogetherAIRerank,
SILICONFLOWRerank,
BaiduYiyanRerank,
VoyageRerank,
QWenRerank,
GPUStackRerank,
HuggingfaceRerank,
NovitaRerank,
GiteeRerank
)
from .sequence2txt_model import ( MODULE_MAPPING = {
GPTSeq2txt, "chat_model": ChatModel,
QWenSeq2txt, "cv_model": CvModel,
AzureSeq2txt, "embedding_model": EmbeddingModel,
XinferenceSeq2txt, "rerank_model": RerankModel,
TencentCloudSeq2txt, "sequence2txt_model": Seq2txtModel,
GPUStackSeq2txt, "tts_model": TTSModel,
GiteeSeq2txt
)
from .tts_model import (
FishAudioTTS,
QwenTTS,
OpenAITTS,
SparkTTS,
XinferenceTTS,
GPUStackTTS,
SILICONFLOWTTS,
)
EmbeddingModel = {
"Ollama": OllamaEmbed,
"LocalAI": LocalAIEmbed,
"OpenAI": OpenAIEmbed,
"Azure-OpenAI": AzureEmbed,
"Xinference": XinferenceEmbed,
"Tongyi-Qianwen": QWenEmbed,
"ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed,
"Youdao": YoudaoEmbed,
"BaiChuan": BaiChuanEmbed,
"Jina": JinaEmbed,
"BAAI": DefaultEmbedding,
"Mistral": MistralEmbed,
"Bedrock": BedrockEmbed,
"Gemini": GeminiEmbed,
"NVIDIA": NvidiaEmbed,
"LM-Studio": LmStudioEmbed,
"OpenAI-API-Compatible": OpenAI_APIEmbed,
"VLLM": OpenAI_APIEmbed,
"Cohere": CoHereEmbed,
"TogetherAI": TogetherAIEmbed,
"PerfXCloud": PerfXCloudEmbed,
"Upstage": UpstageEmbed,
"SILICONFLOW": SILICONFLOWEmbed,
"Replicate": ReplicateEmbed,
"BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed,
"HuggingFace": HuggingFaceEmbed,
"VolcEngine": VolcEngineEmbed,
"GPUStack": GPUStackEmbed,
"NovitaAI": NovitaEmbed,
"GiteeAI": GiteeEmbed
} }
CvModel = { package_name = __name__
"OpenAI": GptV4,
"Azure-OpenAI": AzureGptV4,
"Ollama": OllamaCV,
"Xinference": XinferenceCV,
"Tongyi-Qianwen": QWenCV,
"ZHIPU-AI": Zhipu4V,
"Moonshot": LocalCV,
"Gemini": GeminiCV,
"OpenRouter": OpenRouterCV,
"LocalAI": LocalAICV,
"NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV,
"StepFun": StepFunCV,
"OpenAI-API-Compatible": OpenAI_APICV,
"VLLM": OpenAI_APICV,
"TogetherAI": TogetherAICV,
"01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV,
"Anthropic": AnthropicCV,
"SILICONFLOW": SILICONFLOWCV,
"GPUStack": GPUStackCV,
"Google Cloud": GoogleCV
}
ChatModel = { for module_name, mapping_dict in MODULE_MAPPING.items():
"OpenAI": GptTurbo, full_module_name = f"{package_name}.{module_name}"
"Azure-OpenAI": AzureChat, module = importlib.import_module(full_module_name)
"ZHIPU-AI": ZhipuChat,
"Tongyi-Qianwen": QWenChat,
"Ollama": OllamaChat,
"LocalAI": LocalAIChat,
"Xinference": XinferenceChat,
"Moonshot": MoonshotChat,
"DeepSeek": DeepSeekChat,
"VolcEngine": VolcEngineChat,
"BaiChuan": BaiChuanChat,
"MiniMax": MiniMaxChat,
"Mistral": MistralChat,
"Gemini": GeminiChat,
"Bedrock": BedrockChat,
"Groq": GroqChat,
"OpenRouter": OpenRouterChat,
"StepFun": StepFunChat,
"NVIDIA": NvidiaChat,
"LM-Studio": LmStudioChat,
"OpenAI-API-Compatible": OpenAI_APIChat,
"VLLM": OpenAI_APIChat,
"Cohere": CoHereChat,
"LeptonAI": LeptonAIChat,
"TogetherAI": TogetherAIChat,
"PerfXCloud": PerfXCloudChat,
"Upstage": UpstageChat,
"NovitaAI": NovitaAIChat,
"SILICONFLOW": SILICONFLOWChat,
"PPIO": PPIOChat,
"01.AI": YiChat,
"Replicate": ReplicateChat,
"Tencent Hunyuan": HunyuanChat,
"XunFei Spark": SparkChat,
"BaiduYiyan": BaiduYiyanChat,
"Anthropic": AnthropicChat,
"Google Cloud": GoogleChat,
"HuggingFace": HuggingFaceChat,
"GPUStack": GPUStackChat,
"ModelScope":ModelScopeChat,
"GiteeAI": GiteeChat
}
RerankModel = { base_class = None
"LocalAI": LocalAIRerank, for name, obj in inspect.getmembers(module):
"BAAI": DefaultRerank, if inspect.isclass(obj) and name == "Base":
"Jina": JinaRerank, base_class = obj
"Youdao": YoudaoRerank, break
"Xinference": XInferenceRerank, if base_class is None:
"NVIDIA": NvidiaRerank, continue
"LM-Studio": LmStudioRerank,
"OpenAI-API-Compatible": OpenAI_APIRerank,
"VLLM": CoHereRerank,
"Cohere": CoHereRerank,
"TogetherAI": TogetherAIRerank,
"SILICONFLOW": SILICONFLOWRerank,
"BaiduYiyan": BaiduYiyanRerank,
"Voyage AI": VoyageRerank,
"Tongyi-Qianwen": QWenRerank,
"GPUStack": GPUStackRerank,
"HuggingFace": HuggingfaceRerank,
"NovitaAI": NovitaRerank,
"GiteeAI": GiteeRerank
}
Seq2txtModel = { for _, obj in inspect.getmembers(module):
"OpenAI": GPTSeq2txt, if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
"Tongyi-Qianwen": QWenSeq2txt, if isinstance(obj._FACTORY_NAME, list):
"Azure-OpenAI": AzureSeq2txt, for factory_name in obj._FACTORY_NAME:
"Xinference": XinferenceSeq2txt, mapping_dict[factory_name] = obj
"Tencent Cloud": TencentCloudSeq2txt, else:
"GPUStack": GPUStackSeq2txt, mapping_dict[obj._FACTORY_NAME] = obj
"GiteeAI": GiteeSeq2txt
}
TTSModel = { __all__ = [
"Fish Audio": FishAudioTTS, "ChatModel",
"Tongyi-Qianwen": QwenTTS, "CvModel",
"OpenAI": OpenAITTS, "EmbeddingModel",
"XunFei Spark": SparkTTS, "RerankModel",
"Xinference": XinferenceTTS, "Seq2txtModel",
"GPUStack": GPUStackTTS, "TTSModel",
"SILICONFLOW": SILICONFLOWTTS, ]
}

View File

@ -142,11 +142,7 @@ class Base(ABC):
return f"{ERROR_PREFIX}: {error_code} - {str(e)}" return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def _verbose_tool_use(self, name, args, res): def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({ return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
"name": name,
"args": args,
"result": res
}, ensure_ascii=False, indent=2) + "</tool_call>"
def _append_history(self, hist, tool_call, tool_res): def _append_history(self, hist, tool_call, tool_res):
hist.append( hist.append(
@ -430,6 +426,8 @@ class Base(ABC):
class GptTurbo(Base): class GptTurbo(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.openai.com/v1" base_url = "https://api.openai.com/v1"
@ -437,6 +435,8 @@ class GptTurbo(Base):
class MoonshotChat(Base): class MoonshotChat(Base):
_FACTORY_NAME = "Moonshot"
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs): def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.moonshot.cn/v1" base_url = "https://api.moonshot.cn/v1"
@ -444,6 +444,8 @@ class MoonshotChat(Base):
class XinferenceChat(Base): class XinferenceChat(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key=None, model_name="", base_url="", **kwargs): def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
@ -452,6 +454,8 @@ class XinferenceChat(Base):
class HuggingFaceChat(Base): class HuggingFaceChat(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key=None, model_name="", base_url="", **kwargs): def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
@ -460,6 +464,8 @@ class HuggingFaceChat(Base):
class ModelScopeChat(Base): class ModelScopeChat(Base):
_FACTORY_NAME = "ModelScope"
def __init__(self, key=None, model_name="", base_url="", **kwargs): def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
@ -468,6 +474,8 @@ class ModelScopeChat(Base):
class DeepSeekChat(Base): class DeepSeekChat(Base):
_FACTORY_NAME = "DeepSeek"
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs): def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.deepseek.com/v1" base_url = "https://api.deepseek.com/v1"
@ -475,6 +483,8 @@ class DeepSeekChat(Base):
class AzureChat(Base): class AzureChat(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, base_url, **kwargs): def __init__(self, key, model_name, base_url, **kwargs):
api_key = json.loads(key).get("api_key", "") api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01") api_version = json.loads(key).get("api_version", "2024-02-01")
@ -484,6 +494,8 @@ class AzureChat(Base):
class BaiChuanChat(Base): class BaiChuanChat(Base):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs): def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.baichuan-ai.com/v1" base_url = "https://api.baichuan-ai.com/v1"
@ -557,6 +569,8 @@ class BaiChuanChat(Base):
class QWenChat(Base): class QWenChat(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs): def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
if not base_url: if not base_url:
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
@ -565,6 +579,8 @@ class QWenChat(Base):
class ZhipuChat(Base): class ZhipuChat(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs): def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -630,6 +646,8 @@ class ZhipuChat(Base):
class OllamaChat(Base): class OllamaChat(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -694,6 +712,8 @@ class OllamaChat(Base):
class LocalAIChat(Base): class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -752,6 +772,8 @@ class LocalLLM(Base):
class VolcEngineChat(Base): class VolcEngineChat(Base):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs): def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
""" """
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
@ -765,6 +787,8 @@ class VolcEngineChat(Base):
class MiniMaxChat(Base): class MiniMaxChat(Base):
_FACTORY_NAME = "MiniMax"
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -843,6 +867,8 @@ class MiniMaxChat(Base):
class MistralChat(Base): class MistralChat(Base):
_FACTORY_NAME = "Mistral"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -896,6 +922,8 @@ class MistralChat(Base):
class BedrockChat(Base): class BedrockChat(Base):
_FACTORY_NAME = "Bedrock"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -978,6 +1006,8 @@ class BedrockChat(Base):
class GeminiChat(Base): class GeminiChat(Base):
_FACTORY_NAME = "Gemini"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -997,6 +1027,7 @@ class GeminiChat(Base):
def _chat(self, history, gen_conf): def _chat(self, history, gen_conf):
from google.generativeai.types import content_types from google.generativeai.types import content_types
system = history[0]["content"] if history and history[0]["role"] == "system" else "" system = history[0]["content"] if history and history[0]["role"] == "system" else ""
hist = [] hist = []
for item in history: for item in history:
@ -1019,6 +1050,7 @@ class GeminiChat(Base):
def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
from google.generativeai.types import content_types from google.generativeai.types import content_types
gen_conf = self._clean_conf(gen_conf) gen_conf = self._clean_conf(gen_conf)
if system: if system:
self.model._system_instruction = content_types.to_content(system) self.model._system_instruction = content_types.to_content(system)
@ -1042,6 +1074,8 @@ class GeminiChat(Base):
class GroqChat(Base): class GroqChat(Base):
_FACTORY_NAME = "Groq"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -1086,6 +1120,8 @@ class GroqChat(Base):
## openrouter ## openrouter
class OpenRouterChat(Base): class OpenRouterChat(Base):
_FACTORY_NAME = "OpenRouter"
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs): def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://openrouter.ai/api/v1" base_url = "https://openrouter.ai/api/v1"
@ -1093,6 +1129,8 @@ class OpenRouterChat(Base):
class StepFunChat(Base): class StepFunChat(Base):
_FACTORY_NAME = "StepFun"
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs): def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.stepfun.com/v1" base_url = "https://api.stepfun.com/v1"
@ -1100,6 +1138,8 @@ class StepFunChat(Base):
class NvidiaChat(Base): class NvidiaChat(Base):
_FACTORY_NAME = "NVIDIA"
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs): def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://integrate.api.nvidia.com/v1" base_url = "https://integrate.api.nvidia.com/v1"
@ -1107,6 +1147,8 @@ class NvidiaChat(Base):
class LmStudioChat(Base): class LmStudioChat(Base):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url, **kwargs): def __init__(self, key, model_name, base_url, **kwargs):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
@ -1117,6 +1159,8 @@ class LmStudioChat(Base):
class OpenAI_APIChat(Base): class OpenAI_APIChat(Base):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
if not base_url: if not base_url:
raise ValueError("url cannot be None") raise ValueError("url cannot be None")
@ -1125,6 +1169,8 @@ class OpenAI_APIChat(Base):
class PPIOChat(Base): class PPIOChat(Base):
_FACTORY_NAME = "PPIO"
def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs): def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.ppinfra.com/v3/openai" base_url = "https://api.ppinfra.com/v3/openai"
@ -1132,6 +1178,8 @@ class PPIOChat(Base):
class CoHereChat(Base): class CoHereChat(Base):
_FACTORY_NAME = "Cohere"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -1207,6 +1255,8 @@ class CoHereChat(Base):
class LeptonAIChat(Base): class LeptonAIChat(Base):
_FACTORY_NAME = "LeptonAI"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
if not base_url: if not base_url:
base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1") base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1")
@ -1214,6 +1264,8 @@ class LeptonAIChat(Base):
class TogetherAIChat(Base): class TogetherAIChat(Base):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs): def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.together.xyz/v1" base_url = "https://api.together.xyz/v1"
@ -1221,6 +1273,8 @@ class TogetherAIChat(Base):
class PerfXCloudChat(Base): class PerfXCloudChat(Base):
_FACTORY_NAME = "PerfXCloud"
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs): def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://cloud.perfxlab.cn/v1" base_url = "https://cloud.perfxlab.cn/v1"
@ -1228,6 +1282,8 @@ class PerfXCloudChat(Base):
class UpstageChat(Base): class UpstageChat(Base):
_FACTORY_NAME = "Upstage"
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs): def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.upstage.ai/v1/solar" base_url = "https://api.upstage.ai/v1/solar"
@ -1235,6 +1291,8 @@ class UpstageChat(Base):
class NovitaAIChat(Base): class NovitaAIChat(Base):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs): def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.novita.ai/v3/openai" base_url = "https://api.novita.ai/v3/openai"
@ -1242,6 +1300,8 @@ class NovitaAIChat(Base):
class SILICONFLOWChat(Base): class SILICONFLOWChat(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs): def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.siliconflow.cn/v1" base_url = "https://api.siliconflow.cn/v1"
@ -1249,6 +1309,8 @@ class SILICONFLOWChat(Base):
class YiChat(Base): class YiChat(Base):
_FACTORY_NAME = "01.AI"
def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs): def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.lingyiwanwu.com/v1" base_url = "https://api.lingyiwanwu.com/v1"
@ -1256,6 +1318,8 @@ class YiChat(Base):
class GiteeChat(Base): class GiteeChat(Base):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs): def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs):
if not base_url: if not base_url:
base_url = "https://ai.gitee.com/v1/" base_url = "https://ai.gitee.com/v1/"
@ -1263,6 +1327,8 @@ class GiteeChat(Base):
class ReplicateChat(Base): class ReplicateChat(Base):
_FACTORY_NAME = "Replicate"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -1302,6 +1368,8 @@ class ReplicateChat(Base):
class HunyuanChat(Base): class HunyuanChat(Base):
_FACTORY_NAME = "Tencent Hunyuan"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -1378,6 +1446,8 @@ class HunyuanChat(Base):
class SparkChat(Base): class SparkChat(Base):
_FACTORY_NAME = "XunFei Spark"
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs): def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs):
if not base_url: if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1" base_url = "https://spark-api-open.xf-yun.com/v1"
@ -1398,6 +1468,8 @@ class SparkChat(Base):
class BaiduYiyanChat(Base): class BaiduYiyanChat(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -1444,6 +1516,8 @@ class BaiduYiyanChat(Base):
class AnthropicChat(Base): class AnthropicChat(Base):
_FACTORY_NAME = "Anthropic"
def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs): def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs):
if not base_url: if not base_url:
base_url = "https://api.anthropic.com/v1/" base_url = "https://api.anthropic.com/v1/"
@ -1451,6 +1525,8 @@ class AnthropicChat(Base):
class GoogleChat(Base): class GoogleChat(Base):
_FACTORY_NAME = "Google Cloud"
def __init__(self, key, model_name, base_url=None, **kwargs): def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs) super().__init__(key, model_name, base_url=base_url, **kwargs)
@ -1529,9 +1605,11 @@ class GoogleChat(Base):
if "role" in item and item["role"] == "assistant": if "role" in item and item["role"] == "assistant":
item["role"] = "model" item["role"] = "model"
if "content" in item: if "content" in item:
item["parts"] = [{ item["parts"] = [
{
"text": item.pop("content"), "text": item.pop("content"),
}] }
]
response = self.client.generate_content(hist, generation_config=gen_conf) response = self.client.generate_content(hist, generation_config=gen_conf)
ans = response.text ans = response.text
@ -1587,6 +1665,8 @@ class GoogleChat(Base):
class GPUStackChat(Base): class GPUStackChat(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key=None, model_name="", base_url="", **kwargs): def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")

View File

@ -57,7 +57,7 @@ class Base(ABC):
model=self.model_name, model=self.model_name,
messages=history, messages=history,
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7) top_p=gen_conf.get("top_p", 0.7),
) )
return response.choices[0].message.content.strip(), response.usage.total_tokens return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e: except Exception as e:
@ -79,7 +79,7 @@ class Base(ABC):
messages=history, messages=history,
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7), top_p=gen_conf.get("top_p", 0.7),
stream=True stream=True,
) )
for resp in response: for resp in response:
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
@ -87,8 +87,7 @@ class Base(ABC):
delta = resp.choices[0].delta.content delta = resp.choices[0].delta.content
ans += delta ans += delta
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens tk_count = resp.usage.total_tokens
@ -117,13 +116,12 @@ class Base(ABC):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": f"data:image/jpeg;base64,{b64}"},
"url": f"data:image/jpeg;base64,{b64}"
},
}, },
{ {
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
}, },
], ],
} }
@ -136,9 +134,7 @@ class Base(ABC):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": f"data:image/jpeg;base64,{b64}"},
"url": f"data:image/jpeg;base64,{b64}"
},
}, },
{ {
"type": "text", "type": "text",
@ -156,14 +152,13 @@ class Base(ABC):
"url": f"data:image/jpeg;base64,{b64}", "url": f"data:image/jpeg;base64,{b64}",
}, },
}, },
{ {"type": "text", "text": text},
"type": "text",
"text": text
},
] ]
class GptV4(Base): class GptV4(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url: if not base_url:
base_url = "https://api.openai.com/v1" base_url = "https://api.openai.com/v1"
@ -181,7 +176,7 @@ class GptV4(Base):
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=prompt messages=prompt,
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -197,9 +192,11 @@ class GptV4(Base):
class AzureGptV4(Base): class AzureGptV4(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs): def __init__(self, key, model_name, lang="Chinese", **kwargs):
api_key = json.loads(key).get('api_key', '') api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get('api_version', '2024-02-01') api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
@ -212,10 +209,7 @@ class AzureGptV4(Base):
if "text" in c: if "text" in c:
c["type"] = "text" c["type"] = "text"
res = self.client.chat.completions.create( res = self.client.chat.completions.create(model=self.model_name, messages=prompt)
model=self.model_name,
messages=prompt
)
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None): def describe_with_prompt(self, image, prompt=None):
@ -230,8 +224,11 @@ class AzureGptV4(Base):
class QWenCV(Base): class QWenCV(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs): def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
import dashscope import dashscope
dashscope.api_key = key dashscope.api_key = key
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
@ -247,12 +244,11 @@ class QWenCV(Base):
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"image": f"file://{path}"},
{ {
"image": f"file://{path}" "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
}, if self.lang.lower() == "chinese"
{ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
}, },
], ],
} }
@ -270,9 +266,7 @@ class QWenCV(Base):
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {"image": f"file://{path}"},
"image": f"file://{path}"
},
{ {
"text": prompt if prompt else vision_llm_describe_prompt(), "text": prompt if prompt else vision_llm_describe_prompt(),
}, },
@ -290,9 +284,10 @@ class QWenCV(Base):
from http import HTTPStatus from http import HTTPStatus
from dashscope import MultiModalConversation from dashscope import MultiModalConversation
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image)) response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
return response.message, 0 return response.message, 0
def describe_with_prompt(self, image, prompt=None): def describe_with_prompt(self, image, prompt=None):
@ -303,33 +298,36 @@ class QWenCV(Base):
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image) vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt) response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
return response.message, 0 return response.message, 0
def chat(self, system, history, gen_conf, image=""): def chat(self, system, history, gen_conf, image=""):
from http import HTTPStatus from http import HTTPStatus
from dashscope import MultiModalConversation from dashscope import MultiModalConversation
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
for his in history: for his in history:
if his["role"] == "user": if his["role"] == "user":
his["content"] = self.chat_prompt(his["content"], image) his["content"] = self.chat_prompt(his["content"], image)
response = MultiModalConversation.call(model=self.model_name, messages=history, response = MultiModalConversation.call(
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)) top_p=gen_conf.get("top_p", 0.7),
)
ans = "" ans = ""
tk_count = 0 tk_count = 0
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
ans = response.output.choices[0]['message']['content'] ans = response.output.choices[0]["message"]["content"]
if isinstance(ans, list): if isinstance(ans, list):
ans = ans[0]["text"] if ans else "" ans = ans[0]["text"] if ans else ""
tk_count += response.usage.total_tokens tk_count += response.usage.total_tokens
if response.output.choices[0].get("finish_reason", "") == "length": if response.output.choices[0].get("finish_reason", "") == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, tk_count return ans, tk_count
return "**ERROR**: " + response.message, tk_count return "**ERROR**: " + response.message, tk_count
@ -338,6 +336,7 @@ class QWenCV(Base):
from http import HTTPStatus from http import HTTPStatus
from dashscope import MultiModalConversation from dashscope import MultiModalConversation
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -348,24 +347,25 @@ class QWenCV(Base):
ans = "" ans = ""
tk_count = 0 tk_count = 0
try: try:
response = MultiModalConversation.call(model=self.model_name, messages=history, response = MultiModalConversation.call(
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7), top_p=gen_conf.get("top_p", 0.7),
stream=True) stream=True,
)
for resp in response: for resp in response:
if resp.status_code == HTTPStatus.OK: if resp.status_code == HTTPStatus.OK:
cnt = resp.output.choices[0]['message']['content'] cnt = resp.output.choices[0]["message"]["content"]
if isinstance(cnt, list): if isinstance(cnt, list):
cnt = cnt[0]["text"] if ans else "" cnt = cnt[0]["text"] if ans else ""
ans += cnt ans += cnt
tk_count = resp.usage.total_tokens tk_count = resp.usage.total_tokens
if resp.output.choices[0].get("finish_reason", "") == "length": if resp.output.choices[0].get("finish_reason", "") == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans yield ans
else: else:
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find( yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
@ -373,6 +373,8 @@ class QWenCV(Base):
class Zhipu4V(Base): class Zhipu4V(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)
self.model_name = model_name self.model_name = model_name
@ -394,10 +396,7 @@ class Zhipu4V(Base):
b64 = self.image2base64(image) b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create( res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt)
model=self.model_name,
messages=vision_prompt
)
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
def chat(self, system, history, gen_conf, image=""): def chat(self, system, history, gen_conf, image=""):
@ -412,7 +411,7 @@ class Zhipu4V(Base):
model=self.model_name, model=self.model_name,
messages=history, messages=history,
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7) top_p=gen_conf.get("top_p", 0.7),
) )
return response.choices[0].message.content.strip(), response.usage.total_tokens return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e: except Exception as e:
@ -434,7 +433,7 @@ class Zhipu4V(Base):
messages=history, messages=history,
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7), top_p=gen_conf.get("top_p", 0.7),
stream=True stream=True,
) )
for resp in response: for resp in response:
if not resp.choices[0].delta.content: if not resp.choices[0].delta.content:
@ -442,8 +441,7 @@ class Zhipu4V(Base):
delta = resp.choices[0].delta.content delta = resp.choices[0].delta.content
ans += delta ans += delta
if resp.choices[0].finish_reason == "length": if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens tk_count = resp.usage.total_tokens
@ -455,6 +453,8 @@ class Zhipu4V(Base):
class OllamaCV(Base): class OllamaCV(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, lang="Chinese", **kwargs): def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"]) self.client = Client(host=kwargs["base_url"])
self.model_name = model_name self.model_name = model_name
@ -466,7 +466,7 @@ class OllamaCV(Base):
response = self.client.generate( response = self.client.generate(
model=self.model_name, model=self.model_name,
prompt=prompt[0]["content"][1]["text"], prompt=prompt[0]["content"][1]["text"],
images=[image] images=[image],
) )
ans = response["response"].strip() ans = response["response"].strip()
return ans, 128 return ans, 128
@ -507,7 +507,7 @@ class OllamaCV(Base):
model=self.model_name, model=self.model_name,
messages=history, messages=history,
options=options, options=options,
keep_alive=-1 keep_alive=-1,
) )
ans = response["message"]["content"].strip() ans = response["message"]["content"].strip()
@ -538,7 +538,7 @@ class OllamaCV(Base):
messages=history, messages=history,
stream=True, stream=True,
options=options, options=options,
keep_alive=-1 keep_alive=-1,
) )
for resp in response: for resp in response:
if resp["done"]: if resp["done"]:
@ -551,6 +551,8 @@ class OllamaCV(Base):
class LocalAICV(GptV4): class LocalAICV(GptV4):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url, lang="Chinese"): def __init__(self, key, model_name, base_url, lang="Chinese"):
if not base_url: if not base_url:
raise ValueError("Local cv model url cannot be None") raise ValueError("Local cv model url cannot be None")
@ -561,6 +563,8 @@ class LocalAICV(GptV4):
class XinferenceCV(Base): class XinferenceCV(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", lang="Chinese", base_url=""): def __init__(self, key, model_name="", lang="Chinese", base_url=""):
base_url = urljoin(base_url, "v1") base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
@ -570,10 +574,7 @@ class XinferenceCV(Base):
def describe(self, image): def describe(self, image):
b64 = self.image2base64(image) b64 = self.image2base64(image)
res = self.client.chat.completions.create( res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64))
model=self.model_name,
messages=self.prompt(b64)
)
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None): def describe_with_prompt(self, image, prompt=None):
@ -588,8 +589,11 @@ class XinferenceCV(Base):
class GeminiCV(Base): class GeminiCV(Base):
_FACTORY_NAME = "Gemini"
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import GenerativeModel, client from google.generativeai import GenerativeModel, client
client.configure(api_key=key) client.configure(api_key=key)
_client = client.get_default_generative_client() _client = client.get_default_generative_client()
self.model_name = model_name self.model_name = model_name
@ -599,18 +603,21 @@ class GeminiCV(Base):
def describe(self, image): def describe(self, image):
from PIL.Image import open from PIL.Image import open
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." prompt = (
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
)
b64 = self.image2base64(image) b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64))) img = open(BytesIO(base64.b64decode(b64)))
input = [prompt, img] input = [prompt, img]
res = self.model.generate_content( res = self.model.generate_content(input)
input
)
return res.text, res.usage_metadata.total_token_count return res.text, res.usage_metadata.total_token_count
def describe_with_prompt(self, image, prompt=None): def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open from PIL.Image import open
b64 = self.image2base64(image) b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
img = open(BytesIO(base64.b64decode(b64))) img = open(BytesIO(base64.b64decode(b64)))
@ -622,6 +629,7 @@ class GeminiCV(Base):
def chat(self, system, history, gen_conf, image=""): def chat(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig from transformers import GenerationConfig
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
try: try:
@ -635,9 +643,7 @@ class GeminiCV(Base):
his.pop("content") his.pop("content")
history[-1]["parts"].append("data:image/jpeg;base64," + image) history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig( response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)))
ans = response.text ans = response.text
return ans, response.usage_metadata.total_token_count return ans, response.usage_metadata.total_token_count
@ -646,6 +652,7 @@ class GeminiCV(Base):
def chat_streamly(self, system, history, gen_conf, image=""): def chat_streamly(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig from transformers import GenerationConfig
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -661,9 +668,11 @@ class GeminiCV(Base):
his.pop("content") his.pop("content")
history[-1]["parts"].append("data:image/jpeg;base64," + image) history[-1]["parts"].append("data:image/jpeg;base64," + image)
response = self.model.generate_content(history, generation_config=GenerationConfig( response = self.model.generate_content(
temperature=gen_conf.get("temperature", 0.3), history,
top_p=gen_conf.get("top_p", 0.7)), stream=True) generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)),
stream=True,
)
for resp in response: for resp in response:
if not resp.text: if not resp.text:
@ -677,6 +686,8 @@ class GeminiCV(Base):
class OpenRouterCV(GptV4): class OpenRouterCV(GptV4):
_FACTORY_NAME = "OpenRouter"
def __init__( def __init__(
self, self,
key, key,
@ -692,6 +703,8 @@ class OpenRouterCV(GptV4):
class LocalCV(Base): class LocalCV(Base):
_FACTORY_NAME = "Moonshot"
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass pass
@ -700,6 +713,8 @@ class LocalCV(Base):
class NvidiaCV(Base): class NvidiaCV(Base):
_FACTORY_NAME = "NVIDIA"
def __init__( def __init__(
self, self,
key, key,
@ -726,9 +741,7 @@ class NvidiaCV(Base):
"content-type": "application/json", "content-type": "application/json",
"Authorization": f"Bearer {self.key}", "Authorization": f"Bearer {self.key}",
}, },
json={ json={"messages": self.prompt(b64)},
"messages": self.prompt(b64)
},
) )
response = response.json() response = response.json()
return ( return (
@ -774,10 +787,7 @@ class NvidiaCV(Base):
return [ return [
{ {
"role": "user", "role": "user",
"content": ( "content": (prompt if prompt else vision_llm_describe_prompt()) + f' <img src="data:image/jpeg;base64,{b64}"/>',
prompt if prompt else vision_llm_describe_prompt()
)
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
} }
] ]
@ -791,6 +801,8 @@ class NvidiaCV(Base):
class StepFunCV(GptV4): class StepFunCV(GptV4):
_FACTORY_NAME = "StepFun"
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url: if not base_url:
base_url = "https://api.stepfun.com/v1" base_url = "https://api.stepfun.com/v1"
@ -800,6 +812,8 @@ class StepFunCV(GptV4):
class LmStudioCV(GptV4): class LmStudioCV(GptV4):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, lang="Chinese", base_url=""): def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
@ -810,6 +824,8 @@ class LmStudioCV(GptV4):
class OpenAI_APICV(GptV4): class OpenAI_APICV(GptV4):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, lang="Chinese", base_url=""): def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url: if not base_url:
raise ValueError("url cannot be None") raise ValueError("url cannot be None")
@ -820,6 +836,8 @@ class OpenAI_APICV(GptV4):
class TogetherAICV(GptV4): class TogetherAICV(GptV4):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"): def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
if not base_url: if not base_url:
base_url = "https://api.together.xyz/v1" base_url = "https://api.together.xyz/v1"
@ -827,20 +845,38 @@ class TogetherAICV(GptV4):
class YiCV(GptV4): class YiCV(GptV4):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",): _FACTORY_NAME = "01.AI"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.lingyiwanwu.com/v1",
):
if not base_url: if not base_url:
base_url = "https://api.lingyiwanwu.com/v1" base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name, lang, base_url) super().__init__(key, model_name, lang, base_url)
class SILICONFLOWCV(GptV4): class SILICONFLOWCV(GptV4):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",): _FACTORY_NAME = "SILICONFLOW"
def __init__(
self,
key,
model_name,
lang="Chinese",
base_url="https://api.siliconflow.cn/v1",
):
if not base_url: if not base_url:
base_url = "https://api.siliconflow.cn/v1" base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, lang, base_url) super().__init__(key, model_name, lang, base_url)
class HunyuanCV(Base): class HunyuanCV(Base):
_FACTORY_NAME = "Tencent Hunyuan"
def __init__(self, key, model_name, lang="Chinese", base_url=None): def __init__(self, key, model_name, lang="Chinese", base_url=None):
from tencentcloud.common import credential from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client from tencentcloud.hunyuan.v20230901 import hunyuan_client
@ -895,14 +931,13 @@ class HunyuanCV(Base):
"Contents": [ "Contents": [
{ {
"Type": "image_url", "Type": "image_url",
"ImageUrl": { "ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"},
"Url": f"data:image/jpeg;base64,{b64}"
},
}, },
{ {
"Type": "text", "Type": "text",
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
}, },
], ],
} }
@ -910,6 +945,8 @@ class HunyuanCV(Base):
class AnthropicCV(Base): class AnthropicCV(Base):
_FACTORY_NAME = "Anthropic"
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
import anthropic import anthropic
@ -933,37 +970,28 @@ class AnthropicCV(Base):
"data": b64, "data": b64,
}, },
}, },
{ {"type": "text", "text": prompt},
"type": "text",
"text": prompt
}
], ],
} }
] ]
def describe(self, image): def describe(self, image):
b64 = self.image2base64(image) b64 = self.image2base64(image)
prompt = self.prompt(b64, prompt = self.prompt(
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else b64,
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
) )
response = self.client.messages.create( response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
model=self.model_name,
max_tokens=self.max_tokens,
messages=prompt
)
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"] return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
def describe_with_prompt(self, image, prompt=None): def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image) b64 = self.image2base64(image)
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt()) prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
response = self.client.messages.create( response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
model=self.model_name,
max_tokens=self.max_tokens,
messages=prompt
)
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"] return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
@ -984,11 +1012,7 @@ class AnthropicCV(Base):
).to_dict() ).to_dict()
ans = response["content"][0]["text"] ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens": if response["stop_reason"] == "max_tokens":
ans += ( ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
return ( return (
ans, ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"], response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
@ -1014,7 +1038,7 @@ class AnthropicCV(Base):
**gen_conf, **gen_conf,
) )
for res in response: for res in response:
if res.type == 'content_block_delta': if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking: if res.delta.type == "thinking_delta" and res.delta.thinking:
if ans.find("<think>") < 0: if ans.find("<think>") < 0:
ans += "<think>" ans += "<think>"
@ -1030,7 +1054,10 @@ class AnthropicCV(Base):
yield total_tokens yield total_tokens
class GPUStackCV(GptV4): class GPUStackCV(GptV4):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, lang="Chinese", base_url=""): def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
@ -1041,6 +1068,8 @@ class GPUStackCV(GptV4):
class GoogleCV(Base): class GoogleCV(Base):
_FACTORY_NAME = "Google Cloud"
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs): def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
import base64 import base64
@ -1079,8 +1108,11 @@ class GoogleCV(Base):
self.client = glm.GenerativeModel(model_name=self.model_name) self.client = glm.GenerativeModel(model_name=self.model_name)
def describe(self, image): def describe(self, image):
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ prompt = (
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
)
if "claude" in self.model_name: if "claude" in self.model_name:
b64 = self.image2base64(image) b64 = self.image2base64(image)
@ -1096,17 +1128,14 @@ class GoogleCV(Base):
"data": b64, "data": b64,
}, },
}, },
{ {"type": "text", "text": prompt},
"type": "text",
"text": prompt
}
], ],
} }
] ]
response = self.client.messages.create( response = self.client.messages.create(
model=self.model_name, model=self.model_name,
max_tokens=8192, max_tokens=8192,
messages=vision_prompt messages=vision_prompt,
) )
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
else: else:
@ -1114,10 +1143,7 @@ class GoogleCV(Base):
b64 = self.image2base64(image) b64 = self.image2base64(image)
# Create proper image part for Gemini # Create proper image part for Gemini
image_part = glm.Part.from_data( image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
data=base64.b64decode(b64),
mime_type="image/jpeg"
)
input = [prompt, image_part] input = [prompt, image_part]
res = self.client.generate_content(input) res = self.client.generate_content(input)
return res.text, res.usage_metadata.total_token_count return res.text, res.usage_metadata.total_token_count
@ -1137,18 +1163,11 @@ class GoogleCV(Base):
"data": b64, "data": b64,
}, },
}, },
{ {"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()},
"type": "text",
"text": prompt if prompt else vision_llm_describe_prompt()
}
], ],
} }
] ]
response = self.client.messages.create( response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt)
model=self.model_name,
max_tokens=8192,
messages=vision_prompt
)
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
else: else:
import vertexai.generative_models as glm import vertexai.generative_models as glm
@ -1156,10 +1175,7 @@ class GoogleCV(Base):
b64 = self.image2base64(image) b64 = self.image2base64(image)
vision_prompt = prompt if prompt else vision_llm_describe_prompt() vision_prompt = prompt if prompt else vision_llm_describe_prompt()
# Create proper image part for Gemini # Create proper image part for Gemini
image_part = glm.Part.from_data( image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
data=base64.b64decode(b64),
mime_type="image/jpeg"
)
input = [vision_prompt, image_part] input = [vision_prompt, image_part]
res = self.client.generate_content(input) res = self.client.generate_content(input)
return res.text, res.usage_metadata.total_token_count return res.text, res.usage_metadata.total_token_count
@ -1180,25 +1196,17 @@ class GoogleCV(Base):
"data": image, "data": image,
}, },
}, },
{ {"type": "text", "text": his["content"]},
"type": "text",
"text": his["content"]
}
] ]
response = self.client.messages.create( response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
model=self.model_name,
max_tokens=8192,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)
)
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
except Exception as e: except Exception as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
else: else:
import vertexai.generative_models as glm import vertexai.generative_models as glm
from transformers import GenerationConfig from transformers import GenerationConfig
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
try: try:
@ -1213,15 +1221,10 @@ class GoogleCV(Base):
# Create proper image part for Gemini # Create proper image part for Gemini
img_bytes = base64.b64decode(image) img_bytes = base64.b64decode(image)
image_part = glm.Part.from_data( image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg")
data=img_bytes,
mime_type="image/jpeg"
)
history[-1]["parts"].append(image_part) history[-1]["parts"].append(image_part)
response = self.client.generate_content(history, generation_config=GenerationConfig( response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)))
ans = response.text ans = response.text
return ans, response.usage_metadata.total_token_count return ans, response.usage_metadata.total_token_count

View File

@ -13,28 +13,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
import logging import logging
import os
import re import re
import threading import threading
from abc import ABC
from urllib.parse import urljoin from urllib.parse import urljoin
import dashscope
import google.generativeai as genai
import numpy as np
import requests import requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from zhipuai import ZhipuAI
import os
from abc import ABC
from ollama import Client from ollama import Client
import dashscope
from openai import OpenAI from openai import OpenAI
import numpy as np from zhipuai import ZhipuAI
import asyncio
from api import settings from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
import json
class Base(ABC): class Base(ABC):
@ -60,7 +59,8 @@ class Base(ABC):
class DefaultEmbedding(Base): class DefaultEmbedding(Base):
os.environ['CUDA_VISIBLE_DEVICES'] = '0' _FACTORY_NAME = "BAAI"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
_model = None _model = None
_model_name = "" _model_name = ""
_model_lock = threading.Lock() _model_lock = threading.Lock()
@ -79,21 +79,22 @@ class DefaultEmbedding(Base):
""" """
if not settings.LIGHTEN: if not settings.LIGHTEN:
with DefaultEmbedding._model_lock: with DefaultEmbedding._model_lock:
from FlagEmbedding import FlagModel
import torch import torch
from FlagEmbedding import FlagModel
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try: try:
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), DefaultEmbedding._model = FlagModel(
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available()) use_fp16=torch.cuda.is_available(),
)
DefaultEmbedding._model_name = model_name DefaultEmbedding._model_name = model_name
except Exception: except Exception:
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", model_dir = snapshot_download(
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
local_dir_use_symlinks=False) )
DefaultEmbedding._model = FlagModel(model_dir, DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
self._model = DefaultEmbedding._model self._model = DefaultEmbedding._model
self._model_name = DefaultEmbedding._model_name self._model_name = DefaultEmbedding._model_name
@ -114,8 +115,9 @@ class DefaultEmbedding(Base):
class OpenAIEmbed(Base): class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002", _FACTORY_NAME = "OpenAI"
base_url="https://api.openai.com/v1"):
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url: if not base_url:
base_url = "https://api.openai.com/v1" base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
@ -128,8 +130,7 @@ class OpenAIEmbed(Base):
ress = [] ress = []
total_tokens = 0 total_tokens = 0
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size], res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
model=self.model_name)
try: try:
ress.extend([d.embedding for d in res.data]) ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res) total_tokens += self.total_token_count(res)
@ -138,12 +139,13 @@ class OpenAIEmbed(Base):
return np.array(ress), total_tokens return np.array(ress), total_tokens
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)], res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name)
model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res) return np.array(res.data[0].embedding), self.total_token_count(res)
class LocalAIEmbed(Base): class LocalAIEmbed(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
if not base_url: if not base_url:
raise ValueError("Local embedding model url cannot be None") raise ValueError("Local embedding model url cannot be None")
@ -169,41 +171,42 @@ class LocalAIEmbed(Base):
class AzureEmbed(OpenAIEmbed): class AzureEmbed(OpenAIEmbed):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
api_key = json.loads(key).get('api_key', '')
api_version = json.loads(key).get('api_version', '2024-02-01') api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name self.model_name = model_name
class BaiChuanEmbed(OpenAIEmbed): class BaiChuanEmbed(OpenAIEmbed):
def __init__(self, key, _FACTORY_NAME = "BaiChuan"
model_name='Baichuan-Text-Embedding',
base_url='https://api.baichuan-ai.com/v1'): def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
if not base_url: if not base_url:
base_url = "https://api.baichuan-ai.com/v1" base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url) super().__init__(key, model_name, base_url)
class QWenEmbed(Base): class QWenEmbed(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="text_embedding_v2", **kwargs): def __init__(self, key, model_name="text_embedding_v2", **kwargs):
self.key = key self.key = key
self.model_name = model_name self.model_name = model_name
def encode(self, texts: list): def encode(self, texts: list):
import dashscope import dashscope
batch_size = 4 batch_size = 4
res = [] res = []
token_count = 0 token_count = 0
texts = [truncate(t, 2048) for t in texts] texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
resp = dashscope.TextEmbedding.call( resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
model=self.model_name,
input=texts[i:i + batch_size],
api_key=self.key,
text_type="document"
)
try: try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))] embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]: for e in resp["output"]["embeddings"]:
@ -216,20 +219,16 @@ class QWenEmbed(Base):
return np.array(res), token_count return np.array(res), token_count
def encode_queries(self, text): def encode_queries(self, text):
resp = dashscope.TextEmbedding.call( resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
model=self.model_name,
input=text[:2048],
api_key=self.key,
text_type="query"
)
try: try:
return np.array(resp["output"]["embeddings"][0] return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
["embedding"]), self.total_token_count(resp)
except Exception as _e: except Exception as _e:
log_exception(_e, resp) log_exception(_e, resp)
class ZhipuEmbed(Base): class ZhipuEmbed(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="embedding-2", **kwargs): def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key) self.client = ZhipuAI(api_key=key)
self.model_name = model_name self.model_name = model_name
@ -246,8 +245,7 @@ class ZhipuEmbed(Base):
texts = [truncate(t, MAX_LEN) for t in texts] texts = [truncate(t, MAX_LEN) for t in texts]
for txt in texts: for txt in texts:
res = self.client.embeddings.create(input=txt, res = self.client.embeddings.create(input=txt, model=self.model_name)
model=self.model_name)
try: try:
arr.append(res.data[0].embedding) arr.append(res.data[0].embedding)
tks_num += self.total_token_count(res) tks_num += self.total_token_count(res)
@ -256,8 +254,7 @@ class ZhipuEmbed(Base):
return np.array(arr), tks_num return np.array(arr), tks_num
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings.create(input=text, res = self.client.embeddings.create(input=text, model=self.model_name)
model=self.model_name)
try: try:
return np.array(res.data[0].embedding), self.total_token_count(res) return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e: except Exception as _e:
@ -265,18 +262,17 @@ class ZhipuEmbed(Base):
class OllamaEmbed(Base): class OllamaEmbed(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \ self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.model_name = model_name self.model_name = model_name
def encode(self, texts: list): def encode(self, texts: list):
arr = [] arr = []
tks_num = 0 tks_num = 0
for txt in texts: for txt in texts:
res = self.client.embeddings(prompt=txt, res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
model=self.model_name,
options={"use_mmap": True})
try: try:
arr.append(res["embedding"]) arr.append(res["embedding"])
except Exception as _e: except Exception as _e:
@ -285,9 +281,7 @@ class OllamaEmbed(Base):
return np.array(arr), tks_num return np.array(arr), tks_num
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings(prompt=text, res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
model=self.model_name,
options={"use_mmap": True})
try: try:
return np.array(res["embedding"]), 128 return np.array(res["embedding"]), 128
except Exception as _e: except Exception as _e:
@ -295,6 +289,7 @@ class OllamaEmbed(Base):
class FastEmbed(DefaultEmbedding): class FastEmbed(DefaultEmbedding):
_FACTORY_NAME = "FastEmbed"
def __init__( def __init__(
self, self,
@ -307,15 +302,15 @@ class FastEmbed(DefaultEmbedding):
if not settings.LIGHTEN: if not settings.LIGHTEN:
with FastEmbed._model_lock: with FastEmbed._model_lock:
from fastembed import TextEmbedding from fastembed import TextEmbedding
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try: try:
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
DefaultEmbedding._model_name = model_name DefaultEmbedding._model_name = model_name
except Exception: except Exception:
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5", cache_dir = snapshot_download(
local_dir=os.path.join(get_home_cache_dir(), repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), )
local_dir_use_symlinks=False)
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
self._model = DefaultEmbedding._model self._model = DefaultEmbedding._model
self._model_name = model_name self._model_name = model_name
@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding):
class XinferenceEmbed(Base): class XinferenceEmbed(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", base_url=""): def __init__(self, key, model_name="", base_url=""):
base_url = urljoin(base_url, "v1") base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
@ -359,8 +356,7 @@ class XinferenceEmbed(Base):
return np.array(ress), total_tokens return np.array(ress), total_tokens
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings.create(input=[text], res = self.client.embeddings.create(input=[text], model=self.model_name)
model=self.model_name)
try: try:
return np.array(res.data[0].embedding), self.total_token_count(res) return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e: except Exception as _e:
@ -368,20 +364,18 @@ class XinferenceEmbed(Base):
class YoudaoEmbed(Base): class YoudaoEmbed(Base):
_FACTORY_NAME = "Youdao"
_client = None _client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoEmbed._client: if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing from BCEmbedding import EmbeddingModel as qanthing
try: try:
logging.info("LOADING BCE...") logging.info("LOADING BCE...")
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join( YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
get_home_cache_dir(),
"bce-embedding-base_v1"))
except Exception: except Exception:
YoudaoEmbed._client = qanthing( YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))
def encode(self, texts: list): def encode(self, texts: list):
batch_size = 10 batch_size = 10
@ -400,14 +394,11 @@ class YoudaoEmbed(Base):
class JinaEmbed(Base): class JinaEmbed(Base):
def __init__(self, key, model_name="jina-embeddings-v3", _FACTORY_NAME = "Jina"
base_url="https://api.jina.ai/v1/embeddings"):
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings" self.base_url = "https://api.jina.ai/v1/embeddings"
self.headers = { self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.model_name = model_name self.model_name = model_name
def encode(self, texts: list): def encode(self, texts: list):
@ -416,11 +407,7 @@ class JinaEmbed(Base):
ress = [] ress = []
token_count = 0 token_count = 0
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
data = { data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
"model": self.model_name,
"input": texts[i:i + batch_size],
'encoding_type': 'float'
}
response = requests.post(self.base_url, headers=self.headers, json=data) response = requests.post(self.base_url, headers=self.headers, json=data)
try: try:
res = response.json() res = response.json()
@ -435,50 +422,12 @@ class JinaEmbed(Base):
return np.array(embds[0]), cnt return np.array(embds[0]), cnt
class InfinityEmbed(Base):
_model = None
def __init__(
self,
model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
engine_kwargs: dict = {},
key = None,
):
from infinity_emb import EngineArgs
from infinity_emb.engine import AsyncEngineArray
self._default_model = model_names[0]
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
async def _embed(self, sentences: list[str], model_name: str = ""):
if not model_name:
model_name = self._default_model
engine = self.engine_array[model_name]
was_already_running = engine.is_running
if not was_already_running:
await engine.astart()
embeddings, usage = await engine.embed(sentences=sentences)
if not was_already_running:
await engine.astop()
return embeddings, usage
def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
embeddings, usage = asyncio.run(self._embed(texts, model_name))
return np.array(embeddings), usage
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
# Using the internal tokenizer to encode the texts and get the total
# number of tokens
return self.encode([text])
class MistralEmbed(Base): class MistralEmbed(Base):
def __init__(self, key, model_name="mistral-embed", _FACTORY_NAME = "Mistral"
base_url=None):
def __init__(self, key, model_name="mistral-embed", base_url=None):
from mistralai.client import MistralClient from mistralai.client import MistralClient
self.client = MistralClient(api_key=key) self.client = MistralClient(api_key=key)
self.model_name = model_name self.model_name = model_name
@ -488,8 +437,7 @@ class MistralEmbed(Base):
ress = [] ress = []
token_count = 0 token_count = 0
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
res = self.client.embeddings(input=texts[i:i + batch_size], res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
model=self.model_name)
try: try:
ress.extend([d.embedding for d in res.data]) ress.extend([d.embedding for d in res.data])
token_count += self.total_token_count(res) token_count += self.total_token_count(res)
@ -498,8 +446,7 @@ class MistralEmbed(Base):
return np.array(ress), token_count return np.array(ress), token_count
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embeddings(input=[truncate(text, 8196)], res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
model=self.model_name)
try: try:
return np.array(res.data[0].embedding), self.total_token_count(res) return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e: except Exception as _e:
@ -507,30 +454,31 @@ class MistralEmbed(Base):
class BedrockEmbed(Base): class BedrockEmbed(Base):
def __init__(self, key, model_name, _FACTORY_NAME = "Bedrock"
**kwargs):
def __init__(self, key, model_name, **kwargs):
import boto3 import boto3
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
self.bedrock_sk = json.loads(key).get('bedrock_sk', '') self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_region = json.loads(key).get('bedrock_region', '') self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name self.model_name = model_name
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
self.client = boto3.client('bedrock-runtime') self.client = boto3.client("bedrock-runtime")
else: else:
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def encode(self, texts: list): def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts] texts = [truncate(t, 8196) for t in texts]
embeddings = [] embeddings = []
token_count = 0 token_count = 0
for text in texts: for text in texts:
if self.model_name.split('.')[0] == 'amazon': if self.model_name.split(".")[0] == "amazon":
body = {"inputText": text} body = {"inputText": text}
elif self.model_name.split('.')[0] == 'cohere': elif self.model_name.split(".")[0] == "cohere":
body = {"texts": [text], "input_type": 'search_document'} body = {"texts": [text], "input_type": "search_document"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try: try:
@ -545,10 +493,10 @@ class BedrockEmbed(Base):
def encode_queries(self, text): def encode_queries(self, text):
embeddings = [] embeddings = []
token_count = num_tokens_from_string(text) token_count = num_tokens_from_string(text)
if self.model_name.split('.')[0] == 'amazon': if self.model_name.split(".")[0] == "amazon":
body = {"inputText": truncate(text, 8196)} body = {"inputText": truncate(text, 8196)}
elif self.model_name.split('.')[0] == 'cohere': elif self.model_name.split(".")[0] == "cohere":
body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'} body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try: try:
@ -561,10 +509,11 @@ class BedrockEmbed(Base):
class GeminiEmbed(Base): class GeminiEmbed(Base):
def __init__(self, key, model_name='models/text-embedding-004', _FACTORY_NAME = "Gemini"
**kwargs):
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
self.key = key self.key = key
self.model_name = 'models/' + model_name self.model_name = "models/" + model_name
def encode(self, texts: list): def encode(self, texts: list):
texts = [truncate(t, 2048) for t in texts] texts = [truncate(t, 2048) for t in texts]
@ -573,35 +522,27 @@ class GeminiEmbed(Base):
batch_size = 16 batch_size = 16
ress = [] ress = []
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
result = genai.embed_content( result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
model=self.model_name,
content=texts[i: i + batch_size],
task_type="retrieval_document",
title="Embedding of single string")
try: try:
ress.extend(result['embedding']) ress.extend(result["embedding"])
except Exception as _e: except Exception as _e:
log_exception(_e, result) log_exception(_e, result)
return np.array(ress), token_count return np.array(ress), token_count
def encode_queries(self, text): def encode_queries(self, text):
genai.configure(api_key=self.key) genai.configure(api_key=self.key)
result = genai.embed_content( result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
model=self.model_name,
content=truncate(text,2048),
task_type="retrieval_document",
title="Embedding of single string")
token_count = num_tokens_from_string(text) token_count = num_tokens_from_string(text)
try: try:
return np.array(result['embedding']), token_count return np.array(result["embedding"]), token_count
except Exception as _e: except Exception as _e:
log_exception(_e, result) log_exception(_e, result)
class NvidiaEmbed(Base): class NvidiaEmbed(Base):
def __init__( _FACTORY_NAME = "NVIDIA"
self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
): def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
if not base_url: if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings" base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key self.api_key = key
@ -645,6 +586,8 @@ class NvidiaEmbed(Base):
class LmStudioEmbed(LocalAIEmbed): class LmStudioEmbed(LocalAIEmbed):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
if not base_url: if not base_url:
raise ValueError("Local llm url cannot be None") raise ValueError("Local llm url cannot be None")
@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed):
class OpenAI_APIEmbed(OpenAIEmbed): class OpenAI_APIEmbed(OpenAIEmbed):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
if not base_url: if not base_url:
raise ValueError("url cannot be None") raise ValueError("url cannot be None")
@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed):
class CoHereEmbed(Base): class CoHereEmbed(Base):
_FACTORY_NAME = "Cohere"
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
from cohere import Client from cohere import Client
@ -701,6 +648,8 @@ class CoHereEmbed(Base):
class TogetherAIEmbed(OpenAIEmbed): class TogetherAIEmbed(OpenAIEmbed):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
if not base_url: if not base_url:
base_url = "https://api.together.xyz/v1" base_url = "https://api.together.xyz/v1"
@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed):
class PerfXCloudEmbed(OpenAIEmbed): class PerfXCloudEmbed(OpenAIEmbed):
_FACTORY_NAME = "PerfXCloud"
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url: if not base_url:
base_url = "https://cloud.perfxlab.cn/v1" base_url = "https://cloud.perfxlab.cn/v1"
@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed):
class UpstageEmbed(OpenAIEmbed): class UpstageEmbed(OpenAIEmbed):
_FACTORY_NAME = "Upstage"
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"): def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
if not base_url: if not base_url:
base_url = "https://api.upstage.ai/v1/solar" base_url = "https://api.upstage.ai/v1/solar"
@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed):
class SILICONFLOWEmbed(Base): class SILICONFLOWEmbed(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"): def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
if not base_url: if not base_url:
base_url = "https://api.siliconflow.cn/v1/embeddings" base_url = "https://api.siliconflow.cn/v1/embeddings"
@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base):
class ReplicateEmbed(Base): class ReplicateEmbed(Base):
_FACTORY_NAME = "Replicate"
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
from replicate.client import Client from replicate.client import Client
@ -790,6 +747,8 @@ class ReplicateEmbed(Base):
class BaiduYiyanEmbed(Base): class BaiduYiyanEmbed(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
import qianfan import qianfan
@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base):
class VoyageEmbed(Base): class VoyageEmbed(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
import voyageai import voyageai
@ -832,9 +793,7 @@ class VoyageEmbed(Base):
ress = [] ress = []
token_count = 0 token_count = 0
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
res = self.client.embed( res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
)
try: try:
ress.extend(res.embeddings) ress.extend(res.embeddings)
token_count += res.total_tokens token_count += res.total_tokens
@ -843,9 +802,7 @@ class VoyageEmbed(Base):
return np.array(ress), token_count return np.array(ress), token_count
def encode_queries(self, text): def encode_queries(self, text):
res = self.client.embed( res = self.client.embed(texts=text, model=self.model_name, input_type="query")
texts=text, model=self.model_name, input_type="query"
)
try: try:
return np.array(res.embeddings)[0], res.total_tokens return np.array(res.embeddings)[0], res.total_tokens
except Exception as _e: except Exception as _e:
@ -853,6 +810,8 @@ class VoyageEmbed(Base):
class HuggingFaceEmbed(Base): class HuggingFaceEmbed(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
if not model_name: if not model_name:
raise ValueError("Model name cannot be None") raise ValueError("Model name cannot be None")
@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base):
def encode(self, texts: list): def encode(self, texts: list):
embeddings = [] embeddings = []
for text in texts: for text in texts:
response = requests.post( response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
f"{self.base_url}/embed",
json={"inputs": text},
headers={'Content-Type': 'application/json'}
)
if response.status_code == 200: if response.status_code == 200:
embedding = response.json() embedding = response.json()
embeddings.append(embedding[0]) embeddings.append(embedding[0])
@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base):
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text): def encode_queries(self, text):
response = requests.post( response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
f"{self.base_url}/embed",
json={"inputs": text},
headers={'Content-Type': 'application/json'}
)
if response.status_code == 200: if response.status_code == 200:
embedding = response.json() embedding = response.json()
return np.array(embedding[0]), num_tokens_from_string(text) return np.array(embedding[0]), num_tokens_from_string(text)
@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base):
class VolcEngineEmbed(OpenAIEmbed): class VolcEngineEmbed(OpenAIEmbed):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url: if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3" base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get('ark_api_key', '') ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url) super().__init__(ark_api_key, model_name, base_url)
class GPUStackEmbed(OpenAIEmbed): class GPUStackEmbed(OpenAIEmbed):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
if not base_url: if not base_url:
raise ValueError("url cannot be None") raise ValueError("url cannot be None")
@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed):
class NovitaEmbed(SILICONFLOWEmbed): class NovitaEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"): def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
if not base_url: if not base_url:
base_url = "https://api.novita.ai/v3/openai/embeddings" base_url = "https://api.novita.ai/v3/openai/embeddings"
@ -915,6 +872,8 @@ class NovitaEmbed(SILICONFLOWEmbed):
class GiteeEmbed(SILICONFLOWEmbed): class GiteeEmbed(SILICONFLOWEmbed):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"): def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
if not base_url: if not base_url:
base_url = "https://ai.gitee.com/v1/embeddings" base_url = "https://ai.gitee.com/v1/embeddings"

View File

@ -13,24 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
import os
import re import re
import threading import threading
from abc import ABC
from collections.abc import Iterable from collections.abc import Iterable
from urllib.parse import urljoin from urllib.parse import urljoin
import requests
import httpx import httpx
from huggingface_hub import snapshot_download
import os
from abc import ABC
import numpy as np import numpy as np
import requests
from huggingface_hub import snapshot_download
from yarl import URL from yarl import URL
from api import settings from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import json
def sigmoid(x): def sigmoid(x):
@ -57,6 +57,7 @@ class Base(ABC):
class DefaultRerank(Base): class DefaultRerank(Base):
_FACTORY_NAME = "BAAI"
_model = None _model = None
_model_lock = threading.Lock() _model_lock = threading.Lock()
@ -75,17 +76,13 @@ class DefaultRerank(Base):
if not settings.LIGHTEN and not DefaultRerank._model: if not settings.LIGHTEN and not DefaultRerank._model:
import torch import torch
from FlagEmbedding import FlagReranker from FlagEmbedding import FlagReranker
with DefaultRerank._model_lock: with DefaultRerank._model_lock:
if not DefaultRerank._model: if not DefaultRerank._model:
try: try:
DefaultRerank._model = FlagReranker( DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
use_fp16=torch.cuda.is_available())
except Exception: except Exception:
model_dir = snapshot_download(repo_id=model_name, model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
local_dir=os.path.join(get_home_cache_dir(),
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model self._model = DefaultRerank._model
self._dynamic_batch_size = 8 self._dynamic_batch_size = 8
@ -94,6 +91,7 @@ class DefaultRerank(Base):
def torch_empty_cache(self): def torch_empty_cache(self):
try: try:
import torch import torch
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as e: except Exception as e:
print(f"Error emptying cache: {e}") print(f"Error emptying cache: {e}")
@ -152,23 +150,16 @@ class DefaultRerank(Base):
class JinaRerank(Base): class JinaRerank(Base):
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", _FACTORY_NAME = "Jina"
base_url="https://api.jina.ai/v1/rerank"):
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
self.base_url = "https://api.jina.ai/v1/rerank" self.base_url = "https://api.jina.ai/v1/rerank"
self.headers = { self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.model_name = model_name self.model_name = model_name
def similarity(self, query: str, texts: list): def similarity(self, query: str, texts: list):
texts = [truncate(t, 8196) for t in texts] texts = [truncate(t, 8196) for t in texts]
data = { data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
"model": self.model_name,
"query": query,
"documents": texts,
"top_n": len(texts)
}
res = requests.post(self.base_url, headers=self.headers, json=data).json() res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float) rank = np.zeros(len(texts), dtype=float)
try: try:
@ -180,22 +171,20 @@ class JinaRerank(Base):
class YoudaoRerank(DefaultRerank): class YoudaoRerank(DefaultRerank):
_FACTORY_NAME = "Youdao"
_model = None _model = None
_model_lock = threading.Lock() _model_lock = threading.Lock()
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoRerank._model: if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel from BCEmbedding import RerankerModel
with YoudaoRerank._model_lock: with YoudaoRerank._model_lock:
if not YoudaoRerank._model: if not YoudaoRerank._model:
try: try:
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
get_home_cache_dir(),
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
except Exception: except Exception:
YoudaoRerank._model = RerankerModel( YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow"))
self._model = YoudaoRerank._model self._model = YoudaoRerank._model
self._dynamic_batch_size = 8 self._dynamic_batch_size = 8
@ -212,6 +201,8 @@ class YoudaoRerank(DefaultRerank):
class XInferenceRerank(Base): class XInferenceRerank(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key="x", model_name="", base_url=""): def __init__(self, key="x", model_name="", base_url=""):
if base_url.find("/v1") == -1: if base_url.find("/v1") == -1:
base_url = urljoin(base_url, "/v1/rerank") base_url = urljoin(base_url, "/v1/rerank")
@ -219,10 +210,7 @@ class XInferenceRerank(Base):
base_url = urljoin(base_url, "/v1/rerank") base_url = urljoin(base_url, "/v1/rerank")
self.model_name = model_name self.model_name = model_name
self.base_url = base_url self.base_url = base_url
self.headers = { self.headers = {"Content-Type": "application/json", "accept": "application/json"}
"Content-Type": "application/json",
"accept": "application/json"
}
if key and key != "x": if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}" self.headers["Authorization"] = f"Bearer {key}"
@ -233,13 +221,7 @@ class XInferenceRerank(Base):
token_count = 0 token_count = 0
for _, t in pairs: for _, t in pairs:
token_count += num_tokens_from_string(t) token_count += num_tokens_from_string(t)
data = { data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
"model": self.model_name,
"query": query,
"return_documents": "true",
"return_len": "true",
"documents": texts
}
res = requests.post(self.base_url, headers=self.headers, json=data).json() res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float) rank = np.zeros(len(texts), dtype=float)
try: try:
@ -251,15 +233,14 @@ class XInferenceRerank(Base):
class LocalAIRerank(Base): class LocalAIRerank(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1: if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank") self.base_url = urljoin(base_url, "/rerank")
else: else:
self.base_url = base_url self.base_url = base_url
self.headers = { self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.model_name = model_name.split("___")[0] self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list): def similarity(self, query: str, texts: list):
@ -296,16 +277,15 @@ class LocalAIRerank(Base):
class NvidiaRerank(Base): class NvidiaRerank(Base):
def __init__( _FACTORY_NAME = "NVIDIA"
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
): def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
if not base_url: if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/" base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self.model_name = model_name self.model_name = model_name
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3": if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking" self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
)
if self.model_name == "nvidia/rerank-qa-mistral-4b": if self.model_name == "nvidia/rerank-qa-mistral-4b":
self.base_url = urljoin(base_url, "reranking") self.base_url = urljoin(base_url, "reranking")
@ -318,9 +298,7 @@ class NvidiaRerank(Base):
} }
def similarity(self, query: str, texts: list): def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum( token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
[num_tokens_from_string(t) for t in texts]
)
data = { data = {
"model": self.model_name, "model": self.model_name,
"query": {"text": query}, "query": {"text": query},
@ -339,6 +317,8 @@ class NvidiaRerank(Base):
class LmStudioRerank(Base): class LmStudioRerank(Base):
_FACTORY_NAME = "LM-Studio"
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
pass pass
@ -347,15 +327,14 @@ class LmStudioRerank(Base):
class OpenAI_APIRerank(Base): class OpenAI_APIRerank(Base):
_FACTORY_NAME = "OpenAI-API-Compatible"
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1: if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank") self.base_url = urljoin(base_url, "/rerank")
else: else:
self.base_url = base_url self.base_url = base_url
self.headers = { self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
"Content-Type": "application/json",
"Authorization": f"Bearer {key}"
}
self.model_name = model_name.split("___")[0] self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list): def similarity(self, query: str, texts: list):
@ -392,6 +371,8 @@ class OpenAI_APIRerank(Base):
class CoHereRerank(Base): class CoHereRerank(Base):
_FACTORY_NAME = ["Cohere", "VLLM"]
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
from cohere import Client from cohere import Client
@ -399,9 +380,7 @@ class CoHereRerank(Base):
self.model_name = model_name.split("___")[0] self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list): def similarity(self, query: str, texts: list):
token_count = num_tokens_from_string(query) + sum( token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
[num_tokens_from_string(t) for t in texts]
)
res = self.client.rerank( res = self.client.rerank(
model=self.model_name, model=self.model_name,
query=query, query=query,
@ -419,6 +398,8 @@ class CoHereRerank(Base):
class TogetherAIRerank(Base): class TogetherAIRerank(Base):
_FACTORY_NAME = "TogetherAI"
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
pass pass
@ -427,9 +408,9 @@ class TogetherAIRerank(Base):
class SILICONFLOWRerank(Base): class SILICONFLOWRerank(Base):
def __init__( _FACTORY_NAME = "SILICONFLOW"
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
): def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
if not base_url: if not base_url:
base_url = "https://api.siliconflow.cn/v1/rerank" base_url = "https://api.siliconflow.cn/v1/rerank"
self.model_name = model_name self.model_name = model_name
@ -450,9 +431,7 @@ class SILICONFLOWRerank(Base):
"max_chunks_per_doc": 1024, "max_chunks_per_doc": 1024,
"overlap_tokens": 80, "overlap_tokens": 80,
} }
response = requests.post( response = requests.post(self.base_url, json=payload, headers=self.headers).json()
self.base_url, json=payload, headers=self.headers
).json()
rank = np.zeros(len(texts), dtype=float) rank = np.zeros(len(texts), dtype=float)
try: try:
for d in response["results"]: for d in response["results"]:
@ -466,6 +445,8 @@ class SILICONFLOWRerank(Base):
class BaiduYiyanRerank(Base): class BaiduYiyanRerank(Base):
_FACTORY_NAME = "BaiduYiyan"
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
from qianfan.resources import Reranker from qianfan.resources import Reranker
@ -492,6 +473,8 @@ class BaiduYiyanRerank(Base):
class VoyageRerank(Base): class VoyageRerank(Base):
_FACTORY_NAME = "Voyage AI"
def __init__(self, key, model_name, base_url=None): def __init__(self, key, model_name, base_url=None):
import voyageai import voyageai
@ -502,9 +485,7 @@ class VoyageRerank(Base):
rank = np.zeros(len(texts), dtype=float) rank = np.zeros(len(texts), dtype=float)
if not texts: if not texts:
return rank, 0 return rank, 0
res = self.client.rerank( res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
query=query, documents=texts, model=self.model_name, top_k=len(texts)
)
try: try:
for r in res.results: for r in res.results:
rank[r.index] = r.relevance_score rank[r.index] = r.relevance_score
@ -514,22 +495,20 @@ class VoyageRerank(Base):
class QWenRerank(Base): class QWenRerank(Base):
def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs): _FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
import dashscope import dashscope
self.api_key = key self.api_key = key
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
def similarity(self, query: str, texts: list): def similarity(self, query: str, texts: list):
import dashscope
from http import HTTPStatus from http import HTTPStatus
resp = dashscope.TextReRank.call(
api_key=self.api_key, import dashscope
model=self.model_name,
query=query, resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
documents=texts,
top_n=len(texts),
return_documents=False
)
rank = np.zeros(len(texts), dtype=float) rank = np.zeros(len(texts), dtype=float)
if resp.status_code == HTTPStatus.OK: if resp.status_code == HTTPStatus.OK:
try: try:
@ -543,6 +522,8 @@ class QWenRerank(Base):
class HuggingfaceRerank(DefaultRerank): class HuggingfaceRerank(DefaultRerank):
_FACTORY_NAME = "HuggingFace"
@staticmethod @staticmethod
def post(query: str, texts: list, url="127.0.0.1"): def post(query: str, texts: list, url="127.0.0.1"):
exc = None exc = None
@ -550,9 +531,9 @@ class HuggingfaceRerank(DefaultRerank):
batch_size = 8 batch_size = 8
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
try: try:
res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"}, res = requests.post(
json={"query": query, "texts": texts[i: i + batch_size], f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
"raw_scores": False, "truncate": True}) )
for o in res.json(): for o in res.json():
scores[o["index"] + i] = o["score"] scores[o["index"] + i] = o["score"]
@ -577,9 +558,9 @@ class HuggingfaceRerank(DefaultRerank):
class GPUStackRerank(Base): class GPUStackRerank(Base):
def __init__( _FACTORY_NAME = "GPUStack"
self, key, model_name, base_url
): def __init__(self, key, model_name, base_url):
if not base_url: if not base_url:
raise ValueError("url cannot be None") raise ValueError("url cannot be None")
@ -600,9 +581,7 @@ class GPUStackRerank(Base):
} }
try: try:
response = requests.post( response = requests.post(self.base_url, json=payload, headers=self.headers)
self.base_url, json=payload, headers=self.headers
)
response.raise_for_status() response.raise_for_status()
response_json = response.json() response_json = response.json()
@ -623,11 +602,12 @@ class GPUStackRerank(Base):
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise ValueError( raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
class NovitaRerank(JinaRerank): class NovitaRerank(JinaRerank):
_FACTORY_NAME = "NovitaAI"
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"): def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
if not base_url: if not base_url:
base_url = "https://api.novita.ai/v3/openai/rerank" base_url = "https://api.novita.ai/v3/openai/rerank"
@ -635,6 +615,8 @@ class NovitaRerank(JinaRerank):
class GiteeRerank(JinaRerank): class GiteeRerank(JinaRerank):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"): def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
if not base_url: if not base_url:
base_url = "https://ai.gitee.com/v1/rerank" base_url = "https://ai.gitee.com/v1/rerank"

View File

@ -13,16 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
import requests
from openai.lib.azure import AzureOpenAI
import io
from abc import ABC
from openai import OpenAI
import json
from rag.utils import num_tokens_from_string
import base64 import base64
import io
import json
import os
import re import re
from abc import ABC
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from rag.utils import num_tokens_from_string
class Base(ABC): class Base(ABC):
@ -30,11 +32,7 @@ class Base(ABC):
pass pass
def transcription(self, audio, **kwargs): def transcription(self, audio, **kwargs):
transcription = self.client.audio.transcriptions.create( transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text")
model=self.model_name,
file=audio,
response_format="text"
)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self, audio): def audio2base64(self, audio):
@ -46,6 +44,8 @@ class Base(ABC):
class GPTSeq2txt(Base): class GPTSeq2txt(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url: if not base_url:
base_url = "https://api.openai.com/v1" base_url = "https://api.openai.com/v1"
@ -54,31 +54,34 @@ class GPTSeq2txt(Base):
class QWenSeq2txt(Base): class QWenSeq2txt(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs): def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
import dashscope import dashscope
dashscope.api_key = key dashscope.api_key = key
self.model_name = model_name self.model_name = model_name
def transcription(self, audio, format): def transcription(self, audio, format):
from http import HTTPStatus from http import HTTPStatus
from dashscope.audio.asr import Recognition from dashscope.audio.asr import Recognition
recognition = Recognition(model=self.model_name, recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None)
format=format,
sample_rate=16000,
callback=None)
result = recognition.call(audio) result = recognition.call(audio)
ans = "" ans = ""
if result.status_code == HTTPStatus.OK: if result.status_code == HTTPStatus.OK:
for sentence in result.get_sentence(): for sentence in result.get_sentence():
ans += sentence.text.decode('utf-8') + '\n' ans += sentence.text.decode("utf-8") + "\n"
return ans, num_tokens_from_string(ans) return ans, num_tokens_from_string(ans)
return "**ERROR**: " + result.message, 0 return "**ERROR**: " + result.message, 0
class AzureSeq2txt(Base): class AzureSeq2txt(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs): def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name self.model_name = model_name
@ -86,43 +89,33 @@ class AzureSeq2txt(Base):
class XinferenceSeq2txt(Base): class XinferenceSeq2txt(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="whisper-small", **kwargs): def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get('base_url', None) self.base_url = kwargs.get("base_url", None)
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str): if isinstance(audio, str):
audio_file = open(audio, 'rb') audio_file = open(audio, "rb")
audio_data = audio_file.read() audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1] audio_file_name = audio.split("/")[-1]
else: else:
audio_data = audio audio_data = audio
audio_file_name = "audio.wav" audio_file_name = "audio.wav"
payload = { payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
"model": self.model_name,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature
}
files = { files = {"file": (audio_file_name, audio_data, "audio/wav")}
"file": (audio_file_name, audio_data, 'audio/wav')
}
try: try:
response = requests.post( response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
f"{self.base_url}/v1/audio/transcriptions",
files=files,
data=payload
)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
if 'text' in result: if "text" in result:
transcription_text = result['text'].strip() transcription_text = result["text"].strip()
return transcription_text, num_tokens_from_string(transcription_text) return transcription_text, num_tokens_from_string(transcription_text)
else: else:
return "**ERROR**: Failed to retrieve transcription.", 0 return "**ERROR**: Failed to retrieve transcription.", 0
@ -132,11 +125,11 @@ class XinferenceSeq2txt(Base):
class TencentCloudSeq2txt(Base): class TencentCloudSeq2txt(Base):
def __init__( _FACTORY_NAME = "Tencent Cloud"
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
): def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
from tencentcloud.common import credential
from tencentcloud.asr.v20190614 import asr_client from tencentcloud.asr.v20190614 import asr_client
from tencentcloud.common import credential
key = json.loads(key) key = json.loads(key)
sid = key.get("tencent_cloud_sid", "") sid = key.get("tencent_cloud_sid", "")
@ -146,11 +139,12 @@ class TencentCloudSeq2txt(Base):
self.model_name = model_name self.model_name = model_name
def transcription(self, audio, max_retries=60, retry_interval=5): def transcription(self, audio, max_retries=60, retry_interval=5):
import time
from tencentcloud.asr.v20190614 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException, TencentCloudSDKException,
) )
from tencentcloud.asr.v20190614 import models
import time
b64 = self.audio2base64(audio) b64 = self.audio2base64(audio)
try: try:
@ -174,9 +168,7 @@ class TencentCloudSeq2txt(Base):
while retries < max_retries: while retries < max_retries:
resp = self.client.DescribeTaskStatus(req) resp = self.client.DescribeTaskStatus(req)
if resp.Data.StatusStr == "success": if resp.Data.StatusStr == "success":
text = re.sub( text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
).strip()
return text, num_tokens_from_string(text) return text, num_tokens_from_string(text)
elif resp.Data.StatusStr == "failed": elif resp.Data.StatusStr == "failed":
return ( return (
@ -195,6 +187,8 @@ class TencentCloudSeq2txt(Base):
class GPUStackSeq2txt(Base): class GPUStackSeq2txt(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url): def __init__(self, key, model_name, base_url):
if not base_url: if not base_url:
raise ValueError("url cannot be None") raise ValueError("url cannot be None")
@ -206,8 +200,11 @@ class GPUStackSeq2txt(Base):
class GiteeSeq2txt(Base): class GiteeSeq2txt(Base):
_FACTORY_NAME = "GiteeAI"
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"): def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
if not base_url: if not base_url:
base_url = "https://ai.gitee.com/v1/" base_url = "https://ai.gitee.com/v1/"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name

View File

@ -70,10 +70,12 @@ class Base(ABC):
pass pass
def normalize_text(self, text): def normalize_text(self, text):
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text) return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
class FishAudioTTS(Base): class FishAudioTTS(Base):
_FACTORY_NAME = "Fish Audio"
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"): def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
if not base_url: if not base_url:
base_url = "https://api.fish.audio/v1/tts" base_url = "https://api.fish.audio/v1/tts"
@ -96,9 +98,7 @@ class FishAudioTTS(Base):
with client.stream( with client.stream(
method="POST", method="POST",
url=self.base_url, url=self.base_url,
content=ormsgpack.packb( content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
),
headers=self.headers, headers=self.headers,
timeout=None, timeout=None,
) as response: ) as response:
@ -115,6 +115,8 @@ class FishAudioTTS(Base):
class QwenTTS(Base): class QwenTTS(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name, base_url=""): def __init__(self, key, model_name, base_url=""):
import dashscope import dashscope
@ -122,10 +124,11 @@ class QwenTTS(Base):
dashscope.api_key = key dashscope.api_key = key
def tts(self, text): def tts(self, text):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
from collections import deque from collections import deque
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
class Callback(ResultCallback): class Callback(ResultCallback):
def __init__(self) -> None: def __init__(self) -> None:
self.dque = deque() self.dque = deque()
@ -159,10 +162,7 @@ class QwenTTS(Base):
text = self.normalize_text(text) text = self.normalize_text(text)
callback = Callback() callback = Callback()
SpeechSynthesizer.call(model=self.model_name, SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
text=text,
callback=callback,
format="mp3")
try: try:
for data in callback._run(): for data in callback._run():
yield data yield data
@ -173,24 +173,19 @@ class QwenTTS(Base):
class OpenAITTS(Base): class OpenAITTS(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url: if not base_url:
base_url = "https://api.openai.com/v1" base_url = "https://api.openai.com/v1"
self.api_key = key self.api_key = key
self.model_name = model_name self.model_name = model_name
self.base_url = base_url self.base_url = base_url
self.headers = { self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
def tts(self, text, voice="alloy"): def tts(self, text, voice="alloy"):
text = self.normalize_text(text) text = self.normalize_text(text)
payload = { payload = {"model": self.model_name, "voice": voice, "input": text}
"model": self.model_name,
"voice": voice,
"input": text
}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True) response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
@ -201,7 +196,8 @@ class OpenAITTS(Base):
yield chunk yield chunk
class SparkTTS: class SparkTTS(Base):
_FACTORY_NAME = "XunFei Spark"
STATUS_FIRST_FRAME = 0 STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1 STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2 STATUS_LAST_FRAME = 2
@ -219,29 +215,23 @@ class SparkTTS:
# 生成url # 生成url
def create_url(self): def create_url(self):
url = 'wss://tts-api.xfyun.cn/v2/tts' url = "wss://tts-api.xfyun.cn/v2/tts"
now = datetime.now() now = datetime.now()
date = format_date_time(mktime(now.timetuple())) date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n" signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
digestmod=hashlib.sha256).digest() signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
self.APIKey, "hmac-sha256", "host date request-line", signature_sha) v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') url = url + "?" + urlencode(v)
v = {
"authorization": authorization,
"date": date,
"host": "ws-api.xfyun.cn"
}
url = url + '?' + urlencode(v)
return url return url
def tts(self, text): def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"} BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')} Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
CommonArgs = {"app_id": self.APPID} CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue audio_queue = self.audio_queue
model_name = self.model_name model_name = self.model_name
@ -273,9 +263,7 @@ class SparkTTS:
def on_open(self, ws): def on_open(self, ws):
def run(*args): def run(*args):
d = {"common": CommonArgs, d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
"business": BusinessArgs,
"data": Data}
ws.send(json.dumps(d)) ws.send(json.dumps(d))
thread.start_new_thread(run, ()) thread.start_new_thread(run, ())
@ -283,44 +271,32 @@ class SparkTTS:
wsUrl = self.create_url() wsUrl = self.create_url()
websocket.enableTrace(False) websocket.enableTrace(False)
a = Callback() a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
on_message=a.on_message)
status_code = 0 status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True: while True:
audio_chunk = self.audio_queue.get() audio_chunk = self.audio_queue.get()
if audio_chunk is None: if audio_chunk is None:
if status_code == 0: if status_code == 0:
raise Exception( raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else: else:
break break
status_code = 1 status_code = 1
yield audio_chunk yield audio_chunk
class XinferenceTTS: class XinferenceTTS(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None) self.base_url = kwargs.get("base_url", None)
self.model_name = model_name self.model_name = model_name
self.headers = { self.headers = {"accept": "application/json", "Content-Type": "application/json"}
"accept": "application/json",
"Content-Type": "application/json"
}
def tts(self, text, voice="中文女", stream=True): def tts(self, text, voice="中文女", stream=True):
payload = { payload = {"model": self.model_name, "input": text, "voice": voice}
"model": self.model_name,
"input": text,
"voice": voice
}
response = requests.post( response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)
if response.status_code != 200: if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}") raise Exception(f"**Error**: {response.status_code}, {response.text}")
@ -336,18 +312,12 @@ class OllamaTTS(Base):
base_url = "https://api.ollama.ai/v1" base_url = "https://api.ollama.ai/v1"
self.model_name = model_name self.model_name = model_name
self.base_url = base_url self.base_url = base_url
self.headers = { self.headers = {"Content-Type": "application/json"}
"Content-Type": "application/json"
}
if key and key != "x": if key and key != "x":
self.headers["Authorization"] = f"Bear {key}" self.headers["Authorization"] = f"Bear {key}"
def tts(self, text, voice="standard-voice"): def tts(self, text, voice="standard-voice"):
payload = { payload = {"model": self.model_name, "voice": voice, "input": text}
"model": self.model_name,
"voice": voice,
"input": text
}
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True) response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
@ -359,30 +329,19 @@ class OllamaTTS(Base):
yield chunk yield chunk
class GPUStackTTS: class GPUStackTTS(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, **kwargs): def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None) self.base_url = kwargs.get("base_url", None)
self.api_key = key self.api_key = key
self.model_name = model_name self.model_name = model_name
self.headers = { self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
def tts(self, text, voice="Chinese Female", stream=True): def tts(self, text, voice="Chinese Female", stream=True):
payload = { payload = {"model": self.model_name, "input": text, "voice": voice}
"model": self.model_name,
"input": text,
"voice": voice
}
response = requests.post( response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)
if response.status_code != 200: if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}") raise Exception(f"**Error**: {response.status_code}, {response.text}")
@ -393,16 +352,15 @@ class GPUStackTTS:
class SILICONFLOWTTS(Base): class SILICONFLOWTTS(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"): def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
if not base_url: if not base_url:
base_url = "https://api.siliconflow.cn/v1" base_url = "https://api.siliconflow.cn/v1"
self.api_key = key self.api_key = key
self.model_name = model_name self.model_name = model_name
self.base_url = base_url self.base_url = base_url
self.headers = { self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
def tts(self, text, voice="anna"): def tts(self, text, voice="anna"):
text = self.normalize_text(text) text = self.normalize_text(text)
@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base):
"sample_rate": 123, "sample_rate": 123,
"stream": True, "stream": True,
"speed": 1, "speed": 1,
"gain": 0 "gain": 0,
} }
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload) response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)