diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py
index b22d722d0..e9542bbe8 100644
--- a/rag/llm/__init__.py
+++ b/rag/llm/__init__.py
@@ -15,289 +15,53 @@
#
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
#
-from .embedding_model import (
- OllamaEmbed,
- LocalAIEmbed,
- OpenAIEmbed,
- AzureEmbed,
- XinferenceEmbed,
- QWenEmbed,
- ZhipuEmbed,
- FastEmbed,
- YoudaoEmbed,
- BaiChuanEmbed,
- JinaEmbed,
- DefaultEmbedding,
- MistralEmbed,
- BedrockEmbed,
- GeminiEmbed,
- NvidiaEmbed,
- LmStudioEmbed,
- OpenAI_APIEmbed,
- CoHereEmbed,
- TogetherAIEmbed,
- PerfXCloudEmbed,
- UpstageEmbed,
- SILICONFLOWEmbed,
- ReplicateEmbed,
- BaiduYiyanEmbed,
- VoyageEmbed,
- HuggingFaceEmbed,
- VolcEngineEmbed,
- GPUStackEmbed,
- NovitaEmbed,
- GiteeEmbed
-)
-from .chat_model import (
- GptTurbo,
- AzureChat,
- ZhipuChat,
- QWenChat,
- OllamaChat,
- LocalAIChat,
- XinferenceChat,
- MoonshotChat,
- DeepSeekChat,
- VolcEngineChat,
- BaiChuanChat,
- MiniMaxChat,
- MistralChat,
- GeminiChat,
- BedrockChat,
- GroqChat,
- OpenRouterChat,
- StepFunChat,
- NvidiaChat,
- LmStudioChat,
- OpenAI_APIChat,
- CoHereChat,
- LeptonAIChat,
- TogetherAIChat,
- PerfXCloudChat,
- UpstageChat,
- NovitaAIChat,
- SILICONFLOWChat,
- PPIOChat,
- YiChat,
- ReplicateChat,
- HunyuanChat,
- SparkChat,
- BaiduYiyanChat,
- AnthropicChat,
- GoogleChat,
- HuggingFaceChat,
- GPUStackChat,
- ModelScopeChat,
- GiteeChat
-)
-from .cv_model import (
- GptV4,
- AzureGptV4,
- OllamaCV,
- XinferenceCV,
- QWenCV,
- Zhipu4V,
- LocalCV,
- GeminiCV,
- OpenRouterCV,
- LocalAICV,
- NvidiaCV,
- LmStudioCV,
- StepFunCV,
- OpenAI_APICV,
- TogetherAICV,
- YiCV,
- HunyuanCV,
- AnthropicCV,
- SILICONFLOWCV,
- GPUStackCV,
- GoogleCV,
-)
+import importlib
+import inspect
-from .rerank_model import (
- LocalAIRerank,
- DefaultRerank,
- JinaRerank,
- YoudaoRerank,
- XInferenceRerank,
- NvidiaRerank,
- LmStudioRerank,
- OpenAI_APIRerank,
- CoHereRerank,
- TogetherAIRerank,
- SILICONFLOWRerank,
- BaiduYiyanRerank,
- VoyageRerank,
- QWenRerank,
- GPUStackRerank,
- HuggingfaceRerank,
- NovitaRerank,
- GiteeRerank
-)
+ChatModel = globals().get("ChatModel", {})
+CvModel = globals().get("CvModel", {})
+EmbeddingModel = globals().get("EmbeddingModel", {})
+RerankModel = globals().get("RerankModel", {})
+Seq2txtModel = globals().get("Seq2txtModel", {})
+TTSModel = globals().get("TTSModel", {})
-from .sequence2txt_model import (
- GPTSeq2txt,
- QWenSeq2txt,
- AzureSeq2txt,
- XinferenceSeq2txt,
- TencentCloudSeq2txt,
- GPUStackSeq2txt,
- GiteeSeq2txt
-)
-
-from .tts_model import (
- FishAudioTTS,
- QwenTTS,
- OpenAITTS,
- SparkTTS,
- XinferenceTTS,
- GPUStackTTS,
- SILICONFLOWTTS,
-)
-
-EmbeddingModel = {
- "Ollama": OllamaEmbed,
- "LocalAI": LocalAIEmbed,
- "OpenAI": OpenAIEmbed,
- "Azure-OpenAI": AzureEmbed,
- "Xinference": XinferenceEmbed,
- "Tongyi-Qianwen": QWenEmbed,
- "ZHIPU-AI": ZhipuEmbed,
- "FastEmbed": FastEmbed,
- "Youdao": YoudaoEmbed,
- "BaiChuan": BaiChuanEmbed,
- "Jina": JinaEmbed,
- "BAAI": DefaultEmbedding,
- "Mistral": MistralEmbed,
- "Bedrock": BedrockEmbed,
- "Gemini": GeminiEmbed,
- "NVIDIA": NvidiaEmbed,
- "LM-Studio": LmStudioEmbed,
- "OpenAI-API-Compatible": OpenAI_APIEmbed,
- "VLLM": OpenAI_APIEmbed,
- "Cohere": CoHereEmbed,
- "TogetherAI": TogetherAIEmbed,
- "PerfXCloud": PerfXCloudEmbed,
- "Upstage": UpstageEmbed,
- "SILICONFLOW": SILICONFLOWEmbed,
- "Replicate": ReplicateEmbed,
- "BaiduYiyan": BaiduYiyanEmbed,
- "Voyage AI": VoyageEmbed,
- "HuggingFace": HuggingFaceEmbed,
- "VolcEngine": VolcEngineEmbed,
- "GPUStack": GPUStackEmbed,
- "NovitaAI": NovitaEmbed,
- "GiteeAI": GiteeEmbed
+MODULE_MAPPING = {
+ "chat_model": ChatModel,
+ "cv_model": CvModel,
+ "embedding_model": EmbeddingModel,
+ "rerank_model": RerankModel,
+ "sequence2txt_model": Seq2txtModel,
+ "tts_model": TTSModel,
}
-CvModel = {
- "OpenAI": GptV4,
- "Azure-OpenAI": AzureGptV4,
- "Ollama": OllamaCV,
- "Xinference": XinferenceCV,
- "Tongyi-Qianwen": QWenCV,
- "ZHIPU-AI": Zhipu4V,
- "Moonshot": LocalCV,
- "Gemini": GeminiCV,
- "OpenRouter": OpenRouterCV,
- "LocalAI": LocalAICV,
- "NVIDIA": NvidiaCV,
- "LM-Studio": LmStudioCV,
- "StepFun": StepFunCV,
- "OpenAI-API-Compatible": OpenAI_APICV,
- "VLLM": OpenAI_APICV,
- "TogetherAI": TogetherAICV,
- "01.AI": YiCV,
- "Tencent Hunyuan": HunyuanCV,
- "Anthropic": AnthropicCV,
- "SILICONFLOW": SILICONFLOWCV,
- "GPUStack": GPUStackCV,
- "Google Cloud": GoogleCV
-}
+package_name = __name__
-ChatModel = {
- "OpenAI": GptTurbo,
- "Azure-OpenAI": AzureChat,
- "ZHIPU-AI": ZhipuChat,
- "Tongyi-Qianwen": QWenChat,
- "Ollama": OllamaChat,
- "LocalAI": LocalAIChat,
- "Xinference": XinferenceChat,
- "Moonshot": MoonshotChat,
- "DeepSeek": DeepSeekChat,
- "VolcEngine": VolcEngineChat,
- "BaiChuan": BaiChuanChat,
- "MiniMax": MiniMaxChat,
- "Mistral": MistralChat,
- "Gemini": GeminiChat,
- "Bedrock": BedrockChat,
- "Groq": GroqChat,
- "OpenRouter": OpenRouterChat,
- "StepFun": StepFunChat,
- "NVIDIA": NvidiaChat,
- "LM-Studio": LmStudioChat,
- "OpenAI-API-Compatible": OpenAI_APIChat,
- "VLLM": OpenAI_APIChat,
- "Cohere": CoHereChat,
- "LeptonAI": LeptonAIChat,
- "TogetherAI": TogetherAIChat,
- "PerfXCloud": PerfXCloudChat,
- "Upstage": UpstageChat,
- "NovitaAI": NovitaAIChat,
- "SILICONFLOW": SILICONFLOWChat,
- "PPIO": PPIOChat,
- "01.AI": YiChat,
- "Replicate": ReplicateChat,
- "Tencent Hunyuan": HunyuanChat,
- "XunFei Spark": SparkChat,
- "BaiduYiyan": BaiduYiyanChat,
- "Anthropic": AnthropicChat,
- "Google Cloud": GoogleChat,
- "HuggingFace": HuggingFaceChat,
- "GPUStack": GPUStackChat,
- "ModelScope":ModelScopeChat,
- "GiteeAI": GiteeChat
-}
+for module_name, mapping_dict in MODULE_MAPPING.items():
+ full_module_name = f"{package_name}.{module_name}"
+ module = importlib.import_module(full_module_name)
-RerankModel = {
- "LocalAI": LocalAIRerank,
- "BAAI": DefaultRerank,
- "Jina": JinaRerank,
- "Youdao": YoudaoRerank,
- "Xinference": XInferenceRerank,
- "NVIDIA": NvidiaRerank,
- "LM-Studio": LmStudioRerank,
- "OpenAI-API-Compatible": OpenAI_APIRerank,
- "VLLM": CoHereRerank,
- "Cohere": CoHereRerank,
- "TogetherAI": TogetherAIRerank,
- "SILICONFLOW": SILICONFLOWRerank,
- "BaiduYiyan": BaiduYiyanRerank,
- "Voyage AI": VoyageRerank,
- "Tongyi-Qianwen": QWenRerank,
- "GPUStack": GPUStackRerank,
- "HuggingFace": HuggingfaceRerank,
- "NovitaAI": NovitaRerank,
- "GiteeAI": GiteeRerank
-}
+ base_class = None
+ for name, obj in inspect.getmembers(module):
+ if inspect.isclass(obj) and name == "Base":
+ base_class = obj
+ break
+ if base_class is None:
+ continue
-Seq2txtModel = {
- "OpenAI": GPTSeq2txt,
- "Tongyi-Qianwen": QWenSeq2txt,
- "Azure-OpenAI": AzureSeq2txt,
- "Xinference": XinferenceSeq2txt,
- "Tencent Cloud": TencentCloudSeq2txt,
- "GPUStack": GPUStackSeq2txt,
- "GiteeAI": GiteeSeq2txt
-}
+ for _, obj in inspect.getmembers(module):
+ if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
+ if isinstance(obj._FACTORY_NAME, list):
+ for factory_name in obj._FACTORY_NAME:
+ mapping_dict[factory_name] = obj
+ else:
+ mapping_dict[obj._FACTORY_NAME] = obj
-TTSModel = {
- "Fish Audio": FishAudioTTS,
- "Tongyi-Qianwen": QwenTTS,
- "OpenAI": OpenAITTS,
- "XunFei Spark": SparkTTS,
- "Xinference": XinferenceTTS,
- "GPUStack": GPUStackTTS,
- "SILICONFLOW": SILICONFLOWTTS,
-}
+__all__ = [
+ "ChatModel",
+ "CvModel",
+ "EmbeddingModel",
+ "RerankModel",
+ "Seq2txtModel",
+ "TTSModel",
+]
diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py
index 020815104..f254ed186 100644
--- a/rag/llm/chat_model.py
+++ b/rag/llm/chat_model.py
@@ -142,11 +142,7 @@ class Base(ABC):
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def _verbose_tool_use(self, name, args, res):
- return "" + json.dumps({
- "name": name,
- "args": args,
- "result": res
- }, ensure_ascii=False, indent=2) + ""
+ return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + ""
def _append_history(self, hist, tool_call, tool_res):
hist.append(
@@ -191,10 +187,10 @@ class Base(ABC):
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
- for attempt in range(self.max_retries+1):
+ for attempt in range(self.max_retries + 1):
history = hist
try:
- for _ in range(self.max_rounds*2):
+ for _ in range(self.max_rounds * 2):
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
tk_count += self.total_token_count(response)
if any([not response.choices, not response.choices[0].message]):
@@ -269,7 +265,7 @@ class Base(ABC):
for attempt in range(self.max_retries + 1):
history = hist
try:
- for _ in range(self.max_rounds*2):
+ for _ in range(self.max_rounds * 2):
reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
final_tool_calls = {}
@@ -430,6 +426,8 @@ class Base(ABC):
class GptTurbo(Base):
+ _FACTORY_NAME = "OpenAI"
+
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
@@ -437,6 +435,8 @@ class GptTurbo(Base):
class MoonshotChat(Base):
+ _FACTORY_NAME = "Moonshot"
+
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs):
if not base_url:
base_url = "https://api.moonshot.cn/v1"
@@ -444,6 +444,8 @@ class MoonshotChat(Base):
class XinferenceChat(Base):
+ _FACTORY_NAME = "Xinference"
+
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -452,6 +454,8 @@ class XinferenceChat(Base):
class HuggingFaceChat(Base):
+ _FACTORY_NAME = "HuggingFace"
+
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -460,6 +464,8 @@ class HuggingFaceChat(Base):
class ModelScopeChat(Base):
+ _FACTORY_NAME = "ModelScope"
+
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -468,6 +474,8 @@ class ModelScopeChat(Base):
class DeepSeekChat(Base):
+ _FACTORY_NAME = "DeepSeek"
+
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs):
if not base_url:
base_url = "https://api.deepseek.com/v1"
@@ -475,6 +483,8 @@ class DeepSeekChat(Base):
class AzureChat(Base):
+ _FACTORY_NAME = "Azure-OpenAI"
+
def __init__(self, key, model_name, base_url, **kwargs):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
@@ -484,6 +494,8 @@ class AzureChat(Base):
class BaiChuanChat(Base):
+ _FACTORY_NAME = "BaiChuan"
+
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
@@ -557,6 +569,8 @@ class BaiChuanChat(Base):
class QWenChat(Base):
+ _FACTORY_NAME = "Tongyi-Qianwen"
+
def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs):
if not base_url:
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
@@ -565,6 +579,8 @@ class QWenChat(Base):
class ZhipuChat(Base):
+ _FACTORY_NAME = "ZHIPU-AI"
+
def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -630,6 +646,8 @@ class ZhipuChat(Base):
class OllamaChat(Base):
+ _FACTORY_NAME = "Ollama"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -694,6 +712,8 @@ class OllamaChat(Base):
class LocalAIChat(Base):
+ _FACTORY_NAME = "LocalAI"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -752,6 +772,8 @@ class LocalLLM(Base):
class VolcEngineChat(Base):
+ _FACTORY_NAME = "VolcEngine"
+
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
"""
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
@@ -765,6 +787,8 @@ class VolcEngineChat(Base):
class MiniMaxChat(Base):
+ _FACTORY_NAME = "MiniMax"
+
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -843,6 +867,8 @@ class MiniMaxChat(Base):
class MistralChat(Base):
+ _FACTORY_NAME = "Mistral"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -896,6 +922,8 @@ class MistralChat(Base):
class BedrockChat(Base):
+ _FACTORY_NAME = "Bedrock"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -978,6 +1006,8 @@ class BedrockChat(Base):
class GeminiChat(Base):
+ _FACTORY_NAME = "Gemini"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -997,6 +1027,7 @@ class GeminiChat(Base):
def _chat(self, history, gen_conf):
from google.generativeai.types import content_types
+
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
hist = []
for item in history:
@@ -1019,6 +1050,7 @@ class GeminiChat(Base):
def chat_streamly(self, system, history, gen_conf):
from google.generativeai.types import content_types
+
gen_conf = self._clean_conf(gen_conf)
if system:
self.model._system_instruction = content_types.to_content(system)
@@ -1042,6 +1074,8 @@ class GeminiChat(Base):
class GroqChat(Base):
+ _FACTORY_NAME = "Groq"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -1086,6 +1120,8 @@ class GroqChat(Base):
## openrouter
class OpenRouterChat(Base):
+ _FACTORY_NAME = "OpenRouter"
+
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs):
if not base_url:
base_url = "https://openrouter.ai/api/v1"
@@ -1093,6 +1129,8 @@ class OpenRouterChat(Base):
class StepFunChat(Base):
+ _FACTORY_NAME = "StepFun"
+
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
if not base_url:
base_url = "https://api.stepfun.com/v1"
@@ -1100,6 +1138,8 @@ class StepFunChat(Base):
class NvidiaChat(Base):
+ _FACTORY_NAME = "NVIDIA"
+
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1"
@@ -1107,6 +1147,8 @@ class NvidiaChat(Base):
class LmStudioChat(Base):
+ _FACTORY_NAME = "LM-Studio"
+
def __init__(self, key, model_name, base_url, **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -1117,6 +1159,8 @@ class LmStudioChat(Base):
class OpenAI_APIChat(Base):
+ _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
+
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -1125,6 +1169,8 @@ class OpenAI_APIChat(Base):
class PPIOChat(Base):
+ _FACTORY_NAME = "PPIO"
+
def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs):
if not base_url:
base_url = "https://api.ppinfra.com/v3/openai"
@@ -1132,6 +1178,8 @@ class PPIOChat(Base):
class CoHereChat(Base):
+ _FACTORY_NAME = "Cohere"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -1207,6 +1255,8 @@ class CoHereChat(Base):
class LeptonAIChat(Base):
+ _FACTORY_NAME = "LeptonAI"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
if not base_url:
base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1")
@@ -1214,6 +1264,8 @@ class LeptonAIChat(Base):
class TogetherAIChat(Base):
+ _FACTORY_NAME = "TogetherAI"
+
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs):
if not base_url:
base_url = "https://api.together.xyz/v1"
@@ -1221,6 +1273,8 @@ class TogetherAIChat(Base):
class PerfXCloudChat(Base):
+ _FACTORY_NAME = "PerfXCloud"
+
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs):
if not base_url:
base_url = "https://cloud.perfxlab.cn/v1"
@@ -1228,6 +1282,8 @@ class PerfXCloudChat(Base):
class UpstageChat(Base):
+ _FACTORY_NAME = "Upstage"
+
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
@@ -1235,6 +1291,8 @@ class UpstageChat(Base):
class NovitaAIChat(Base):
+ _FACTORY_NAME = "NovitaAI"
+
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs):
if not base_url:
base_url = "https://api.novita.ai/v3/openai"
@@ -1242,6 +1300,8 @@ class NovitaAIChat(Base):
class SILICONFLOWChat(Base):
+ _FACTORY_NAME = "SILICONFLOW"
+
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
@@ -1249,6 +1309,8 @@ class SILICONFLOWChat(Base):
class YiChat(Base):
+ _FACTORY_NAME = "01.AI"
+
def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
@@ -1256,6 +1318,8 @@ class YiChat(Base):
class GiteeChat(Base):
+ _FACTORY_NAME = "GiteeAI"
+
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs):
if not base_url:
base_url = "https://ai.gitee.com/v1/"
@@ -1263,6 +1327,8 @@ class GiteeChat(Base):
class ReplicateChat(Base):
+ _FACTORY_NAME = "Replicate"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -1302,6 +1368,8 @@ class ReplicateChat(Base):
class HunyuanChat(Base):
+ _FACTORY_NAME = "Tencent Hunyuan"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -1378,6 +1446,8 @@ class HunyuanChat(Base):
class SparkChat(Base):
+ _FACTORY_NAME = "XunFei Spark"
+
def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs):
if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1"
@@ -1398,6 +1468,8 @@ class SparkChat(Base):
class BaiduYiyanChat(Base):
+ _FACTORY_NAME = "BaiduYiyan"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -1444,6 +1516,8 @@ class BaiduYiyanChat(Base):
class AnthropicChat(Base):
+ _FACTORY_NAME = "Anthropic"
+
def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs):
if not base_url:
base_url = "https://api.anthropic.com/v1/"
@@ -1451,6 +1525,8 @@ class AnthropicChat(Base):
class GoogleChat(Base):
+ _FACTORY_NAME = "Google Cloud"
+
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
@@ -1529,9 +1605,11 @@ class GoogleChat(Base):
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "content" in item:
- item["parts"] = [{
- "text": item.pop("content"),
- }]
+ item["parts"] = [
+ {
+ "text": item.pop("content"),
+ }
+ ]
response = self.client.generate_content(hist, generation_config=gen_conf)
ans = response.text
@@ -1587,8 +1665,10 @@ class GoogleChat(Base):
class GPUStackChat(Base):
+ _FACTORY_NAME = "GPUStack"
+
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
- super().__init__(key, model_name, base_url, **kwargs)
\ No newline at end of file
+ super().__init__(key, model_name, base_url, **kwargs)
diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py
index da7efaef1..afa39d69a 100644
--- a/rag/llm/cv_model.py
+++ b/rag/llm/cv_model.py
@@ -57,7 +57,7 @@ class Base(ABC):
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
- top_p=gen_conf.get("top_p", 0.7)
+ top_p=gen_conf.get("top_p", 0.7),
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
@@ -79,7 +79,7 @@ class Base(ABC):
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
- stream=True
+ stream=True,
)
for resp in response:
if not resp.choices[0].delta.content:
@@ -87,8 +87,7 @@ class Base(ABC):
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
- ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
- [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
@@ -117,13 +116,12 @@ class Base(ABC):
"content": [
{
"type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{b64}"
- },
+ "image_url": {"url": f"data:image/jpeg;base64,{b64}"},
},
{
- "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
- "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
+ "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
+ if self.lang.lower() == "chinese"
+ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
@@ -136,9 +134,7 @@ class Base(ABC):
"content": [
{
"type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{b64}"
- },
+ "image_url": {"url": f"data:image/jpeg;base64,{b64}"},
},
{
"type": "text",
@@ -156,14 +152,13 @@ class Base(ABC):
"url": f"data:image/jpeg;base64,{b64}",
},
},
- {
- "type": "text",
- "text": text
- },
+ {"type": "text", "text": text},
]
class GptV4(Base):
+ _FACTORY_NAME = "OpenAI"
+
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
@@ -181,7 +176,7 @@ class GptV4(Base):
res = self.client.chat.completions.create(
model=self.model_name,
- messages=prompt
+ messages=prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
@@ -197,9 +192,11 @@ class GptV4(Base):
class AzureGptV4(Base):
+ _FACTORY_NAME = "Azure-OpenAI"
+
def __init__(self, key, model_name, lang="Chinese", **kwargs):
- api_key = json.loads(key).get('api_key', '')
- api_version = json.loads(key).get('api_version', '2024-02-01')
+ api_key = json.loads(key).get("api_key", "")
+ api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang
@@ -212,10 +209,7 @@ class AzureGptV4(Base):
if "text" in c:
c["type"] = "text"
- res = self.client.chat.completions.create(
- model=self.model_name,
- messages=prompt
- )
+ res = self.client.chat.completions.create(model=self.model_name, messages=prompt)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
@@ -230,8 +224,11 @@ class AzureGptV4(Base):
class QWenCV(Base):
+ _FACTORY_NAME = "Tongyi-Qianwen"
+
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
import dashscope
+
dashscope.api_key = key
self.model_name = model_name
self.lang = lang
@@ -247,12 +244,11 @@ class QWenCV(Base):
{
"role": "user",
"content": [
+ {"image": f"file://{path}"},
{
- "image": f"file://{path}"
- },
- {
- "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
- "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
+ "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
+ if self.lang.lower() == "chinese"
+ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
@@ -270,11 +266,9 @@ class QWenCV(Base):
{
"role": "user",
"content": [
+ {"image": f"file://{path}"},
{
- "image": f"file://{path}"
- },
- {
- "text": prompt if prompt else vision_llm_describe_prompt(),
+ "text": prompt if prompt else vision_llm_describe_prompt(),
},
],
}
@@ -290,9 +284,10 @@ class QWenCV(Base):
from http import HTTPStatus
from dashscope import MultiModalConversation
+
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
if response.status_code == HTTPStatus.OK:
- return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
+ return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
return response.message, 0
def describe_with_prompt(self, image, prompt=None):
@@ -303,33 +298,36 @@ class QWenCV(Base):
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
if response.status_code == HTTPStatus.OK:
- return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
+ return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
return response.message, 0
def chat(self, system, history, gen_conf, image=""):
from http import HTTPStatus
from dashscope import MultiModalConversation
+
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
for his in history:
if his["role"] == "user":
his["content"] = self.chat_prompt(his["content"], image)
- response = MultiModalConversation.call(model=self.model_name, messages=history,
- temperature=gen_conf.get("temperature", 0.3),
- top_p=gen_conf.get("top_p", 0.7))
+ response = MultiModalConversation.call(
+ model=self.model_name,
+ messages=history,
+ temperature=gen_conf.get("temperature", 0.3),
+ top_p=gen_conf.get("top_p", 0.7),
+ )
ans = ""
tk_count = 0
if response.status_code == HTTPStatus.OK:
- ans = response.output.choices[0]['message']['content']
+ ans = response.output.choices[0]["message"]["content"]
if isinstance(ans, list):
ans = ans[0]["text"] if ans else ""
tk_count += response.usage.total_tokens
if response.output.choices[0].get("finish_reason", "") == "length":
- ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
- [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, tk_count
return "**ERROR**: " + response.message, tk_count
@@ -338,6 +336,7 @@ class QWenCV(Base):
from http import HTTPStatus
from dashscope import MultiModalConversation
+
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@@ -348,24 +347,25 @@ class QWenCV(Base):
ans = ""
tk_count = 0
try:
- response = MultiModalConversation.call(model=self.model_name, messages=history,
- temperature=gen_conf.get("temperature", 0.3),
- top_p=gen_conf.get("top_p", 0.7),
- stream=True)
+ response = MultiModalConversation.call(
+ model=self.model_name,
+ messages=history,
+ temperature=gen_conf.get("temperature", 0.3),
+ top_p=gen_conf.get("top_p", 0.7),
+ stream=True,
+ )
for resp in response:
if resp.status_code == HTTPStatus.OK:
- cnt = resp.output.choices[0]['message']['content']
+ cnt = resp.output.choices[0]["message"]["content"]
if isinstance(cnt, list):
cnt = cnt[0]["text"] if ans else ""
ans += cnt
tk_count = resp.usage.total_tokens
if resp.output.choices[0].get("finish_reason", "") == "length":
- ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
- [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
else:
- yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
- "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
+ yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@@ -373,6 +373,8 @@ class QWenCV(Base):
class Zhipu4V(Base):
+ _FACTORY_NAME = "ZHIPU-AI"
+
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
@@ -394,10 +396,7 @@ class Zhipu4V(Base):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
- res = self.client.chat.completions.create(
- model=self.model_name,
- messages=vision_prompt
- )
+ res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def chat(self, system, history, gen_conf, image=""):
@@ -412,7 +411,7 @@ class Zhipu4V(Base):
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
- top_p=gen_conf.get("top_p", 0.7)
+ top_p=gen_conf.get("top_p", 0.7),
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
@@ -434,7 +433,7 @@ class Zhipu4V(Base):
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
- stream=True
+ stream=True,
)
for resp in response:
if not resp.choices[0].delta.content:
@@ -442,8 +441,7 @@ class Zhipu4V(Base):
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
- ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
- [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
@@ -455,6 +453,8 @@ class Zhipu4V(Base):
class OllamaCV(Base):
+ _FACTORY_NAME = "Ollama"
+
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
@@ -466,7 +466,7 @@ class OllamaCV(Base):
response = self.client.generate(
model=self.model_name,
prompt=prompt[0]["content"][1]["text"],
- images=[image]
+ images=[image],
)
ans = response["response"].strip()
return ans, 128
@@ -507,7 +507,7 @@ class OllamaCV(Base):
model=self.model_name,
messages=history,
options=options,
- keep_alive=-1
+ keep_alive=-1,
)
ans = response["message"]["content"].strip()
@@ -538,7 +538,7 @@ class OllamaCV(Base):
messages=history,
stream=True,
options=options,
- keep_alive=-1
+ keep_alive=-1,
)
for resp in response:
if resp["done"]:
@@ -551,6 +551,8 @@ class OllamaCV(Base):
class LocalAICV(GptV4):
+ _FACTORY_NAME = "LocalAI"
+
def __init__(self, key, model_name, base_url, lang="Chinese"):
if not base_url:
raise ValueError("Local cv model url cannot be None")
@@ -561,6 +563,8 @@ class LocalAICV(GptV4):
class XinferenceCV(Base):
+ _FACTORY_NAME = "Xinference"
+
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
@@ -570,10 +574,7 @@ class XinferenceCV(Base):
def describe(self, image):
b64 = self.image2base64(image)
- res = self.client.chat.completions.create(
- model=self.model_name,
- messages=self.prompt(b64)
- )
+ res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64))
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
@@ -588,8 +589,11 @@ class XinferenceCV(Base):
class GeminiCV(Base):
+ _FACTORY_NAME = "Gemini"
+
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import GenerativeModel, client
+
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
@@ -599,18 +603,21 @@ class GeminiCV(Base):
def describe(self, image):
from PIL.Image import open
- prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
- "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
+
+ prompt = (
+ "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
+ if self.lang.lower() == "chinese"
+ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
+ )
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt, img]
- res = self.model.generate_content(
- input
- )
+ res = self.model.generate_content(input)
return res.text, res.usage_metadata.total_token_count
def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open
+
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
img = open(BytesIO(base64.b64decode(b64)))
@@ -622,6 +629,7 @@ class GeminiCV(Base):
def chat(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig
+
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
try:
@@ -635,9 +643,7 @@ class GeminiCV(Base):
his.pop("content")
history[-1]["parts"].append("data:image/jpeg;base64," + image)
- response = self.model.generate_content(history, generation_config=GenerationConfig(
- temperature=gen_conf.get("temperature", 0.3),
- top_p=gen_conf.get("top_p", 0.7)))
+ response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
ans = response.text
return ans, response.usage_metadata.total_token_count
@@ -646,6 +652,7 @@ class GeminiCV(Base):
def chat_streamly(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig
+
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@@ -661,9 +668,11 @@ class GeminiCV(Base):
his.pop("content")
history[-1]["parts"].append("data:image/jpeg;base64," + image)
- response = self.model.generate_content(history, generation_config=GenerationConfig(
- temperature=gen_conf.get("temperature", 0.3),
- top_p=gen_conf.get("top_p", 0.7)), stream=True)
+ response = self.model.generate_content(
+ history,
+ generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)),
+ stream=True,
+ )
for resp in response:
if not resp.text:
@@ -677,6 +686,8 @@ class GeminiCV(Base):
class OpenRouterCV(GptV4):
+ _FACTORY_NAME = "OpenRouter"
+
def __init__(
self,
key,
@@ -692,6 +703,8 @@ class OpenRouterCV(GptV4):
class LocalCV(Base):
+ _FACTORY_NAME = "Moonshot"
+
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass
@@ -700,6 +713,8 @@ class LocalCV(Base):
class NvidiaCV(Base):
+ _FACTORY_NAME = "NVIDIA"
+
def __init__(
self,
key,
@@ -726,9 +741,7 @@ class NvidiaCV(Base):
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
- json={
- "messages": self.prompt(b64)
- },
+ json={"messages": self.prompt(b64)},
)
response = response.json()
return (
@@ -774,10 +787,7 @@ class NvidiaCV(Base):
return [
{
"role": "user",
- "content": (
- prompt if prompt else vision_llm_describe_prompt()
- )
- + f'
',
+ "content": (prompt if prompt else vision_llm_describe_prompt()) + f'
',
}
]
@@ -791,6 +801,8 @@ class NvidiaCV(Base):
class StepFunCV(GptV4):
+ _FACTORY_NAME = "StepFun"
+
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url:
base_url = "https://api.stepfun.com/v1"
@@ -800,6 +812,8 @@ class StepFunCV(GptV4):
class LmStudioCV(GptV4):
+ _FACTORY_NAME = "LM-Studio"
+
def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -810,6 +824,8 @@ class LmStudioCV(GptV4):
class OpenAI_APICV(GptV4):
+ _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
+
def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url:
raise ValueError("url cannot be None")
@@ -820,6 +836,8 @@ class OpenAI_APICV(GptV4):
class TogetherAICV(GptV4):
+ _FACTORY_NAME = "TogetherAI"
+
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
@@ -827,20 +845,38 @@ class TogetherAICV(GptV4):
class YiCV(GptV4):
- def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
+ _FACTORY_NAME = "01.AI"
+
+ def __init__(
+ self,
+ key,
+ model_name,
+ lang="Chinese",
+ base_url="https://api.lingyiwanwu.com/v1",
+ ):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name, lang, base_url)
class SILICONFLOWCV(GptV4):
- def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",):
+ _FACTORY_NAME = "SILICONFLOW"
+
+ def __init__(
+ self,
+ key,
+ model_name,
+ lang="Chinese",
+ base_url="https://api.siliconflow.cn/v1",
+ ):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
super().__init__(key, model_name, lang, base_url)
class HunyuanCV(Base):
+ _FACTORY_NAME = "Tencent Hunyuan"
+
def __init__(self, key, model_name, lang="Chinese", base_url=None):
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client
@@ -895,14 +931,13 @@ class HunyuanCV(Base):
"Contents": [
{
"Type": "image_url",
- "ImageUrl": {
- "Url": f"data:image/jpeg;base64,{b64}"
- },
+ "ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"},
},
{
"Type": "text",
- "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
- "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
+ "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
+ if self.lang.lower() == "chinese"
+ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
@@ -910,6 +945,8 @@ class HunyuanCV(Base):
class AnthropicCV(Base):
+ _FACTORY_NAME = "Anthropic"
+
def __init__(self, key, model_name, base_url=None):
import anthropic
@@ -933,38 +970,29 @@ class AnthropicCV(Base):
"data": b64,
},
},
- {
- "type": "text",
- "text": prompt
- }
+ {"type": "text", "text": prompt},
],
}
]
def describe(self, image):
b64 = self.image2base64(image)
- prompt = self.prompt(b64,
- "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
- "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
- )
-
- response = self.client.messages.create(
- model=self.model_name,
- max_tokens=self.max_tokens,
- messages=prompt
+ prompt = self.prompt(
+ b64,
+ "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
+ if self.lang.lower() == "chinese"
+ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
)
- return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
+
+ response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
+ return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
- response = self.client.messages.create(
- model=self.model_name,
- max_tokens=self.max_tokens,
- messages=prompt
- )
- return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
+ response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
+ return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
def chat(self, system, history, gen_conf):
if "presence_penalty" in gen_conf:
@@ -984,11 +1012,7 @@ class AnthropicCV(Base):
).to_dict()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
- ans += (
- "...\nFor the content length reason, it stopped, continue?"
- if is_english([ans])
- else "······\n由于长度的原因,回答被截断了,要继续吗?"
- )
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return (
ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
@@ -1014,7 +1038,7 @@ class AnthropicCV(Base):
**gen_conf,
)
for res in response:
- if res.type == 'content_block_delta':
+ if res.type == "content_block_delta":
if res.delta.type == "thinking_delta" and res.delta.thinking:
if ans.find("") < 0:
ans += ""
@@ -1030,7 +1054,10 @@ class AnthropicCV(Base):
yield total_tokens
+
class GPUStackCV(GptV4):
+ _FACTORY_NAME = "GPUStack"
+
def __init__(self, key, model_name, lang="Chinese", base_url=""):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -1041,11 +1068,13 @@ class GPUStackCV(GptV4):
class GoogleCV(Base):
+ _FACTORY_NAME = "Google Cloud"
+
def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
import base64
from google.oauth2 import service_account
-
+
key = json.loads(key)
access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
project_id = key.get("google_project_id", "")
@@ -1079,9 +1108,12 @@ class GoogleCV(Base):
self.client = glm.GenerativeModel(model_name=self.model_name)
def describe(self, image):
- prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
- "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
-
+ prompt = (
+ "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
+ if self.lang.lower() == "chinese"
+ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
+ )
+
if "claude" in self.model_name:
b64 = self.image2base64(image)
vision_prompt = [
@@ -1096,28 +1128,22 @@ class GoogleCV(Base):
"data": b64,
},
},
- {
- "type": "text",
- "text": prompt
- }
+ {"type": "text", "text": prompt},
],
}
]
response = self.client.messages.create(
model=self.model_name,
max_tokens=8192,
- messages=vision_prompt
+ messages=vision_prompt,
)
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
else:
import vertexai.generative_models as glm
-
+
b64 = self.image2base64(image)
# Create proper image part for Gemini
- image_part = glm.Part.from_data(
- data=base64.b64decode(b64),
- mime_type="image/jpeg"
- )
+ image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
input = [prompt, image_part]
res = self.client.generate_content(input)
return res.text, res.usage_metadata.total_token_count
@@ -1137,29 +1163,19 @@ class GoogleCV(Base):
"data": b64,
},
},
- {
- "type": "text",
- "text": prompt if prompt else vision_llm_describe_prompt()
- }
+ {"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()},
],
}
]
- response = self.client.messages.create(
- model=self.model_name,
- max_tokens=8192,
- messages=vision_prompt
- )
+ response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt)
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
else:
import vertexai.generative_models as glm
-
+
b64 = self.image2base64(image)
vision_prompt = prompt if prompt else vision_llm_describe_prompt()
# Create proper image part for Gemini
- image_part = glm.Part.from_data(
- data=base64.b64decode(b64),
- mime_type="image/jpeg"
- )
+ image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
input = [vision_prompt, image_part]
res = self.client.generate_content(input)
return res.text, res.usage_metadata.total_token_count
@@ -1180,25 +1196,17 @@ class GoogleCV(Base):
"data": image,
},
},
- {
- "type": "text",
- "text": his["content"]
- }
+ {"type": "text", "text": his["content"]},
]
- response = self.client.messages.create(
- model=self.model_name,
- max_tokens=8192,
- messages=history,
- temperature=gen_conf.get("temperature", 0.3),
- top_p=gen_conf.get("top_p", 0.7)
- )
+ response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
else:
import vertexai.generative_models as glm
from transformers import GenerationConfig
+
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
try:
@@ -1210,20 +1218,15 @@ class GoogleCV(Base):
if his["role"] == "user":
his["parts"] = [his["content"]]
his.pop("content")
-
+
# Create proper image part for Gemini
img_bytes = base64.b64decode(image)
- image_part = glm.Part.from_data(
- data=img_bytes,
- mime_type="image/jpeg"
- )
+ image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg")
history[-1]["parts"].append(image_part)
- response = self.client.generate_content(history, generation_config=GenerationConfig(
- temperature=gen_conf.get("temperature", 0.3),
- top_p=gen_conf.get("top_p", 0.7)))
+ response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
- return "**ERROR**: " + str(e), 0
\ No newline at end of file
+ return "**ERROR**: " + str(e), 0
diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py
index f53d492b9..fbfb7468b 100644
--- a/rag/llm/embedding_model.py
+++ b/rag/llm/embedding_model.py
@@ -13,28 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import json
import logging
+import os
import re
import threading
+from abc import ABC
from urllib.parse import urljoin
+import dashscope
+import google.generativeai as genai
+import numpy as np
import requests
from huggingface_hub import snapshot_download
-from zhipuai import ZhipuAI
-import os
-from abc import ABC
from ollama import Client
-import dashscope
from openai import OpenAI
-import numpy as np
-import asyncio
+from zhipuai import ZhipuAI
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
-import google.generativeai as genai
-import json
class Base(ABC):
@@ -60,7 +59,8 @@ class Base(ABC):
class DefaultEmbedding(Base):
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+ _FACTORY_NAME = "BAAI"
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
_model = None
_model_name = ""
_model_lock = threading.Lock()
@@ -79,21 +79,22 @@ class DefaultEmbedding(Base):
"""
if not settings.LIGHTEN:
with DefaultEmbedding._model_lock:
- from FlagEmbedding import FlagModel
import torch
+ from FlagEmbedding import FlagModel
+
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
- DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
- query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
- use_fp16=torch.cuda.is_available())
+ DefaultEmbedding._model = FlagModel(
+ os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
+ query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
+ use_fp16=torch.cuda.is_available(),
+ )
DefaultEmbedding._model_name = model_name
except Exception:
- model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
- local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
- local_dir_use_symlinks=False)
- DefaultEmbedding._model = FlagModel(model_dir,
- query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
- use_fp16=torch.cuda.is_available())
+ model_dir = snapshot_download(
+ repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
+ )
+ DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available())
self._model = DefaultEmbedding._model
self._model_name = DefaultEmbedding._model_name
@@ -105,7 +106,7 @@ class DefaultEmbedding(Base):
token_count += num_tokens_from_string(t)
ress = []
for i in range(0, len(texts), batch_size):
- ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
+ ress.extend(self._model.encode(texts[i : i + batch_size]).tolist())
return np.array(ress), token_count
def encode_queries(self, text: str):
@@ -114,8 +115,9 @@ class DefaultEmbedding(Base):
class OpenAIEmbed(Base):
- def __init__(self, key, model_name="text-embedding-ada-002",
- base_url="https://api.openai.com/v1"):
+ _FACTORY_NAME = "OpenAI"
+
+ def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
@@ -128,8 +130,7 @@ class OpenAIEmbed(Base):
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
- res = self.client.embeddings.create(input=texts[i:i + batch_size],
- model=self.model_name)
+ res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
@@ -138,12 +139,13 @@ class OpenAIEmbed(Base):
return np.array(ress), total_tokens
def encode_queries(self, text):
- res = self.client.embeddings.create(input=[truncate(text, 8191)],
- model=self.model_name)
+ res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name)
return np.array(res.data[0].embedding), self.total_token_count(res)
class LocalAIEmbed(Base):
+ _FACTORY_NAME = "LocalAI"
+
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local embedding model url cannot be None")
@@ -155,7 +157,7 @@ class LocalAIEmbed(Base):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
- res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
+ res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
except Exception as _e:
@@ -169,41 +171,42 @@ class LocalAIEmbed(Base):
class AzureEmbed(OpenAIEmbed):
+ _FACTORY_NAME = "Azure-OpenAI"
+
def __init__(self, key, model_name, **kwargs):
from openai.lib.azure import AzureOpenAI
- api_key = json.loads(key).get('api_key', '')
- api_version = json.loads(key).get('api_version', '2024-02-01')
+
+ api_key = json.loads(key).get("api_key", "")
+ api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
class BaiChuanEmbed(OpenAIEmbed):
- def __init__(self, key,
- model_name='Baichuan-Text-Embedding',
- base_url='https://api.baichuan-ai.com/v1'):
+ _FACTORY_NAME = "BaiChuan"
+
+ def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
class QWenEmbed(Base):
+ _FACTORY_NAME = "Tongyi-Qianwen"
+
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
self.key = key
self.model_name = model_name
def encode(self, texts: list):
import dashscope
+
batch_size = 4
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
- resp = dashscope.TextEmbedding.call(
- model=self.model_name,
- input=texts[i:i + batch_size],
- api_key=self.key,
- text_type="document"
- )
+ resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document")
try:
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
@@ -216,20 +219,16 @@ class QWenEmbed(Base):
return np.array(res), token_count
def encode_queries(self, text):
- resp = dashscope.TextEmbedding.call(
- model=self.model_name,
- input=text[:2048],
- api_key=self.key,
- text_type="query"
- )
+ resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query")
try:
- return np.array(resp["output"]["embeddings"][0]
- ["embedding"]), self.total_token_count(resp)
+ return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp)
except Exception as _e:
log_exception(_e, resp)
class ZhipuEmbed(Base):
+ _FACTORY_NAME = "ZHIPU-AI"
+
def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
@@ -246,8 +245,7 @@ class ZhipuEmbed(Base):
texts = [truncate(t, MAX_LEN) for t in texts]
for txt in texts:
- res = self.client.embeddings.create(input=txt,
- model=self.model_name)
+ res = self.client.embeddings.create(input=txt, model=self.model_name)
try:
arr.append(res.data[0].embedding)
tks_num += self.total_token_count(res)
@@ -256,8 +254,7 @@ class ZhipuEmbed(Base):
return np.array(arr), tks_num
def encode_queries(self, text):
- res = self.client.embeddings.create(input=text,
- model=self.model_name)
+ res = self.client.embeddings.create(input=text, model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@@ -265,18 +262,17 @@ class ZhipuEmbed(Base):
class OllamaEmbed(Base):
+ _FACTORY_NAME = "Ollama"
+
def __init__(self, key, model_name, **kwargs):
- self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \
- Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
+ self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.model_name = model_name
def encode(self, texts: list):
arr = []
tks_num = 0
for txt in texts:
- res = self.client.embeddings(prompt=txt,
- model=self.model_name,
- options={"use_mmap": True})
+ res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
try:
arr.append(res["embedding"])
except Exception as _e:
@@ -285,9 +281,7 @@ class OllamaEmbed(Base):
return np.array(arr), tks_num
def encode_queries(self, text):
- res = self.client.embeddings(prompt=text,
- model=self.model_name,
- options={"use_mmap": True})
+ res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
try:
return np.array(res["embedding"]), 128
except Exception as _e:
@@ -295,27 +289,28 @@ class OllamaEmbed(Base):
class FastEmbed(DefaultEmbedding):
-
+ _FACTORY_NAME = "FastEmbed"
+
def __init__(
- self,
- key: str | None = None,
- model_name: str = "BAAI/bge-small-en-v1.5",
- cache_dir: str | None = None,
- threads: int | None = None,
- **kwargs,
+ self,
+ key: str | None = None,
+ model_name: str = "BAAI/bge-small-en-v1.5",
+ cache_dir: str | None = None,
+ threads: int | None = None,
+ **kwargs,
):
if not settings.LIGHTEN:
with FastEmbed._model_lock:
from fastembed import TextEmbedding
+
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
try:
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
DefaultEmbedding._model_name = model_name
except Exception:
- cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
- local_dir=os.path.join(get_home_cache_dir(),
- re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
- local_dir_use_symlinks=False)
+ cache_dir = snapshot_download(
+ repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False
+ )
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
self._model = DefaultEmbedding._model
self._model_name = model_name
@@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding):
class XinferenceEmbed(Base):
+ _FACTORY_NAME = "Xinference"
+
def __init__(self, key, model_name="", base_url=""):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
@@ -350,7 +347,7 @@ class XinferenceEmbed(Base):
ress = []
total_tokens = 0
for i in range(0, len(texts), batch_size):
- res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
+ res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
total_tokens += self.total_token_count(res)
@@ -359,8 +356,7 @@ class XinferenceEmbed(Base):
return np.array(ress), total_tokens
def encode_queries(self, text):
- res = self.client.embeddings.create(input=[text],
- model=self.model_name)
+ res = self.client.embeddings.create(input=[text], model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@@ -368,20 +364,18 @@ class XinferenceEmbed(Base):
class YoudaoEmbed(Base):
+ _FACTORY_NAME = "Youdao"
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing
+
try:
logging.info("LOADING BCE...")
- YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
- get_home_cache_dir(),
- "bce-embedding-base_v1"))
+ YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1"))
except Exception:
- YoudaoEmbed._client = qanthing(
- model_name_or_path=model_name.replace(
- "maidalun1020", "InfiniFlow"))
+ YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
def encode(self, texts: list):
batch_size = 10
@@ -390,7 +384,7 @@ class YoudaoEmbed(Base):
for t in texts:
token_count += num_tokens_from_string(t)
for i in range(0, len(texts), batch_size):
- embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
+ embds = YoudaoEmbed._client.encode(texts[i : i + batch_size])
res.extend(embds)
return np.array(res), token_count
@@ -400,14 +394,11 @@ class YoudaoEmbed(Base):
class JinaEmbed(Base):
- def __init__(self, key, model_name="jina-embeddings-v3",
- base_url="https://api.jina.ai/v1/embeddings"):
+ _FACTORY_NAME = "Jina"
+ def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"):
self.base_url = "https://api.jina.ai/v1/embeddings"
- self.headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {key}"
- }
+ self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def encode(self, texts: list):
@@ -416,11 +407,7 @@ class JinaEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
- data = {
- "model": self.model_name,
- "input": texts[i:i + batch_size],
- 'encoding_type': 'float'
- }
+ data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"}
response = requests.post(self.base_url, headers=self.headers, json=data)
try:
res = response.json()
@@ -435,50 +422,12 @@ class JinaEmbed(Base):
return np.array(embds[0]), cnt
-class InfinityEmbed(Base):
- _model = None
-
- def __init__(
- self,
- model_names: list[str] = ("BAAI/bge-small-en-v1.5",),
- engine_kwargs: dict = {},
- key = None,
- ):
-
- from infinity_emb import EngineArgs
- from infinity_emb.engine import AsyncEngineArray
-
- self._default_model = model_names[0]
- self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names])
-
- async def _embed(self, sentences: list[str], model_name: str = ""):
- if not model_name:
- model_name = self._default_model
- engine = self.engine_array[model_name]
- was_already_running = engine.is_running
- if not was_already_running:
- await engine.astart()
- embeddings, usage = await engine.embed(sentences=sentences)
- if not was_already_running:
- await engine.astop()
- return embeddings, usage
-
- def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]:
- # Using the internal tokenizer to encode the texts and get the total
- # number of tokens
- embeddings, usage = asyncio.run(self._embed(texts, model_name))
- return np.array(embeddings), usage
-
- def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
- # Using the internal tokenizer to encode the texts and get the total
- # number of tokens
- return self.encode([text])
-
-
class MistralEmbed(Base):
- def __init__(self, key, model_name="mistral-embed",
- base_url=None):
+ _FACTORY_NAME = "Mistral"
+
+ def __init__(self, key, model_name="mistral-embed", base_url=None):
from mistralai.client import MistralClient
+
self.client = MistralClient(api_key=key)
self.model_name = model_name
@@ -488,8 +437,7 @@ class MistralEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
- res = self.client.embeddings(input=texts[i:i + batch_size],
- model=self.model_name)
+ res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name)
try:
ress.extend([d.embedding for d in res.data])
token_count += self.total_token_count(res)
@@ -498,8 +446,7 @@ class MistralEmbed(Base):
return np.array(ress), token_count
def encode_queries(self, text):
- res = self.client.embeddings(input=[truncate(text, 8196)],
- model=self.model_name)
+ res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name)
try:
return np.array(res.data[0].embedding), self.total_token_count(res)
except Exception as _e:
@@ -507,30 +454,31 @@ class MistralEmbed(Base):
class BedrockEmbed(Base):
- def __init__(self, key, model_name,
- **kwargs):
+ _FACTORY_NAME = "Bedrock"
+
+ def __init__(self, key, model_name, **kwargs):
import boto3
- self.bedrock_ak = json.loads(key).get('bedrock_ak', '')
- self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
- self.bedrock_region = json.loads(key).get('bedrock_region', '')
+
+ self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
+ self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
+ self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name
-
- if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
+
+ if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
- self.client = boto3.client('bedrock-runtime')
+ self.client = boto3.client("bedrock-runtime")
else:
- self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
- aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
+ self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def encode(self, texts: list):
texts = [truncate(t, 8196) for t in texts]
embeddings = []
token_count = 0
for text in texts:
- if self.model_name.split('.')[0] == 'amazon':
+ if self.model_name.split(".")[0] == "amazon":
body = {"inputText": text}
- elif self.model_name.split('.')[0] == 'cohere':
- body = {"texts": [text], "input_type": 'search_document'}
+ elif self.model_name.split(".")[0] == "cohere":
+ body = {"texts": [text], "input_type": "search_document"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
@@ -545,10 +493,10 @@ class BedrockEmbed(Base):
def encode_queries(self, text):
embeddings = []
token_count = num_tokens_from_string(text)
- if self.model_name.split('.')[0] == 'amazon':
+ if self.model_name.split(".")[0] == "amazon":
body = {"inputText": truncate(text, 8196)}
- elif self.model_name.split('.')[0] == 'cohere':
- body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
+ elif self.model_name.split(".")[0] == "cohere":
+ body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
try:
@@ -561,11 +509,12 @@ class BedrockEmbed(Base):
class GeminiEmbed(Base):
- def __init__(self, key, model_name='models/text-embedding-004',
- **kwargs):
+ _FACTORY_NAME = "Gemini"
+
+ def __init__(self, key, model_name="models/text-embedding-004", **kwargs):
self.key = key
- self.model_name = 'models/' + model_name
-
+ self.model_name = "models/" + model_name
+
def encode(self, texts: list):
texts = [truncate(t, 2048) for t in texts]
token_count = sum(num_tokens_from_string(text) for text in texts)
@@ -573,35 +522,27 @@ class GeminiEmbed(Base):
batch_size = 16
ress = []
for i in range(0, len(texts), batch_size):
- result = genai.embed_content(
- model=self.model_name,
- content=texts[i: i + batch_size],
- task_type="retrieval_document",
- title="Embedding of single string")
+ result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string")
try:
- ress.extend(result['embedding'])
+ ress.extend(result["embedding"])
except Exception as _e:
log_exception(_e, result)
- return np.array(ress),token_count
-
+ return np.array(ress), token_count
+
def encode_queries(self, text):
genai.configure(api_key=self.key)
- result = genai.embed_content(
- model=self.model_name,
- content=truncate(text,2048),
- task_type="retrieval_document",
- title="Embedding of single string")
+ result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string")
token_count = num_tokens_from_string(text)
try:
- return np.array(result['embedding']), token_count
+ return np.array(result["embedding"]), token_count
except Exception as _e:
log_exception(_e, result)
class NvidiaEmbed(Base):
- def __init__(
- self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
- ):
+ _FACTORY_NAME = "NVIDIA"
+
+ def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"):
if not base_url:
base_url = "https://integrate.api.nvidia.com/v1/embeddings"
self.api_key = key
@@ -645,6 +586,8 @@ class NvidiaEmbed(Base):
class LmStudioEmbed(LocalAIEmbed):
+ _FACTORY_NAME = "LM-Studio"
+
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local llm url cannot be None")
@@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed):
class OpenAI_APIEmbed(OpenAIEmbed):
+ _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
+
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed):
class CoHereEmbed(Base):
+ _FACTORY_NAME = "Cohere"
+
def __init__(self, key, model_name, base_url=None):
from cohere import Client
@@ -701,6 +648,8 @@ class CoHereEmbed(Base):
class TogetherAIEmbed(OpenAIEmbed):
+ _FACTORY_NAME = "TogetherAI"
+
def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
@@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed):
class PerfXCloudEmbed(OpenAIEmbed):
+ _FACTORY_NAME = "PerfXCloud"
+
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
if not base_url:
base_url = "https://cloud.perfxlab.cn/v1"
@@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed):
class UpstageEmbed(OpenAIEmbed):
+ _FACTORY_NAME = "Upstage"
+
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"):
if not base_url:
base_url = "https://api.upstage.ai/v1/solar"
@@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed):
class SILICONFLOWEmbed(Base):
+ _FACTORY_NAME = "SILICONFLOW"
+
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/embeddings"
@@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base):
class ReplicateEmbed(Base):
+ _FACTORY_NAME = "Replicate"
+
def __init__(self, key, model_name, base_url=None):
from replicate.client import Client
@@ -790,6 +747,8 @@ class ReplicateEmbed(Base):
class BaiduYiyanEmbed(Base):
+ _FACTORY_NAME = "BaiduYiyan"
+
def __init__(self, key, model_name, base_url=None):
import qianfan
@@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base):
class VoyageEmbed(Base):
+ _FACTORY_NAME = "Voyage AI"
+
def __init__(self, key, model_name, base_url=None):
import voyageai
@@ -832,9 +793,7 @@ class VoyageEmbed(Base):
ress = []
token_count = 0
for i in range(0, len(texts), batch_size):
- res = self.client.embed(
- texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
- )
+ res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document")
try:
ress.extend(res.embeddings)
token_count += res.total_tokens
@@ -843,9 +802,7 @@ class VoyageEmbed(Base):
return np.array(ress), token_count
def encode_queries(self, text):
- res = self.client.embed(
- texts=text, model=self.model_name, input_type="query"
- )
+ res = self.client.embed(texts=text, model=self.model_name, input_type="query")
try:
return np.array(res.embeddings)[0], res.total_tokens
except Exception as _e:
@@ -853,6 +810,8 @@ class VoyageEmbed(Base):
class HuggingFaceEmbed(Base):
+ _FACTORY_NAME = "HuggingFace"
+
def __init__(self, key, model_name, base_url=None):
if not model_name:
raise ValueError("Model name cannot be None")
@@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base):
def encode(self, texts: list):
embeddings = []
for text in texts:
- response = requests.post(
- f"{self.base_url}/embed",
- json={"inputs": text},
- headers={'Content-Type': 'application/json'}
- )
+ response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
embeddings.append(embedding[0])
@@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base):
return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts])
def encode_queries(self, text):
- response = requests.post(
- f"{self.base_url}/embed",
- json={"inputs": text},
- headers={'Content-Type': 'application/json'}
- )
+ response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"})
if response.status_code == 200:
embedding = response.json()
return np.array(embedding[0]), num_tokens_from_string(text)
@@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base):
class VolcEngineEmbed(OpenAIEmbed):
+ _FACTORY_NAME = "VolcEngine"
+
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
- ark_api_key = json.loads(key).get('ark_api_key', '')
- model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
- super().__init__(ark_api_key,model_name,base_url)
+ ark_api_key = json.loads(key).get("ark_api_key", "")
+ model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
+ super().__init__(ark_api_key, model_name, base_url)
class GPUStackEmbed(OpenAIEmbed):
+ _FACTORY_NAME = "GPUStack"
+
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed):
class NovitaEmbed(SILICONFLOWEmbed):
+ _FACTORY_NAME = "NovitaAI"
+
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/embeddings"
@@ -915,7 +872,9 @@ class NovitaEmbed(SILICONFLOWEmbed):
class GiteeEmbed(SILICONFLOWEmbed):
+ _FACTORY_NAME = "GiteeAI"
+
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"):
if not base_url:
base_url = "https://ai.gitee.com/v1/embeddings"
- super().__init__(key, model_name, base_url)
\ No newline at end of file
+ super().__init__(key, model_name, base_url)
diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py
index fafab7ee0..88fda1478 100644
--- a/rag/llm/rerank_model.py
+++ b/rag/llm/rerank_model.py
@@ -13,24 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import json
+import os
import re
import threading
+from abc import ABC
from collections.abc import Iterable
from urllib.parse import urljoin
-import requests
import httpx
-from huggingface_hub import snapshot_download
-import os
-from abc import ABC
import numpy as np
+import requests
+from huggingface_hub import snapshot_download
from yarl import URL
from api import settings
from api.utils.file_utils import get_home_cache_dir
from api.utils.log_utils import log_exception
from rag.utils import num_tokens_from_string, truncate
-import json
def sigmoid(x):
@@ -57,6 +57,7 @@ class Base(ABC):
class DefaultRerank(Base):
+ _FACTORY_NAME = "BAAI"
_model = None
_model_lock = threading.Lock()
@@ -75,17 +76,13 @@ class DefaultRerank(Base):
if not settings.LIGHTEN and not DefaultRerank._model:
import torch
from FlagEmbedding import FlagReranker
+
with DefaultRerank._model_lock:
if not DefaultRerank._model:
try:
- DefaultRerank._model = FlagReranker(
- os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
- use_fp16=torch.cuda.is_available())
+ DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available())
except Exception:
- model_dir = snapshot_download(repo_id=model_name,
- local_dir=os.path.join(get_home_cache_dir(),
- re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
- local_dir_use_symlinks=False)
+ model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False)
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
self._model = DefaultRerank._model
self._dynamic_batch_size = 8
@@ -94,6 +91,7 @@ class DefaultRerank(Base):
def torch_empty_cache(self):
try:
import torch
+
torch.cuda.empty_cache()
except Exception as e:
print(f"Error emptying cache: {e}")
@@ -112,7 +110,7 @@ class DefaultRerank(Base):
while retry_count < max_retries:
try:
# call subclass implemented batch processing calculation
- batch_scores = self._compute_batch_scores(pairs[i:i + current_batch])
+ batch_scores = self._compute_batch_scores(pairs[i : i + current_batch])
res.extend(batch_scores)
i += current_batch
self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8)
@@ -152,23 +150,16 @@ class DefaultRerank(Base):
class JinaRerank(Base):
- def __init__(self, key, model_name="jina-reranker-v2-base-multilingual",
- base_url="https://api.jina.ai/v1/rerank"):
+ _FACTORY_NAME = "Jina"
+
+ def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"):
self.base_url = "https://api.jina.ai/v1/rerank"
- self.headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {key}"
- }
+ self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name
def similarity(self, query: str, texts: list):
texts = [truncate(t, 8196) for t in texts]
- data = {
- "model": self.model_name,
- "query": query,
- "documents": texts,
- "top_n": len(texts)
- }
+ data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
@@ -180,22 +171,20 @@ class JinaRerank(Base):
class YoudaoRerank(DefaultRerank):
+ _FACTORY_NAME = "Youdao"
_model = None
_model_lock = threading.Lock()
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel
+
with YoudaoRerank._model_lock:
if not YoudaoRerank._model:
try:
- YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
- get_home_cache_dir(),
- re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
+ YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)))
except Exception:
- YoudaoRerank._model = RerankerModel(
- model_name_or_path=model_name.replace(
- "maidalun1020", "InfiniFlow"))
+ YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow"))
self._model = YoudaoRerank._model
self._dynamic_batch_size = 8
@@ -212,6 +201,8 @@ class YoudaoRerank(DefaultRerank):
class XInferenceRerank(Base):
+ _FACTORY_NAME = "Xinference"
+
def __init__(self, key="x", model_name="", base_url=""):
if base_url.find("/v1") == -1:
base_url = urljoin(base_url, "/v1/rerank")
@@ -219,10 +210,7 @@ class XInferenceRerank(Base):
base_url = urljoin(base_url, "/v1/rerank")
self.model_name = model_name
self.base_url = base_url
- self.headers = {
- "Content-Type": "application/json",
- "accept": "application/json"
- }
+ self.headers = {"Content-Type": "application/json", "accept": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bearer {key}"
@@ -233,13 +221,7 @@ class XInferenceRerank(Base):
token_count = 0
for _, t in pairs:
token_count += num_tokens_from_string(t)
- data = {
- "model": self.model_name,
- "query": query,
- "return_documents": "true",
- "return_len": "true",
- "documents": texts
- }
+ data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
rank = np.zeros(len(texts), dtype=float)
try:
@@ -251,15 +233,14 @@ class XInferenceRerank(Base):
class LocalAIRerank(Base):
+ _FACTORY_NAME = "LocalAI"
+
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
- self.headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {key}"
- }
+ self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
@@ -296,16 +277,15 @@ class LocalAIRerank(Base):
class NvidiaRerank(Base):
- def __init__(
- self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
- ):
+ _FACTORY_NAME = "NVIDIA"
+
+ def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"):
if not base_url:
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
self.model_name = model_name
if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
- self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking"
- )
+ self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking")
if self.model_name == "nvidia/rerank-qa-mistral-4b":
self.base_url = urljoin(base_url, "reranking")
@@ -318,9 +298,7 @@ class NvidiaRerank(Base):
}
def similarity(self, query: str, texts: list):
- token_count = num_tokens_from_string(query) + sum(
- [num_tokens_from_string(t) for t in texts]
- )
+ token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
data = {
"model": self.model_name,
"query": {"text": query},
@@ -339,6 +317,8 @@ class NvidiaRerank(Base):
class LmStudioRerank(Base):
+ _FACTORY_NAME = "LM-Studio"
+
def __init__(self, key, model_name, base_url):
pass
@@ -347,15 +327,14 @@ class LmStudioRerank(Base):
class OpenAI_APIRerank(Base):
+ _FACTORY_NAME = "OpenAI-API-Compatible"
+
def __init__(self, key, model_name, base_url):
if base_url.find("/rerank") == -1:
self.base_url = urljoin(base_url, "/rerank")
else:
self.base_url = base_url
- self.headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {key}"
- }
+ self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
@@ -392,6 +371,8 @@ class OpenAI_APIRerank(Base):
class CoHereRerank(Base):
+ _FACTORY_NAME = ["Cohere", "VLLM"]
+
def __init__(self, key, model_name, base_url=None):
from cohere import Client
@@ -399,9 +380,7 @@ class CoHereRerank(Base):
self.model_name = model_name.split("___")[0]
def similarity(self, query: str, texts: list):
- token_count = num_tokens_from_string(query) + sum(
- [num_tokens_from_string(t) for t in texts]
- )
+ token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts])
res = self.client.rerank(
model=self.model_name,
query=query,
@@ -419,6 +398,8 @@ class CoHereRerank(Base):
class TogetherAIRerank(Base):
+ _FACTORY_NAME = "TogetherAI"
+
def __init__(self, key, model_name, base_url):
pass
@@ -427,9 +408,9 @@ class TogetherAIRerank(Base):
class SILICONFLOWRerank(Base):
- def __init__(
- self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
- ):
+ _FACTORY_NAME = "SILICONFLOW"
+
+ def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1/rerank"
self.model_name = model_name
@@ -450,9 +431,7 @@ class SILICONFLOWRerank(Base):
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
}
- response = requests.post(
- self.base_url, json=payload, headers=self.headers
- ).json()
+ response = requests.post(self.base_url, json=payload, headers=self.headers).json()
rank = np.zeros(len(texts), dtype=float)
try:
for d in response["results"]:
@@ -466,6 +445,8 @@ class SILICONFLOWRerank(Base):
class BaiduYiyanRerank(Base):
+ _FACTORY_NAME = "BaiduYiyan"
+
def __init__(self, key, model_name, base_url=None):
from qianfan.resources import Reranker
@@ -492,6 +473,8 @@ class BaiduYiyanRerank(Base):
class VoyageRerank(Base):
+ _FACTORY_NAME = "Voyage AI"
+
def __init__(self, key, model_name, base_url=None):
import voyageai
@@ -502,9 +485,7 @@ class VoyageRerank(Base):
rank = np.zeros(len(texts), dtype=float)
if not texts:
return rank, 0
- res = self.client.rerank(
- query=query, documents=texts, model=self.model_name, top_k=len(texts)
- )
+ res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts))
try:
for r in res.results:
rank[r.index] = r.relevance_score
@@ -514,22 +495,20 @@ class VoyageRerank(Base):
class QWenRerank(Base):
- def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs):
+ _FACTORY_NAME = "Tongyi-Qianwen"
+
+ def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs):
import dashscope
+
self.api_key = key
self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name
def similarity(self, query: str, texts: list):
- import dashscope
from http import HTTPStatus
- resp = dashscope.TextReRank.call(
- api_key=self.api_key,
- model=self.model_name,
- query=query,
- documents=texts,
- top_n=len(texts),
- return_documents=False
- )
+
+ import dashscope
+
+ resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False)
rank = np.zeros(len(texts), dtype=float)
if resp.status_code == HTTPStatus.OK:
try:
@@ -543,6 +522,8 @@ class QWenRerank(Base):
class HuggingfaceRerank(DefaultRerank):
+ _FACTORY_NAME = "HuggingFace"
+
@staticmethod
def post(query: str, texts: list, url="127.0.0.1"):
exc = None
@@ -550,9 +531,9 @@ class HuggingfaceRerank(DefaultRerank):
batch_size = 8
for i in range(0, len(texts), batch_size):
try:
- res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"},
- json={"query": query, "texts": texts[i: i + batch_size],
- "raw_scores": False, "truncate": True})
+ res = requests.post(
+ f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True}
+ )
for o in res.json():
scores[o["index"] + i] = o["score"]
@@ -577,9 +558,9 @@ class HuggingfaceRerank(DefaultRerank):
class GPUStackRerank(Base):
- def __init__(
- self, key, model_name, base_url
- ):
+ _FACTORY_NAME = "GPUStack"
+
+ def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -600,9 +581,7 @@ class GPUStackRerank(Base):
}
try:
- response = requests.post(
- self.base_url, json=payload, headers=self.headers
- )
+ response = requests.post(self.base_url, json=payload, headers=self.headers)
response.raise_for_status()
response_json = response.json()
@@ -623,11 +602,12 @@ class GPUStackRerank(Base):
)
except httpx.HTTPStatusError as e:
- raise ValueError(
- f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
+ raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}")
class NovitaRerank(JinaRerank):
+ _FACTORY_NAME = "NovitaAI"
+
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"):
if not base_url:
base_url = "https://api.novita.ai/v3/openai/rerank"
@@ -635,7 +615,9 @@ class NovitaRerank(JinaRerank):
class GiteeRerank(JinaRerank):
+ _FACTORY_NAME = "GiteeAI"
+
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"):
if not base_url:
base_url = "https://ai.gitee.com/v1/rerank"
- super().__init__(key, model_name, base_url)
\ No newline at end of file
+ super().__init__(key, model_name, base_url)
diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py
index 193f0e14d..7d6c24b76 100644
--- a/rag/llm/sequence2txt_model.py
+++ b/rag/llm/sequence2txt_model.py
@@ -13,16 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import os
-import requests
-from openai.lib.azure import AzureOpenAI
-import io
-from abc import ABC
-from openai import OpenAI
-import json
-from rag.utils import num_tokens_from_string
import base64
+import io
+import json
+import os
import re
+from abc import ABC
+
+import requests
+from openai import OpenAI
+from openai.lib.azure import AzureOpenAI
+
+from rag.utils import num_tokens_from_string
class Base(ABC):
@@ -30,11 +32,7 @@ class Base(ABC):
pass
def transcription(self, audio, **kwargs):
- transcription = self.client.audio.transcriptions.create(
- model=self.model_name,
- file=audio,
- response_format="text"
- )
+ transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text")
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self, audio):
@@ -46,6 +44,8 @@ class Base(ABC):
class GPTSeq2txt(Base):
+ _FACTORY_NAME = "OpenAI"
+
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
@@ -54,31 +54,34 @@ class GPTSeq2txt(Base):
class QWenSeq2txt(Base):
+ _FACTORY_NAME = "Tongyi-Qianwen"
+
def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
import dashscope
+
dashscope.api_key = key
self.model_name = model_name
def transcription(self, audio, format):
from http import HTTPStatus
+
from dashscope.audio.asr import Recognition
- recognition = Recognition(model=self.model_name,
- format=format,
- sample_rate=16000,
- callback=None)
+ recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None)
result = recognition.call(audio)
ans = ""
if result.status_code == HTTPStatus.OK:
for sentence in result.get_sentence():
- ans += sentence.text.decode('utf-8') + '\n'
+ ans += sentence.text.decode("utf-8") + "\n"
return ans, num_tokens_from_string(ans)
return "**ERROR**: " + result.message, 0
class AzureSeq2txt(Base):
+ _FACTORY_NAME = "Azure-OpenAI"
+
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
@@ -86,43 +89,33 @@ class AzureSeq2txt(Base):
class XinferenceSeq2txt(Base):
+ _FACTORY_NAME = "Xinference"
+
def __init__(self, key, model_name="whisper-small", **kwargs):
- self.base_url = kwargs.get('base_url', None)
+ self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.key = key
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
- audio_file = open(audio, 'rb')
+ audio_file = open(audio, "rb")
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"
- payload = {
- "model": self.model_name,
- "language": language,
- "prompt": prompt,
- "response_format": response_format,
- "temperature": temperature
- }
+ payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
- files = {
- "file": (audio_file_name, audio_data, 'audio/wav')
- }
+ files = {"file": (audio_file_name, audio_data, "audio/wav")}
try:
- response = requests.post(
- f"{self.base_url}/v1/audio/transcriptions",
- files=files,
- data=payload
- )
+ response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
response.raise_for_status()
result = response.json()
- if 'text' in result:
- transcription_text = result['text'].strip()
+ if "text" in result:
+ transcription_text = result["text"].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0
@@ -132,11 +125,11 @@ class XinferenceSeq2txt(Base):
class TencentCloudSeq2txt(Base):
- def __init__(
- self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
- ):
- from tencentcloud.common import credential
+ _FACTORY_NAME = "Tencent Cloud"
+
+ def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
from tencentcloud.asr.v20190614 import asr_client
+ from tencentcloud.common import credential
key = json.loads(key)
sid = key.get("tencent_cloud_sid", "")
@@ -146,11 +139,12 @@ class TencentCloudSeq2txt(Base):
self.model_name = model_name
def transcription(self, audio, max_retries=60, retry_interval=5):
+ import time
+
+ from tencentcloud.asr.v20190614 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
- from tencentcloud.asr.v20190614 import models
- import time
b64 = self.audio2base64(audio)
try:
@@ -174,9 +168,7 @@ class TencentCloudSeq2txt(Base):
while retries < max_retries:
resp = self.client.DescribeTaskStatus(req)
if resp.Data.StatusStr == "success":
- text = re.sub(
- r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
- ).strip()
+ text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
return text, num_tokens_from_string(text)
elif resp.Data.StatusStr == "failed":
return (
@@ -195,6 +187,8 @@ class TencentCloudSeq2txt(Base):
class GPUStackSeq2txt(Base):
+ _FACTORY_NAME = "GPUStack"
+
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
@@ -206,8 +200,11 @@ class GPUStackSeq2txt(Base):
class GiteeSeq2txt(Base):
+ _FACTORY_NAME = "GiteeAI"
+
def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
if not base_url:
base_url = "https://ai.gitee.com/v1/"
self.client = OpenAI(api_key=key, base_url=base_url)
- self.model_name = model_name
\ No newline at end of file
+ self.model_name = model_name
+
diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py
index 7111b2f60..b2333b426 100644
--- a/rag/llm/tts_model.py
+++ b/rag/llm/tts_model.py
@@ -70,10 +70,12 @@ class Base(ABC):
pass
def normalize_text(self, text):
- return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
+ return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
class FishAudioTTS(Base):
+ _FACTORY_NAME = "Fish Audio"
+
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
if not base_url:
base_url = "https://api.fish.audio/v1/tts"
@@ -94,13 +96,11 @@ class FishAudioTTS(Base):
with httpx.Client() as client:
try:
with client.stream(
- method="POST",
- url=self.base_url,
- content=ormsgpack.packb(
- request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
- ),
- headers=self.headers,
- timeout=None,
+ method="POST",
+ url=self.base_url,
+ content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+ headers=self.headers,
+ timeout=None,
) as response:
if response.status_code == HTTPStatus.OK:
for chunk in response.iter_bytes():
@@ -115,6 +115,8 @@ class FishAudioTTS(Base):
class QwenTTS(Base):
+ _FACTORY_NAME = "Tongyi-Qianwen"
+
def __init__(self, key, model_name, base_url=""):
import dashscope
@@ -122,10 +124,11 @@ class QwenTTS(Base):
dashscope.api_key = key
def tts(self, text):
- from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
- from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
from collections import deque
+ from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
+ from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
+
class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
@@ -159,10 +162,7 @@ class QwenTTS(Base):
text = self.normalize_text(text)
callback = Callback()
- SpeechSynthesizer.call(model=self.model_name,
- text=text,
- callback=callback,
- format="mp3")
+ SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
try:
for data in callback._run():
yield data
@@ -173,24 +173,19 @@ class QwenTTS(Base):
class OpenAITTS(Base):
+ _FACTORY_NAME = "OpenAI"
+
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
- self.headers = {
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json"
- }
+ self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def tts(self, text, voice="alloy"):
text = self.normalize_text(text)
- payload = {
- "model": self.model_name,
- "voice": voice,
- "input": text
- }
+ payload = {"model": self.model_name, "voice": voice, "input": text}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
@@ -201,7 +196,8 @@ class OpenAITTS(Base):
yield chunk
-class SparkTTS:
+class SparkTTS(Base):
+ _FACTORY_NAME = "XunFei Spark"
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2
@@ -219,29 +215,23 @@ class SparkTTS:
# 生成url
def create_url(self):
- url = 'wss://tts-api.xfyun.cn/v2/tts'
+ url = "wss://tts-api.xfyun.cn/v2/tts"
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
- signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
- digestmod=hashlib.sha256).digest()
- signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
- authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
- self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
- authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
- v = {
- "authorization": authorization,
- "date": date,
- "host": "ws-api.xfyun.cn"
- }
- url = url + '?' + urlencode(v)
+ signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
+ signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
+ authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
+ authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
+ v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
+ url = url + "?" + urlencode(v)
return url
def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
- Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
+ Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name
@@ -273,9 +263,7 @@ class SparkTTS:
def on_open(self, ws):
def run(*args):
- d = {"common": CommonArgs,
- "business": BusinessArgs,
- "data": Data}
+ d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
ws.send(json.dumps(d))
thread.start_new_thread(run, ())
@@ -283,44 +271,32 @@ class SparkTTS:
wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
- ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
- on_message=a.on_message)
+ ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
- raise Exception(
- f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
+ raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk
-class XinferenceTTS:
+class XinferenceTTS(Base):
+ _FACTORY_NAME = "Xinference"
+
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
- self.headers = {
- "accept": "application/json",
- "Content-Type": "application/json"
- }
+ self.headers = {"accept": "application/json", "Content-Type": "application/json"}
def tts(self, text, voice="中文女", stream=True):
- payload = {
- "model": self.model_name,
- "input": text,
- "voice": voice
- }
+ payload = {"model": self.model_name, "input": text, "voice": voice}
- response = requests.post(
- f"{self.base_url}/v1/audio/speech",
- headers=self.headers,
- json=payload,
- stream=stream
- )
+ response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
@@ -332,22 +308,16 @@ class XinferenceTTS:
class OllamaTTS(Base):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
- if not base_url:
+ if not base_url:
base_url = "https://api.ollama.ai/v1"
self.model_name = model_name
self.base_url = base_url
- self.headers = {
- "Content-Type": "application/json"
- }
+ self.headers = {"Content-Type": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bear {key}"
def tts(self, text, voice="standard-voice"):
- payload = {
- "model": self.model_name,
- "voice": voice,
- "input": text
- }
+ payload = {"model": self.model_name, "voice": voice, "input": text}
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
@@ -359,30 +329,19 @@ class OllamaTTS(Base):
yield chunk
-class GPUStackTTS:
+class GPUStackTTS(Base):
+ _FACTORY_NAME = "GPUStack"
+
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.api_key = key
self.model_name = model_name
- self.headers = {
- "accept": "application/json",
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}"
- }
+ self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def tts(self, text, voice="Chinese Female", stream=True):
- payload = {
- "model": self.model_name,
- "input": text,
- "voice": voice
- }
+ payload = {"model": self.model_name, "input": text, "voice": voice}
- response = requests.post(
- f"{self.base_url}/v1/audio/speech",
- headers=self.headers,
- json=payload,
- stream=stream
- )
+ response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
@@ -393,16 +352,15 @@ class GPUStackTTS:
class SILICONFLOWTTS(Base):
+ _FACTORY_NAME = "SILICONFLOW"
+
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
- self.headers = {
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json"
- }
+ self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def tts(self, text, voice="anna"):
text = self.normalize_text(text)
@@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base):
"sample_rate": 123,
"stream": True,
"speed": 1,
- "gain": 0
+ "gain": 0,
}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)