mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: automatic LLMs registration (#8651)
### What problem does this PR solve? Support automatic LLMs registration. ### Type of change - [x] Refactoring
This commit is contained in:
@ -15,289 +15,53 @@
|
|||||||
#
|
#
|
||||||
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
|
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
|
||||||
#
|
#
|
||||||
from .embedding_model import (
|
|
||||||
OllamaEmbed,
|
|
||||||
LocalAIEmbed,
|
|
||||||
OpenAIEmbed,
|
|
||||||
AzureEmbed,
|
|
||||||
XinferenceEmbed,
|
|
||||||
QWenEmbed,
|
|
||||||
ZhipuEmbed,
|
|
||||||
FastEmbed,
|
|
||||||
YoudaoEmbed,
|
|
||||||
BaiChuanEmbed,
|
|
||||||
JinaEmbed,
|
|
||||||
DefaultEmbedding,
|
|
||||||
MistralEmbed,
|
|
||||||
BedrockEmbed,
|
|
||||||
GeminiEmbed,
|
|
||||||
NvidiaEmbed,
|
|
||||||
LmStudioEmbed,
|
|
||||||
OpenAI_APIEmbed,
|
|
||||||
CoHereEmbed,
|
|
||||||
TogetherAIEmbed,
|
|
||||||
PerfXCloudEmbed,
|
|
||||||
UpstageEmbed,
|
|
||||||
SILICONFLOWEmbed,
|
|
||||||
ReplicateEmbed,
|
|
||||||
BaiduYiyanEmbed,
|
|
||||||
VoyageEmbed,
|
|
||||||
HuggingFaceEmbed,
|
|
||||||
VolcEngineEmbed,
|
|
||||||
GPUStackEmbed,
|
|
||||||
NovitaEmbed,
|
|
||||||
GiteeEmbed
|
|
||||||
)
|
|
||||||
from .chat_model import (
|
|
||||||
GptTurbo,
|
|
||||||
AzureChat,
|
|
||||||
ZhipuChat,
|
|
||||||
QWenChat,
|
|
||||||
OllamaChat,
|
|
||||||
LocalAIChat,
|
|
||||||
XinferenceChat,
|
|
||||||
MoonshotChat,
|
|
||||||
DeepSeekChat,
|
|
||||||
VolcEngineChat,
|
|
||||||
BaiChuanChat,
|
|
||||||
MiniMaxChat,
|
|
||||||
MistralChat,
|
|
||||||
GeminiChat,
|
|
||||||
BedrockChat,
|
|
||||||
GroqChat,
|
|
||||||
OpenRouterChat,
|
|
||||||
StepFunChat,
|
|
||||||
NvidiaChat,
|
|
||||||
LmStudioChat,
|
|
||||||
OpenAI_APIChat,
|
|
||||||
CoHereChat,
|
|
||||||
LeptonAIChat,
|
|
||||||
TogetherAIChat,
|
|
||||||
PerfXCloudChat,
|
|
||||||
UpstageChat,
|
|
||||||
NovitaAIChat,
|
|
||||||
SILICONFLOWChat,
|
|
||||||
PPIOChat,
|
|
||||||
YiChat,
|
|
||||||
ReplicateChat,
|
|
||||||
HunyuanChat,
|
|
||||||
SparkChat,
|
|
||||||
BaiduYiyanChat,
|
|
||||||
AnthropicChat,
|
|
||||||
GoogleChat,
|
|
||||||
HuggingFaceChat,
|
|
||||||
GPUStackChat,
|
|
||||||
ModelScopeChat,
|
|
||||||
GiteeChat
|
|
||||||
)
|
|
||||||
|
|
||||||
from .cv_model import (
|
import importlib
|
||||||
GptV4,
|
import inspect
|
||||||
AzureGptV4,
|
|
||||||
OllamaCV,
|
|
||||||
XinferenceCV,
|
|
||||||
QWenCV,
|
|
||||||
Zhipu4V,
|
|
||||||
LocalCV,
|
|
||||||
GeminiCV,
|
|
||||||
OpenRouterCV,
|
|
||||||
LocalAICV,
|
|
||||||
NvidiaCV,
|
|
||||||
LmStudioCV,
|
|
||||||
StepFunCV,
|
|
||||||
OpenAI_APICV,
|
|
||||||
TogetherAICV,
|
|
||||||
YiCV,
|
|
||||||
HunyuanCV,
|
|
||||||
AnthropicCV,
|
|
||||||
SILICONFLOWCV,
|
|
||||||
GPUStackCV,
|
|
||||||
GoogleCV,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .rerank_model import (
|
ChatModel = globals().get("ChatModel", {})
|
||||||
LocalAIRerank,
|
CvModel = globals().get("CvModel", {})
|
||||||
DefaultRerank,
|
EmbeddingModel = globals().get("EmbeddingModel", {})
|
||||||
JinaRerank,
|
RerankModel = globals().get("RerankModel", {})
|
||||||
YoudaoRerank,
|
Seq2txtModel = globals().get("Seq2txtModel", {})
|
||||||
XInferenceRerank,
|
TTSModel = globals().get("TTSModel", {})
|
||||||
NvidiaRerank,
|
|
||||||
LmStudioRerank,
|
|
||||||
OpenAI_APIRerank,
|
|
||||||
CoHereRerank,
|
|
||||||
TogetherAIRerank,
|
|
||||||
SILICONFLOWRerank,
|
|
||||||
BaiduYiyanRerank,
|
|
||||||
VoyageRerank,
|
|
||||||
QWenRerank,
|
|
||||||
GPUStackRerank,
|
|
||||||
HuggingfaceRerank,
|
|
||||||
NovitaRerank,
|
|
||||||
GiteeRerank
|
|
||||||
)
|
|
||||||
|
|
||||||
from .sequence2txt_model import (
|
MODULE_MAPPING = {
|
||||||
GPTSeq2txt,
|
"chat_model": ChatModel,
|
||||||
QWenSeq2txt,
|
"cv_model": CvModel,
|
||||||
AzureSeq2txt,
|
"embedding_model": EmbeddingModel,
|
||||||
XinferenceSeq2txt,
|
"rerank_model": RerankModel,
|
||||||
TencentCloudSeq2txt,
|
"sequence2txt_model": Seq2txtModel,
|
||||||
GPUStackSeq2txt,
|
"tts_model": TTSModel,
|
||||||
GiteeSeq2txt
|
|
||||||
)
|
|
||||||
|
|
||||||
from .tts_model import (
|
|
||||||
FishAudioTTS,
|
|
||||||
QwenTTS,
|
|
||||||
OpenAITTS,
|
|
||||||
SparkTTS,
|
|
||||||
XinferenceTTS,
|
|
||||||
GPUStackTTS,
|
|
||||||
SILICONFLOWTTS,
|
|
||||||
)
|
|
||||||
|
|
||||||
EmbeddingModel = {
|
|
||||||
"Ollama": OllamaEmbed,
|
|
||||||
"LocalAI": LocalAIEmbed,
|
|
||||||
"OpenAI": OpenAIEmbed,
|
|
||||||
"Azure-OpenAI": AzureEmbed,
|
|
||||||
"Xinference": XinferenceEmbed,
|
|
||||||
"Tongyi-Qianwen": QWenEmbed,
|
|
||||||
"ZHIPU-AI": ZhipuEmbed,
|
|
||||||
"FastEmbed": FastEmbed,
|
|
||||||
"Youdao": YoudaoEmbed,
|
|
||||||
"BaiChuan": BaiChuanEmbed,
|
|
||||||
"Jina": JinaEmbed,
|
|
||||||
"BAAI": DefaultEmbedding,
|
|
||||||
"Mistral": MistralEmbed,
|
|
||||||
"Bedrock": BedrockEmbed,
|
|
||||||
"Gemini": GeminiEmbed,
|
|
||||||
"NVIDIA": NvidiaEmbed,
|
|
||||||
"LM-Studio": LmStudioEmbed,
|
|
||||||
"OpenAI-API-Compatible": OpenAI_APIEmbed,
|
|
||||||
"VLLM": OpenAI_APIEmbed,
|
|
||||||
"Cohere": CoHereEmbed,
|
|
||||||
"TogetherAI": TogetherAIEmbed,
|
|
||||||
"PerfXCloud": PerfXCloudEmbed,
|
|
||||||
"Upstage": UpstageEmbed,
|
|
||||||
"SILICONFLOW": SILICONFLOWEmbed,
|
|
||||||
"Replicate": ReplicateEmbed,
|
|
||||||
"BaiduYiyan": BaiduYiyanEmbed,
|
|
||||||
"Voyage AI": VoyageEmbed,
|
|
||||||
"HuggingFace": HuggingFaceEmbed,
|
|
||||||
"VolcEngine": VolcEngineEmbed,
|
|
||||||
"GPUStack": GPUStackEmbed,
|
|
||||||
"NovitaAI": NovitaEmbed,
|
|
||||||
"GiteeAI": GiteeEmbed
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CvModel = {
|
package_name = __name__
|
||||||
"OpenAI": GptV4,
|
|
||||||
"Azure-OpenAI": AzureGptV4,
|
|
||||||
"Ollama": OllamaCV,
|
|
||||||
"Xinference": XinferenceCV,
|
|
||||||
"Tongyi-Qianwen": QWenCV,
|
|
||||||
"ZHIPU-AI": Zhipu4V,
|
|
||||||
"Moonshot": LocalCV,
|
|
||||||
"Gemini": GeminiCV,
|
|
||||||
"OpenRouter": OpenRouterCV,
|
|
||||||
"LocalAI": LocalAICV,
|
|
||||||
"NVIDIA": NvidiaCV,
|
|
||||||
"LM-Studio": LmStudioCV,
|
|
||||||
"StepFun": StepFunCV,
|
|
||||||
"OpenAI-API-Compatible": OpenAI_APICV,
|
|
||||||
"VLLM": OpenAI_APICV,
|
|
||||||
"TogetherAI": TogetherAICV,
|
|
||||||
"01.AI": YiCV,
|
|
||||||
"Tencent Hunyuan": HunyuanCV,
|
|
||||||
"Anthropic": AnthropicCV,
|
|
||||||
"SILICONFLOW": SILICONFLOWCV,
|
|
||||||
"GPUStack": GPUStackCV,
|
|
||||||
"Google Cloud": GoogleCV
|
|
||||||
}
|
|
||||||
|
|
||||||
ChatModel = {
|
for module_name, mapping_dict in MODULE_MAPPING.items():
|
||||||
"OpenAI": GptTurbo,
|
full_module_name = f"{package_name}.{module_name}"
|
||||||
"Azure-OpenAI": AzureChat,
|
module = importlib.import_module(full_module_name)
|
||||||
"ZHIPU-AI": ZhipuChat,
|
|
||||||
"Tongyi-Qianwen": QWenChat,
|
|
||||||
"Ollama": OllamaChat,
|
|
||||||
"LocalAI": LocalAIChat,
|
|
||||||
"Xinference": XinferenceChat,
|
|
||||||
"Moonshot": MoonshotChat,
|
|
||||||
"DeepSeek": DeepSeekChat,
|
|
||||||
"VolcEngine": VolcEngineChat,
|
|
||||||
"BaiChuan": BaiChuanChat,
|
|
||||||
"MiniMax": MiniMaxChat,
|
|
||||||
"Mistral": MistralChat,
|
|
||||||
"Gemini": GeminiChat,
|
|
||||||
"Bedrock": BedrockChat,
|
|
||||||
"Groq": GroqChat,
|
|
||||||
"OpenRouter": OpenRouterChat,
|
|
||||||
"StepFun": StepFunChat,
|
|
||||||
"NVIDIA": NvidiaChat,
|
|
||||||
"LM-Studio": LmStudioChat,
|
|
||||||
"OpenAI-API-Compatible": OpenAI_APIChat,
|
|
||||||
"VLLM": OpenAI_APIChat,
|
|
||||||
"Cohere": CoHereChat,
|
|
||||||
"LeptonAI": LeptonAIChat,
|
|
||||||
"TogetherAI": TogetherAIChat,
|
|
||||||
"PerfXCloud": PerfXCloudChat,
|
|
||||||
"Upstage": UpstageChat,
|
|
||||||
"NovitaAI": NovitaAIChat,
|
|
||||||
"SILICONFLOW": SILICONFLOWChat,
|
|
||||||
"PPIO": PPIOChat,
|
|
||||||
"01.AI": YiChat,
|
|
||||||
"Replicate": ReplicateChat,
|
|
||||||
"Tencent Hunyuan": HunyuanChat,
|
|
||||||
"XunFei Spark": SparkChat,
|
|
||||||
"BaiduYiyan": BaiduYiyanChat,
|
|
||||||
"Anthropic": AnthropicChat,
|
|
||||||
"Google Cloud": GoogleChat,
|
|
||||||
"HuggingFace": HuggingFaceChat,
|
|
||||||
"GPUStack": GPUStackChat,
|
|
||||||
"ModelScope":ModelScopeChat,
|
|
||||||
"GiteeAI": GiteeChat
|
|
||||||
}
|
|
||||||
|
|
||||||
RerankModel = {
|
base_class = None
|
||||||
"LocalAI": LocalAIRerank,
|
for name, obj in inspect.getmembers(module):
|
||||||
"BAAI": DefaultRerank,
|
if inspect.isclass(obj) and name == "Base":
|
||||||
"Jina": JinaRerank,
|
base_class = obj
|
||||||
"Youdao": YoudaoRerank,
|
break
|
||||||
"Xinference": XInferenceRerank,
|
if base_class is None:
|
||||||
"NVIDIA": NvidiaRerank,
|
continue
|
||||||
"LM-Studio": LmStudioRerank,
|
|
||||||
"OpenAI-API-Compatible": OpenAI_APIRerank,
|
|
||||||
"VLLM": CoHereRerank,
|
|
||||||
"Cohere": CoHereRerank,
|
|
||||||
"TogetherAI": TogetherAIRerank,
|
|
||||||
"SILICONFLOW": SILICONFLOWRerank,
|
|
||||||
"BaiduYiyan": BaiduYiyanRerank,
|
|
||||||
"Voyage AI": VoyageRerank,
|
|
||||||
"Tongyi-Qianwen": QWenRerank,
|
|
||||||
"GPUStack": GPUStackRerank,
|
|
||||||
"HuggingFace": HuggingfaceRerank,
|
|
||||||
"NovitaAI": NovitaRerank,
|
|
||||||
"GiteeAI": GiteeRerank
|
|
||||||
}
|
|
||||||
|
|
||||||
Seq2txtModel = {
|
for _, obj in inspect.getmembers(module):
|
||||||
"OpenAI": GPTSeq2txt,
|
if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
|
||||||
"Tongyi-Qianwen": QWenSeq2txt,
|
if isinstance(obj._FACTORY_NAME, list):
|
||||||
"Azure-OpenAI": AzureSeq2txt,
|
for factory_name in obj._FACTORY_NAME:
|
||||||
"Xinference": XinferenceSeq2txt,
|
mapping_dict[factory_name] = obj
|
||||||
"Tencent Cloud": TencentCloudSeq2txt,
|
else:
|
||||||
"GPUStack": GPUStackSeq2txt,
|
mapping_dict[obj._FACTORY_NAME] = obj
|
||||||
"GiteeAI": GiteeSeq2txt
|
|
||||||
}
|
|
||||||
|
|
||||||
TTSModel = {
|
__all__ = [
|
||||||
"Fish Audio": FishAudioTTS,
|
"ChatModel",
|
||||||
"Tongyi-Qianwen": QwenTTS,
|
"CvModel",
|
||||||
"OpenAI": OpenAITTS,
|
"EmbeddingModel",
|
||||||
"XunFei Spark": SparkTTS,
|
"RerankModel",
|
||||||
"Xinference": XinferenceTTS,
|
"Seq2txtModel",
|
||||||
"GPUStack": GPUStackTTS,
|
"TTSModel",
|
||||||
"SILICONFLOW": SILICONFLOWTTS,
|
]
|
||||||
}
|
|
||||||
|
|||||||
@ -142,11 +142,7 @@ class Base(ABC):
|
|||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||||
|
|
||||||
def _verbose_tool_use(self, name, args, res):
|
def _verbose_tool_use(self, name, args, res):
|
||||||
return "<tool_call>" + json.dumps({
|
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||||
"name": name,
|
|
||||||
"args": args,
|
|
||||||
"result": res
|
|
||||||
}, ensure_ascii=False, indent=2) + "</tool_call>"
|
|
||||||
|
|
||||||
def _append_history(self, hist, tool_call, tool_res):
|
def _append_history(self, hist, tool_call, tool_res):
|
||||||
hist.append(
|
hist.append(
|
||||||
@ -191,10 +187,10 @@ class Base(ABC):
|
|||||||
tk_count = 0
|
tk_count = 0
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
# Implement exponential backoff retry strategy
|
# Implement exponential backoff retry strategy
|
||||||
for attempt in range(self.max_retries+1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = hist
|
history = hist
|
||||||
try:
|
try:
|
||||||
for _ in range(self.max_rounds*2):
|
for _ in range(self.max_rounds * 2):
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
||||||
tk_count += self.total_token_count(response)
|
tk_count += self.total_token_count(response)
|
||||||
if any([not response.choices, not response.choices[0].message]):
|
if any([not response.choices, not response.choices[0].message]):
|
||||||
@ -269,7 +265,7 @@ class Base(ABC):
|
|||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = hist
|
history = hist
|
||||||
try:
|
try:
|
||||||
for _ in range(self.max_rounds*2):
|
for _ in range(self.max_rounds * 2):
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||||
final_tool_calls = {}
|
final_tool_calls = {}
|
||||||
@ -430,6 +426,8 @@ class Base(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
|
_FACTORY_NAME = "OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
@ -437,6 +435,8 @@ class GptTurbo(Base):
|
|||||||
|
|
||||||
|
|
||||||
class MoonshotChat(Base):
|
class MoonshotChat(Base):
|
||||||
|
_FACTORY_NAME = "Moonshot"
|
||||||
|
|
||||||
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs):
|
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.moonshot.cn/v1"
|
base_url = "https://api.moonshot.cn/v1"
|
||||||
@ -444,6 +444,8 @@ class MoonshotChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class XinferenceChat(Base):
|
class XinferenceChat(Base):
|
||||||
|
_FACTORY_NAME = "Xinference"
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
@ -452,6 +454,8 @@ class XinferenceChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class HuggingFaceChat(Base):
|
class HuggingFaceChat(Base):
|
||||||
|
_FACTORY_NAME = "HuggingFace"
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
@ -460,6 +464,8 @@ class HuggingFaceChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class ModelScopeChat(Base):
|
class ModelScopeChat(Base):
|
||||||
|
_FACTORY_NAME = "ModelScope"
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
@ -468,6 +474,8 @@ class ModelScopeChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class DeepSeekChat(Base):
|
class DeepSeekChat(Base):
|
||||||
|
_FACTORY_NAME = "DeepSeek"
|
||||||
|
|
||||||
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs):
|
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.deepseek.com/v1"
|
base_url = "https://api.deepseek.com/v1"
|
||||||
@ -475,6 +483,8 @@ class DeepSeekChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class AzureChat(Base):
|
class AzureChat(Base):
|
||||||
|
_FACTORY_NAME = "Azure-OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url, **kwargs):
|
def __init__(self, key, model_name, base_url, **kwargs):
|
||||||
api_key = json.loads(key).get("api_key", "")
|
api_key = json.loads(key).get("api_key", "")
|
||||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||||
@ -484,6 +494,8 @@ class AzureChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class BaiChuanChat(Base):
|
class BaiChuanChat(Base):
|
||||||
|
_FACTORY_NAME = "BaiChuan"
|
||||||
|
|
||||||
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
|
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.baichuan-ai.com/v1"
|
base_url = "https://api.baichuan-ai.com/v1"
|
||||||
@ -557,6 +569,8 @@ class BaiChuanChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class QWenChat(Base):
|
class QWenChat(Base):
|
||||||
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||||
|
|
||||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
|
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
@ -565,6 +579,8 @@ class QWenChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class ZhipuChat(Base):
|
class ZhipuChat(Base):
|
||||||
|
_FACTORY_NAME = "ZHIPU-AI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
|
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -630,6 +646,8 @@ class ZhipuChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class OllamaChat(Base):
|
class OllamaChat(Base):
|
||||||
|
_FACTORY_NAME = "Ollama"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -694,6 +712,8 @@ class OllamaChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LocalAIChat(Base):
|
class LocalAIChat(Base):
|
||||||
|
_FACTORY_NAME = "LocalAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -752,6 +772,8 @@ class LocalLLM(Base):
|
|||||||
|
|
||||||
|
|
||||||
class VolcEngineChat(Base):
|
class VolcEngineChat(Base):
|
||||||
|
_FACTORY_NAME = "VolcEngine"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
|
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
|
||||||
"""
|
"""
|
||||||
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
|
||||||
@ -765,6 +787,8 @@ class VolcEngineChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class MiniMaxChat(Base):
|
class MiniMaxChat(Base):
|
||||||
|
_FACTORY_NAME = "MiniMax"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -843,6 +867,8 @@ class MiniMaxChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class MistralChat(Base):
|
class MistralChat(Base):
|
||||||
|
_FACTORY_NAME = "Mistral"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -896,6 +922,8 @@ class MistralChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class BedrockChat(Base):
|
class BedrockChat(Base):
|
||||||
|
_FACTORY_NAME = "Bedrock"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -978,6 +1006,8 @@ class BedrockChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GeminiChat(Base):
|
class GeminiChat(Base):
|
||||||
|
_FACTORY_NAME = "Gemini"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -997,6 +1027,7 @@ class GeminiChat(Base):
|
|||||||
|
|
||||||
def _chat(self, history, gen_conf):
|
def _chat(self, history, gen_conf):
|
||||||
from google.generativeai.types import content_types
|
from google.generativeai.types import content_types
|
||||||
|
|
||||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||||
hist = []
|
hist = []
|
||||||
for item in history:
|
for item in history:
|
||||||
@ -1019,6 +1050,7 @@ class GeminiChat(Base):
|
|||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf):
|
def chat_streamly(self, system, history, gen_conf):
|
||||||
from google.generativeai.types import content_types
|
from google.generativeai.types import content_types
|
||||||
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
if system:
|
if system:
|
||||||
self.model._system_instruction = content_types.to_content(system)
|
self.model._system_instruction = content_types.to_content(system)
|
||||||
@ -1042,6 +1074,8 @@ class GeminiChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GroqChat(Base):
|
class GroqChat(Base):
|
||||||
|
_FACTORY_NAME = "Groq"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -1086,6 +1120,8 @@ class GroqChat(Base):
|
|||||||
|
|
||||||
## openrouter
|
## openrouter
|
||||||
class OpenRouterChat(Base):
|
class OpenRouterChat(Base):
|
||||||
|
_FACTORY_NAME = "OpenRouter"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs):
|
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://openrouter.ai/api/v1"
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
@ -1093,6 +1129,8 @@ class OpenRouterChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class StepFunChat(Base):
|
class StepFunChat(Base):
|
||||||
|
_FACTORY_NAME = "StepFun"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.stepfun.com/v1"
|
base_url = "https://api.stepfun.com/v1"
|
||||||
@ -1100,6 +1138,8 @@ class StepFunChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class NvidiaChat(Base):
|
class NvidiaChat(Base):
|
||||||
|
_FACTORY_NAME = "NVIDIA"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs):
|
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://integrate.api.nvidia.com/v1"
|
base_url = "https://integrate.api.nvidia.com/v1"
|
||||||
@ -1107,6 +1147,8 @@ class NvidiaChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LmStudioChat(Base):
|
class LmStudioChat(Base):
|
||||||
|
_FACTORY_NAME = "LM-Studio"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url, **kwargs):
|
def __init__(self, key, model_name, base_url, **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
@ -1117,6 +1159,8 @@ class LmStudioChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAI_APIChat(Base):
|
class OpenAI_APIChat(Base):
|
||||||
|
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("url cannot be None")
|
raise ValueError("url cannot be None")
|
||||||
@ -1125,6 +1169,8 @@ class OpenAI_APIChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class PPIOChat(Base):
|
class PPIOChat(Base):
|
||||||
|
_FACTORY_NAME = "PPIO"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.ppinfra.com/v3/openai"
|
base_url = "https://api.ppinfra.com/v3/openai"
|
||||||
@ -1132,6 +1178,8 @@ class PPIOChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class CoHereChat(Base):
|
class CoHereChat(Base):
|
||||||
|
_FACTORY_NAME = "Cohere"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -1207,6 +1255,8 @@ class CoHereChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LeptonAIChat(Base):
|
class LeptonAIChat(Base):
|
||||||
|
_FACTORY_NAME = "LeptonAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1")
|
base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1")
|
||||||
@ -1214,6 +1264,8 @@ class LeptonAIChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class TogetherAIChat(Base):
|
class TogetherAIChat(Base):
|
||||||
|
_FACTORY_NAME = "TogetherAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.together.xyz/v1"
|
base_url = "https://api.together.xyz/v1"
|
||||||
@ -1221,6 +1273,8 @@ class TogetherAIChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class PerfXCloudChat(Base):
|
class PerfXCloudChat(Base):
|
||||||
|
_FACTORY_NAME = "PerfXCloud"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs):
|
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://cloud.perfxlab.cn/v1"
|
base_url = "https://cloud.perfxlab.cn/v1"
|
||||||
@ -1228,6 +1282,8 @@ class PerfXCloudChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class UpstageChat(Base):
|
class UpstageChat(Base):
|
||||||
|
_FACTORY_NAME = "Upstage"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.upstage.ai/v1/solar"
|
base_url = "https://api.upstage.ai/v1/solar"
|
||||||
@ -1235,6 +1291,8 @@ class UpstageChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class NovitaAIChat(Base):
|
class NovitaAIChat(Base):
|
||||||
|
_FACTORY_NAME = "NovitaAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.novita.ai/v3/openai"
|
base_url = "https://api.novita.ai/v3/openai"
|
||||||
@ -1242,6 +1300,8 @@ class NovitaAIChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class SILICONFLOWChat(Base):
|
class SILICONFLOWChat(Base):
|
||||||
|
_FACTORY_NAME = "SILICONFLOW"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.siliconflow.cn/v1"
|
base_url = "https://api.siliconflow.cn/v1"
|
||||||
@ -1249,6 +1309,8 @@ class SILICONFLOWChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class YiChat(Base):
|
class YiChat(Base):
|
||||||
|
_FACTORY_NAME = "01.AI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.lingyiwanwu.com/v1"
|
base_url = "https://api.lingyiwanwu.com/v1"
|
||||||
@ -1256,6 +1318,8 @@ class YiChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GiteeChat(Base):
|
class GiteeChat(Base):
|
||||||
|
_FACTORY_NAME = "GiteeAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs):
|
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://ai.gitee.com/v1/"
|
base_url = "https://ai.gitee.com/v1/"
|
||||||
@ -1263,6 +1327,8 @@ class GiteeChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class ReplicateChat(Base):
|
class ReplicateChat(Base):
|
||||||
|
_FACTORY_NAME = "Replicate"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -1302,6 +1368,8 @@ class ReplicateChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class HunyuanChat(Base):
|
class HunyuanChat(Base):
|
||||||
|
_FACTORY_NAME = "Tencent Hunyuan"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -1378,6 +1446,8 @@ class HunyuanChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class SparkChat(Base):
|
class SparkChat(Base):
|
||||||
|
_FACTORY_NAME = "XunFei Spark"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs):
|
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://spark-api-open.xf-yun.com/v1"
|
base_url = "https://spark-api-open.xf-yun.com/v1"
|
||||||
@ -1398,6 +1468,8 @@ class SparkChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class BaiduYiyanChat(Base):
|
class BaiduYiyanChat(Base):
|
||||||
|
_FACTORY_NAME = "BaiduYiyan"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -1444,6 +1516,8 @@ class BaiduYiyanChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicChat(Base):
|
class AnthropicChat(Base):
|
||||||
|
_FACTORY_NAME = "Anthropic"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs):
|
def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.anthropic.com/v1/"
|
base_url = "https://api.anthropic.com/v1/"
|
||||||
@ -1451,6 +1525,8 @@ class AnthropicChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GoogleChat(Base):
|
class GoogleChat(Base):
|
||||||
|
_FACTORY_NAME = "Google Cloud"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
@ -1529,9 +1605,11 @@ class GoogleChat(Base):
|
|||||||
if "role" in item and item["role"] == "assistant":
|
if "role" in item and item["role"] == "assistant":
|
||||||
item["role"] = "model"
|
item["role"] = "model"
|
||||||
if "content" in item:
|
if "content" in item:
|
||||||
item["parts"] = [{
|
item["parts"] = [
|
||||||
|
{
|
||||||
"text": item.pop("content"),
|
"text": item.pop("content"),
|
||||||
}]
|
}
|
||||||
|
]
|
||||||
|
|
||||||
response = self.client.generate_content(hist, generation_config=gen_conf)
|
response = self.client.generate_content(hist, generation_config=gen_conf)
|
||||||
ans = response.text
|
ans = response.text
|
||||||
@ -1587,6 +1665,8 @@ class GoogleChat(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GPUStackChat(Base):
|
class GPUStackChat(Base):
|
||||||
|
_FACTORY_NAME = "GPUStack"
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class Base(ABC):
|
|||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
top_p=gen_conf.get("top_p", 0.7)
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -79,7 +79,7 @@ class Base(ABC):
|
|||||||
messages=history,
|
messages=history,
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
top_p=gen_conf.get("top_p", 0.7),
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
@ -87,8 +87,7 @@ class Base(ABC):
|
|||||||
delta = resp.choices[0].delta.content
|
delta = resp.choices[0].delta.content
|
||||||
ans += delta
|
ans += delta
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
||||||
tk_count = resp.usage.total_tokens
|
tk_count = resp.usage.total_tokens
|
||||||
if resp.choices[0].finish_reason == "stop":
|
if resp.choices[0].finish_reason == "stop":
|
||||||
tk_count = resp.usage.total_tokens
|
tk_count = resp.usage.total_tokens
|
||||||
@ -117,13 +116,12 @@ class Base(ABC):
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
||||||
"url": f"data:image/jpeg;base64,{b64}"
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
if self.lang.lower() == "chinese"
|
||||||
|
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -136,9 +134,7 @@ class Base(ABC):
|
|||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
||||||
"url": f"data:image/jpeg;base64,{b64}"
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
@ -156,14 +152,13 @@ class Base(ABC):
|
|||||||
"url": f"data:image/jpeg;base64,{b64}",
|
"url": f"data:image/jpeg;base64,{b64}",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{"type": "text", "text": text},
|
||||||
"type": "text",
|
|
||||||
"text": text
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class GptV4(Base):
|
class GptV4(Base):
|
||||||
|
_FACTORY_NAME = "OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
@ -181,7 +176,7 @@ class GptV4(Base):
|
|||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=prompt
|
messages=prompt,
|
||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
@ -197,9 +192,11 @@ class GptV4(Base):
|
|||||||
|
|
||||||
|
|
||||||
class AzureGptV4(Base):
|
class AzureGptV4(Base):
|
||||||
|
_FACTORY_NAME = "Azure-OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||||
api_key = json.loads(key).get('api_key', '')
|
api_key = json.loads(key).get("api_key", "")
|
||||||
api_version = json.loads(key).get('api_version', '2024-02-01')
|
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
@ -212,10 +209,7 @@ class AzureGptV4(Base):
|
|||||||
if "text" in c:
|
if "text" in c:
|
||||||
c["type"] = "text"
|
c["type"] = "text"
|
||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(model=self.model_name, messages=prompt)
|
||||||
model=self.model_name,
|
|
||||||
messages=prompt
|
|
||||||
)
|
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
@ -230,8 +224,11 @@ class AzureGptV4(Base):
|
|||||||
|
|
||||||
|
|
||||||
class QWenCV(Base):
|
class QWenCV(Base):
|
||||||
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||||
|
|
||||||
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
dashscope.api_key = key
|
dashscope.api_key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
@ -247,12 +244,11 @@ class QWenCV(Base):
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
|
{"image": f"file://{path}"},
|
||||||
{
|
{
|
||||||
"image": f"file://{path}"
|
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||||
},
|
if self.lang.lower() == "chinese"
|
||||||
{
|
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||||
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
|
||||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -270,9 +266,7 @@ class QWenCV(Base):
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{"image": f"file://{path}"},
|
||||||
"image": f"file://{path}"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"text": prompt if prompt else vision_llm_describe_prompt(),
|
"text": prompt if prompt else vision_llm_describe_prompt(),
|
||||||
},
|
},
|
||||||
@ -290,9 +284,10 @@ class QWenCV(Base):
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
|
|
||||||
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
|
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
|
||||||
return response.message, 0
|
return response.message, 0
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
@ -303,33 +298,36 @@ class QWenCV(Base):
|
|||||||
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
|
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
|
||||||
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
|
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
|
return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
|
||||||
return response.message, 0
|
return response.message, 0
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, image=""):
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
for his in history:
|
for his in history:
|
||||||
if his["role"] == "user":
|
if his["role"] == "user":
|
||||||
his["content"] = self.chat_prompt(his["content"], image)
|
his["content"] = self.chat_prompt(his["content"], image)
|
||||||
response = MultiModalConversation.call(model=self.model_name, messages=history,
|
response = MultiModalConversation.call(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
top_p=gen_conf.get("top_p", 0.7))
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
|
)
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
ans = response.output.choices[0]['message']['content']
|
ans = response.output.choices[0]["message"]["content"]
|
||||||
if isinstance(ans, list):
|
if isinstance(ans, list):
|
||||||
ans = ans[0]["text"] if ans else ""
|
ans = ans[0]["text"] if ans else ""
|
||||||
tk_count += response.usage.total_tokens
|
tk_count += response.usage.total_tokens
|
||||||
if response.output.choices[0].get("finish_reason", "") == "length":
|
if response.output.choices[0].get("finish_reason", "") == "length":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
||||||
return ans, tk_count
|
return ans, tk_count
|
||||||
|
|
||||||
return "**ERROR**: " + response.message, tk_count
|
return "**ERROR**: " + response.message, tk_count
|
||||||
@ -338,6 +336,7 @@ class QWenCV(Base):
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from dashscope import MultiModalConversation
|
from dashscope import MultiModalConversation
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
@ -348,24 +347,25 @@ class QWenCV(Base):
|
|||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
try:
|
try:
|
||||||
response = MultiModalConversation.call(model=self.model_name, messages=history,
|
response = MultiModalConversation.call(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=history,
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
top_p=gen_conf.get("top_p", 0.7),
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
stream=True)
|
stream=True,
|
||||||
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if resp.status_code == HTTPStatus.OK:
|
if resp.status_code == HTTPStatus.OK:
|
||||||
cnt = resp.output.choices[0]['message']['content']
|
cnt = resp.output.choices[0]["message"]["content"]
|
||||||
if isinstance(cnt, list):
|
if isinstance(cnt, list):
|
||||||
cnt = cnt[0]["text"] if ans else ""
|
cnt = cnt[0]["text"] if ans else ""
|
||||||
ans += cnt
|
ans += cnt
|
||||||
tk_count = resp.usage.total_tokens
|
tk_count = resp.usage.total_tokens
|
||||||
if resp.output.choices[0].get("finish_reason", "") == "length":
|
if resp.output.choices[0].get("finish_reason", "") == "length":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
||||||
yield ans
|
yield ans
|
||||||
else:
|
else:
|
||||||
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
|
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
||||||
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans + "\n**ERROR**: " + str(e)
|
||||||
|
|
||||||
@ -373,6 +373,8 @@ class QWenCV(Base):
|
|||||||
|
|
||||||
|
|
||||||
class Zhipu4V(Base):
|
class Zhipu4V(Base):
|
||||||
|
_FACTORY_NAME = "ZHIPU-AI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||||
self.client = ZhipuAI(api_key=key)
|
self.client = ZhipuAI(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -394,10 +396,7 @@ class Zhipu4V(Base):
|
|||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt)
|
||||||
model=self.model_name,
|
|
||||||
messages=vision_prompt
|
|
||||||
)
|
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf, image=""):
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
@ -412,7 +411,7 @@ class Zhipu4V(Base):
|
|||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
top_p=gen_conf.get("top_p", 0.7)
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -434,7 +433,7 @@ class Zhipu4V(Base):
|
|||||||
messages=history,
|
messages=history,
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
temperature=gen_conf.get("temperature", 0.3),
|
||||||
top_p=gen_conf.get("top_p", 0.7),
|
top_p=gen_conf.get("top_p", 0.7),
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
@ -442,8 +441,7 @@ class Zhipu4V(Base):
|
|||||||
delta = resp.choices[0].delta.content
|
delta = resp.choices[0].delta.content
|
||||||
ans += delta
|
ans += delta
|
||||||
if resp.choices[0].finish_reason == "length":
|
if resp.choices[0].finish_reason == "length":
|
||||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
||||||
tk_count = resp.usage.total_tokens
|
tk_count = resp.usage.total_tokens
|
||||||
if resp.choices[0].finish_reason == "stop":
|
if resp.choices[0].finish_reason == "stop":
|
||||||
tk_count = resp.usage.total_tokens
|
tk_count = resp.usage.total_tokens
|
||||||
@ -455,6 +453,8 @@ class Zhipu4V(Base):
|
|||||||
|
|
||||||
|
|
||||||
class OllamaCV(Base):
|
class OllamaCV(Base):
|
||||||
|
_FACTORY_NAME = "Ollama"
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||||
self.client = Client(host=kwargs["base_url"])
|
self.client = Client(host=kwargs["base_url"])
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -466,7 +466,7 @@ class OllamaCV(Base):
|
|||||||
response = self.client.generate(
|
response = self.client.generate(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=prompt[0]["content"][1]["text"],
|
prompt=prompt[0]["content"][1]["text"],
|
||||||
images=[image]
|
images=[image],
|
||||||
)
|
)
|
||||||
ans = response["response"].strip()
|
ans = response["response"].strip()
|
||||||
return ans, 128
|
return ans, 128
|
||||||
@ -507,7 +507,7 @@ class OllamaCV(Base):
|
|||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=history,
|
messages=history,
|
||||||
options=options,
|
options=options,
|
||||||
keep_alive=-1
|
keep_alive=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
ans = response["message"]["content"].strip()
|
ans = response["message"]["content"].strip()
|
||||||
@ -538,7 +538,7 @@ class OllamaCV(Base):
|
|||||||
messages=history,
|
messages=history,
|
||||||
stream=True,
|
stream=True,
|
||||||
options=options,
|
options=options,
|
||||||
keep_alive=-1
|
keep_alive=-1,
|
||||||
)
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if resp["done"]:
|
if resp["done"]:
|
||||||
@ -551,6 +551,8 @@ class OllamaCV(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LocalAICV(GptV4):
|
class LocalAICV(GptV4):
|
||||||
|
_FACTORY_NAME = "LocalAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local cv model url cannot be None")
|
raise ValueError("Local cv model url cannot be None")
|
||||||
@ -561,6 +563,8 @@ class LocalAICV(GptV4):
|
|||||||
|
|
||||||
|
|
||||||
class XinferenceCV(Base):
|
class XinferenceCV(Base):
|
||||||
|
_FACTORY_NAME = "Xinference"
|
||||||
|
|
||||||
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
||||||
base_url = urljoin(base_url, "v1")
|
base_url = urljoin(base_url, "v1")
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
@ -570,10 +574,7 @@ class XinferenceCV(Base):
|
|||||||
def describe(self, image):
|
def describe(self, image):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
|
|
||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64))
|
||||||
model=self.model_name,
|
|
||||||
messages=self.prompt(b64)
|
|
||||||
)
|
|
||||||
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
@ -588,8 +589,11 @@ class XinferenceCV(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GeminiCV(Base):
|
class GeminiCV(Base):
|
||||||
|
_FACTORY_NAME = "Gemini"
|
||||||
|
|
||||||
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
|
||||||
from google.generativeai import GenerativeModel, client
|
from google.generativeai import GenerativeModel, client
|
||||||
|
|
||||||
client.configure(api_key=key)
|
client.configure(api_key=key)
|
||||||
_client = client.get_default_generative_client()
|
_client = client.get_default_generative_client()
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -599,18 +603,21 @@ class GeminiCV(Base):
|
|||||||
|
|
||||||
def describe(self, image):
|
def describe(self, image):
|
||||||
from PIL.Image import open
|
from PIL.Image import open
|
||||||
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
|
||||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
prompt = (
|
||||||
|
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||||
|
if self.lang.lower() == "chinese"
|
||||||
|
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||||
|
)
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
img = open(BytesIO(base64.b64decode(b64)))
|
img = open(BytesIO(base64.b64decode(b64)))
|
||||||
input = [prompt, img]
|
input = [prompt, img]
|
||||||
res = self.model.generate_content(
|
res = self.model.generate_content(input)
|
||||||
input
|
|
||||||
)
|
|
||||||
return res.text, res.usage_metadata.total_token_count
|
return res.text, res.usage_metadata.total_token_count
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
from PIL.Image import open
|
from PIL.Image import open
|
||||||
|
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
|
||||||
img = open(BytesIO(base64.b64decode(b64)))
|
img = open(BytesIO(base64.b64decode(b64)))
|
||||||
@ -622,6 +629,7 @@ class GeminiCV(Base):
|
|||||||
|
|
||||||
def chat(self, system, history, gen_conf, image=""):
|
def chat(self, system, history, gen_conf, image=""):
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
try:
|
try:
|
||||||
@ -635,9 +643,7 @@ class GeminiCV(Base):
|
|||||||
his.pop("content")
|
his.pop("content")
|
||||||
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||||
|
|
||||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
|
||||||
top_p=gen_conf.get("top_p", 0.7)))
|
|
||||||
|
|
||||||
ans = response.text
|
ans = response.text
|
||||||
return ans, response.usage_metadata.total_token_count
|
return ans, response.usage_metadata.total_token_count
|
||||||
@ -646,6 +652,7 @@ class GeminiCV(Base):
|
|||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf, image=""):
|
def chat_streamly(self, system, history, gen_conf, image=""):
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
|
|
||||||
@ -661,9 +668,11 @@ class GeminiCV(Base):
|
|||||||
his.pop("content")
|
his.pop("content")
|
||||||
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
history[-1]["parts"].append("data:image/jpeg;base64," + image)
|
||||||
|
|
||||||
response = self.model.generate_content(history, generation_config=GenerationConfig(
|
response = self.model.generate_content(
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
history,
|
||||||
top_p=gen_conf.get("top_p", 0.7)), stream=True)
|
generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.text:
|
if not resp.text:
|
||||||
@ -677,6 +686,8 @@ class GeminiCV(Base):
|
|||||||
|
|
||||||
|
|
||||||
class OpenRouterCV(GptV4):
|
class OpenRouterCV(GptV4):
|
||||||
|
_FACTORY_NAME = "OpenRouter"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key,
|
key,
|
||||||
@ -692,6 +703,8 @@ class OpenRouterCV(GptV4):
|
|||||||
|
|
||||||
|
|
||||||
class LocalCV(Base):
|
class LocalCV(Base):
|
||||||
|
_FACTORY_NAME = "Moonshot"
|
||||||
|
|
||||||
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -700,6 +713,8 @@ class LocalCV(Base):
|
|||||||
|
|
||||||
|
|
||||||
class NvidiaCV(Base):
|
class NvidiaCV(Base):
|
||||||
|
_FACTORY_NAME = "NVIDIA"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key,
|
key,
|
||||||
@ -726,9 +741,7 @@ class NvidiaCV(Base):
|
|||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"Authorization": f"Bearer {self.key}",
|
"Authorization": f"Bearer {self.key}",
|
||||||
},
|
},
|
||||||
json={
|
json={"messages": self.prompt(b64)},
|
||||||
"messages": self.prompt(b64)
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response = response.json()
|
response = response.json()
|
||||||
return (
|
return (
|
||||||
@ -774,10 +787,7 @@ class NvidiaCV(Base):
|
|||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": (
|
"content": (prompt if prompt else vision_llm_describe_prompt()) + f' <img src="data:image/jpeg;base64,{b64}"/>',
|
||||||
prompt if prompt else vision_llm_describe_prompt()
|
|
||||||
)
|
|
||||||
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -791,6 +801,8 @@ class NvidiaCV(Base):
|
|||||||
|
|
||||||
|
|
||||||
class StepFunCV(GptV4):
|
class StepFunCV(GptV4):
|
||||||
|
_FACTORY_NAME = "StepFun"
|
||||||
|
|
||||||
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.stepfun.com/v1"
|
base_url = "https://api.stepfun.com/v1"
|
||||||
@ -800,6 +812,8 @@ class StepFunCV(GptV4):
|
|||||||
|
|
||||||
|
|
||||||
class LmStudioCV(GptV4):
|
class LmStudioCV(GptV4):
|
||||||
|
_FACTORY_NAME = "LM-Studio"
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
@ -810,6 +824,8 @@ class LmStudioCV(GptV4):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAI_APICV(GptV4):
|
class OpenAI_APICV(GptV4):
|
||||||
|
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("url cannot be None")
|
raise ValueError("url cannot be None")
|
||||||
@ -820,6 +836,8 @@ class OpenAI_APICV(GptV4):
|
|||||||
|
|
||||||
|
|
||||||
class TogetherAICV(GptV4):
|
class TogetherAICV(GptV4):
|
||||||
|
_FACTORY_NAME = "TogetherAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
|
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.together.xyz/v1"
|
base_url = "https://api.together.xyz/v1"
|
||||||
@ -827,20 +845,38 @@ class TogetherAICV(GptV4):
|
|||||||
|
|
||||||
|
|
||||||
class YiCV(GptV4):
|
class YiCV(GptV4):
|
||||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
|
_FACTORY_NAME = "01.AI"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key,
|
||||||
|
model_name,
|
||||||
|
lang="Chinese",
|
||||||
|
base_url="https://api.lingyiwanwu.com/v1",
|
||||||
|
):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.lingyiwanwu.com/v1"
|
base_url = "https://api.lingyiwanwu.com/v1"
|
||||||
super().__init__(key, model_name, lang, base_url)
|
super().__init__(key, model_name, lang, base_url)
|
||||||
|
|
||||||
|
|
||||||
class SILICONFLOWCV(GptV4):
|
class SILICONFLOWCV(GptV4):
|
||||||
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",):
|
_FACTORY_NAME = "SILICONFLOW"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key,
|
||||||
|
model_name,
|
||||||
|
lang="Chinese",
|
||||||
|
base_url="https://api.siliconflow.cn/v1",
|
||||||
|
):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.siliconflow.cn/v1"
|
base_url = "https://api.siliconflow.cn/v1"
|
||||||
super().__init__(key, model_name, lang, base_url)
|
super().__init__(key, model_name, lang, base_url)
|
||||||
|
|
||||||
|
|
||||||
class HunyuanCV(Base):
|
class HunyuanCV(Base):
|
||||||
|
_FACTORY_NAME = "Tencent Hunyuan"
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", base_url=None):
|
def __init__(self, key, model_name, lang="Chinese", base_url=None):
|
||||||
from tencentcloud.common import credential
|
from tencentcloud.common import credential
|
||||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||||
@ -895,14 +931,13 @@ class HunyuanCV(Base):
|
|||||||
"Contents": [
|
"Contents": [
|
||||||
{
|
{
|
||||||
"Type": "image_url",
|
"Type": "image_url",
|
||||||
"ImageUrl": {
|
"ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"},
|
||||||
"Url": f"data:image/jpeg;base64,{b64}"
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Type": "text",
|
"Type": "text",
|
||||||
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
if self.lang.lower() == "chinese"
|
||||||
|
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -910,6 +945,8 @@ class HunyuanCV(Base):
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicCV(Base):
|
class AnthropicCV(Base):
|
||||||
|
_FACTORY_NAME = "Anthropic"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
import anthropic
|
import anthropic
|
||||||
|
|
||||||
@ -933,38 +970,29 @@ class AnthropicCV(Base):
|
|||||||
"data": b64,
|
"data": b64,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{"type": "text", "text": prompt},
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
def describe(self, image):
|
def describe(self, image):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
prompt = self.prompt(b64,
|
prompt = self.prompt(
|
||||||
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
b64,
|
||||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||||
|
if self.lang.lower() == "chinese"
|
||||||
|
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
|
||||||
model=self.model_name,
|
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
messages=prompt
|
|
||||||
)
|
|
||||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
|
|
||||||
|
|
||||||
def describe_with_prompt(self, image, prompt=None):
|
def describe_with_prompt(self, image, prompt=None):
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
|
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
|
||||||
|
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
|
||||||
model=self.model_name,
|
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
messages=prompt
|
|
||||||
)
|
|
||||||
return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
|
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
if "presence_penalty" in gen_conf:
|
if "presence_penalty" in gen_conf:
|
||||||
@ -984,11 +1012,7 @@ class AnthropicCV(Base):
|
|||||||
).to_dict()
|
).to_dict()
|
||||||
ans = response["content"][0]["text"]
|
ans = response["content"][0]["text"]
|
||||||
if response["stop_reason"] == "max_tokens":
|
if response["stop_reason"] == "max_tokens":
|
||||||
ans += (
|
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||||
"...\nFor the content length reason, it stopped, continue?"
|
|
||||||
if is_english([ans])
|
|
||||||
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
||||||
)
|
|
||||||
return (
|
return (
|
||||||
ans,
|
ans,
|
||||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||||
@ -1014,7 +1038,7 @@ class AnthropicCV(Base):
|
|||||||
**gen_conf,
|
**gen_conf,
|
||||||
)
|
)
|
||||||
for res in response:
|
for res in response:
|
||||||
if res.type == 'content_block_delta':
|
if res.type == "content_block_delta":
|
||||||
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
if res.delta.type == "thinking_delta" and res.delta.thinking:
|
||||||
if ans.find("<think>") < 0:
|
if ans.find("<think>") < 0:
|
||||||
ans += "<think>"
|
ans += "<think>"
|
||||||
@ -1030,7 +1054,10 @@ class AnthropicCV(Base):
|
|||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
|
|
||||||
class GPUStackCV(GptV4):
|
class GPUStackCV(GptV4):
|
||||||
|
_FACTORY_NAME = "GPUStack"
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
def __init__(self, key, model_name, lang="Chinese", base_url=""):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
@ -1041,6 +1068,8 @@ class GPUStackCV(GptV4):
|
|||||||
|
|
||||||
|
|
||||||
class GoogleCV(Base):
|
class GoogleCV(Base):
|
||||||
|
_FACTORY_NAME = "Google Cloud"
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
|
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
@ -1079,8 +1108,11 @@ class GoogleCV(Base):
|
|||||||
self.client = glm.GenerativeModel(model_name=self.model_name)
|
self.client = glm.GenerativeModel(model_name=self.model_name)
|
||||||
|
|
||||||
def describe(self, image):
|
def describe(self, image):
|
||||||
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
|
prompt = (
|
||||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
||||||
|
if self.lang.lower() == "chinese"
|
||||||
|
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
||||||
|
)
|
||||||
|
|
||||||
if "claude" in self.model_name:
|
if "claude" in self.model_name:
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
@ -1096,17 +1128,14 @@ class GoogleCV(Base):
|
|||||||
"data": b64,
|
"data": b64,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{"type": "text", "text": prompt},
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
max_tokens=8192,
|
max_tokens=8192,
|
||||||
messages=vision_prompt
|
messages=vision_prompt,
|
||||||
)
|
)
|
||||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||||
else:
|
else:
|
||||||
@ -1114,10 +1143,7 @@ class GoogleCV(Base):
|
|||||||
|
|
||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
# Create proper image part for Gemini
|
# Create proper image part for Gemini
|
||||||
image_part = glm.Part.from_data(
|
image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
|
||||||
data=base64.b64decode(b64),
|
|
||||||
mime_type="image/jpeg"
|
|
||||||
)
|
|
||||||
input = [prompt, image_part]
|
input = [prompt, image_part]
|
||||||
res = self.client.generate_content(input)
|
res = self.client.generate_content(input)
|
||||||
return res.text, res.usage_metadata.total_token_count
|
return res.text, res.usage_metadata.total_token_count
|
||||||
@ -1137,18 +1163,11 @@ class GoogleCV(Base):
|
|||||||
"data": b64,
|
"data": b64,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()},
|
||||||
"type": "text",
|
|
||||||
"text": prompt if prompt else vision_llm_describe_prompt()
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt)
|
||||||
model=self.model_name,
|
|
||||||
max_tokens=8192,
|
|
||||||
messages=vision_prompt
|
|
||||||
)
|
|
||||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||||
else:
|
else:
|
||||||
import vertexai.generative_models as glm
|
import vertexai.generative_models as glm
|
||||||
@ -1156,10 +1175,7 @@ class GoogleCV(Base):
|
|||||||
b64 = self.image2base64(image)
|
b64 = self.image2base64(image)
|
||||||
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
|
||||||
# Create proper image part for Gemini
|
# Create proper image part for Gemini
|
||||||
image_part = glm.Part.from_data(
|
image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
|
||||||
data=base64.b64decode(b64),
|
|
||||||
mime_type="image/jpeg"
|
|
||||||
)
|
|
||||||
input = [vision_prompt, image_part]
|
input = [vision_prompt, image_part]
|
||||||
res = self.client.generate_content(input)
|
res = self.client.generate_content(input)
|
||||||
return res.text, res.usage_metadata.total_token_count
|
return res.text, res.usage_metadata.total_token_count
|
||||||
@ -1180,25 +1196,17 @@ class GoogleCV(Base):
|
|||||||
"data": image,
|
"data": image,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{"type": "text", "text": his["content"]},
|
||||||
"type": "text",
|
|
||||||
"text": his["content"]
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.client.messages.create(
|
response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
|
||||||
model=self.model_name,
|
|
||||||
max_tokens=8192,
|
|
||||||
messages=history,
|
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
|
||||||
top_p=gen_conf.get("top_p", 0.7)
|
|
||||||
)
|
|
||||||
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "**ERROR**: " + str(e), 0
|
return "**ERROR**: " + str(e), 0
|
||||||
else:
|
else:
|
||||||
import vertexai.generative_models as glm
|
import vertexai.generative_models as glm
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
||||||
try:
|
try:
|
||||||
@ -1213,15 +1221,10 @@ class GoogleCV(Base):
|
|||||||
|
|
||||||
# Create proper image part for Gemini
|
# Create proper image part for Gemini
|
||||||
img_bytes = base64.b64decode(image)
|
img_bytes = base64.b64decode(image)
|
||||||
image_part = glm.Part.from_data(
|
image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg")
|
||||||
data=img_bytes,
|
|
||||||
mime_type="image/jpeg"
|
|
||||||
)
|
|
||||||
history[-1]["parts"].append(image_part)
|
history[-1]["parts"].append(image_part)
|
||||||
|
|
||||||
response = self.client.generate_content(history, generation_config=GenerationConfig(
|
response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
|
||||||
temperature=gen_conf.get("temperature", 0.3),
|
|
||||||
top_p=gen_conf.get("top_p", 0.7)))
|
|
||||||
|
|
||||||
ans = response.text
|
ans = response.text
|
||||||
return ans, response.usage_metadata.total_token_count
|
return ans, response.usage_metadata.total_token_count
|
||||||
|
|||||||
@ -13,28 +13,27 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
|
from abc import ABC
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import dashscope
|
||||||
|
import google.generativeai as genai
|
||||||
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from zhipuai import ZhipuAI
|
|
||||||
import os
|
|
||||||
from abc import ABC
|
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
import dashscope
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import numpy as np
|
from zhipuai import ZhipuAI
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from api.utils.log_utils import log_exception
|
from api.utils.log_utils import log_exception
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
import google.generativeai as genai
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
@ -60,7 +59,8 @@ class Base(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultEmbedding(Base):
|
class DefaultEmbedding(Base):
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
_FACTORY_NAME = "BAAI"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
_model = None
|
_model = None
|
||||||
_model_name = ""
|
_model_name = ""
|
||||||
_model_lock = threading.Lock()
|
_model_lock = threading.Lock()
|
||||||
@ -79,21 +79,22 @@ class DefaultEmbedding(Base):
|
|||||||
"""
|
"""
|
||||||
if not settings.LIGHTEN:
|
if not settings.LIGHTEN:
|
||||||
with DefaultEmbedding._model_lock:
|
with DefaultEmbedding._model_lock:
|
||||||
from FlagEmbedding import FlagModel
|
|
||||||
import torch
|
import torch
|
||||||
|
from FlagEmbedding import FlagModel
|
||||||
|
|
||||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||||
try:
|
try:
|
||||||
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
DefaultEmbedding._model = FlagModel(
|
||||||
|
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
use_fp16=torch.cuda.is_available())
|
use_fp16=torch.cuda.is_available(),
|
||||||
|
)
|
||||||
DefaultEmbedding._model_name = model_name
|
DefaultEmbedding._model_name = model_name
|
||||||
except Exception:
|
except Exception:
|
||||||
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
model_dir = snapshot_download(
|
||||||
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||||
local_dir_use_symlinks=False)
|
)
|
||||||
DefaultEmbedding._model = FlagModel(model_dir,
|
DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
|
||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
|
||||||
use_fp16=torch.cuda.is_available())
|
|
||||||
self._model = DefaultEmbedding._model
|
self._model = DefaultEmbedding._model
|
||||||
self._model_name = DefaultEmbedding._model_name
|
self._model_name = DefaultEmbedding._model_name
|
||||||
|
|
||||||
@ -105,7 +106,7 @@ class DefaultEmbedding(Base):
|
|||||||
token_count += num_tokens_from_string(t)
|
token_count += num_tokens_from_string(t)
|
||||||
ress = []
|
ress = []
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
ress.extend(self._model.encode(texts[i : i + batch_size]).tolist())
|
||||||
return np.array(ress), token_count
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text: str):
|
def encode_queries(self, text: str):
|
||||||
@ -114,8 +115,9 @@ class DefaultEmbedding(Base):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIEmbed(Base):
|
class OpenAIEmbed(Base):
|
||||||
def __init__(self, key, model_name="text-embedding-ada-002",
|
_FACTORY_NAME = "OpenAI"
|
||||||
base_url="https://api.openai.com/v1"):
|
|
||||||
|
def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
@ -128,8 +130,7 @@ class OpenAIEmbed(Base):
|
|||||||
ress = []
|
ress = []
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res = self.client.embeddings.create(input=texts[i:i + batch_size],
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||||
model=self.model_name)
|
|
||||||
try:
|
try:
|
||||||
ress.extend([d.embedding for d in res.data])
|
ress.extend([d.embedding for d in res.data])
|
||||||
total_tokens += self.total_token_count(res)
|
total_tokens += self.total_token_count(res)
|
||||||
@ -138,12 +139,13 @@ class OpenAIEmbed(Base):
|
|||||||
return np.array(ress), total_tokens
|
return np.array(ress), total_tokens
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name)
|
||||||
model=self.model_name)
|
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||||
|
|
||||||
|
|
||||||
class LocalAIEmbed(Base):
|
class LocalAIEmbed(Base):
|
||||||
|
_FACTORY_NAME = "LocalAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local embedding model url cannot be None")
|
raise ValueError("Local embedding model url cannot be None")
|
||||||
@ -155,7 +157,7 @@ class LocalAIEmbed(Base):
|
|||||||
batch_size = 16
|
batch_size = 16
|
||||||
ress = []
|
ress = []
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||||
try:
|
try:
|
||||||
ress.extend([d.embedding for d in res.data])
|
ress.extend([d.embedding for d in res.data])
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
@ -169,41 +171,42 @@ class LocalAIEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class AzureEmbed(OpenAIEmbed):
|
class AzureEmbed(OpenAIEmbed):
|
||||||
|
_FACTORY_NAME = "Azure-OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
from openai.lib.azure import AzureOpenAI
|
from openai.lib.azure import AzureOpenAI
|
||||||
api_key = json.loads(key).get('api_key', '')
|
|
||||||
api_version = json.loads(key).get('api_version', '2024-02-01')
|
api_key = json.loads(key).get("api_key", "")
|
||||||
|
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
|
|
||||||
class BaiChuanEmbed(OpenAIEmbed):
|
class BaiChuanEmbed(OpenAIEmbed):
|
||||||
def __init__(self, key,
|
_FACTORY_NAME = "BaiChuan"
|
||||||
model_name='Baichuan-Text-Embedding',
|
|
||||||
base_url='https://api.baichuan-ai.com/v1'):
|
def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.baichuan-ai.com/v1"
|
base_url = "https://api.baichuan-ai.com/v1"
|
||||||
super().__init__(key, model_name, base_url)
|
super().__init__(key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
class QWenEmbed(Base):
|
class QWenEmbed(Base):
|
||||||
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||||
|
|
||||||
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
res = []
|
res = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
texts = [truncate(t, 2048) for t in texts]
|
texts = [truncate(t, 2048) for t in texts]
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
resp = dashscope.TextEmbedding.call(
|
resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
|
||||||
model=self.model_name,
|
|
||||||
input=texts[i:i + batch_size],
|
|
||||||
api_key=self.key,
|
|
||||||
text_type="document"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
||||||
for e in resp["output"]["embeddings"]:
|
for e in resp["output"]["embeddings"]:
|
||||||
@ -216,20 +219,16 @@ class QWenEmbed(Base):
|
|||||||
return np.array(res), token_count
|
return np.array(res), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
resp = dashscope.TextEmbedding.call(
|
resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
|
||||||
model=self.model_name,
|
|
||||||
input=text[:2048],
|
|
||||||
api_key=self.key,
|
|
||||||
text_type="query"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
return np.array(resp["output"]["embeddings"][0]
|
return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
|
||||||
["embedding"]), self.total_token_count(resp)
|
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, resp)
|
log_exception(_e, resp)
|
||||||
|
|
||||||
|
|
||||||
class ZhipuEmbed(Base):
|
class ZhipuEmbed(Base):
|
||||||
|
_FACTORY_NAME = "ZHIPU-AI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="embedding-2", **kwargs):
|
def __init__(self, key, model_name="embedding-2", **kwargs):
|
||||||
self.client = ZhipuAI(api_key=key)
|
self.client = ZhipuAI(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -246,8 +245,7 @@ class ZhipuEmbed(Base):
|
|||||||
texts = [truncate(t, MAX_LEN) for t in texts]
|
texts = [truncate(t, MAX_LEN) for t in texts]
|
||||||
|
|
||||||
for txt in texts:
|
for txt in texts:
|
||||||
res = self.client.embeddings.create(input=txt,
|
res = self.client.embeddings.create(input=txt, model=self.model_name)
|
||||||
model=self.model_name)
|
|
||||||
try:
|
try:
|
||||||
arr.append(res.data[0].embedding)
|
arr.append(res.data[0].embedding)
|
||||||
tks_num += self.total_token_count(res)
|
tks_num += self.total_token_count(res)
|
||||||
@ -256,8 +254,7 @@ class ZhipuEmbed(Base):
|
|||||||
return np.array(arr), tks_num
|
return np.array(arr), tks_num
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings.create(input=text,
|
res = self.client.embeddings.create(input=text, model=self.model_name)
|
||||||
model=self.model_name)
|
|
||||||
try:
|
try:
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
@ -265,18 +262,17 @@ class ZhipuEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class OllamaEmbed(Base):
|
class OllamaEmbed(Base):
|
||||||
|
_FACTORY_NAME = "Ollama"
|
||||||
|
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
||||||
Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
arr = []
|
arr = []
|
||||||
tks_num = 0
|
tks_num = 0
|
||||||
for txt in texts:
|
for txt in texts:
|
||||||
res = self.client.embeddings(prompt=txt,
|
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
|
||||||
model=self.model_name,
|
|
||||||
options={"use_mmap": True})
|
|
||||||
try:
|
try:
|
||||||
arr.append(res["embedding"])
|
arr.append(res["embedding"])
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
@ -285,9 +281,7 @@ class OllamaEmbed(Base):
|
|||||||
return np.array(arr), tks_num
|
return np.array(arr), tks_num
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings(prompt=text,
|
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
|
||||||
model=self.model_name,
|
|
||||||
options={"use_mmap": True})
|
|
||||||
try:
|
try:
|
||||||
return np.array(res["embedding"]), 128
|
return np.array(res["embedding"]), 128
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
@ -295,6 +289,7 @@ class OllamaEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class FastEmbed(DefaultEmbedding):
|
class FastEmbed(DefaultEmbedding):
|
||||||
|
_FACTORY_NAME = "FastEmbed"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -307,15 +302,15 @@ class FastEmbed(DefaultEmbedding):
|
|||||||
if not settings.LIGHTEN:
|
if not settings.LIGHTEN:
|
||||||
with FastEmbed._model_lock:
|
with FastEmbed._model_lock:
|
||||||
from fastembed import TextEmbedding
|
from fastembed import TextEmbedding
|
||||||
|
|
||||||
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
||||||
try:
|
try:
|
||||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||||
DefaultEmbedding._model_name = model_name
|
DefaultEmbedding._model_name = model_name
|
||||||
except Exception:
|
except Exception:
|
||||||
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
|
cache_dir = snapshot_download(
|
||||||
local_dir=os.path.join(get_home_cache_dir(),
|
repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
|
||||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
)
|
||||||
local_dir_use_symlinks=False)
|
|
||||||
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||||
self._model = DefaultEmbedding._model
|
self._model = DefaultEmbedding._model
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding):
|
|||||||
|
|
||||||
|
|
||||||
class XinferenceEmbed(Base):
|
class XinferenceEmbed(Base):
|
||||||
|
_FACTORY_NAME = "Xinference"
|
||||||
|
|
||||||
def __init__(self, key, model_name="", base_url=""):
|
def __init__(self, key, model_name="", base_url=""):
|
||||||
base_url = urljoin(base_url, "v1")
|
base_url = urljoin(base_url, "v1")
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
@ -350,7 +347,7 @@ class XinferenceEmbed(Base):
|
|||||||
ress = []
|
ress = []
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
|
||||||
try:
|
try:
|
||||||
ress.extend([d.embedding for d in res.data])
|
ress.extend([d.embedding for d in res.data])
|
||||||
total_tokens += self.total_token_count(res)
|
total_tokens += self.total_token_count(res)
|
||||||
@ -359,8 +356,7 @@ class XinferenceEmbed(Base):
|
|||||||
return np.array(ress), total_tokens
|
return np.array(ress), total_tokens
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings.create(input=[text],
|
res = self.client.embeddings.create(input=[text], model=self.model_name)
|
||||||
model=self.model_name)
|
|
||||||
try:
|
try:
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
@ -368,20 +364,18 @@ class XinferenceEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class YoudaoEmbed(Base):
|
class YoudaoEmbed(Base):
|
||||||
|
_FACTORY_NAME = "Youdao"
|
||||||
_client = None
|
_client = None
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||||
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||||
from BCEmbedding import EmbeddingModel as qanthing
|
from BCEmbedding import EmbeddingModel as qanthing
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("LOADING BCE...")
|
logging.info("LOADING BCE...")
|
||||||
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
|
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
|
||||||
get_home_cache_dir(),
|
|
||||||
"bce-embedding-base_v1"))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
YoudaoEmbed._client = qanthing(
|
YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||||
model_name_or_path=model_name.replace(
|
|
||||||
"maidalun1020", "InfiniFlow"))
|
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
@ -390,7 +384,7 @@ class YoudaoEmbed(Base):
|
|||||||
for t in texts:
|
for t in texts:
|
||||||
token_count += num_tokens_from_string(t)
|
token_count += num_tokens_from_string(t)
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
|
embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
|
||||||
res.extend(embds)
|
res.extend(embds)
|
||||||
return np.array(res), token_count
|
return np.array(res), token_count
|
||||||
|
|
||||||
@ -400,14 +394,11 @@ class YoudaoEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class JinaEmbed(Base):
|
class JinaEmbed(Base):
|
||||||
def __init__(self, key, model_name="jina-embeddings-v3",
|
_FACTORY_NAME = "Jina"
|
||||||
base_url="https://api.jina.ai/v1/embeddings"):
|
|
||||||
|
|
||||||
|
def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
|
||||||
self.base_url = "https://api.jina.ai/v1/embeddings"
|
self.base_url = "https://api.jina.ai/v1/embeddings"
|
||||||
self.headers = {
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}"
|
|
||||||
}
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
@ -416,11 +407,7 @@ class JinaEmbed(Base):
|
|||||||
ress = []
|
ress = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
data = {
|
data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
|
||||||
"model": self.model_name,
|
|
||||||
"input": texts[i:i + batch_size],
|
|
||||||
'encoding_type': 'float'
|
|
||||||
}
|
|
||||||
response = requests.post(self.base_url, headers=self.headers, json=data)
|
response = requests.post(self.base_url, headers=self.headers, json=data)
|
||||||
try:
|
try:
|
||||||
res = response.json()
|
res = response.json()
|
||||||
@ -435,50 +422,12 @@ class JinaEmbed(Base):
|
|||||||
return np.array(embds[0]), cnt
|
return np.array(embds[0]), cnt
|
||||||
|
|
||||||
|
|
||||||
class InfinityEmbed(Base):
|
|
||||||
_model = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
|
|
||||||
engine_kwargs: dict = {},
|
|
||||||
key = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
from infinity_emb import EngineArgs
|
|
||||||
from infinity_emb.engine import AsyncEngineArray
|
|
||||||
|
|
||||||
self._default_model = model_names[0]
|
|
||||||
self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
|
|
||||||
|
|
||||||
async def _embed(self, sentences: list[str], model_name: str = ""):
|
|
||||||
if not model_name:
|
|
||||||
model_name = self._default_model
|
|
||||||
engine = self.engine_array[model_name]
|
|
||||||
was_already_running = engine.is_running
|
|
||||||
if not was_already_running:
|
|
||||||
await engine.astart()
|
|
||||||
embeddings, usage = await engine.embed(sentences=sentences)
|
|
||||||
if not was_already_running:
|
|
||||||
await engine.astop()
|
|
||||||
return embeddings, usage
|
|
||||||
|
|
||||||
def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
|
|
||||||
# Using the internal tokenizer to encode the texts and get the total
|
|
||||||
# number of tokens
|
|
||||||
embeddings, usage = asyncio.run(self._embed(texts, model_name))
|
|
||||||
return np.array(embeddings), usage
|
|
||||||
|
|
||||||
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
|
|
||||||
# Using the internal tokenizer to encode the texts and get the total
|
|
||||||
# number of tokens
|
|
||||||
return self.encode([text])
|
|
||||||
|
|
||||||
|
|
||||||
class MistralEmbed(Base):
|
class MistralEmbed(Base):
|
||||||
def __init__(self, key, model_name="mistral-embed",
|
_FACTORY_NAME = "Mistral"
|
||||||
base_url=None):
|
|
||||||
|
def __init__(self, key, model_name="mistral-embed", base_url=None):
|
||||||
from mistralai.client import MistralClient
|
from mistralai.client import MistralClient
|
||||||
|
|
||||||
self.client = MistralClient(api_key=key)
|
self.client = MistralClient(api_key=key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
@ -488,8 +437,7 @@ class MistralEmbed(Base):
|
|||||||
ress = []
|
ress = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res = self.client.embeddings(input=texts[i:i + batch_size],
|
res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
|
||||||
model=self.model_name)
|
|
||||||
try:
|
try:
|
||||||
ress.extend([d.embedding for d in res.data])
|
ress.extend([d.embedding for d in res.data])
|
||||||
token_count += self.total_token_count(res)
|
token_count += self.total_token_count(res)
|
||||||
@ -498,8 +446,7 @@ class MistralEmbed(Base):
|
|||||||
return np.array(ress), token_count
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embeddings(input=[truncate(text, 8196)],
|
res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
|
||||||
model=self.model_name)
|
|
||||||
try:
|
try:
|
||||||
return np.array(res.data[0].embedding), self.total_token_count(res)
|
return np.array(res.data[0].embedding), self.total_token_count(res)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
@ -507,30 +454,31 @@ class MistralEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class BedrockEmbed(Base):
|
class BedrockEmbed(Base):
|
||||||
def __init__(self, key, model_name,
|
_FACTORY_NAME = "Bedrock"
|
||||||
**kwargs):
|
|
||||||
|
def __init__(self, key, model_name, **kwargs):
|
||||||
import boto3
|
import boto3
|
||||||
self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
|
|
||||||
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
|
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||||
self.bedrock_region = json.loads(key).get('bedrock_region', '')
|
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||||
|
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
|
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
|
||||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||||
self.client = boto3.client('bedrock-runtime')
|
self.client = boto3.client("bedrock-runtime")
|
||||||
else:
|
else:
|
||||||
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||||
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
texts = [truncate(t, 8196) for t in texts]
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
embeddings = []
|
embeddings = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
for text in texts:
|
for text in texts:
|
||||||
if self.model_name.split('.')[0] == 'amazon':
|
if self.model_name.split(".")[0] == "amazon":
|
||||||
body = {"inputText": text}
|
body = {"inputText": text}
|
||||||
elif self.model_name.split('.')[0] == 'cohere':
|
elif self.model_name.split(".")[0] == "cohere":
|
||||||
body = {"texts": [text], "input_type": 'search_document'}
|
body = {"texts": [text], "input_type": "search_document"}
|
||||||
|
|
||||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||||
try:
|
try:
|
||||||
@ -545,10 +493,10 @@ class BedrockEmbed(Base):
|
|||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
token_count = num_tokens_from_string(text)
|
token_count = num_tokens_from_string(text)
|
||||||
if self.model_name.split('.')[0] == 'amazon':
|
if self.model_name.split(".")[0] == "amazon":
|
||||||
body = {"inputText": truncate(text, 8196)}
|
body = {"inputText": truncate(text, 8196)}
|
||||||
elif self.model_name.split('.')[0] == 'cohere':
|
elif self.model_name.split(".")[0] == "cohere":
|
||||||
body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
|
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
|
||||||
|
|
||||||
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
||||||
try:
|
try:
|
||||||
@ -561,10 +509,11 @@ class BedrockEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GeminiEmbed(Base):
|
class GeminiEmbed(Base):
|
||||||
def __init__(self, key, model_name='models/text-embedding-004',
|
_FACTORY_NAME = "Gemini"
|
||||||
**kwargs):
|
|
||||||
|
def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_name = 'models/' + model_name
|
self.model_name = "models/" + model_name
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
texts = [truncate(t, 2048) for t in texts]
|
texts = [truncate(t, 2048) for t in texts]
|
||||||
@ -573,35 +522,27 @@ class GeminiEmbed(Base):
|
|||||||
batch_size = 16
|
batch_size = 16
|
||||||
ress = []
|
ress = []
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
result = genai.embed_content(
|
result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
|
||||||
model=self.model_name,
|
|
||||||
content=texts[i: i + batch_size],
|
|
||||||
task_type="retrieval_document",
|
|
||||||
title="Embedding of single string")
|
|
||||||
try:
|
try:
|
||||||
ress.extend(result['embedding'])
|
ress.extend(result["embedding"])
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, result)
|
log_exception(_e, result)
|
||||||
return np.array(ress),token_count
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
genai.configure(api_key=self.key)
|
genai.configure(api_key=self.key)
|
||||||
result = genai.embed_content(
|
result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
|
||||||
model=self.model_name,
|
|
||||||
content=truncate(text,2048),
|
|
||||||
task_type="retrieval_document",
|
|
||||||
title="Embedding of single string")
|
|
||||||
token_count = num_tokens_from_string(text)
|
token_count = num_tokens_from_string(text)
|
||||||
try:
|
try:
|
||||||
return np.array(result['embedding']), token_count
|
return np.array(result["embedding"]), token_count
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
log_exception(_e, result)
|
log_exception(_e, result)
|
||||||
|
|
||||||
|
|
||||||
class NvidiaEmbed(Base):
|
class NvidiaEmbed(Base):
|
||||||
def __init__(
|
_FACTORY_NAME = "NVIDIA"
|
||||||
self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
|
|
||||||
):
|
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
@ -645,6 +586,8 @@ class NvidiaEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LmStudioEmbed(LocalAIEmbed):
|
class LmStudioEmbed(LocalAIEmbed):
|
||||||
|
_FACTORY_NAME = "LM-Studio"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("Local llm url cannot be None")
|
raise ValueError("Local llm url cannot be None")
|
||||||
@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAI_APIEmbed(OpenAIEmbed):
|
class OpenAI_APIEmbed(OpenAIEmbed):
|
||||||
|
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("url cannot be None")
|
raise ValueError("url cannot be None")
|
||||||
@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed):
|
|||||||
|
|
||||||
|
|
||||||
class CoHereEmbed(Base):
|
class CoHereEmbed(Base):
|
||||||
|
_FACTORY_NAME = "Cohere"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
from cohere import Client
|
from cohere import Client
|
||||||
|
|
||||||
@ -701,6 +648,8 @@ class CoHereEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class TogetherAIEmbed(OpenAIEmbed):
|
class TogetherAIEmbed(OpenAIEmbed):
|
||||||
|
_FACTORY_NAME = "TogetherAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
|
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.together.xyz/v1"
|
base_url = "https://api.together.xyz/v1"
|
||||||
@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed):
|
|||||||
|
|
||||||
|
|
||||||
class PerfXCloudEmbed(OpenAIEmbed):
|
class PerfXCloudEmbed(OpenAIEmbed):
|
||||||
|
_FACTORY_NAME = "PerfXCloud"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://cloud.perfxlab.cn/v1"
|
base_url = "https://cloud.perfxlab.cn/v1"
|
||||||
@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed):
|
|||||||
|
|
||||||
|
|
||||||
class UpstageEmbed(OpenAIEmbed):
|
class UpstageEmbed(OpenAIEmbed):
|
||||||
|
_FACTORY_NAME = "Upstage"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
|
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.upstage.ai/v1/solar"
|
base_url = "https://api.upstage.ai/v1/solar"
|
||||||
@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed):
|
|||||||
|
|
||||||
|
|
||||||
class SILICONFLOWEmbed(Base):
|
class SILICONFLOWEmbed(Base):
|
||||||
|
_FACTORY_NAME = "SILICONFLOW"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
|
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
base_url = "https://api.siliconflow.cn/v1/embeddings"
|
||||||
@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class ReplicateEmbed(Base):
|
class ReplicateEmbed(Base):
|
||||||
|
_FACTORY_NAME = "Replicate"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
from replicate.client import Client
|
from replicate.client import Client
|
||||||
|
|
||||||
@ -790,6 +747,8 @@ class ReplicateEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class BaiduYiyanEmbed(Base):
|
class BaiduYiyanEmbed(Base):
|
||||||
|
_FACTORY_NAME = "BaiduYiyan"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
import qianfan
|
import qianfan
|
||||||
|
|
||||||
@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class VoyageEmbed(Base):
|
class VoyageEmbed(Base):
|
||||||
|
_FACTORY_NAME = "Voyage AI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
import voyageai
|
import voyageai
|
||||||
|
|
||||||
@ -832,9 +793,7 @@ class VoyageEmbed(Base):
|
|||||||
ress = []
|
ress = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res = self.client.embed(
|
res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
|
||||||
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
ress.extend(res.embeddings)
|
ress.extend(res.embeddings)
|
||||||
token_count += res.total_tokens
|
token_count += res.total_tokens
|
||||||
@ -843,9 +802,7 @@ class VoyageEmbed(Base):
|
|||||||
return np.array(ress), token_count
|
return np.array(ress), token_count
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
res = self.client.embed(
|
res = self.client.embed(texts=text, model=self.model_name, input_type="query")
|
||||||
texts=text, model=self.model_name, input_type="query"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
return np.array(res.embeddings)[0], res.total_tokens
|
return np.array(res.embeddings)[0], res.total_tokens
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
@ -853,6 +810,8 @@ class VoyageEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class HuggingFaceEmbed(Base):
|
class HuggingFaceEmbed(Base):
|
||||||
|
_FACTORY_NAME = "HuggingFace"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
if not model_name:
|
if not model_name:
|
||||||
raise ValueError("Model name cannot be None")
|
raise ValueError("Model name cannot be None")
|
||||||
@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base):
|
|||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
response = requests.post(
|
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||||
f"{self.base_url}/embed",
|
|
||||||
json={"inputs": text},
|
|
||||||
headers={'Content-Type': 'application/json'}
|
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
embedding = response.json()
|
embedding = response.json()
|
||||||
embeddings.append(embedding[0])
|
embeddings.append(embedding[0])
|
||||||
@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base):
|
|||||||
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
|
||||||
|
|
||||||
def encode_queries(self, text):
|
def encode_queries(self, text):
|
||||||
response = requests.post(
|
response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
|
||||||
f"{self.base_url}/embed",
|
|
||||||
json={"inputs": text},
|
|
||||||
headers={'Content-Type': 'application/json'}
|
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
embedding = response.json()
|
embedding = response.json()
|
||||||
return np.array(embedding[0]), num_tokens_from_string(text)
|
return np.array(embedding[0]), num_tokens_from_string(text)
|
||||||
@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base):
|
|||||||
|
|
||||||
|
|
||||||
class VolcEngineEmbed(OpenAIEmbed):
|
class VolcEngineEmbed(OpenAIEmbed):
|
||||||
|
_FACTORY_NAME = "VolcEngine"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
base_url = "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
ark_api_key = json.loads(key).get('ark_api_key', '')
|
ark_api_key = json.loads(key).get("ark_api_key", "")
|
||||||
model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
|
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
|
||||||
super().__init__(ark_api_key,model_name,base_url)
|
super().__init__(ark_api_key, model_name, base_url)
|
||||||
|
|
||||||
|
|
||||||
class GPUStackEmbed(OpenAIEmbed):
|
class GPUStackEmbed(OpenAIEmbed):
|
||||||
|
_FACTORY_NAME = "GPUStack"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("url cannot be None")
|
raise ValueError("url cannot be None")
|
||||||
@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed):
|
|||||||
|
|
||||||
|
|
||||||
class NovitaEmbed(SILICONFLOWEmbed):
|
class NovitaEmbed(SILICONFLOWEmbed):
|
||||||
|
_FACTORY_NAME = "NovitaAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
|
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.novita.ai/v3/openai/embeddings"
|
base_url = "https://api.novita.ai/v3/openai/embeddings"
|
||||||
@ -915,6 +872,8 @@ class NovitaEmbed(SILICONFLOWEmbed):
|
|||||||
|
|
||||||
|
|
||||||
class GiteeEmbed(SILICONFLOWEmbed):
|
class GiteeEmbed(SILICONFLOWEmbed):
|
||||||
|
_FACTORY_NAME = "GiteeAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
|
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://ai.gitee.com/v1/embeddings"
|
base_url = "https://ai.gitee.com/v1/embeddings"
|
||||||
|
|||||||
@ -13,24 +13,24 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
|
from abc import ABC
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
|
||||||
import httpx
|
import httpx
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
import os
|
|
||||||
from abc import ABC
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from api.utils.log_utils import log_exception
|
from api.utils.log_utils import log_exception
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def sigmoid(x):
|
def sigmoid(x):
|
||||||
@ -57,6 +57,7 @@ class Base(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultRerank(Base):
|
class DefaultRerank(Base):
|
||||||
|
_FACTORY_NAME = "BAAI"
|
||||||
_model = None
|
_model = None
|
||||||
_model_lock = threading.Lock()
|
_model_lock = threading.Lock()
|
||||||
|
|
||||||
@ -75,17 +76,13 @@ class DefaultRerank(Base):
|
|||||||
if not settings.LIGHTEN and not DefaultRerank._model:
|
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||||
import torch
|
import torch
|
||||||
from FlagEmbedding import FlagReranker
|
from FlagEmbedding import FlagReranker
|
||||||
|
|
||||||
with DefaultRerank._model_lock:
|
with DefaultRerank._model_lock:
|
||||||
if not DefaultRerank._model:
|
if not DefaultRerank._model:
|
||||||
try:
|
try:
|
||||||
DefaultRerank._model = FlagReranker(
|
DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
|
||||||
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
|
||||||
use_fp16=torch.cuda.is_available())
|
|
||||||
except Exception:
|
except Exception:
|
||||||
model_dir = snapshot_download(repo_id=model_name,
|
model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
|
||||||
local_dir=os.path.join(get_home_cache_dir(),
|
|
||||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
|
||||||
local_dir_use_symlinks=False)
|
|
||||||
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
||||||
self._model = DefaultRerank._model
|
self._model = DefaultRerank._model
|
||||||
self._dynamic_batch_size = 8
|
self._dynamic_batch_size = 8
|
||||||
@ -94,6 +91,7 @@ class DefaultRerank(Base):
|
|||||||
def torch_empty_cache(self):
|
def torch_empty_cache(self):
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error emptying cache: {e}")
|
print(f"Error emptying cache: {e}")
|
||||||
@ -112,7 +110,7 @@ class DefaultRerank(Base):
|
|||||||
while retry_count < max_retries:
|
while retry_count < max_retries:
|
||||||
try:
|
try:
|
||||||
# call subclass implemented batch processing calculation
|
# call subclass implemented batch processing calculation
|
||||||
batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
|
batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
|
||||||
res.extend(batch_scores)
|
res.extend(batch_scores)
|
||||||
i += current_batch
|
i += current_batch
|
||||||
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
|
||||||
@ -152,23 +150,16 @@ class DefaultRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class JinaRerank(Base):
|
class JinaRerank(Base):
|
||||||
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
|
_FACTORY_NAME = "Jina"
|
||||||
base_url="https://api.jina.ai/v1/rerank"):
|
|
||||||
|
def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
|
||||||
self.base_url = "https://api.jina.ai/v1/rerank"
|
self.base_url = "https://api.jina.ai/v1/rerank"
|
||||||
self.headers = {
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}"
|
|
||||||
}
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
texts = [truncate(t, 8196) for t in texts]
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
data = {
|
data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
|
||||||
"model": self.model_name,
|
|
||||||
"query": query,
|
|
||||||
"documents": texts,
|
|
||||||
"top_n": len(texts)
|
|
||||||
}
|
|
||||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||||
rank = np.zeros(len(texts), dtype=float)
|
rank = np.zeros(len(texts), dtype=float)
|
||||||
try:
|
try:
|
||||||
@ -180,22 +171,20 @@ class JinaRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class YoudaoRerank(DefaultRerank):
|
class YoudaoRerank(DefaultRerank):
|
||||||
|
_FACTORY_NAME = "Youdao"
|
||||||
_model = None
|
_model = None
|
||||||
_model_lock = threading.Lock()
|
_model_lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||||
if not settings.LIGHTEN and not YoudaoRerank._model:
|
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||||
from BCEmbedding import RerankerModel
|
from BCEmbedding import RerankerModel
|
||||||
|
|
||||||
with YoudaoRerank._model_lock:
|
with YoudaoRerank._model_lock:
|
||||||
if not YoudaoRerank._model:
|
if not YoudaoRerank._model:
|
||||||
try:
|
try:
|
||||||
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
|
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
||||||
get_home_cache_dir(),
|
|
||||||
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
YoudaoRerank._model = RerankerModel(
|
YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
|
||||||
model_name_or_path=model_name.replace(
|
|
||||||
"maidalun1020", "InfiniFlow"))
|
|
||||||
|
|
||||||
self._model = YoudaoRerank._model
|
self._model = YoudaoRerank._model
|
||||||
self._dynamic_batch_size = 8
|
self._dynamic_batch_size = 8
|
||||||
@ -212,6 +201,8 @@ class YoudaoRerank(DefaultRerank):
|
|||||||
|
|
||||||
|
|
||||||
class XInferenceRerank(Base):
|
class XInferenceRerank(Base):
|
||||||
|
_FACTORY_NAME = "Xinference"
|
||||||
|
|
||||||
def __init__(self, key="x", model_name="", base_url=""):
|
def __init__(self, key="x", model_name="", base_url=""):
|
||||||
if base_url.find("/v1") == -1:
|
if base_url.find("/v1") == -1:
|
||||||
base_url = urljoin(base_url, "/v1/rerank")
|
base_url = urljoin(base_url, "/v1/rerank")
|
||||||
@ -219,10 +210,7 @@ class XInferenceRerank(Base):
|
|||||||
base_url = urljoin(base_url, "/v1/rerank")
|
base_url = urljoin(base_url, "/v1/rerank")
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = {
|
self.headers = {"Content-Type": "application/json", "accept": "application/json"}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"accept": "application/json"
|
|
||||||
}
|
|
||||||
if key and key != "x":
|
if key and key != "x":
|
||||||
self.headers["Authorization"] = f"Bearer {key}"
|
self.headers["Authorization"] = f"Bearer {key}"
|
||||||
|
|
||||||
@ -233,13 +221,7 @@ class XInferenceRerank(Base):
|
|||||||
token_count = 0
|
token_count = 0
|
||||||
for _, t in pairs:
|
for _, t in pairs:
|
||||||
token_count += num_tokens_from_string(t)
|
token_count += num_tokens_from_string(t)
|
||||||
data = {
|
data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
|
||||||
"model": self.model_name,
|
|
||||||
"query": query,
|
|
||||||
"return_documents": "true",
|
|
||||||
"return_len": "true",
|
|
||||||
"documents": texts
|
|
||||||
}
|
|
||||||
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
||||||
rank = np.zeros(len(texts), dtype=float)
|
rank = np.zeros(len(texts), dtype=float)
|
||||||
try:
|
try:
|
||||||
@ -251,15 +233,14 @@ class XInferenceRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LocalAIRerank(Base):
|
class LocalAIRerank(Base):
|
||||||
|
_FACTORY_NAME = "LocalAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
if base_url.find("/rerank") == -1:
|
if base_url.find("/rerank") == -1:
|
||||||
self.base_url = urljoin(base_url, "/rerank")
|
self.base_url = urljoin(base_url, "/rerank")
|
||||||
else:
|
else:
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = {
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}"
|
|
||||||
}
|
|
||||||
self.model_name = model_name.split("___")[0]
|
self.model_name = model_name.split("___")[0]
|
||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
@ -296,16 +277,15 @@ class LocalAIRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class NvidiaRerank(Base):
|
class NvidiaRerank(Base):
|
||||||
def __init__(
|
_FACTORY_NAME = "NVIDIA"
|
||||||
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
|
||||||
):
|
def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
|
||||||
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking"
|
self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
|
||||||
)
|
|
||||||
|
|
||||||
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
if self.model_name == "nvidia/rerank-qa-mistral-4b":
|
||||||
self.base_url = urljoin(base_url, "reranking")
|
self.base_url = urljoin(base_url, "reranking")
|
||||||
@ -318,9 +298,7 @@ class NvidiaRerank(Base):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
token_count = num_tokens_from_string(query) + sum(
|
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||||
[num_tokens_from_string(t) for t in texts]
|
|
||||||
)
|
|
||||||
data = {
|
data = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"query": {"text": query},
|
"query": {"text": query},
|
||||||
@ -339,6 +317,8 @@ class NvidiaRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class LmStudioRerank(Base):
|
class LmStudioRerank(Base):
|
||||||
|
_FACTORY_NAME = "LM-Studio"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -347,15 +327,14 @@ class LmStudioRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAI_APIRerank(Base):
|
class OpenAI_APIRerank(Base):
|
||||||
|
_FACTORY_NAME = "OpenAI-API-Compatible"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
if base_url.find("/rerank") == -1:
|
if base_url.find("/rerank") == -1:
|
||||||
self.base_url = urljoin(base_url, "/rerank")
|
self.base_url = urljoin(base_url, "/rerank")
|
||||||
else:
|
else:
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = {
|
self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {key}"
|
|
||||||
}
|
|
||||||
self.model_name = model_name.split("___")[0]
|
self.model_name = model_name.split("___")[0]
|
||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
@ -392,6 +371,8 @@ class OpenAI_APIRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class CoHereRerank(Base):
|
class CoHereRerank(Base):
|
||||||
|
_FACTORY_NAME = ["Cohere", "VLLM"]
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
from cohere import Client
|
from cohere import Client
|
||||||
|
|
||||||
@ -399,9 +380,7 @@ class CoHereRerank(Base):
|
|||||||
self.model_name = model_name.split("___")[0]
|
self.model_name = model_name.split("___")[0]
|
||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
token_count = num_tokens_from_string(query) + sum(
|
token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
|
||||||
[num_tokens_from_string(t) for t in texts]
|
|
||||||
)
|
|
||||||
res = self.client.rerank(
|
res = self.client.rerank(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
query=query,
|
query=query,
|
||||||
@ -419,6 +398,8 @@ class CoHereRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class TogetherAIRerank(Base):
|
class TogetherAIRerank(Base):
|
||||||
|
_FACTORY_NAME = "TogetherAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -427,9 +408,9 @@ class TogetherAIRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class SILICONFLOWRerank(Base):
|
class SILICONFLOWRerank(Base):
|
||||||
def __init__(
|
_FACTORY_NAME = "SILICONFLOW"
|
||||||
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
|
|
||||||
):
|
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.siliconflow.cn/v1/rerank"
|
base_url = "https://api.siliconflow.cn/v1/rerank"
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -450,9 +431,7 @@ class SILICONFLOWRerank(Base):
|
|||||||
"max_chunks_per_doc": 1024,
|
"max_chunks_per_doc": 1024,
|
||||||
"overlap_tokens": 80,
|
"overlap_tokens": 80,
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
||||||
self.base_url, json=payload, headers=self.headers
|
|
||||||
).json()
|
|
||||||
rank = np.zeros(len(texts), dtype=float)
|
rank = np.zeros(len(texts), dtype=float)
|
||||||
try:
|
try:
|
||||||
for d in response["results"]:
|
for d in response["results"]:
|
||||||
@ -466,6 +445,8 @@ class SILICONFLOWRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class BaiduYiyanRerank(Base):
|
class BaiduYiyanRerank(Base):
|
||||||
|
_FACTORY_NAME = "BaiduYiyan"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
from qianfan.resources import Reranker
|
from qianfan.resources import Reranker
|
||||||
|
|
||||||
@ -492,6 +473,8 @@ class BaiduYiyanRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class VoyageRerank(Base):
|
class VoyageRerank(Base):
|
||||||
|
_FACTORY_NAME = "Voyage AI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None):
|
def __init__(self, key, model_name, base_url=None):
|
||||||
import voyageai
|
import voyageai
|
||||||
|
|
||||||
@ -502,9 +485,7 @@ class VoyageRerank(Base):
|
|||||||
rank = np.zeros(len(texts), dtype=float)
|
rank = np.zeros(len(texts), dtype=float)
|
||||||
if not texts:
|
if not texts:
|
||||||
return rank, 0
|
return rank, 0
|
||||||
res = self.client.rerank(
|
res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
|
||||||
query=query, documents=texts, model=self.model_name, top_k=len(texts)
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
for r in res.results:
|
for r in res.results:
|
||||||
rank[r.index] = r.relevance_score
|
rank[r.index] = r.relevance_score
|
||||||
@ -514,22 +495,20 @@ class VoyageRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class QWenRerank(Base):
|
class QWenRerank(Base):
|
||||||
def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||||
|
|
||||||
|
def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
|
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
|
||||||
|
|
||||||
def similarity(self, query: str, texts: list):
|
def similarity(self, query: str, texts: list):
|
||||||
import dashscope
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
resp = dashscope.TextReRank.call(
|
|
||||||
api_key=self.api_key,
|
import dashscope
|
||||||
model=self.model_name,
|
|
||||||
query=query,
|
resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
|
||||||
documents=texts,
|
|
||||||
top_n=len(texts),
|
|
||||||
return_documents=False
|
|
||||||
)
|
|
||||||
rank = np.zeros(len(texts), dtype=float)
|
rank = np.zeros(len(texts), dtype=float)
|
||||||
if resp.status_code == HTTPStatus.OK:
|
if resp.status_code == HTTPStatus.OK:
|
||||||
try:
|
try:
|
||||||
@ -543,6 +522,8 @@ class QWenRerank(Base):
|
|||||||
|
|
||||||
|
|
||||||
class HuggingfaceRerank(DefaultRerank):
|
class HuggingfaceRerank(DefaultRerank):
|
||||||
|
_FACTORY_NAME = "HuggingFace"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post(query: str, texts: list, url="127.0.0.1"):
|
def post(query: str, texts: list, url="127.0.0.1"):
|
||||||
exc = None
|
exc = None
|
||||||
@ -550,9 +531,9 @@ class HuggingfaceRerank(DefaultRerank):
|
|||||||
batch_size = 8
|
batch_size = 8
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
try:
|
try:
|
||||||
res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
|
res = requests.post(
|
||||||
json={"query": query, "texts": texts[i: i + batch_size],
|
f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
|
||||||
"raw_scores": False, "truncate": True})
|
)
|
||||||
|
|
||||||
for o in res.json():
|
for o in res.json():
|
||||||
scores[o["index"] + i] = o["score"]
|
scores[o["index"] + i] = o["score"]
|
||||||
@ -577,9 +558,9 @@ class HuggingfaceRerank(DefaultRerank):
|
|||||||
|
|
||||||
|
|
||||||
class GPUStackRerank(Base):
|
class GPUStackRerank(Base):
|
||||||
def __init__(
|
_FACTORY_NAME = "GPUStack"
|
||||||
self, key, model_name, base_url
|
|
||||||
):
|
def __init__(self, key, model_name, base_url):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("url cannot be None")
|
raise ValueError("url cannot be None")
|
||||||
|
|
||||||
@ -600,9 +581,7 @@ class GPUStackRerank(Base):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(self.base_url, json=payload, headers=self.headers)
|
||||||
self.base_url, json=payload, headers=self.headers
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
@ -623,11 +602,12 @@ class GPUStackRerank(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise ValueError(
|
raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
||||||
f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
|
|
||||||
|
|
||||||
|
|
||||||
class NovitaRerank(JinaRerank):
|
class NovitaRerank(JinaRerank):
|
||||||
|
_FACTORY_NAME = "NovitaAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
|
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.novita.ai/v3/openai/rerank"
|
base_url = "https://api.novita.ai/v3/openai/rerank"
|
||||||
@ -635,6 +615,8 @@ class NovitaRerank(JinaRerank):
|
|||||||
|
|
||||||
|
|
||||||
class GiteeRerank(JinaRerank):
|
class GiteeRerank(JinaRerank):
|
||||||
|
_FACTORY_NAME = "GiteeAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
|
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://ai.gitee.com/v1/rerank"
|
base_url = "https://ai.gitee.com/v1/rerank"
|
||||||
|
|||||||
@ -13,16 +13,18 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import os
|
|
||||||
import requests
|
|
||||||
from openai.lib.azure import AzureOpenAI
|
|
||||||
import io
|
|
||||||
from abc import ABC
|
|
||||||
from openai import OpenAI
|
|
||||||
import json
|
|
||||||
from rag.utils import num_tokens_from_string
|
|
||||||
import base64
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from openai import OpenAI
|
||||||
|
from openai.lib.azure import AzureOpenAI
|
||||||
|
|
||||||
|
from rag.utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
@ -30,11 +32,7 @@ class Base(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def transcription(self, audio, **kwargs):
|
def transcription(self, audio, **kwargs):
|
||||||
transcription = self.client.audio.transcriptions.create(
|
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text")
|
||||||
model=self.model_name,
|
|
||||||
file=audio,
|
|
||||||
response_format="text"
|
|
||||||
)
|
|
||||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||||
|
|
||||||
def audio2base64(self, audio):
|
def audio2base64(self, audio):
|
||||||
@ -46,6 +44,8 @@ class Base(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class GPTSeq2txt(Base):
|
class GPTSeq2txt(Base):
|
||||||
|
_FACTORY_NAME = "OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
@ -54,31 +54,34 @@ class GPTSeq2txt(Base):
|
|||||||
|
|
||||||
|
|
||||||
class QWenSeq2txt(Base):
|
class QWenSeq2txt(Base):
|
||||||
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||||
|
|
||||||
def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
|
def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
dashscope.api_key = key
|
dashscope.api_key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def transcription(self, audio, format):
|
def transcription(self, audio, format):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from dashscope.audio.asr import Recognition
|
from dashscope.audio.asr import Recognition
|
||||||
|
|
||||||
recognition = Recognition(model=self.model_name,
|
recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None)
|
||||||
format=format,
|
|
||||||
sample_rate=16000,
|
|
||||||
callback=None)
|
|
||||||
result = recognition.call(audio)
|
result = recognition.call(audio)
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
if result.status_code == HTTPStatus.OK:
|
if result.status_code == HTTPStatus.OK:
|
||||||
for sentence in result.get_sentence():
|
for sentence in result.get_sentence():
|
||||||
ans += sentence.text.decode('utf-8') + '\n'
|
ans += sentence.text.decode("utf-8") + "\n"
|
||||||
return ans, num_tokens_from_string(ans)
|
return ans, num_tokens_from_string(ans)
|
||||||
|
|
||||||
return "**ERROR**: " + result.message, 0
|
return "**ERROR**: " + result.message, 0
|
||||||
|
|
||||||
|
|
||||||
class AzureSeq2txt(Base):
|
class AzureSeq2txt(Base):
|
||||||
|
_FACTORY_NAME = "Azure-OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -86,43 +89,33 @@ class AzureSeq2txt(Base):
|
|||||||
|
|
||||||
|
|
||||||
class XinferenceSeq2txt(Base):
|
class XinferenceSeq2txt(Base):
|
||||||
|
_FACTORY_NAME = "Xinference"
|
||||||
|
|
||||||
def __init__(self, key, model_name="whisper-small", **kwargs):
|
def __init__(self, key, model_name="whisper-small", **kwargs):
|
||||||
self.base_url = kwargs.get('base_url', None)
|
self.base_url = kwargs.get("base_url", None)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.key = key
|
self.key = key
|
||||||
|
|
||||||
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio_file = open(audio, 'rb')
|
audio_file = open(audio, "rb")
|
||||||
audio_data = audio_file.read()
|
audio_data = audio_file.read()
|
||||||
audio_file_name = audio.split("/")[-1]
|
audio_file_name = audio.split("/")[-1]
|
||||||
else:
|
else:
|
||||||
audio_data = audio
|
audio_data = audio
|
||||||
audio_file_name = "audio.wav"
|
audio_file_name = "audio.wav"
|
||||||
|
|
||||||
payload = {
|
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
|
||||||
"model": self.model_name,
|
|
||||||
"language": language,
|
|
||||||
"prompt": prompt,
|
|
||||||
"response_format": response_format,
|
|
||||||
"temperature": temperature
|
|
||||||
}
|
|
||||||
|
|
||||||
files = {
|
files = {"file": (audio_file_name, audio_data, "audio/wav")}
|
||||||
"file": (audio_file_name, audio_data, 'audio/wav')
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
|
||||||
f"{self.base_url}/v1/audio/transcriptions",
|
|
||||||
files=files,
|
|
||||||
data=payload
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
if 'text' in result:
|
if "text" in result:
|
||||||
transcription_text = result['text'].strip()
|
transcription_text = result["text"].strip()
|
||||||
return transcription_text, num_tokens_from_string(transcription_text)
|
return transcription_text, num_tokens_from_string(transcription_text)
|
||||||
else:
|
else:
|
||||||
return "**ERROR**: Failed to retrieve transcription.", 0
|
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||||
@ -132,11 +125,11 @@ class XinferenceSeq2txt(Base):
|
|||||||
|
|
||||||
|
|
||||||
class TencentCloudSeq2txt(Base):
|
class TencentCloudSeq2txt(Base):
|
||||||
def __init__(
|
_FACTORY_NAME = "Tencent Cloud"
|
||||||
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
|
||||||
):
|
def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
|
||||||
from tencentcloud.common import credential
|
|
||||||
from tencentcloud.asr.v20190614 import asr_client
|
from tencentcloud.asr.v20190614 import asr_client
|
||||||
|
from tencentcloud.common import credential
|
||||||
|
|
||||||
key = json.loads(key)
|
key = json.loads(key)
|
||||||
sid = key.get("tencent_cloud_sid", "")
|
sid = key.get("tencent_cloud_sid", "")
|
||||||
@ -146,11 +139,12 @@ class TencentCloudSeq2txt(Base):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def transcription(self, audio, max_retries=60, retry_interval=5):
|
def transcription(self, audio, max_retries=60, retry_interval=5):
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tencentcloud.asr.v20190614 import models
|
||||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||||
TencentCloudSDKException,
|
TencentCloudSDKException,
|
||||||
)
|
)
|
||||||
from tencentcloud.asr.v20190614 import models
|
|
||||||
import time
|
|
||||||
|
|
||||||
b64 = self.audio2base64(audio)
|
b64 = self.audio2base64(audio)
|
||||||
try:
|
try:
|
||||||
@ -174,9 +168,7 @@ class TencentCloudSeq2txt(Base):
|
|||||||
while retries < max_retries:
|
while retries < max_retries:
|
||||||
resp = self.client.DescribeTaskStatus(req)
|
resp = self.client.DescribeTaskStatus(req)
|
||||||
if resp.Data.StatusStr == "success":
|
if resp.Data.StatusStr == "success":
|
||||||
text = re.sub(
|
text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
|
||||||
r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
|
|
||||||
).strip()
|
|
||||||
return text, num_tokens_from_string(text)
|
return text, num_tokens_from_string(text)
|
||||||
elif resp.Data.StatusStr == "failed":
|
elif resp.Data.StatusStr == "failed":
|
||||||
return (
|
return (
|
||||||
@ -195,6 +187,8 @@ class TencentCloudSeq2txt(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GPUStackSeq2txt(Base):
|
class GPUStackSeq2txt(Base):
|
||||||
|
_FACTORY_NAME = "GPUStack"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url):
|
def __init__(self, key, model_name, base_url):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise ValueError("url cannot be None")
|
raise ValueError("url cannot be None")
|
||||||
@ -206,8 +200,11 @@ class GPUStackSeq2txt(Base):
|
|||||||
|
|
||||||
|
|
||||||
class GiteeSeq2txt(Base):
|
class GiteeSeq2txt(Base):
|
||||||
|
_FACTORY_NAME = "GiteeAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
|
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://ai.gitee.com/v1/"
|
base_url = "https://ai.gitee.com/v1/"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
|
|||||||
@ -70,10 +70,12 @@ class Base(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def normalize_text(self, text):
|
def normalize_text(self, text):
|
||||||
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
|
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
|
||||||
|
|
||||||
|
|
||||||
class FishAudioTTS(Base):
|
class FishAudioTTS(Base):
|
||||||
|
_FACTORY_NAME = "Fish Audio"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
|
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.fish.audio/v1/tts"
|
base_url = "https://api.fish.audio/v1/tts"
|
||||||
@ -96,9 +98,7 @@ class FishAudioTTS(Base):
|
|||||||
with client.stream(
|
with client.stream(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=self.base_url,
|
url=self.base_url,
|
||||||
content=ormsgpack.packb(
|
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||||
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
|
||||||
),
|
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
) as response:
|
) as response:
|
||||||
@ -115,6 +115,8 @@ class FishAudioTTS(Base):
|
|||||||
|
|
||||||
|
|
||||||
class QwenTTS(Base):
|
class QwenTTS(Base):
|
||||||
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=""):
|
def __init__(self, key, model_name, base_url=""):
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
@ -122,10 +124,11 @@ class QwenTTS(Base):
|
|||||||
dashscope.api_key = key
|
dashscope.api_key = key
|
||||||
|
|
||||||
def tts(self, text):
|
def tts(self, text):
|
||||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
|
||||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||||
|
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
|
||||||
|
|
||||||
class Callback(ResultCallback):
|
class Callback(ResultCallback):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.dque = deque()
|
self.dque = deque()
|
||||||
@ -159,10 +162,7 @@ class QwenTTS(Base):
|
|||||||
|
|
||||||
text = self.normalize_text(text)
|
text = self.normalize_text(text)
|
||||||
callback = Callback()
|
callback = Callback()
|
||||||
SpeechSynthesizer.call(model=self.model_name,
|
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
|
||||||
text=text,
|
|
||||||
callback=callback,
|
|
||||||
format="mp3")
|
|
||||||
try:
|
try:
|
||||||
for data in callback._run():
|
for data in callback._run():
|
||||||
yield data
|
yield data
|
||||||
@ -173,24 +173,19 @@ class QwenTTS(Base):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAITTS(Base):
|
class OpenAITTS(Base):
|
||||||
|
_FACTORY_NAME = "OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = {
|
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
def tts(self, text, voice="alloy"):
|
def tts(self, text, voice="alloy"):
|
||||||
text = self.normalize_text(text)
|
text = self.normalize_text(text)
|
||||||
payload = {
|
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||||
"model": self.model_name,
|
|
||||||
"voice": voice,
|
|
||||||
"input": text
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
|
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
|
||||||
|
|
||||||
@ -201,7 +196,8 @@ class OpenAITTS(Base):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class SparkTTS:
|
class SparkTTS(Base):
|
||||||
|
_FACTORY_NAME = "XunFei Spark"
|
||||||
STATUS_FIRST_FRAME = 0
|
STATUS_FIRST_FRAME = 0
|
||||||
STATUS_CONTINUE_FRAME = 1
|
STATUS_CONTINUE_FRAME = 1
|
||||||
STATUS_LAST_FRAME = 2
|
STATUS_LAST_FRAME = 2
|
||||||
@ -219,29 +215,23 @@ class SparkTTS:
|
|||||||
|
|
||||||
# 生成url
|
# 生成url
|
||||||
def create_url(self):
|
def create_url(self):
|
||||||
url = 'wss://tts-api.xfyun.cn/v2/tts'
|
url = "wss://tts-api.xfyun.cn/v2/tts"
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
date = format_date_time(mktime(now.timetuple()))
|
date = format_date_time(mktime(now.timetuple()))
|
||||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||||
signature_origin += "date: " + date + "\n"
|
signature_origin += "date: " + date + "\n"
|
||||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
|
||||||
digestmod=hashlib.sha256).digest()
|
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||||
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
|
||||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
url = url + "?" + urlencode(v)
|
||||||
v = {
|
|
||||||
"authorization": authorization,
|
|
||||||
"date": date,
|
|
||||||
"host": "ws-api.xfyun.cn"
|
|
||||||
}
|
|
||||||
url = url + '?' + urlencode(v)
|
|
||||||
return url
|
return url
|
||||||
|
|
||||||
def tts(self, text):
|
def tts(self, text):
|
||||||
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
||||||
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
|
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
|
||||||
CommonArgs = {"app_id": self.APPID}
|
CommonArgs = {"app_id": self.APPID}
|
||||||
audio_queue = self.audio_queue
|
audio_queue = self.audio_queue
|
||||||
model_name = self.model_name
|
model_name = self.model_name
|
||||||
@ -273,9 +263,7 @@ class SparkTTS:
|
|||||||
|
|
||||||
def on_open(self, ws):
|
def on_open(self, ws):
|
||||||
def run(*args):
|
def run(*args):
|
||||||
d = {"common": CommonArgs,
|
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
|
||||||
"business": BusinessArgs,
|
|
||||||
"data": Data}
|
|
||||||
ws.send(json.dumps(d))
|
ws.send(json.dumps(d))
|
||||||
|
|
||||||
thread.start_new_thread(run, ())
|
thread.start_new_thread(run, ())
|
||||||
@ -283,44 +271,32 @@ class SparkTTS:
|
|||||||
wsUrl = self.create_url()
|
wsUrl = self.create_url()
|
||||||
websocket.enableTrace(False)
|
websocket.enableTrace(False)
|
||||||
a = Callback()
|
a = Callback()
|
||||||
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
|
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
|
||||||
on_message=a.on_message)
|
|
||||||
status_code = 0
|
status_code = 0
|
||||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||||
while True:
|
while True:
|
||||||
audio_chunk = self.audio_queue.get()
|
audio_chunk = self.audio_queue.get()
|
||||||
if audio_chunk is None:
|
if audio_chunk is None:
|
||||||
if status_code == 0:
|
if status_code == 0:
|
||||||
raise Exception(
|
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||||
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
status_code = 1
|
status_code = 1
|
||||||
yield audio_chunk
|
yield audio_chunk
|
||||||
|
|
||||||
|
|
||||||
class XinferenceTTS:
|
class XinferenceTTS(Base):
|
||||||
|
_FACTORY_NAME = "Xinference"
|
||||||
|
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
self.base_url = kwargs.get("base_url", None)
|
self.base_url = kwargs.get("base_url", None)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.headers = {
|
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||||
"accept": "application/json",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
def tts(self, text, voice="中文女", stream=True):
|
def tts(self, text, voice="中文女", stream=True):
|
||||||
payload = {
|
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||||
"model": self.model_name,
|
|
||||||
"input": text,
|
|
||||||
"voice": voice
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||||
f"{self.base_url}/v1/audio/speech",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
stream=stream
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||||
@ -336,18 +312,12 @@ class OllamaTTS(Base):
|
|||||||
base_url = "https://api.ollama.ai/v1"
|
base_url = "https://api.ollama.ai/v1"
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = {
|
self.headers = {"Content-Type": "application/json"}
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
if key and key != "x":
|
if key and key != "x":
|
||||||
self.headers["Authorization"] = f"Bear {key}"
|
self.headers["Authorization"] = f"Bear {key}"
|
||||||
|
|
||||||
def tts(self, text, voice="standard-voice"):
|
def tts(self, text, voice="standard-voice"):
|
||||||
payload = {
|
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||||
"model": self.model_name,
|
|
||||||
"voice": voice,
|
|
||||||
"input": text
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
|
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
|
||||||
|
|
||||||
@ -359,30 +329,19 @@ class OllamaTTS(Base):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class GPUStackTTS:
|
class GPUStackTTS(Base):
|
||||||
|
_FACTORY_NAME = "GPUStack"
|
||||||
|
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
self.base_url = kwargs.get("base_url", None)
|
self.base_url = kwargs.get("base_url", None)
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.headers = {
|
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
"accept": "application/json",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def tts(self, text, voice="Chinese Female", stream=True):
|
def tts(self, text, voice="Chinese Female", stream=True):
|
||||||
payload = {
|
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||||
"model": self.model_name,
|
|
||||||
"input": text,
|
|
||||||
"voice": voice
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||||
f"{self.base_url}/v1/audio/speech",
|
|
||||||
headers=self.headers,
|
|
||||||
json=payload,
|
|
||||||
stream=stream
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||||
@ -393,16 +352,15 @@ class GPUStackTTS:
|
|||||||
|
|
||||||
|
|
||||||
class SILICONFLOWTTS(Base):
|
class SILICONFLOWTTS(Base):
|
||||||
|
_FACTORY_NAME = "SILICONFLOW"
|
||||||
|
|
||||||
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.siliconflow.cn/v1"
|
base_url = "https://api.siliconflow.cn/v1"
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = {
|
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
def tts(self, text, voice="anna"):
|
def tts(self, text, voice="anna"):
|
||||||
text = self.normalize_text(text)
|
text = self.normalize_text(text)
|
||||||
@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base):
|
|||||||
"sample_rate": 123,
|
"sample_rate": 123,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"speed": 1,
|
"speed": 1,
|
||||||
"gain": 0
|
"gain": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||||
|
|||||||
Reference in New Issue
Block a user