mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: automatic LLMs registration (#8651)
### What problem does this PR solve? Support automatic LLMs registration. ### Type of change - [x] Refactoring
This commit is contained in:
@ -15,289 +15,53 @@
|
||||
#
|
||||
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
|
||||
#
|
||||
from .embedding_model import (
|
||||
OllamaEmbed,
|
||||
LocalAIEmbed,
|
||||
OpenAIEmbed,
|
||||
AzureEmbed,
|
||||
XinferenceEmbed,
|
||||
QWenEmbed,
|
||||
ZhipuEmbed,
|
||||
FastEmbed,
|
||||
YoudaoEmbed,
|
||||
BaiChuanEmbed,
|
||||
JinaEmbed,
|
||||
DefaultEmbedding,
|
||||
MistralEmbed,
|
||||
BedrockEmbed,
|
||||
GeminiEmbed,
|
||||
NvidiaEmbed,
|
||||
LmStudioEmbed,
|
||||
OpenAI_APIEmbed,
|
||||
CoHereEmbed,
|
||||
TogetherAIEmbed,
|
||||
PerfXCloudEmbed,
|
||||
UpstageEmbed,
|
||||
SILICONFLOWEmbed,
|
||||
ReplicateEmbed,
|
||||
BaiduYiyanEmbed,
|
||||
VoyageEmbed,
|
||||
HuggingFaceEmbed,
|
||||
VolcEngineEmbed,
|
||||
GPUStackEmbed,
|
||||
NovitaEmbed,
|
||||
GiteeEmbed
|
||||
)
|
||||
from .chat_model import (
|
||||
GptTurbo,
|
||||
AzureChat,
|
||||
ZhipuChat,
|
||||
QWenChat,
|
||||
OllamaChat,
|
||||
LocalAIChat,
|
||||
XinferenceChat,
|
||||
MoonshotChat,
|
||||
DeepSeekChat,
|
||||
VolcEngineChat,
|
||||
BaiChuanChat,
|
||||
MiniMaxChat,
|
||||
MistralChat,
|
||||
GeminiChat,
|
||||
BedrockChat,
|
||||
GroqChat,
|
||||
OpenRouterChat,
|
||||
StepFunChat,
|
||||
NvidiaChat,
|
||||
LmStudioChat,
|
||||
OpenAI_APIChat,
|
||||
CoHereChat,
|
||||
LeptonAIChat,
|
||||
TogetherAIChat,
|
||||
PerfXCloudChat,
|
||||
UpstageChat,
|
||||
NovitaAIChat,
|
||||
SILICONFLOWChat,
|
||||
PPIOChat,
|
||||
YiChat,
|
||||
ReplicateChat,
|
||||
HunyuanChat,
|
||||
SparkChat,
|
||||
BaiduYiyanChat,
|
||||
AnthropicChat,
|
||||
GoogleChat,
|
||||
HuggingFaceChat,
|
||||
GPUStackChat,
|
||||
ModelScopeChat,
|
||||
GiteeChat
|
||||
)
|
||||
|
||||
from .cv_model import (
|
||||
GptV4,
|
||||
AzureGptV4,
|
||||
OllamaCV,
|
||||
XinferenceCV,
|
||||
QWenCV,
|
||||
Zhipu4V,
|
||||
LocalCV,
|
||||
GeminiCV,
|
||||
OpenRouterCV,
|
||||
LocalAICV,
|
||||
NvidiaCV,
|
||||
LmStudioCV,
|
||||
StepFunCV,
|
||||
OpenAI_APICV,
|
||||
TogetherAICV,
|
||||
YiCV,
|
||||
HunyuanCV,
|
||||
AnthropicCV,
|
||||
SILICONFLOWCV,
|
||||
GPUStackCV,
|
||||
GoogleCV,
|
||||
)
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from .rerank_model import (
|
||||
LocalAIRerank,
|
||||
DefaultRerank,
|
||||
JinaRerank,
|
||||
YoudaoRerank,
|
||||
XInferenceRerank,
|
||||
NvidiaRerank,
|
||||
LmStudioRerank,
|
||||
OpenAI_APIRerank,
|
||||
CoHereRerank,
|
||||
TogetherAIRerank,
|
||||
SILICONFLOWRerank,
|
||||
BaiduYiyanRerank,
|
||||
VoyageRerank,
|
||||
QWenRerank,
|
||||
GPUStackRerank,
|
||||
HuggingfaceRerank,
|
||||
NovitaRerank,
|
||||
GiteeRerank
|
||||
)
|
||||
ChatModel = globals().get("ChatModel", {})
|
||||
CvModel = globals().get("CvModel", {})
|
||||
EmbeddingModel = globals().get("EmbeddingModel", {})
|
||||
RerankModel = globals().get("RerankModel", {})
|
||||
Seq2txtModel = globals().get("Seq2txtModel", {})
|
||||
TTSModel = globals().get("TTSModel", {})
|
||||
|
||||
from .sequence2txt_model import (
|
||||
GPTSeq2txt,
|
||||
QWenSeq2txt,
|
||||
AzureSeq2txt,
|
||||
XinferenceSeq2txt,
|
||||
TencentCloudSeq2txt,
|
||||
GPUStackSeq2txt,
|
||||
GiteeSeq2txt
|
||||
)
|
||||
|
||||
from .tts_model import (
|
||||
FishAudioTTS,
|
||||
QwenTTS,
|
||||
OpenAITTS,
|
||||
SparkTTS,
|
||||
XinferenceTTS,
|
||||
GPUStackTTS,
|
||||
SILICONFLOWTTS,
|
||||
)
|
||||
|
||||
EmbeddingModel = {
|
||||
"Ollama": OllamaEmbed,
|
||||
"LocalAI": LocalAIEmbed,
|
||||
"OpenAI": OpenAIEmbed,
|
||||
"Azure-OpenAI": AzureEmbed,
|
||||
"Xinference": XinferenceEmbed,
|
||||
"Tongyi-Qianwen": QWenEmbed,
|
||||
"ZHIPU-AI": ZhipuEmbed,
|
||||
"FastEmbed": FastEmbed,
|
||||
"Youdao": YoudaoEmbed,
|
||||
"BaiChuan": BaiChuanEmbed,
|
||||
"Jina": JinaEmbed,
|
||||
"BAAI": DefaultEmbedding,
|
||||
"Mistral": MistralEmbed,
|
||||
"Bedrock": BedrockEmbed,
|
||||
"Gemini": GeminiEmbed,
|
||||
"NVIDIA": NvidiaEmbed,
|
||||
"LM-Studio": LmStudioEmbed,
|
||||
"OpenAI-API-Compatible": OpenAI_APIEmbed,
|
||||
"VLLM": OpenAI_APIEmbed,
|
||||
"Cohere": CoHereEmbed,
|
||||
"TogetherAI": TogetherAIEmbed,
|
||||
"PerfXCloud": PerfXCloudEmbed,
|
||||
"Upstage": UpstageEmbed,
|
||||
"SILICONFLOW": SILICONFLOWEmbed,
|
||||
"Replicate": ReplicateEmbed,
|
||||
"BaiduYiyan": BaiduYiyanEmbed,
|
||||
"Voyage AI": VoyageEmbed,
|
||||
"HuggingFace": HuggingFaceEmbed,
|
||||
"VolcEngine": VolcEngineEmbed,
|
||||
"GPUStack": GPUStackEmbed,
|
||||
"NovitaAI": NovitaEmbed,
|
||||
"GiteeAI": GiteeEmbed
|
||||
MODULE_MAPPING = {
|
||||
"chat_model": ChatModel,
|
||||
"cv_model": CvModel,
|
||||
"embedding_model": EmbeddingModel,
|
||||
"rerank_model": RerankModel,
|
||||
"sequence2txt_model": Seq2txtModel,
|
||||
"tts_model": TTSModel,
|
||||
}
|
||||
|
||||
CvModel = {
|
||||
"OpenAI": GptV4,
|
||||
"Azure-OpenAI": AzureGptV4,
|
||||
"Ollama": OllamaCV,
|
||||
"Xinference": XinferenceCV,
|
||||
"Tongyi-Qianwen": QWenCV,
|
||||
"ZHIPU-AI": Zhipu4V,
|
||||
"Moonshot": LocalCV,
|
||||
"Gemini": GeminiCV,
|
||||
"OpenRouter": OpenRouterCV,
|
||||
"LocalAI": LocalAICV,
|
||||
"NVIDIA": NvidiaCV,
|
||||
"LM-Studio": LmStudioCV,
|
||||
"StepFun": StepFunCV,
|
||||
"OpenAI-API-Compatible": OpenAI_APICV,
|
||||
"VLLM": OpenAI_APICV,
|
||||
"TogetherAI": TogetherAICV,
|
||||
"01.AI": YiCV,
|
||||
"Tencent Hunyuan": HunyuanCV,
|
||||
"Anthropic": AnthropicCV,
|
||||
"SILICONFLOW": SILICONFLOWCV,
|
||||
"GPUStack": GPUStackCV,
|
||||
"Google Cloud": GoogleCV
|
||||
}
|
||||
package_name = __name__
|
||||
|
||||
ChatModel = {
|
||||
"OpenAI": GptTurbo,
|
||||
"Azure-OpenAI": AzureChat,
|
||||
"ZHIPU-AI": ZhipuChat,
|
||||
"Tongyi-Qianwen": QWenChat,
|
||||
"Ollama": OllamaChat,
|
||||
"LocalAI": LocalAIChat,
|
||||
"Xinference": XinferenceChat,
|
||||
"Moonshot": MoonshotChat,
|
||||
"DeepSeek": DeepSeekChat,
|
||||
"VolcEngine": VolcEngineChat,
|
||||
"BaiChuan": BaiChuanChat,
|
||||
"MiniMax": MiniMaxChat,
|
||||
"Mistral": MistralChat,
|
||||
"Gemini": GeminiChat,
|
||||
"Bedrock": BedrockChat,
|
||||
"Groq": GroqChat,
|
||||
"OpenRouter": OpenRouterChat,
|
||||
"StepFun": StepFunChat,
|
||||
"NVIDIA": NvidiaChat,
|
||||
"LM-Studio": LmStudioChat,
|
||||
"OpenAI-API-Compatible": OpenAI_APIChat,
|
||||
"VLLM": OpenAI_APIChat,
|
||||
"Cohere": CoHereChat,
|
||||
"LeptonAI": LeptonAIChat,
|
||||
"TogetherAI": TogetherAIChat,
|
||||
"PerfXCloud": PerfXCloudChat,
|
||||
"Upstage": UpstageChat,
|
||||
"NovitaAI": NovitaAIChat,
|
||||
"SILICONFLOW": SILICONFLOWChat,
|
||||
"PPIO": PPIOChat,
|
||||
"01.AI": YiChat,
|
||||
"Replicate": ReplicateChat,
|
||||
"Tencent Hunyuan": HunyuanChat,
|
||||
"XunFei Spark": SparkChat,
|
||||
"BaiduYiyan": BaiduYiyanChat,
|
||||
"Anthropic": AnthropicChat,
|
||||
"Google Cloud": GoogleChat,
|
||||
"HuggingFace": HuggingFaceChat,
|
||||
"GPUStack": GPUStackChat,
|
||||
"ModelScope":ModelScopeChat,
|
||||
"GiteeAI": GiteeChat
|
||||
}
|
||||
for module_name, mapping_dict in MODULE_MAPPING.items():
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
module = importlib.import_module(full_module_name)
|
||||
|
||||
RerankModel = {
|
||||
"LocalAI": LocalAIRerank,
|
||||
"BAAI": DefaultRerank,
|
||||
"Jina": JinaRerank,
|
||||
"Youdao": YoudaoRerank,
|
||||
"Xinference": XInferenceRerank,
|
||||
"NVIDIA": NvidiaRerank,
|
||||
"LM-Studio": LmStudioRerank,
|
||||
"OpenAI-API-Compatible": OpenAI_APIRerank,
|
||||
"VLLM": CoHereRerank,
|
||||
"Cohere": CoHereRerank,
|
||||
"TogetherAI": TogetherAIRerank,
|
||||
"SILICONFLOW": SILICONFLOWRerank,
|
||||
"BaiduYiyan": BaiduYiyanRerank,
|
||||
"Voyage AI": VoyageRerank,
|
||||
"Tongyi-Qianwen": QWenRerank,
|
||||
"GPUStack": GPUStackRerank,
|
||||
"HuggingFace": HuggingfaceRerank,
|
||||
"NovitaAI": NovitaRerank,
|
||||
"GiteeAI": GiteeRerank
|
||||
}
|
||||
base_class = None
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and name == "Base":
|
||||
base_class = obj
|
||||
break
|
||||
if base_class is None:
|
||||
continue
|
||||
|
||||
Seq2txtModel = {
|
||||
"OpenAI": GPTSeq2txt,
|
||||
"Tongyi-Qianwen": QWenSeq2txt,
|
||||
"Azure-OpenAI": AzureSeq2txt,
|
||||
"Xinference": XinferenceSeq2txt,
|
||||
"Tencent Cloud": TencentCloudSeq2txt,
|
||||
"GPUStack": GPUStackSeq2txt,
|
||||
"GiteeAI": GiteeSeq2txt
|
||||
}
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
|
||||
if isinstance(obj._FACTORY_NAME, list):
|
||||
for factory_name in obj._FACTORY_NAME:
|
||||
mapping_dict[factory_name] = obj
|
||||
else:
|
||||
mapping_dict[obj._FACTORY_NAME] = obj
|
||||
|
||||
TTSModel = {
|
||||
"Fish Audio": FishAudioTTS,
|
||||
"Tongyi-Qianwen": QwenTTS,
|
||||
"OpenAI": OpenAITTS,
|
||||
"XunFei Spark": SparkTTS,
|
||||
"Xinference": XinferenceTTS,
|
||||
"GPUStack": GPUStackTTS,
|
||||
"SILICONFLOW": SILICONFLOWTTS,
|
||||
}
|
||||
__all__ = [
|
||||
"ChatModel",
|
||||
"CvModel",
|
||||
"EmbeddingModel",
|
||||
"RerankModel",
|
||||
"Seq2txtModel",
|
||||
"TTSModel",
|
||||
]
|
||||
|
||||
@ -142,11 +142,7 @@ class Base(ABC):
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({
|
||||
"name": name,
|
||||
"args": args,
|
||||
"result": res
|
||||
}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
|
||||
def _append_history(self, hist, tool_call, tool_res):
|
||||
hist.append(
|
||||
@ -191,10 +187,10 @@ class Base(ABC):
|
||||
tk_count = 0
|
||||
hist = deepcopy(history)
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries+1):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds*2):
|
||||
for _ in range(self.max_rounds * 2):
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
||||
tk_count += self.total_token_count(response)
|
||||
if any([not response.choices, not response.choices[0].message]):
|
||||
@ -269,7 +265,7 @@ class Base(ABC):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds*2):
|
||||
for _ in range(self.max_rounds * 2):
|
||||
reasoning_start = False
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
final_tool_calls = {}
|
||||
@ -430,6 +426,8 @@ class Base(ABC):
|
||||
|
||||
|
||||
class GptTurbo(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
@ -437,6 +435,8 @@ class GptTurbo(Base):
|
||||
|
||||
|
||||
class MoonshotChat(Base):
|
||||
_FACTORY_NAME = "Moonshot"
|
||||
|
||||
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.moonshot.cn/v1"
|
||||
@ -444,6 +444,8 @@ class MoonshotChat(Base):
|
||||
|
||||
|
||||
class XinferenceChat(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -452,6 +454,8 @@ class XinferenceChat(Base):
|
||||
|
||||
|
||||
class HuggingFaceChat(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -460,6 +464,8 @@ class HuggingFaceChat(Base):
|
||||
|
||||
|
||||
class ModelScopeChat(Base):
|
||||
_FACTORY_NAME = "ModelScope"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -468,6 +474,8 @@ class ModelScopeChat(Base):
|
||||
|
||||
|
||||
class DeepSeekChat(Base):
|
||||
_FACTORY_NAME = "DeepSeek"
|
||||
|
||||
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
@ -475,6 +483,8 @@ class DeepSeekChat(Base):
|
||||
|
||||
|
||||
class AzureChat(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
@ -484,6 +494,8 @@ class AzureChat(Base):
|
||||
|
||||
|
||||
class BaiChuanChat(Base):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
@ -557,6 +569,8 @@ class BaiChuanChat(Base):
|
||||
|
||||
|
||||
class QWenChat(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
@ -565,6 +579,8 @@ class QWenChat(Base):
|
||||
|
||||
|
||||
class ZhipuChat(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -630,6 +646,8 @@ class ZhipuChat(Base):
|
||||
|
||||
|
||||
class OllamaChat(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -694,6 +712,8 @@ class OllamaChat(Base):
|
||||
|
||||
|
||||
class LocalAIChat(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -752,6 +772,8 @@ class LocalLLM(Base):
|
||||
|
||||
|
||||
class VolcEngineChat(Base):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
|
||||
"""
|
||||
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
||||
@ -765,6 +787,8 @@ class VolcEngineChat(Base):
|
||||
|
||||
|
||||
class MiniMaxChat(Base):
|
||||
_FACTORY_NAME = "MiniMax"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -843,6 +867,8 @@ class MiniMaxChat(Base):
|
||||
|
||||
|
||||
class MistralChat(Base):
|
||||
_FACTORY_NAME = "Mistral"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -896,6 +922,8 @@ class MistralChat(Base):
|
||||
|
||||
|
||||
class BedrockChat(Base):
|
||||
_FACTORY_NAME = "Bedrock"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -978,6 +1006,8 @@ class BedrockChat(Base):
|
||||
|
||||
|
||||
class GeminiChat(Base):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -997,6 +1027,7 @@ class GeminiChat(Base):
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
hist = []
|
||||
for item in history:
|
||||
@ -1019,6 +1050,7 @@ class GeminiChat(Base):
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
if system:
|
||||
self.model._system_instruction = content_types.to_content(system)
|
||||
@ -1042,6 +1074,8 @@ class GeminiChat(Base):
|
||||
|
||||
|
||||
class GroqChat(Base):
|
||||
_FACTORY_NAME = "Groq"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -1086,6 +1120,8 @@ class GroqChat(Base):
|
||||
|
||||
## openrouter
|
||||
class OpenRouterChat(Base):
|
||||
_FACTORY_NAME = "OpenRouter"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
@ -1093,6 +1129,8 @@ class OpenRouterChat(Base):
|
||||
|
||||
|
||||
class StepFunChat(Base):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
@ -1100,6 +1138,8 @@ class StepFunChat(Base):
|
||||
|
||||
|
||||
class NvidiaChat(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://integrate.api.nvidia.com/v1"
|
||||
@ -1107,6 +1147,8 @@ class NvidiaChat(Base):
|
||||
|
||||
|
||||
class LmStudioChat(Base):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -1117,6 +1159,8 @@ class LmStudioChat(Base):
|
||||
|
||||
|
||||
class OpenAI_APIChat(Base):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
@ -1125,6 +1169,8 @@ class OpenAI_APIChat(Base):
|
||||
|
||||
|
||||
class PPIOChat(Base):
|
||||
_FACTORY_NAME = "PPIO"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.ppinfra.com/v3/openai"
|
||||
@ -1132,6 +1178,8 @@ class PPIOChat(Base):
|
||||
|
||||
|
||||
class CoHereChat(Base):
|
||||
_FACTORY_NAME = "Cohere"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -1207,6 +1255,8 @@ class CoHereChat(Base):
|
||||
|
||||
|
||||
class LeptonAIChat(Base):
|
||||
_FACTORY_NAME = "LeptonAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
if not base_url:
|
||||
base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1")
|
||||
@ -1214,6 +1264,8 @@ class LeptonAIChat(Base):
|
||||
|
||||
|
||||
class TogetherAIChat(Base):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
@ -1221,6 +1273,8 @@ class TogetherAIChat(Base):
|
||||
|
||||
|
||||
class PerfXCloudChat(Base):
|
||||
_FACTORY_NAME = "PerfXCloud"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://cloud.perfxlab.cn/v1"
|
||||
@ -1228,6 +1282,8 @@ class PerfXCloudChat(Base):
|
||||
|
||||
|
||||
class UpstageChat(Base):
|
||||
_FACTORY_NAME = "Upstage"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.upstage.ai/v1/solar"
|
||||
@ -1235,6 +1291,8 @@ class UpstageChat(Base):
|
||||
|
||||
|
||||
class NovitaAIChat(Base):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai"
|
||||
@ -1242,6 +1300,8 @@ class NovitaAIChat(Base):
|
||||
|
||||
|
||||
class SILICONFLOWChat(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
@ -1249,6 +1309,8 @@ class SILICONFLOWChat(Base):
|
||||
|
||||
|
||||
class YiChat(Base):
|
||||
_FACTORY_NAME = "01.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.lingyiwanwu.com/v1"
|
||||
@ -1256,6 +1318,8 @@ class YiChat(Base):
|
||||
|
||||
|
||||
class GiteeChat(Base):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/"
|
||||
@ -1263,6 +1327,8 @@ class GiteeChat(Base):
|
||||
|
||||
|
||||
class ReplicateChat(Base):
|
||||
_FACTORY_NAME = "Replicate"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -1302,6 +1368,8 @@ class ReplicateChat(Base):
|
||||
|
||||
|
||||
class HunyuanChat(Base):
|
||||
_FACTORY_NAME = "Tencent Hunyuan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -1378,6 +1446,8 @@ class HunyuanChat(Base):
|
||||
|
||||
|
||||
class SparkChat(Base):
|
||||
_FACTORY_NAME = "XunFei Spark"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://spark-api-open.xf-yun.com/v1"
|
||||
@ -1398,6 +1468,8 @@ class SparkChat(Base):
|
||||
|
||||
|
||||
class BaiduYiyanChat(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -1444,6 +1516,8 @@ class BaiduYiyanChat(Base):
|
||||
|
||||
|
||||
class AnthropicChat(Base):
|
||||
_FACTORY_NAME = "Anthropic"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.anthropic.com/v1/"
|
||||
@ -1451,6 +1525,8 @@ class AnthropicChat(Base):
|
||||
|
||||
|
||||
class GoogleChat(Base):
|
||||
_FACTORY_NAME = "Google Cloud"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
@ -1529,9 +1605,11 @@ class GoogleChat(Base):
|
||||
if "role" in item and item["role"] == "assistant":
|
||||
item["role"] = "model"
|
||||
if "content" in item:
|
||||
item["parts"] = [{
|
||||
item["parts"] = [
|
||||
{
|
||||
"text": item.pop("content"),
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
response = self.client.generate_content(hist, generation_config=gen_conf)
|
||||
ans = response.text
|
||||
@ -1587,6 +1665,8 @@ class GoogleChat(Base):
|
||||
|
||||
|
||||
class GPUStackChat(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
|
||||
@ -57,7 +57,7 @@ class Base(ABC):
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
@ -79,7 +79,7 @@ class Base(ABC):
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
@ -87,8 +87,7 @@ class Base(ABC):
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = resp.usage.total_tokens
|
||||
@ -117,13 +116,12 @@ class Base(ABC):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
||||
},
|
||||
{
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -136,9 +134,7 @@ class Base(ABC):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
@ -156,14 +152,13 @@ class Base(ABC):
|
||||
"url": f"data:image/jpeg;base64,{b64}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": text
|
||||
},
|
||||
{"type": "text", "text": text},
|
||||
]
|
||||
|
||||
|
||||
class GptV4(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
@ -181,7 +176,7 @@ class GptV4(Base):
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=prompt
|
||||
messages=prompt,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
@ -197,9 +192,11 @@ class GptV4(Base):
|
||||
|
||||
|
||||
class AzureGptV4(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
api_key = json.loads(key).get('api_key', '')
|
||||
api_version = json.loads(key).get('api_version', '2024-02-01')
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
@ -212,10 +209,7 @@ class AzureGptV4(Base):
|
||||
if "text" in c:
|
||||
c["type"] = "text"
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=prompt
|
||||
)
|
||||
res = self.client.chat.completions.create(model=self.model_name, messages=prompt)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
@ -230,8 +224,11 @@ class AzureGptV4(Base):
|
||||
|
||||
|
||||
class QWenCV(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
@ -247,12 +244,11 @@ class QWenCV(Base):
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"image": f"file://{path}"},
|
||||
{
|
||||
"image": f"file://{path}"
|
||||
},
|
||||
{
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -270,9 +266,7 @@ class QWenCV(Base):
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"image": f"file://{path}"
|
||||
},
|
||||
{"image": f"file://{path}"},
|
||||
{
|
||||
"text": prompt if prompt else vision_llm_describe_prompt(),
|
||||
},
|
||||
@ -290,9 +284,10 @@ class QWenCV(Base):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||
return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
|
||||
return response.message, 0
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
@ -303,33 +298,36 @@ class QWenCV(Base):
|
||||
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
||||
return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
|
||||
return response.message, 0
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
|
||||
for his in history:
|
||||
if his["role"] == "user":
|
||||
his["content"] = self.chat_prompt(his["content"], image)
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=history,
|
||||
response = MultiModalConversation.call(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7))
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
)
|
||||
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
ans = response.output.choices[0]['message']['content']
|
||||
ans = response.output.choices[0]["message"]["content"]
|
||||
if isinstance(ans, list):
|
||||
ans = ans[0]["text"] if ans else ""
|
||||
tk_count += response.usage.total_tokens
|
||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return ans, tk_count
|
||||
|
||||
return "**ERROR**: " + response.message, tk_count
|
||||
@ -338,6 +336,7 @@ class QWenCV(Base):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
|
||||
@ -348,24 +347,25 @@ class QWenCV(Base):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
response = MultiModalConversation.call(model=self.model_name, messages=history,
|
||||
response = MultiModalConversation.call(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
for resp in response:
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
cnt = resp.output.choices[0]['message']['content']
|
||||
cnt = resp.output.choices[0]["message"]["content"]
|
||||
if isinstance(cnt, list):
|
||||
cnt = cnt[0]["text"] if ans else ""
|
||||
ans += cnt
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
else:
|
||||
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
|
||||
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
@ -373,6 +373,8 @@ class QWenCV(Base):
|
||||
|
||||
|
||||
class Zhipu4V(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
@ -394,10 +396,7 @@ class Zhipu4V(Base):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=vision_prompt
|
||||
)
|
||||
res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt)
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
@ -412,7 +411,7 @@ class Zhipu4V(Base):
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
)
|
||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||
except Exception as e:
|
||||
@ -434,7 +433,7 @@ class Zhipu4V(Base):
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7),
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
@ -442,8 +441,7 @@ class Zhipu4V(Base):
|
||||
delta = resp.choices[0].delta.content
|
||||
ans += delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
tk_count = resp.usage.total_tokens
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = resp.usage.total_tokens
|
||||
@ -455,6 +453,8 @@ class Zhipu4V(Base):
|
||||
|
||||
|
||||
class OllamaCV(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
@ -466,7 +466,7 @@ class OllamaCV(Base):
|
||||
response = self.client.generate(
|
||||
model=self.model_name,
|
||||
prompt=prompt[0]["content"][1]["text"],
|
||||
images=[image]
|
||||
images=[image],
|
||||
)
|
||||
ans = response["response"].strip()
|
||||
return ans, 128
|
||||
@ -507,7 +507,7 @@ class OllamaCV(Base):
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
options=options,
|
||||
keep_alive=-1
|
||||
keep_alive=-1,
|
||||
)
|
||||
|
||||
ans = response["message"]["content"].strip()
|
||||
@ -538,7 +538,7 @@ class OllamaCV(Base):
|
||||
messages=history,
|
||||
stream=True,
|
||||
options=options,
|
||||
keep_alive=-1
|
||||
keep_alive=-1,
|
||||
)
|
||||
for resp in response:
|
||||
if resp["done"]:
|
||||
@ -551,6 +551,8 @@ class OllamaCV(Base):
|
||||
|
||||
|
||||
class LocalAICV(GptV4):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
||||
if not base_url:
|
||||
raise ValueError("Local cv model url cannot be None")
|
||||
@ -561,6 +563,8 @@ class LocalAICV(GptV4):
|
||||
|
||||
|
||||
class XinferenceCV(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
@ -570,10 +574,7 @@ class XinferenceCV(Base):
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
|
||||
res = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64)
|
||||
)
|
||||
res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64))
|
||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
@ -588,8 +589,11 @@ class XinferenceCV(Base):
|
||||
|
||||
|
||||
class GeminiCV(Base):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||
from google.generativeai import GenerativeModel, client
|
||||
|
||||
client.configure(api_key=key)
|
||||
_client = client.get_default_generative_client()
|
||||
self.model_name = model_name
|
||||
@ -599,18 +603,21 @@ class GeminiCV(Base):
|
||||
|
||||
def describe(self, image):
|
||||
from PIL.Image import open
|
||||
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
|
||||
prompt = (
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
)
|
||||
b64 = self.image2base64(image)
|
||||
img = open(BytesIO(base64.b64decode(b64)))
|
||||
input = [prompt, img]
|
||||
res = self.model.generate_content(
|
||||
input
|
||||
)
|
||||
res = self.model.generate_content(input)
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
from PIL.Image import open
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||
img = open(BytesIO(base64.b64decode(b64)))
|
||||
@ -622,6 +629,7 @@ class GeminiCV(Base):
|
||||
|
||||
def chat(self, system, history, gen_conf, image=""):
|
||||
from transformers import GenerationConfig
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
try:
|
||||
@ -635,9 +643,7 @@ class GeminiCV(Base):
|
||||
his.pop("content")
|
||||
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||
|
||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)))
|
||||
response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
|
||||
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
@ -646,6 +652,7 @@ class GeminiCV(Base):
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||
from transformers import GenerationConfig
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
|
||||
@ -661,9 +668,11 @@ class GeminiCV(Base):
|
||||
his.pop("content")
|
||||
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||
|
||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)), stream=True)
|
||||
response = self.model.generate_content(
|
||||
history,
|
||||
generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
if not resp.text:
|
||||
@ -677,6 +686,8 @@ class GeminiCV(Base):
|
||||
|
||||
|
||||
class OpenRouterCV(GptV4):
|
||||
_FACTORY_NAME = "OpenRouter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
@ -692,6 +703,8 @@ class OpenRouterCV(GptV4):
|
||||
|
||||
|
||||
class LocalCV(Base):
|
||||
_FACTORY_NAME = "Moonshot"
|
||||
|
||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||
pass
|
||||
|
||||
@ -700,6 +713,8 @@ class LocalCV(Base):
|
||||
|
||||
|
||||
class NvidiaCV(Base):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
@ -726,9 +741,7 @@ class NvidiaCV(Base):
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {self.key}",
|
||||
},
|
||||
json={
|
||||
"messages": self.prompt(b64)
|
||||
},
|
||||
json={"messages": self.prompt(b64)},
|
||||
)
|
||||
response = response.json()
|
||||
return (
|
||||
@ -774,10 +787,7 @@ class NvidiaCV(Base):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
prompt if prompt else vision_llm_describe_prompt()
|
||||
)
|
||||
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
|
||||
"content": (prompt if prompt else vision_llm_describe_prompt()) + f' <img src="data:image/jpeg;base64,{b64}"/>',
|
||||
}
|
||||
]
|
||||
|
||||
@ -791,6 +801,8 @@ class NvidiaCV(Base):
|
||||
|
||||
|
||||
class StepFunCV(GptV4):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
|
||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
@ -800,6 +812,8 @@ class StepFunCV(GptV4):
|
||||
|
||||
|
||||
class LmStudioCV(GptV4):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -810,6 +824,8 @@ class LmStudioCV(GptV4):
|
||||
|
||||
|
||||
class OpenAI_APICV(GptV4):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
@ -820,6 +836,8 @@ class OpenAI_APICV(GptV4):
|
||||
|
||||
|
||||
class TogetherAICV(GptV4):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
@ -827,20 +845,38 @@ class TogetherAICV(GptV4):
|
||||
|
||||
|
||||
class YiCV(GptV4):
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
|
||||
_FACTORY_NAME = "01.AI"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://api.lingyiwanwu.com/v1",
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.lingyiwanwu.com/v1"
|
||||
super().__init__(key, model_name, lang, base_url)
|
||||
|
||||
|
||||
class SILICONFLOWCV(GptV4):
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key,
|
||||
model_name,
|
||||
lang="Chinese",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
super().__init__(key, model_name, lang, base_url)
|
||||
|
||||
|
||||
class HunyuanCV(Base):
|
||||
_FACTORY_NAME = "Tencent Hunyuan"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||
@ -895,14 +931,13 @@ class HunyuanCV(Base):
|
||||
"Contents": [
|
||||
{
|
||||
"Type": "image_url",
|
||||
"ImageUrl": {
|
||||
"Url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
"ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"},
|
||||
},
|
||||
{
|
||||
"Type": "text",
|
||||
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -910,6 +945,8 @@ class HunyuanCV(Base):
|
||||
|
||||
|
||||
class AnthropicCV(Base):
|
||||
_FACTORY_NAME = "Anthropic"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import anthropic
|
||||
|
||||
@ -933,38 +970,29 @@ class AnthropicCV(Base):
|
||||
"data": b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def describe(self, image):
|
||||
b64 = self.image2base64(image)
|
||||
prompt = self.prompt(b64,
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
prompt = self.prompt(
|
||||
b64,
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
)
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=self.max_tokens,
|
||||
messages=prompt
|
||||
)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=self.max_tokens,
|
||||
messages=prompt
|
||||
)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
|
||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if "presence_penalty" in gen_conf:
|
||||
@ -984,11 +1012,7 @@ class AnthropicCV(Base):
|
||||
).to_dict()
|
||||
ans = response["content"][0]["text"]
|
||||
if response["stop_reason"] == "max_tokens":
|
||||
ans += (
|
||||
"...\nFor the content length reason, it stopped, continue?"
|
||||
if is_english([ans])
|
||||
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
)
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return (
|
||||
ans,
|
||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||
@ -1014,7 +1038,7 @@ class AnthropicCV(Base):
|
||||
**gen_conf,
|
||||
)
|
||||
for res in response:
|
||||
if res.type == 'content_block_delta':
|
||||
if res.type == "content_block_delta":
|
||||
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
||||
if ans.find("<think>") < 0:
|
||||
ans += "<think>"
|
||||
@ -1030,7 +1054,10 @@ class AnthropicCV(Base):
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class GPUStackCV(GptV4):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -1041,6 +1068,8 @@ class GPUStackCV(GptV4):
|
||||
|
||||
|
||||
class GoogleCV(Base):
|
||||
_FACTORY_NAME = "Google Cloud"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
|
||||
import base64
|
||||
|
||||
@ -1079,8 +1108,11 @@ class GoogleCV(Base):
|
||||
self.client = glm.GenerativeModel(model_name=self.model_name)
|
||||
|
||||
def describe(self, image):
|
||||
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
prompt = (
|
||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||
if self.lang.lower() == "chinese"
|
||||
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||
)
|
||||
|
||||
if "claude" in self.model_name:
|
||||
b64 = self.image2base64(image)
|
||||
@ -1096,17 +1128,14 @@ class GoogleCV(Base):
|
||||
"data": b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=8192,
|
||||
messages=vision_prompt
|
||||
messages=vision_prompt,
|
||||
)
|
||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||
else:
|
||||
@ -1114,10 +1143,7 @@ class GoogleCV(Base):
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
# Create proper image part for Gemini
|
||||
image_part = glm.Part.from_data(
|
||||
data=base64.b64decode(b64),
|
||||
mime_type="image/jpeg"
|
||||
)
|
||||
image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
|
||||
input = [prompt, image_part]
|
||||
res = self.client.generate_content(input)
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
@ -1137,18 +1163,11 @@ class GoogleCV(Base):
|
||||
"data": b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt if prompt else vision_llm_describe_prompt()
|
||||
}
|
||||
{"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()},
|
||||
],
|
||||
}
|
||||
]
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=8192,
|
||||
messages=vision_prompt
|
||||
)
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt)
|
||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||
else:
|
||||
import vertexai.generative_models as glm
|
||||
@ -1156,10 +1175,7 @@ class GoogleCV(Base):
|
||||
b64 = self.image2base64(image)
|
||||
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
||||
# Create proper image part for Gemini
|
||||
image_part = glm.Part.from_data(
|
||||
data=base64.b64decode(b64),
|
||||
mime_type="image/jpeg"
|
||||
)
|
||||
image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
|
||||
input = [vision_prompt, image_part]
|
||||
res = self.client.generate_content(input)
|
||||
return res.text, res.usage_metadata.total_token_count
|
||||
@ -1180,25 +1196,17 @@ class GoogleCV(Base):
|
||||
"data": image,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": his["content"]
|
||||
}
|
||||
{"type": "text", "text": his["content"]},
|
||||
]
|
||||
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
max_tokens=8192,
|
||||
messages=history,
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)
|
||||
)
|
||||
response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
else:
|
||||
import vertexai.generative_models as glm
|
||||
from transformers import GenerationConfig
|
||||
|
||||
if system:
|
||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||
try:
|
||||
@ -1213,15 +1221,10 @@ class GoogleCV(Base):
|
||||
|
||||
# Create proper image part for Gemini
|
||||
img_bytes = base64.b64decode(image)
|
||||
image_part = glm.Part.from_data(
|
||||
data=img_bytes,
|
||||
mime_type="image/jpeg"
|
||||
)
|
||||
image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg")
|
||||
history[-1]["parts"].append(image_part)
|
||||
|
||||
response = self.client.generate_content(history, generation_config=GenerationConfig(
|
||||
temperature=gen_conf.get("temperature", 0.3),
|
||||
top_p=gen_conf.get("top_p", 0.7)))
|
||||
response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
|
||||
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
|
||||
@ -13,28 +13,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from zhipuai import ZhipuAI
|
||||
import os
|
||||
from abc import ABC
|
||||
from ollama import Client
|
||||
import dashscope
|
||||
from openai import OpenAI
|
||||
import numpy as np
|
||||
import asyncio
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import google.generativeai as genai
|
||||
import json
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -60,7 +59,8 @@ class Base(ABC):
|
||||
|
||||
|
||||
class DefaultEmbedding(Base):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
_FACTORY_NAME = "BAAI"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
_model = None
|
||||
_model_name = ""
|
||||
_model_lock = threading.Lock()
|
||||
@ -79,21 +79,22 @@ class DefaultEmbedding(Base):
|
||||
"""
|
||||
if not settings.LIGHTEN:
|
||||
with DefaultEmbedding._model_lock:
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
from FlagEmbedding import FlagModel
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
DefaultEmbedding._model = FlagModel(
|
||||
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
use_fp16=torch.cuda.is_available(),
|
||||
)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
||||
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
DefaultEmbedding._model = FlagModel(model_dir,
|
||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||
use_fp16=torch.cuda.is_available())
|
||||
model_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = DefaultEmbedding._model_name
|
||||
|
||||
@ -105,7 +106,7 @@ class DefaultEmbedding(Base):
|
||||
token_count += num_tokens_from_string(t)
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
||||
ress.extend(self._model.encode(texts[i : i + batch_size]).tolist())
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text: str):
|
||||
@ -114,8 +115,9 @@ class DefaultEmbedding(Base):
|
||||
|
||||
|
||||
class OpenAIEmbed(Base):
|
||||
def __init__(self, key, model_name="text-embedding-ada-002",
|
||||
base_url="https://api.openai.com/v1"):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
@ -128,8 +130,7 @@ class OpenAIEmbed(Base):
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i:i + batch_size],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
@ -138,12 +139,13 @@ class OpenAIEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name)
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
|
||||
|
||||
class LocalAIEmbed(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local embedding model url cannot be None")
|
||||
@ -155,7 +157,7 @@ class LocalAIEmbed(Base):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
except Exception as _e:
|
||||
@ -169,41 +171,42 @@ class LocalAIEmbed(Base):
|
||||
|
||||
|
||||
class AzureEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
api_key = json.loads(key).get('api_key', '')
|
||||
api_version = json.loads(key).get('api_version', '2024-02-01')
|
||||
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class BaiChuanEmbed(OpenAIEmbed):
|
||||
def __init__(self, key,
|
||||
model_name='Baichuan-Text-Embedding',
|
||||
base_url='https://api.baichuan-ai.com/v1'):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class QWenEmbed(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
import dashscope
|
||||
|
||||
batch_size = 4
|
||||
res = []
|
||||
token_count = 0
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
for i in range(0, len(texts), batch_size):
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=self.model_name,
|
||||
input=texts[i:i + batch_size],
|
||||
api_key=self.key,
|
||||
text_type="document"
|
||||
)
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||
try:
|
||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||
for e in resp["output"]["embeddings"]:
|
||||
@ -216,20 +219,16 @@ class QWenEmbed(Base):
|
||||
return np.array(res), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=self.model_name,
|
||||
input=text[:2048],
|
||||
api_key=self.key,
|
||||
text_type="query"
|
||||
)
|
||||
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
||||
try:
|
||||
return np.array(resp["output"]["embeddings"][0]
|
||||
["embedding"]), self.total_token_count(resp)
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
|
||||
except Exception as _e:
|
||||
log_exception(_e, resp)
|
||||
|
||||
|
||||
class ZhipuEmbed(Base):
|
||||
_FACTORY_NAME = "ZHIPU-AI"
|
||||
|
||||
def __init__(self, key, model_name="embedding-2", **kwargs):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
@ -246,8 +245,7 @@ class ZhipuEmbed(Base):
|
||||
texts = [truncate(t, MAX_LEN) for t in texts]
|
||||
|
||||
for txt in texts:
|
||||
res = self.client.embeddings.create(input=txt,
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
||||
try:
|
||||
arr.append(res.data[0].embedding)
|
||||
tks_num += self.total_token_count(res)
|
||||
@ -256,8 +254,7 @@ class ZhipuEmbed(Base):
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=text,
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
@ -265,18 +262,17 @@ class ZhipuEmbed(Base):
|
||||
|
||||
|
||||
class OllamaEmbed(Base):
|
||||
_FACTORY_NAME = "Ollama"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
|
||||
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
arr = []
|
||||
tks_num = 0
|
||||
for txt in texts:
|
||||
res = self.client.embeddings(prompt=txt,
|
||||
model=self.model_name,
|
||||
options={"use_mmap": True})
|
||||
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
|
||||
try:
|
||||
arr.append(res["embedding"])
|
||||
except Exception as _e:
|
||||
@ -285,9 +281,7 @@ class OllamaEmbed(Base):
|
||||
return np.array(arr), tks_num
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings(prompt=text,
|
||||
model=self.model_name,
|
||||
options={"use_mmap": True})
|
||||
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
|
||||
try:
|
||||
return np.array(res["embedding"]), 128
|
||||
except Exception as _e:
|
||||
@ -295,6 +289,7 @@ class OllamaEmbed(Base):
|
||||
|
||||
|
||||
class FastEmbed(DefaultEmbedding):
|
||||
_FACTORY_NAME = "FastEmbed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -307,15 +302,15 @@ class FastEmbed(DefaultEmbedding):
|
||||
if not settings.LIGHTEN:
|
||||
with FastEmbed._model_lock:
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||
try:
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
DefaultEmbedding._model_name = model_name
|
||||
except Exception:
|
||||
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
|
||||
local_dir=os.path.join(get_home_cache_dir(),
|
||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
cache_dir = snapshot_download(
|
||||
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||
)
|
||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
self._model = DefaultEmbedding._model
|
||||
self._model_name = model_name
|
||||
@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding):
|
||||
|
||||
|
||||
class XinferenceEmbed(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="", base_url=""):
|
||||
base_url = urljoin(base_url, "v1")
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
@ -350,7 +347,7 @@ class XinferenceEmbed(Base):
|
||||
ress = []
|
||||
total_tokens = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
||||
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
total_tokens += self.total_token_count(res)
|
||||
@ -359,8 +356,7 @@ class XinferenceEmbed(Base):
|
||||
return np.array(ress), total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=[text],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
@ -368,20 +364,18 @@ class XinferenceEmbed(Base):
|
||||
|
||||
|
||||
class YoudaoEmbed(Base):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_client = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||
from BCEmbedding import EmbeddingModel as qanthing
|
||||
|
||||
try:
|
||||
logging.info("LOADING BCE...")
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
|
||||
get_home_cache_dir(),
|
||||
"bce-embedding-base_v1"))
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
|
||||
except Exception:
|
||||
YoudaoEmbed._client = qanthing(
|
||||
model_name_or_path=model_name.replace(
|
||||
"maidalun1020", "InfiniFlow"))
|
||||
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
def encode(self, texts: list):
|
||||
batch_size = 10
|
||||
@ -390,7 +384,7 @@ class YoudaoEmbed(Base):
|
||||
for t in texts:
|
||||
token_count += num_tokens_from_string(t)
|
||||
for i in range(0, len(texts), batch_size):
|
||||
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
|
||||
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
|
||||
res.extend(embds)
|
||||
return np.array(res), token_count
|
||||
|
||||
@ -400,14 +394,11 @@ class YoudaoEmbed(Base):
|
||||
|
||||
|
||||
class JinaEmbed(Base):
|
||||
def __init__(self, key, model_name="jina-embeddings-v3",
|
||||
base_url="https://api.jina.ai/v1/embeddings"):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
|
||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
@ -416,11 +407,7 @@ class JinaEmbed(Base):
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"input": texts[i:i + batch_size],
|
||||
'encoding_type': 'float'
|
||||
}
|
||||
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
|
||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||
try:
|
||||
res = response.json()
|
||||
@ -435,50 +422,12 @@ class JinaEmbed(Base):
|
||||
return np.array(embds[0]), cnt
|
||||
|
||||
|
||||
class InfinityEmbed(Base):
|
||||
_model = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
|
||||
engine_kwargs: dict = {},
|
||||
key = None,
|
||||
):
|
||||
|
||||
from infinity_emb import EngineArgs
|
||||
from infinity_emb.engine import AsyncEngineArray
|
||||
|
||||
self._default_model = model_names[0]
|
||||
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
|
||||
|
||||
async def _embed(self, sentences: list[str], model_name: str = ""):
|
||||
if not model_name:
|
||||
model_name = self._default_model
|
||||
engine = self.engine_array[model_name]
|
||||
was_already_running = engine.is_running
|
||||
if not was_already_running:
|
||||
await engine.astart()
|
||||
embeddings, usage = await engine.embed(sentences=sentences)
|
||||
if not was_already_running:
|
||||
await engine.astop()
|
||||
return embeddings, usage
|
||||
|
||||
def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
embeddings, usage = asyncio.run(self._embed(texts, model_name))
|
||||
return np.array(embeddings), usage
|
||||
|
||||
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
|
||||
# Using the internal tokenizer to encode the texts and get the total
|
||||
# number of tokens
|
||||
return self.encode([text])
|
||||
|
||||
|
||||
class MistralEmbed(Base):
|
||||
def __init__(self, key, model_name="mistral-embed",
|
||||
base_url=None):
|
||||
_FACTORY_NAME = "Mistral"
|
||||
|
||||
def __init__(self, key, model_name="mistral-embed", base_url=None):
|
||||
from mistralai.client import MistralClient
|
||||
|
||||
self.client = MistralClient(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
@ -488,8 +437,7 @@ class MistralEmbed(Base):
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embeddings(input=texts[i:i + batch_size],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
||||
try:
|
||||
ress.extend([d.embedding for d in res.data])
|
||||
token_count += self.total_token_count(res)
|
||||
@ -498,8 +446,7 @@ class MistralEmbed(Base):
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings(input=[truncate(text, 8196)],
|
||||
model=self.model_name)
|
||||
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
||||
try:
|
||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||
except Exception as _e:
|
||||
@ -507,30 +454,31 @@ class MistralEmbed(Base):
|
||||
|
||||
|
||||
class BedrockEmbed(Base):
|
||||
def __init__(self, key, model_name,
|
||||
**kwargs):
|
||||
_FACTORY_NAME = "Bedrock"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
import boto3
|
||||
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
|
||||
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
|
||||
self.bedrock_region = json.loads(key).get('bedrock_region', '')
|
||||
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
self.model_name = model_name
|
||||
|
||||
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||
self.client = boto3.client('bedrock-runtime')
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
else:
|
||||
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
||||
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
embeddings = []
|
||||
token_count = 0
|
||||
for text in texts:
|
||||
if self.model_name.split('.')[0] == 'amazon':
|
||||
if self.model_name.split(".")[0] == "amazon":
|
||||
body = {"inputText": text}
|
||||
elif self.model_name.split('.')[0] == 'cohere':
|
||||
body = {"texts": [text], "input_type": 'search_document'}
|
||||
elif self.model_name.split(".")[0] == "cohere":
|
||||
body = {"texts": [text], "input_type": "search_document"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
@ -545,10 +493,10 @@ class BedrockEmbed(Base):
|
||||
def encode_queries(self, text):
|
||||
embeddings = []
|
||||
token_count = num_tokens_from_string(text)
|
||||
if self.model_name.split('.')[0] == 'amazon':
|
||||
if self.model_name.split(".")[0] == "amazon":
|
||||
body = {"inputText": truncate(text, 8196)}
|
||||
elif self.model_name.split('.')[0] == 'cohere':
|
||||
body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
|
||||
elif self.model_name.split(".")[0] == "cohere":
|
||||
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
||||
|
||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||
try:
|
||||
@ -561,10 +509,11 @@ class BedrockEmbed(Base):
|
||||
|
||||
|
||||
class GeminiEmbed(Base):
|
||||
def __init__(self, key, model_name='models/text-embedding-004',
|
||||
**kwargs):
|
||||
_FACTORY_NAME = "Gemini"
|
||||
|
||||
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
|
||||
self.key = key
|
||||
self.model_name = 'models/' + model_name
|
||||
self.model_name = "models/" + model_name
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 2048) for t in texts]
|
||||
@ -573,35 +522,27 @@ class GeminiEmbed(Base):
|
||||
batch_size = 16
|
||||
ress = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
result = genai.embed_content(
|
||||
model=self.model_name,
|
||||
content=texts[i: i + batch_size],
|
||||
task_type="retrieval_document",
|
||||
title="Embedding of single string")
|
||||
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
|
||||
try:
|
||||
ress.extend(result['embedding'])
|
||||
ress.extend(result["embedding"])
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
return np.array(ress),token_count
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
genai.configure(api_key=self.key)
|
||||
result = genai.embed_content(
|
||||
model=self.model_name,
|
||||
content=truncate(text,2048),
|
||||
task_type="retrieval_document",
|
||||
title="Embedding of single string")
|
||||
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
|
||||
token_count = num_tokens_from_string(text)
|
||||
try:
|
||||
return np.array(result['embedding']), token_count
|
||||
return np.array(result["embedding"]), token_count
|
||||
except Exception as _e:
|
||||
log_exception(_e, result)
|
||||
|
||||
|
||||
class NvidiaEmbed(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
|
||||
):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
||||
self.api_key = key
|
||||
@ -645,6 +586,8 @@ class NvidiaEmbed(Base):
|
||||
|
||||
|
||||
class LmStudioEmbed(LocalAIEmbed):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("Local llm url cannot be None")
|
||||
@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed):
|
||||
|
||||
|
||||
class OpenAI_APIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class CoHereEmbed(Base):
|
||||
_FACTORY_NAME = "Cohere"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from cohere import Client
|
||||
|
||||
@ -701,6 +648,8 @@ class CoHereEmbed(Base):
|
||||
|
||||
|
||||
class TogetherAIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.together.xyz/v1"
|
||||
@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class PerfXCloudEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "PerfXCloud"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://cloud.perfxlab.cn/v1"
|
||||
@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class UpstageEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Upstage"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
|
||||
if not base_url:
|
||||
base_url = "https://api.upstage.ai/v1/solar"
|
||||
@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class SILICONFLOWEmbed(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base):
|
||||
|
||||
|
||||
class ReplicateEmbed(Base):
|
||||
_FACTORY_NAME = "Replicate"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from replicate.client import Client
|
||||
|
||||
@ -790,6 +747,8 @@ class ReplicateEmbed(Base):
|
||||
|
||||
|
||||
class BaiduYiyanEmbed(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import qianfan
|
||||
|
||||
@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base):
|
||||
|
||||
|
||||
class VoyageEmbed(Base):
|
||||
_FACTORY_NAME = "Voyage AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import voyageai
|
||||
|
||||
@ -832,9 +793,7 @@ class VoyageEmbed(Base):
|
||||
ress = []
|
||||
token_count = 0
|
||||
for i in range(0, len(texts), batch_size):
|
||||
res = self.client.embed(
|
||||
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
|
||||
)
|
||||
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
|
||||
try:
|
||||
ress.extend(res.embeddings)
|
||||
token_count += res.total_tokens
|
||||
@ -843,9 +802,7 @@ class VoyageEmbed(Base):
|
||||
return np.array(ress), token_count
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embed(
|
||||
texts=text, model=self.model_name, input_type="query"
|
||||
)
|
||||
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
||||
try:
|
||||
return np.array(res.embeddings)[0], res.total_tokens
|
||||
except Exception as _e:
|
||||
@ -853,6 +810,8 @@ class VoyageEmbed(Base):
|
||||
|
||||
|
||||
class HuggingFaceEmbed(Base):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
if not model_name:
|
||||
raise ValueError("Model name cannot be None")
|
||||
@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base):
|
||||
def encode(self, texts: list):
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/embed",
|
||||
json={"inputs": text},
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
embeddings.append(embedding[0])
|
||||
@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base):
|
||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||
|
||||
def encode_queries(self, text):
|
||||
response = requests.post(
|
||||
f"{self.base_url}/embed",
|
||||
json={"inputs": text},
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||
if response.status_code == 200:
|
||||
embedding = response.json()
|
||||
return np.array(embedding[0]), num_tokens_from_string(text)
|
||||
@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base):
|
||||
|
||||
|
||||
class VolcEngineEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "VolcEngine"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||
if not base_url:
|
||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
ark_api_key = json.loads(key).get('ark_api_key', '')
|
||||
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
|
||||
super().__init__(ark_api_key,model_name,base_url)
|
||||
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||
super().__init__(ark_api_key, model_name, base_url)
|
||||
|
||||
|
||||
class GPUStackEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed):
|
||||
|
||||
|
||||
class NovitaEmbed(SILICONFLOWEmbed):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai/embeddings"
|
||||
@ -915,6 +872,8 @@ class NovitaEmbed(SILICONFLOWEmbed):
|
||||
|
||||
|
||||
class GiteeEmbed(SILICONFLOWEmbed):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/embeddings"
|
||||
|
||||
@ -13,24 +13,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from huggingface_hub import snapshot_download
|
||||
import os
|
||||
from abc import ABC
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from yarl import URL
|
||||
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from api.utils.log_utils import log_exception
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import json
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
@ -57,6 +57,7 @@ class Base(ABC):
|
||||
|
||||
|
||||
class DefaultRerank(Base):
|
||||
_FACTORY_NAME = "BAAI"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
@ -75,17 +76,13 @@ class DefaultRerank(Base):
|
||||
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||
import torch
|
||||
from FlagEmbedding import FlagReranker
|
||||
|
||||
with DefaultRerank._model_lock:
|
||||
if not DefaultRerank._model:
|
||||
try:
|
||||
DefaultRerank._model = FlagReranker(
|
||||
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
use_fp16=torch.cuda.is_available())
|
||||
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
|
||||
except Exception:
|
||||
model_dir = snapshot_download(repo_id=model_name,
|
||||
local_dir=os.path.join(get_home_cache_dir(),
|
||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||
local_dir_use_symlinks=False)
|
||||
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
|
||||
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
||||
self._model = DefaultRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
@ -94,6 +91,7 @@ class DefaultRerank(Base):
|
||||
def torch_empty_cache(self):
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
print(f"Error emptying cache: {e}")
|
||||
@ -112,7 +110,7 @@ class DefaultRerank(Base):
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# call subclass implemented batch processing calculation
|
||||
batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
|
||||
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
|
||||
res.extend(batch_scores)
|
||||
i += current_batch
|
||||
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
||||
@ -152,23 +150,16 @@ class DefaultRerank(Base):
|
||||
|
||||
|
||||
class JinaRerank(Base):
|
||||
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
|
||||
base_url="https://api.jina.ai/v1/rerank"):
|
||||
_FACTORY_NAME = "Jina"
|
||||
|
||||
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
|
||||
self.base_url = "https://api.jina.ai/v1/rerank"
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": len(texts)
|
||||
}
|
||||
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
@ -180,22 +171,20 @@ class JinaRerank(Base):
|
||||
|
||||
|
||||
class YoudaoRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "Youdao"
|
||||
_model = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||
from BCEmbedding import RerankerModel
|
||||
|
||||
with YoudaoRerank._model_lock:
|
||||
if not YoudaoRerank._model:
|
||||
try:
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
||||
get_home_cache_dir(),
|
||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
||||
except Exception:
|
||||
YoudaoRerank._model = RerankerModel(
|
||||
model_name_or_path=model_name.replace(
|
||||
"maidalun1020", "InfiniFlow"))
|
||||
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||
|
||||
self._model = YoudaoRerank._model
|
||||
self._dynamic_batch_size = 8
|
||||
@ -212,6 +201,8 @@ class YoudaoRerank(DefaultRerank):
|
||||
|
||||
|
||||
class XInferenceRerank(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key="x", model_name="", base_url=""):
|
||||
if base_url.find("/v1") == -1:
|
||||
base_url = urljoin(base_url, "/v1/rerank")
|
||||
@ -219,10 +210,7 @@ class XInferenceRerank(Base):
|
||||
base_url = urljoin(base_url, "/v1/rerank")
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"accept": "application/json"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
@ -233,13 +221,7 @@ class XInferenceRerank(Base):
|
||||
token_count = 0
|
||||
for _, t in pairs:
|
||||
token_count += num_tokens_from_string(t)
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"return_documents": "true",
|
||||
"return_len": "true",
|
||||
"documents": texts
|
||||
}
|
||||
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
|
||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
@ -251,15 +233,14 @@ class XInferenceRerank(Base):
|
||||
|
||||
|
||||
class LocalAIRerank(Base):
|
||||
_FACTORY_NAME = "LocalAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if base_url.find("/rerank") == -1:
|
||||
self.base_url = urljoin(base_url, "/rerank")
|
||||
else:
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
@ -296,16 +277,15 @@ class LocalAIRerank(Base):
|
||||
|
||||
|
||||
class NvidiaRerank(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||
):
|
||||
_FACTORY_NAME = "NVIDIA"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||
self.model_name = model_name
|
||||
|
||||
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
||||
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking"
|
||||
)
|
||||
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
|
||||
|
||||
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
||||
self.base_url = urljoin(base_url, "reranking")
|
||||
@ -318,9 +298,7 @@ class NvidiaRerank(Base):
|
||||
}
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum(
|
||||
[num_tokens_from_string(t) for t in texts]
|
||||
)
|
||||
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"query": {"text": query},
|
||||
@ -339,6 +317,8 @@ class NvidiaRerank(Base):
|
||||
|
||||
|
||||
class LmStudioRerank(Base):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
pass
|
||||
|
||||
@ -347,15 +327,14 @@ class LmStudioRerank(Base):
|
||||
|
||||
|
||||
class OpenAI_APIRerank(Base):
|
||||
_FACTORY_NAME = "OpenAI-API-Compatible"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if base_url.find("/rerank") == -1:
|
||||
self.base_url = urljoin(base_url, "/rerank")
|
||||
else:
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
@ -392,6 +371,8 @@ class OpenAI_APIRerank(Base):
|
||||
|
||||
|
||||
class CoHereRerank(Base):
|
||||
_FACTORY_NAME = ["Cohere", "VLLM"]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from cohere import Client
|
||||
|
||||
@ -399,9 +380,7 @@ class CoHereRerank(Base):
|
||||
self.model_name = model_name.split("___")[0]
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
token_count = num_tokens_from_string(query) + sum(
|
||||
[num_tokens_from_string(t) for t in texts]
|
||||
)
|
||||
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||
res = self.client.rerank(
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
@ -419,6 +398,8 @@ class CoHereRerank(Base):
|
||||
|
||||
|
||||
class TogetherAIRerank(Base):
|
||||
_FACTORY_NAME = "TogetherAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
pass
|
||||
|
||||
@ -427,9 +408,9 @@ class TogetherAIRerank(Base):
|
||||
|
||||
|
||||
class SILICONFLOWRerank(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
|
||||
):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1/rerank"
|
||||
self.model_name = model_name
|
||||
@ -450,9 +431,7 @@ class SILICONFLOWRerank(Base):
|
||||
"max_chunks_per_doc": 1024,
|
||||
"overlap_tokens": 80,
|
||||
}
|
||||
response = requests.post(
|
||||
self.base_url, json=payload, headers=self.headers
|
||||
).json()
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
try:
|
||||
for d in response["results"]:
|
||||
@ -466,6 +445,8 @@ class SILICONFLOWRerank(Base):
|
||||
|
||||
|
||||
class BaiduYiyanRerank(Base):
|
||||
_FACTORY_NAME = "BaiduYiyan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from qianfan.resources import Reranker
|
||||
|
||||
@ -492,6 +473,8 @@ class BaiduYiyanRerank(Base):
|
||||
|
||||
|
||||
class VoyageRerank(Base):
|
||||
_FACTORY_NAME = "Voyage AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import voyageai
|
||||
|
||||
@ -502,9 +485,7 @@ class VoyageRerank(Base):
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
if not texts:
|
||||
return rank, 0
|
||||
res = self.client.rerank(
|
||||
query=query, documents=texts, model=self.model_name, top_k=len(texts)
|
||||
)
|
||||
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
|
||||
try:
|
||||
for r in res.results:
|
||||
rank[r.index] = r.relevance_score
|
||||
@ -514,22 +495,20 @@ class VoyageRerank(Base):
|
||||
|
||||
|
||||
class QWenRerank(Base):
|
||||
def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
|
||||
import dashscope
|
||||
|
||||
self.api_key = key
|
||||
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
|
||||
|
||||
def similarity(self, query: str, texts: list):
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
resp = dashscope.TextReRank.call(
|
||||
api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
query=query,
|
||||
documents=texts,
|
||||
top_n=len(texts),
|
||||
return_documents=False
|
||||
)
|
||||
|
||||
import dashscope
|
||||
|
||||
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
|
||||
rank = np.zeros(len(texts), dtype=float)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
try:
|
||||
@ -543,6 +522,8 @@ class QWenRerank(Base):
|
||||
|
||||
|
||||
class HuggingfaceRerank(DefaultRerank):
|
||||
_FACTORY_NAME = "HuggingFace"
|
||||
|
||||
@staticmethod
|
||||
def post(query: str, texts: list, url="127.0.0.1"):
|
||||
exc = None
|
||||
@ -550,9 +531,9 @@ class HuggingfaceRerank(DefaultRerank):
|
||||
batch_size = 8
|
||||
for i in range(0, len(texts), batch_size):
|
||||
try:
|
||||
res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
|
||||
json={"query": query, "texts": texts[i: i + batch_size],
|
||||
"raw_scores": False, "truncate": True})
|
||||
res = requests.post(
|
||||
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
|
||||
)
|
||||
|
||||
for o in res.json():
|
||||
scores[o["index"] + i] = o["score"]
|
||||
@ -577,9 +558,9 @@ class HuggingfaceRerank(DefaultRerank):
|
||||
|
||||
|
||||
class GPUStackRerank(Base):
|
||||
def __init__(
|
||||
self, key, model_name, base_url
|
||||
):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
|
||||
@ -600,9 +581,7 @@ class GPUStackRerank(Base):
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.base_url, json=payload, headers=self.headers
|
||||
)
|
||||
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
@ -623,11 +602,12 @@ class GPUStackRerank(Base):
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise ValueError(
|
||||
f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
|
||||
class NovitaRerank(JinaRerank):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai/rerank"
|
||||
@ -635,6 +615,8 @@ class NovitaRerank(JinaRerank):
|
||||
|
||||
|
||||
class GiteeRerank(JinaRerank):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/rerank"
|
||||
|
||||
@ -13,16 +13,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import requests
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
import io
|
||||
from abc import ABC
|
||||
from openai import OpenAI
|
||||
import json
|
||||
from rag.utils import num_tokens_from_string
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@ -30,11 +32,7 @@ class Base(ABC):
|
||||
pass
|
||||
|
||||
def transcription(self, audio, **kwargs):
|
||||
transcription = self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
file=audio,
|
||||
response_format="text"
|
||||
)
|
||||
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text")
|
||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||
|
||||
def audio2base64(self, audio):
|
||||
@ -46,6 +44,8 @@ class Base(ABC):
|
||||
|
||||
|
||||
class GPTSeq2txt(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
@ -54,31 +54,34 @@ class GPTSeq2txt(Base):
|
||||
|
||||
|
||||
class QWenSeq2txt(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
|
||||
import dashscope
|
||||
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio, format):
|
||||
from http import HTTPStatus
|
||||
|
||||
from dashscope.audio.asr import Recognition
|
||||
|
||||
recognition = Recognition(model=self.model_name,
|
||||
format=format,
|
||||
sample_rate=16000,
|
||||
callback=None)
|
||||
recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None)
|
||||
result = recognition.call(audio)
|
||||
|
||||
ans = ""
|
||||
if result.status_code == HTTPStatus.OK:
|
||||
for sentence in result.get_sentence():
|
||||
ans += sentence.text.decode('utf-8') + '\n'
|
||||
ans += sentence.text.decode("utf-8") + "\n"
|
||||
return ans, num_tokens_from_string(ans)
|
||||
|
||||
return "**ERROR**: " + result.message, 0
|
||||
|
||||
|
||||
class AzureSeq2txt(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||
self.model_name = model_name
|
||||
@ -86,43 +89,33 @@ class AzureSeq2txt(Base):
|
||||
|
||||
|
||||
class XinferenceSeq2txt(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name="whisper-small", **kwargs):
|
||||
self.base_url = kwargs.get('base_url', None)
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
|
||||
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||
if isinstance(audio, str):
|
||||
audio_file = open(audio, 'rb')
|
||||
audio_file = open(audio, "rb")
|
||||
audio_data = audio_file.read()
|
||||
audio_file_name = audio.split("/")[-1]
|
||||
else:
|
||||
audio_data = audio
|
||||
audio_file_name = "audio.wav"
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"language": language,
|
||||
"prompt": prompt,
|
||||
"response_format": response_format,
|
||||
"temperature": temperature
|
||||
}
|
||||
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
|
||||
|
||||
files = {
|
||||
"file": (audio_file_name, audio_data, 'audio/wav')
|
||||
}
|
||||
files = {"file": (audio_file_name, audio_data, "audio/wav")}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=payload
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if 'text' in result:
|
||||
transcription_text = result['text'].strip()
|
||||
if "text" in result:
|
||||
transcription_text = result["text"].strip()
|
||||
return transcription_text, num_tokens_from_string(transcription_text)
|
||||
else:
|
||||
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||
@ -132,11 +125,11 @@ class XinferenceSeq2txt(Base):
|
||||
|
||||
|
||||
class TencentCloudSeq2txt(Base):
|
||||
def __init__(
|
||||
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
||||
):
|
||||
from tencentcloud.common import credential
|
||||
_FACTORY_NAME = "Tencent Cloud"
|
||||
|
||||
def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
|
||||
from tencentcloud.asr.v20190614 import asr_client
|
||||
from tencentcloud.common import credential
|
||||
|
||||
key = json.loads(key)
|
||||
sid = key.get("tencent_cloud_sid", "")
|
||||
@ -146,11 +139,12 @@ class TencentCloudSeq2txt(Base):
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio, max_retries=60, retry_interval=5):
|
||||
import time
|
||||
|
||||
from tencentcloud.asr.v20190614 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
from tencentcloud.asr.v20190614 import models
|
||||
import time
|
||||
|
||||
b64 = self.audio2base64(audio)
|
||||
try:
|
||||
@ -174,9 +168,7 @@ class TencentCloudSeq2txt(Base):
|
||||
while retries < max_retries:
|
||||
resp = self.client.DescribeTaskStatus(req)
|
||||
if resp.Data.StatusStr == "success":
|
||||
text = re.sub(
|
||||
r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
|
||||
).strip()
|
||||
text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
|
||||
return text, num_tokens_from_string(text)
|
||||
elif resp.Data.StatusStr == "failed":
|
||||
return (
|
||||
@ -195,6 +187,8 @@ class TencentCloudSeq2txt(Base):
|
||||
|
||||
|
||||
class GPUStackSeq2txt(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, base_url):
|
||||
if not base_url:
|
||||
raise ValueError("url cannot be None")
|
||||
@ -206,8 +200,11 @@ class GPUStackSeq2txt(Base):
|
||||
|
||||
|
||||
class GiteeSeq2txt(Base):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/"
|
||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
@ -70,10 +70,12 @@ class Base(ABC):
|
||||
pass
|
||||
|
||||
def normalize_text(self, text):
|
||||
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
|
||||
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
|
||||
|
||||
|
||||
class FishAudioTTS(Base):
|
||||
_FACTORY_NAME = "Fish Audio"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
|
||||
if not base_url:
|
||||
base_url = "https://api.fish.audio/v1/tts"
|
||||
@ -96,9 +98,7 @@ class FishAudioTTS(Base):
|
||||
with client.stream(
|
||||
method="POST",
|
||||
url=self.base_url,
|
||||
content=ormsgpack.packb(
|
||||
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
||||
),
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
headers=self.headers,
|
||||
timeout=None,
|
||||
) as response:
|
||||
@ -115,6 +115,8 @@ class FishAudioTTS(Base):
|
||||
|
||||
|
||||
class QwenTTS(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
import dashscope
|
||||
|
||||
@ -122,10 +124,11 @@ class QwenTTS(Base):
|
||||
dashscope.api_key = key
|
||||
|
||||
def tts(self, text):
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
||||
from collections import deque
|
||||
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
|
||||
|
||||
class Callback(ResultCallback):
|
||||
def __init__(self) -> None:
|
||||
self.dque = deque()
|
||||
@ -159,10 +162,7 @@ class QwenTTS(Base):
|
||||
|
||||
text = self.normalize_text(text)
|
||||
callback = Callback()
|
||||
SpeechSynthesizer.call(model=self.model_name,
|
||||
text=text,
|
||||
callback=callback,
|
||||
format="mp3")
|
||||
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
|
||||
try:
|
||||
for data in callback._run():
|
||||
yield data
|
||||
@ -173,24 +173,19 @@ class QwenTTS(Base):
|
||||
|
||||
|
||||
class OpenAITTS(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="alloy"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"voice": voice,
|
||||
"input": text
|
||||
}
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
@ -201,7 +196,8 @@ class OpenAITTS(Base):
|
||||
yield chunk
|
||||
|
||||
|
||||
class SparkTTS:
|
||||
class SparkTTS(Base):
|
||||
_FACTORY_NAME = "XunFei Spark"
|
||||
STATUS_FIRST_FRAME = 0
|
||||
STATUS_CONTINUE_FRAME = 1
|
||||
STATUS_LAST_FRAME = 2
|
||||
@ -219,29 +215,23 @@ class SparkTTS:
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = 'wss://tts-api.xfyun.cn/v2/tts'
|
||||
url = "wss://tts-api.xfyun.cn/v2/tts"
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
url = url + '?' + urlencode(v)
|
||||
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
|
||||
url = url + "?" + urlencode(v)
|
||||
return url
|
||||
|
||||
def tts(self, text):
|
||||
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
||||
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
|
||||
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
|
||||
CommonArgs = {"app_id": self.APPID}
|
||||
audio_queue = self.audio_queue
|
||||
model_name = self.model_name
|
||||
@ -273,9 +263,7 @@ class SparkTTS:
|
||||
|
||||
def on_open(self, ws):
|
||||
def run(*args):
|
||||
d = {"common": CommonArgs,
|
||||
"business": BusinessArgs,
|
||||
"data": Data}
|
||||
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
|
||||
ws.send(json.dumps(d))
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
@ -283,44 +271,32 @@ class SparkTTS:
|
||||
wsUrl = self.create_url()
|
||||
websocket.enableTrace(False)
|
||||
a = Callback()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
|
||||
on_message=a.on_message)
|
||||
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
|
||||
status_code = 0
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
while True:
|
||||
audio_chunk = self.audio_queue.get()
|
||||
if audio_chunk is None:
|
||||
if status_code == 0:
|
||||
raise Exception(
|
||||
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||
else:
|
||||
break
|
||||
status_code = 1
|
||||
yield audio_chunk
|
||||
|
||||
|
||||
class XinferenceTTS:
|
||||
class XinferenceTTS(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="中文女", stream=True):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": voice
|
||||
}
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/audio/speech",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
stream=stream
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
@ -336,18 +312,12 @@ class OllamaTTS(Base):
|
||||
base_url = "https://api.ollama.ai/v1"
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bear {key}"
|
||||
|
||||
def tts(self, text, voice="standard-voice"):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"voice": voice,
|
||||
"input": text
|
||||
}
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
@ -359,30 +329,19 @@ class OllamaTTS(Base):
|
||||
yield chunk
|
||||
|
||||
|
||||
class GPUStackTTS:
|
||||
class GPUStackTTS(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def tts(self, text, voice="Chinese Female", stream=True):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": voice
|
||||
}
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/audio/speech",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
stream=stream
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
@ -393,16 +352,15 @@ class GPUStackTTS:
|
||||
|
||||
|
||||
class SILICONFLOWTTS(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="anna"):
|
||||
text = self.normalize_text(text)
|
||||
@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base):
|
||||
"sample_rate": 123,
|
||||
"stream": True,
|
||||
"speed": 1,
|
||||
"gain": 0
|
||||
"gain": 0,
|
||||
}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||
|
||||
Reference in New Issue
Block a user