diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index b22d722d0..e9542bbe8 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -15,289 +15,53 @@ # # AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency! # -from .embedding_model import ( - OllamaEmbed, - LocalAIEmbed, - OpenAIEmbed, - AzureEmbed, - XinferenceEmbed, - QWenEmbed, - ZhipuEmbed, - FastEmbed, - YoudaoEmbed, - BaiChuanEmbed, - JinaEmbed, - DefaultEmbedding, - MistralEmbed, - BedrockEmbed, - GeminiEmbed, - NvidiaEmbed, - LmStudioEmbed, - OpenAI_APIEmbed, - CoHereEmbed, - TogetherAIEmbed, - PerfXCloudEmbed, - UpstageEmbed, - SILICONFLOWEmbed, - ReplicateEmbed, - BaiduYiyanEmbed, - VoyageEmbed, - HuggingFaceEmbed, - VolcEngineEmbed, - GPUStackEmbed, - NovitaEmbed, - GiteeEmbed -) -from .chat_model import ( - GptTurbo, - AzureChat, - ZhipuChat, - QWenChat, - OllamaChat, - LocalAIChat, - XinferenceChat, - MoonshotChat, - DeepSeekChat, - VolcEngineChat, - BaiChuanChat, - MiniMaxChat, - MistralChat, - GeminiChat, - BedrockChat, - GroqChat, - OpenRouterChat, - StepFunChat, - NvidiaChat, - LmStudioChat, - OpenAI_APIChat, - CoHereChat, - LeptonAIChat, - TogetherAIChat, - PerfXCloudChat, - UpstageChat, - NovitaAIChat, - SILICONFLOWChat, - PPIOChat, - YiChat, - ReplicateChat, - HunyuanChat, - SparkChat, - BaiduYiyanChat, - AnthropicChat, - GoogleChat, - HuggingFaceChat, - GPUStackChat, - ModelScopeChat, - GiteeChat -) -from .cv_model import ( - GptV4, - AzureGptV4, - OllamaCV, - XinferenceCV, - QWenCV, - Zhipu4V, - LocalCV, - GeminiCV, - OpenRouterCV, - LocalAICV, - NvidiaCV, - LmStudioCV, - StepFunCV, - OpenAI_APICV, - TogetherAICV, - YiCV, - HunyuanCV, - AnthropicCV, - SILICONFLOWCV, - GPUStackCV, - GoogleCV, -) +import importlib +import inspect -from .rerank_model import ( - LocalAIRerank, - DefaultRerank, - JinaRerank, - YoudaoRerank, - XInferenceRerank, - NvidiaRerank, - LmStudioRerank, - OpenAI_APIRerank, - CoHereRerank, - TogetherAIRerank, - SILICONFLOWRerank, - BaiduYiyanRerank, - VoyageRerank, - QWenRerank, - GPUStackRerank, - HuggingfaceRerank, - NovitaRerank, - GiteeRerank -) +ChatModel = globals().get("ChatModel", {}) +CvModel = globals().get("CvModel", {}) +EmbeddingModel = globals().get("EmbeddingModel", {}) +RerankModel = globals().get("RerankModel", {}) +Seq2txtModel = globals().get("Seq2txtModel", {}) +TTSModel = globals().get("TTSModel", {}) -from .sequence2txt_model import ( - GPTSeq2txt, - QWenSeq2txt, - AzureSeq2txt, - XinferenceSeq2txt, - TencentCloudSeq2txt, - GPUStackSeq2txt, - GiteeSeq2txt -) - -from .tts_model import ( - FishAudioTTS, - QwenTTS, - OpenAITTS, - SparkTTS, - XinferenceTTS, - GPUStackTTS, - SILICONFLOWTTS, -) - -EmbeddingModel = { - "Ollama": OllamaEmbed, - "LocalAI": LocalAIEmbed, - "OpenAI": OpenAIEmbed, - "Azure-OpenAI": AzureEmbed, - "Xinference": XinferenceEmbed, - "Tongyi-Qianwen": QWenEmbed, - "ZHIPU-AI": ZhipuEmbed, - "FastEmbed": FastEmbed, - "Youdao": YoudaoEmbed, - "BaiChuan": BaiChuanEmbed, - "Jina": JinaEmbed, - "BAAI": DefaultEmbedding, - "Mistral": MistralEmbed, - "Bedrock": BedrockEmbed, - "Gemini": GeminiEmbed, - "NVIDIA": NvidiaEmbed, - "LM-Studio": LmStudioEmbed, - "OpenAI-API-Compatible": OpenAI_APIEmbed, - "VLLM": OpenAI_APIEmbed, - "Cohere": CoHereEmbed, - "TogetherAI": TogetherAIEmbed, - "PerfXCloud": PerfXCloudEmbed, - "Upstage": UpstageEmbed, - "SILICONFLOW": SILICONFLOWEmbed, - "Replicate": ReplicateEmbed, - "BaiduYiyan": BaiduYiyanEmbed, - "Voyage AI": VoyageEmbed, - "HuggingFace": HuggingFaceEmbed, - "VolcEngine": VolcEngineEmbed, - "GPUStack": GPUStackEmbed, - "NovitaAI": NovitaEmbed, - "GiteeAI": GiteeEmbed +MODULE_MAPPING = { + "chat_model": ChatModel, + "cv_model": CvModel, + "embedding_model": EmbeddingModel, + "rerank_model": RerankModel, + "sequence2txt_model": Seq2txtModel, + "tts_model": TTSModel, } -CvModel = { - "OpenAI": GptV4, - "Azure-OpenAI": AzureGptV4, - "Ollama": OllamaCV, - "Xinference": XinferenceCV, - "Tongyi-Qianwen": QWenCV, - "ZHIPU-AI": Zhipu4V, - "Moonshot": LocalCV, - "Gemini": GeminiCV, - "OpenRouter": OpenRouterCV, - "LocalAI": LocalAICV, - "NVIDIA": NvidiaCV, - "LM-Studio": LmStudioCV, - "StepFun": StepFunCV, - "OpenAI-API-Compatible": OpenAI_APICV, - "VLLM": OpenAI_APICV, - "TogetherAI": TogetherAICV, - "01.AI": YiCV, - "Tencent Hunyuan": HunyuanCV, - "Anthropic": AnthropicCV, - "SILICONFLOW": SILICONFLOWCV, - "GPUStack": GPUStackCV, - "Google Cloud": GoogleCV -} +package_name = __name__ -ChatModel = { - "OpenAI": GptTurbo, - "Azure-OpenAI": AzureChat, - "ZHIPU-AI": ZhipuChat, - "Tongyi-Qianwen": QWenChat, - "Ollama": OllamaChat, - "LocalAI": LocalAIChat, - "Xinference": XinferenceChat, - "Moonshot": MoonshotChat, - "DeepSeek": DeepSeekChat, - "VolcEngine": VolcEngineChat, - "BaiChuan": BaiChuanChat, - "MiniMax": MiniMaxChat, - "Mistral": MistralChat, - "Gemini": GeminiChat, - "Bedrock": BedrockChat, - "Groq": GroqChat, - "OpenRouter": OpenRouterChat, - "StepFun": StepFunChat, - "NVIDIA": NvidiaChat, - "LM-Studio": LmStudioChat, - "OpenAI-API-Compatible": OpenAI_APIChat, - "VLLM": OpenAI_APIChat, - "Cohere": CoHereChat, - "LeptonAI": LeptonAIChat, - "TogetherAI": TogetherAIChat, - "PerfXCloud": PerfXCloudChat, - "Upstage": UpstageChat, - "NovitaAI": NovitaAIChat, - "SILICONFLOW": SILICONFLOWChat, - "PPIO": PPIOChat, - "01.AI": YiChat, - "Replicate": ReplicateChat, - "Tencent Hunyuan": HunyuanChat, - "XunFei Spark": SparkChat, - "BaiduYiyan": BaiduYiyanChat, - "Anthropic": AnthropicChat, - "Google Cloud": GoogleChat, - "HuggingFace": HuggingFaceChat, - "GPUStack": GPUStackChat, - "ModelScope":ModelScopeChat, - "GiteeAI": GiteeChat -} +for module_name, mapping_dict in MODULE_MAPPING.items(): + full_module_name = f"{package_name}.{module_name}" + module = importlib.import_module(full_module_name) -RerankModel = { - "LocalAI": LocalAIRerank, - "BAAI": DefaultRerank, - "Jina": JinaRerank, - "Youdao": YoudaoRerank, - "Xinference": XInferenceRerank, - "NVIDIA": NvidiaRerank, - "LM-Studio": LmStudioRerank, - "OpenAI-API-Compatible": OpenAI_APIRerank, - "VLLM": CoHereRerank, - "Cohere": CoHereRerank, - "TogetherAI": TogetherAIRerank, - "SILICONFLOW": SILICONFLOWRerank, - "BaiduYiyan": BaiduYiyanRerank, - "Voyage AI": VoyageRerank, - "Tongyi-Qianwen": QWenRerank, - "GPUStack": GPUStackRerank, - "HuggingFace": HuggingfaceRerank, - "NovitaAI": NovitaRerank, - "GiteeAI": GiteeRerank -} + base_class = None + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and name == "Base": + base_class = obj + break + if base_class is None: + continue -Seq2txtModel = { - "OpenAI": GPTSeq2txt, - "Tongyi-Qianwen": QWenSeq2txt, - "Azure-OpenAI": AzureSeq2txt, - "Xinference": XinferenceSeq2txt, - "Tencent Cloud": TencentCloudSeq2txt, - "GPUStack": GPUStackSeq2txt, - "GiteeAI": GiteeSeq2txt -} + for _, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"): + if isinstance(obj._FACTORY_NAME, list): + for factory_name in obj._FACTORY_NAME: + mapping_dict[factory_name] = obj + else: + mapping_dict[obj._FACTORY_NAME] = obj -TTSModel = { - "Fish Audio": FishAudioTTS, - "Tongyi-Qianwen": QwenTTS, - "OpenAI": OpenAITTS, - "XunFei Spark": SparkTTS, - "Xinference": XinferenceTTS, - "GPUStack": GPUStackTTS, - "SILICONFLOW": SILICONFLOWTTS, -} +__all__ = [ + "ChatModel", + "CvModel", + "EmbeddingModel", + "RerankModel", + "Seq2txtModel", + "TTSModel", +] diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 020815104..f254ed186 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -142,11 +142,7 @@ class Base(ABC): return f"{ERROR_PREFIX}: {error_code} - {str(e)}" def _verbose_tool_use(self, name, args, res): - return "" + json.dumps({ - "name": name, - "args": args, - "result": res - }, ensure_ascii=False, indent=2) + "" + return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" def _append_history(self, hist, tool_call, tool_res): hist.append( @@ -191,10 +187,10 @@ class Base(ABC): tk_count = 0 hist = deepcopy(history) # Implement exponential backoff retry strategy - for attempt in range(self.max_retries+1): + for attempt in range(self.max_retries + 1): history = hist try: - for _ in range(self.max_rounds*2): + for _ in range(self.max_rounds * 2): response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf) tk_count += self.total_token_count(response) if any([not response.choices, not response.choices[0].message]): @@ -269,7 +265,7 @@ class Base(ABC): for attempt in range(self.max_retries + 1): history = hist try: - for _ in range(self.max_rounds*2): + for _ in range(self.max_rounds * 2): reasoning_start = False response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) final_tool_calls = {} @@ -430,6 +426,8 @@ class Base(ABC): class GptTurbo(Base): + _FACTORY_NAME = "OpenAI" + def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs): if not base_url: base_url = "https://api.openai.com/v1" @@ -437,6 +435,8 @@ class GptTurbo(Base): class MoonshotChat(Base): + _FACTORY_NAME = "Moonshot" + def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1", **kwargs): if not base_url: base_url = "https://api.moonshot.cn/v1" @@ -444,6 +444,8 @@ class MoonshotChat(Base): class XinferenceChat(Base): + _FACTORY_NAME = "Xinference" + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") @@ -452,6 +454,8 @@ class XinferenceChat(Base): class HuggingFaceChat(Base): + _FACTORY_NAME = "HuggingFace" + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") @@ -460,6 +464,8 @@ class HuggingFaceChat(Base): class ModelScopeChat(Base): + _FACTORY_NAME = "ModelScope" + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") @@ -468,6 +474,8 @@ class ModelScopeChat(Base): class DeepSeekChat(Base): + _FACTORY_NAME = "DeepSeek" + def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1", **kwargs): if not base_url: base_url = "https://api.deepseek.com/v1" @@ -475,6 +483,8 @@ class DeepSeekChat(Base): class AzureChat(Base): + _FACTORY_NAME = "Azure-OpenAI" + def __init__(self, key, model_name, base_url, **kwargs): api_key = json.loads(key).get("api_key", "") api_version = json.loads(key).get("api_version", "2024-02-01") @@ -484,6 +494,8 @@ class AzureChat(Base): class BaiChuanChat(Base): + _FACTORY_NAME = "BaiChuan" + def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs): if not base_url: base_url = "https://api.baichuan-ai.com/v1" @@ -557,6 +569,8 @@ class BaiChuanChat(Base): class QWenChat(Base): + _FACTORY_NAME = "Tongyi-Qianwen" + def __init__(self, key, model_name=Generation.Models.qwen_turbo, base_url=None, **kwargs): if not base_url: base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" @@ -565,6 +579,8 @@ class QWenChat(Base): class ZhipuChat(Base): + _FACTORY_NAME = "ZHIPU-AI" + def __init__(self, key, model_name="glm-3-turbo", base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -630,6 +646,8 @@ class ZhipuChat(Base): class OllamaChat(Base): + _FACTORY_NAME = "Ollama" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -694,6 +712,8 @@ class OllamaChat(Base): class LocalAIChat(Base): + _FACTORY_NAME = "LocalAI" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -752,6 +772,8 @@ class LocalLLM(Base): class VolcEngineChat(Base): + _FACTORY_NAME = "VolcEngine" + def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs): """ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, @@ -765,6 +787,8 @@ class VolcEngineChat(Base): class MiniMaxChat(Base): + _FACTORY_NAME = "MiniMax" + def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -843,6 +867,8 @@ class MiniMaxChat(Base): class MistralChat(Base): + _FACTORY_NAME = "Mistral" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -896,6 +922,8 @@ class MistralChat(Base): class BedrockChat(Base): + _FACTORY_NAME = "Bedrock" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -978,6 +1006,8 @@ class BedrockChat(Base): class GeminiChat(Base): + _FACTORY_NAME = "Gemini" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -997,6 +1027,7 @@ class GeminiChat(Base): def _chat(self, history, gen_conf): from google.generativeai.types import content_types + system = history[0]["content"] if history and history[0]["role"] == "system" else "" hist = [] for item in history: @@ -1019,6 +1050,7 @@ class GeminiChat(Base): def chat_streamly(self, system, history, gen_conf): from google.generativeai.types import content_types + gen_conf = self._clean_conf(gen_conf) if system: self.model._system_instruction = content_types.to_content(system) @@ -1042,6 +1074,8 @@ class GeminiChat(Base): class GroqChat(Base): + _FACTORY_NAME = "Groq" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -1086,6 +1120,8 @@ class GroqChat(Base): ## openrouter class OpenRouterChat(Base): + _FACTORY_NAME = "OpenRouter" + def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs): if not base_url: base_url = "https://openrouter.ai/api/v1" @@ -1093,6 +1129,8 @@ class OpenRouterChat(Base): class StepFunChat(Base): + _FACTORY_NAME = "StepFun" + def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs): if not base_url: base_url = "https://api.stepfun.com/v1" @@ -1100,6 +1138,8 @@ class StepFunChat(Base): class NvidiaChat(Base): + _FACTORY_NAME = "NVIDIA" + def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1", **kwargs): if not base_url: base_url = "https://integrate.api.nvidia.com/v1" @@ -1107,6 +1147,8 @@ class NvidiaChat(Base): class LmStudioChat(Base): + _FACTORY_NAME = "LM-Studio" + def __init__(self, key, model_name, base_url, **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") @@ -1117,6 +1159,8 @@ class LmStudioChat(Base): class OpenAI_APIChat(Base): + _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] + def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") @@ -1125,6 +1169,8 @@ class OpenAI_APIChat(Base): class PPIOChat(Base): + _FACTORY_NAME = "PPIO" + def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs): if not base_url: base_url = "https://api.ppinfra.com/v3/openai" @@ -1132,6 +1178,8 @@ class PPIOChat(Base): class CoHereChat(Base): + _FACTORY_NAME = "Cohere" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -1207,6 +1255,8 @@ class CoHereChat(Base): class LeptonAIChat(Base): + _FACTORY_NAME = "LeptonAI" + def __init__(self, key, model_name, base_url=None, **kwargs): if not base_url: base_url = urljoin("https://" + model_name + ".lepton.run", "api/v1") @@ -1214,6 +1264,8 @@ class LeptonAIChat(Base): class TogetherAIChat(Base): + _FACTORY_NAME = "TogetherAI" + def __init__(self, key, model_name, base_url="https://api.together.xyz/v1", **kwargs): if not base_url: base_url = "https://api.together.xyz/v1" @@ -1221,6 +1273,8 @@ class TogetherAIChat(Base): class PerfXCloudChat(Base): + _FACTORY_NAME = "PerfXCloud" + def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs): if not base_url: base_url = "https://cloud.perfxlab.cn/v1" @@ -1228,6 +1282,8 @@ class PerfXCloudChat(Base): class UpstageChat(Base): + _FACTORY_NAME = "Upstage" + def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs): if not base_url: base_url = "https://api.upstage.ai/v1/solar" @@ -1235,6 +1291,8 @@ class UpstageChat(Base): class NovitaAIChat(Base): + _FACTORY_NAME = "NovitaAI" + def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs): if not base_url: base_url = "https://api.novita.ai/v3/openai" @@ -1242,6 +1300,8 @@ class NovitaAIChat(Base): class SILICONFLOWChat(Base): + _FACTORY_NAME = "SILICONFLOW" + def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs): if not base_url: base_url = "https://api.siliconflow.cn/v1" @@ -1249,6 +1309,8 @@ class SILICONFLOWChat(Base): class YiChat(Base): + _FACTORY_NAME = "01.AI" + def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs): if not base_url: base_url = "https://api.lingyiwanwu.com/v1" @@ -1256,6 +1318,8 @@ class YiChat(Base): class GiteeChat(Base): + _FACTORY_NAME = "GiteeAI" + def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs): if not base_url: base_url = "https://ai.gitee.com/v1/" @@ -1263,6 +1327,8 @@ class GiteeChat(Base): class ReplicateChat(Base): + _FACTORY_NAME = "Replicate" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -1302,6 +1368,8 @@ class ReplicateChat(Base): class HunyuanChat(Base): + _FACTORY_NAME = "Tencent Hunyuan" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -1378,6 +1446,8 @@ class HunyuanChat(Base): class SparkChat(Base): + _FACTORY_NAME = "XunFei Spark" + def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1", **kwargs): if not base_url: base_url = "https://spark-api-open.xf-yun.com/v1" @@ -1398,6 +1468,8 @@ class SparkChat(Base): class BaiduYiyanChat(Base): + _FACTORY_NAME = "BaiduYiyan" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -1444,6 +1516,8 @@ class BaiduYiyanChat(Base): class AnthropicChat(Base): + _FACTORY_NAME = "Anthropic" + def __init__(self, key, model_name, base_url="https://api.anthropic.com/v1/", **kwargs): if not base_url: base_url = "https://api.anthropic.com/v1/" @@ -1451,6 +1525,8 @@ class AnthropicChat(Base): class GoogleChat(Base): + _FACTORY_NAME = "Google Cloud" + def __init__(self, key, model_name, base_url=None, **kwargs): super().__init__(key, model_name, base_url=base_url, **kwargs) @@ -1529,9 +1605,11 @@ class GoogleChat(Base): if "role" in item and item["role"] == "assistant": item["role"] = "model" if "content" in item: - item["parts"] = [{ - "text": item.pop("content"), - }] + item["parts"] = [ + { + "text": item.pop("content"), + } + ] response = self.client.generate_content(hist, generation_config=gen_conf) ans = response.text @@ -1587,8 +1665,10 @@ class GoogleChat(Base): class GPUStackChat(Base): + _FACTORY_NAME = "GPUStack" + def __init__(self, key=None, model_name="", base_url="", **kwargs): if not base_url: raise ValueError("Local llm url cannot be None") base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url, **kwargs) \ No newline at end of file + super().__init__(key, model_name, base_url, **kwargs) diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index da7efaef1..afa39d69a 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -57,7 +57,7 @@ class Base(ABC): model=self.model_name, messages=history, temperature=gen_conf.get("temperature", 0.3), - top_p=gen_conf.get("top_p", 0.7) + top_p=gen_conf.get("top_p", 0.7), ) return response.choices[0].message.content.strip(), response.usage.total_tokens except Exception as e: @@ -79,7 +79,7 @@ class Base(ABC): messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7), - stream=True + stream=True, ) for resp in response: if not resp.choices[0].delta.content: @@ -87,8 +87,7 @@ class Base(ABC): delta = resp.choices[0].delta.content ans += delta if resp.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" tk_count = resp.usage.total_tokens if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens @@ -117,13 +116,12 @@ class Base(ABC): "content": [ { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{b64}" - }, + "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, }, { - "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else - "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", + "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" + if self.lang.lower() == "chinese" + else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", }, ], } @@ -136,9 +134,7 @@ class Base(ABC): "content": [ { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{b64}" - }, + "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, }, { "type": "text", @@ -156,14 +152,13 @@ class Base(ABC): "url": f"data:image/jpeg;base64,{b64}", }, }, - { - "type": "text", - "text": text - }, + {"type": "text", "text": text}, ] class GptV4(Base): + _FACTORY_NAME = "OpenAI" + def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): if not base_url: base_url = "https://api.openai.com/v1" @@ -181,7 +176,7 @@ class GptV4(Base): res = self.client.chat.completions.create( model=self.model_name, - messages=prompt + messages=prompt, ) return res.choices[0].message.content.strip(), res.usage.total_tokens @@ -197,9 +192,11 @@ class GptV4(Base): class AzureGptV4(Base): + _FACTORY_NAME = "Azure-OpenAI" + def __init__(self, key, model_name, lang="Chinese", **kwargs): - api_key = json.loads(key).get('api_key', '') - api_version = json.loads(key).get('api_version', '2024-02-01') + api_key = json.loads(key).get("api_key", "") + api_version = json.loads(key).get("api_version", "2024-02-01") self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name self.lang = lang @@ -212,10 +209,7 @@ class AzureGptV4(Base): if "text" in c: c["type"] = "text" - res = self.client.chat.completions.create( - model=self.model_name, - messages=prompt - ) + res = self.client.chat.completions.create(model=self.model_name, messages=prompt) return res.choices[0].message.content.strip(), res.usage.total_tokens def describe_with_prompt(self, image, prompt=None): @@ -230,8 +224,11 @@ class AzureGptV4(Base): class QWenCV(Base): + _FACTORY_NAME = "Tongyi-Qianwen" + def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs): import dashscope + dashscope.api_key = key self.model_name = model_name self.lang = lang @@ -247,12 +244,11 @@ class QWenCV(Base): { "role": "user", "content": [ + {"image": f"file://{path}"}, { - "image": f"file://{path}" - }, - { - "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else - "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", + "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" + if self.lang.lower() == "chinese" + else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", }, ], } @@ -270,11 +266,9 @@ class QWenCV(Base): { "role": "user", "content": [ + {"image": f"file://{path}"}, { - "image": f"file://{path}" - }, - { - "text": prompt if prompt else vision_llm_describe_prompt(), + "text": prompt if prompt else vision_llm_describe_prompt(), }, ], } @@ -290,9 +284,10 @@ class QWenCV(Base): from http import HTTPStatus from dashscope import MultiModalConversation + response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image)) if response.status_code == HTTPStatus.OK: - return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens + return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens return response.message, 0 def describe_with_prompt(self, image, prompt=None): @@ -303,33 +298,36 @@ class QWenCV(Base): vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image) response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt) if response.status_code == HTTPStatus.OK: - return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens + return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens return response.message, 0 def chat(self, system, history, gen_conf, image=""): from http import HTTPStatus from dashscope import MultiModalConversation + if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] for his in history: if his["role"] == "user": his["content"] = self.chat_prompt(his["content"], image) - response = MultiModalConversation.call(model=self.model_name, messages=history, - temperature=gen_conf.get("temperature", 0.3), - top_p=gen_conf.get("top_p", 0.7)) + response = MultiModalConversation.call( + model=self.model_name, + messages=history, + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7), + ) ans = "" tk_count = 0 if response.status_code == HTTPStatus.OK: - ans = response.output.choices[0]['message']['content'] + ans = response.output.choices[0]["message"]["content"] if isinstance(ans, list): ans = ans[0]["text"] if ans else "" tk_count += response.usage.total_tokens if response.output.choices[0].get("finish_reason", "") == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ans, tk_count return "**ERROR**: " + response.message, tk_count @@ -338,6 +336,7 @@ class QWenCV(Base): from http import HTTPStatus from dashscope import MultiModalConversation + if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] @@ -348,24 +347,25 @@ class QWenCV(Base): ans = "" tk_count = 0 try: - response = MultiModalConversation.call(model=self.model_name, messages=history, - temperature=gen_conf.get("temperature", 0.3), - top_p=gen_conf.get("top_p", 0.7), - stream=True) + response = MultiModalConversation.call( + model=self.model_name, + messages=history, + temperature=gen_conf.get("temperature", 0.3), + top_p=gen_conf.get("top_p", 0.7), + stream=True, + ) for resp in response: if resp.status_code == HTTPStatus.OK: - cnt = resp.output.choices[0]['message']['content'] + cnt = resp.output.choices[0]["message"]["content"] if isinstance(cnt, list): cnt = cnt[0]["text"] if ans else "" ans += cnt tk_count = resp.usage.total_tokens if resp.output.choices[0].get("finish_reason", "") == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" yield ans else: - yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find( - "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**" + yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**" except Exception as e: yield ans + "\n**ERROR**: " + str(e) @@ -373,6 +373,8 @@ class QWenCV(Base): class Zhipu4V(Base): + _FACTORY_NAME = "ZHIPU-AI" + def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): self.client = ZhipuAI(api_key=key) self.model_name = model_name @@ -394,10 +396,7 @@ class Zhipu4V(Base): b64 = self.image2base64(image) vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) - res = self.client.chat.completions.create( - model=self.model_name, - messages=vision_prompt - ) + res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt) return res.choices[0].message.content.strip(), res.usage.total_tokens def chat(self, system, history, gen_conf, image=""): @@ -412,7 +411,7 @@ class Zhipu4V(Base): model=self.model_name, messages=history, temperature=gen_conf.get("temperature", 0.3), - top_p=gen_conf.get("top_p", 0.7) + top_p=gen_conf.get("top_p", 0.7), ) return response.choices[0].message.content.strip(), response.usage.total_tokens except Exception as e: @@ -434,7 +433,7 @@ class Zhipu4V(Base): messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7), - stream=True + stream=True, ) for resp in response: if not resp.choices[0].delta.content: @@ -442,8 +441,7 @@ class Zhipu4V(Base): delta = resp.choices[0].delta.content ans += delta if resp.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english( - [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" tk_count = resp.usage.total_tokens if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens @@ -455,6 +453,8 @@ class Zhipu4V(Base): class OllamaCV(Base): + _FACTORY_NAME = "Ollama" + def __init__(self, key, model_name, lang="Chinese", **kwargs): self.client = Client(host=kwargs["base_url"]) self.model_name = model_name @@ -466,7 +466,7 @@ class OllamaCV(Base): response = self.client.generate( model=self.model_name, prompt=prompt[0]["content"][1]["text"], - images=[image] + images=[image], ) ans = response["response"].strip() return ans, 128 @@ -507,7 +507,7 @@ class OllamaCV(Base): model=self.model_name, messages=history, options=options, - keep_alive=-1 + keep_alive=-1, ) ans = response["message"]["content"].strip() @@ -538,7 +538,7 @@ class OllamaCV(Base): messages=history, stream=True, options=options, - keep_alive=-1 + keep_alive=-1, ) for resp in response: if resp["done"]: @@ -551,6 +551,8 @@ class OllamaCV(Base): class LocalAICV(GptV4): + _FACTORY_NAME = "LocalAI" + def __init__(self, key, model_name, base_url, lang="Chinese"): if not base_url: raise ValueError("Local cv model url cannot be None") @@ -561,6 +563,8 @@ class LocalAICV(GptV4): class XinferenceCV(Base): + _FACTORY_NAME = "Xinference" + def __init__(self, key, model_name="", lang="Chinese", base_url=""): base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) @@ -570,10 +574,7 @@ class XinferenceCV(Base): def describe(self, image): b64 = self.image2base64(image) - res = self.client.chat.completions.create( - model=self.model_name, - messages=self.prompt(b64) - ) + res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64)) return res.choices[0].message.content.strip(), res.usage.total_tokens def describe_with_prompt(self, image, prompt=None): @@ -588,8 +589,11 @@ class XinferenceCV(Base): class GeminiCV(Base): + _FACTORY_NAME = "Gemini" + def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): from google.generativeai import GenerativeModel, client + client.configure(api_key=key) _client = client.get_default_generative_client() self.model_name = model_name @@ -599,18 +603,21 @@ class GeminiCV(Base): def describe(self, image): from PIL.Image import open - prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ - "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." + + prompt = ( + "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" + if self.lang.lower() == "chinese" + else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." + ) b64 = self.image2base64(image) img = open(BytesIO(base64.b64decode(b64))) input = [prompt, img] - res = self.model.generate_content( - input - ) + res = self.model.generate_content(input) return res.text, res.usage_metadata.total_token_count def describe_with_prompt(self, image, prompt=None): from PIL.Image import open + b64 = self.image2base64(image) vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) img = open(BytesIO(base64.b64decode(b64))) @@ -622,6 +629,7 @@ class GeminiCV(Base): def chat(self, system, history, gen_conf, image=""): from transformers import GenerationConfig + if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] try: @@ -635,9 +643,7 @@ class GeminiCV(Base): his.pop("content") history[-1]["parts"].append("data:image/jpeg;base64," + image) - response = self.model.generate_content(history, generation_config=GenerationConfig( - temperature=gen_conf.get("temperature", 0.3), - top_p=gen_conf.get("top_p", 0.7))) + response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))) ans = response.text return ans, response.usage_metadata.total_token_count @@ -646,6 +652,7 @@ class GeminiCV(Base): def chat_streamly(self, system, history, gen_conf, image=""): from transformers import GenerationConfig + if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] @@ -661,9 +668,11 @@ class GeminiCV(Base): his.pop("content") history[-1]["parts"].append("data:image/jpeg;base64," + image) - response = self.model.generate_content(history, generation_config=GenerationConfig( - temperature=gen_conf.get("temperature", 0.3), - top_p=gen_conf.get("top_p", 0.7)), stream=True) + response = self.model.generate_content( + history, + generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)), + stream=True, + ) for resp in response: if not resp.text: @@ -677,6 +686,8 @@ class GeminiCV(Base): class OpenRouterCV(GptV4): + _FACTORY_NAME = "OpenRouter" + def __init__( self, key, @@ -692,6 +703,8 @@ class OpenRouterCV(GptV4): class LocalCV(Base): + _FACTORY_NAME = "Moonshot" + def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): pass @@ -700,6 +713,8 @@ class LocalCV(Base): class NvidiaCV(Base): + _FACTORY_NAME = "NVIDIA" + def __init__( self, key, @@ -726,9 +741,7 @@ class NvidiaCV(Base): "content-type": "application/json", "Authorization": f"Bearer {self.key}", }, - json={ - "messages": self.prompt(b64) - }, + json={"messages": self.prompt(b64)}, ) response = response.json() return ( @@ -774,10 +787,7 @@ class NvidiaCV(Base): return [ { "role": "user", - "content": ( - prompt if prompt else vision_llm_describe_prompt() - ) - + f' ', + "content": (prompt if prompt else vision_llm_describe_prompt()) + f' ', } ] @@ -791,6 +801,8 @@ class NvidiaCV(Base): class StepFunCV(GptV4): + _FACTORY_NAME = "StepFun" + def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): if not base_url: base_url = "https://api.stepfun.com/v1" @@ -800,6 +812,8 @@ class StepFunCV(GptV4): class LmStudioCV(GptV4): + _FACTORY_NAME = "LM-Studio" + def __init__(self, key, model_name, lang="Chinese", base_url=""): if not base_url: raise ValueError("Local llm url cannot be None") @@ -810,6 +824,8 @@ class LmStudioCV(GptV4): class OpenAI_APICV(GptV4): + _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] + def __init__(self, key, model_name, lang="Chinese", base_url=""): if not base_url: raise ValueError("url cannot be None") @@ -820,6 +836,8 @@ class OpenAI_APICV(GptV4): class TogetherAICV(GptV4): + _FACTORY_NAME = "TogetherAI" + def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"): if not base_url: base_url = "https://api.together.xyz/v1" @@ -827,20 +845,38 @@ class TogetherAICV(GptV4): class YiCV(GptV4): - def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",): + _FACTORY_NAME = "01.AI" + + def __init__( + self, + key, + model_name, + lang="Chinese", + base_url="https://api.lingyiwanwu.com/v1", + ): if not base_url: base_url = "https://api.lingyiwanwu.com/v1" super().__init__(key, model_name, lang, base_url) class SILICONFLOWCV(GptV4): - def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",): + _FACTORY_NAME = "SILICONFLOW" + + def __init__( + self, + key, + model_name, + lang="Chinese", + base_url="https://api.siliconflow.cn/v1", + ): if not base_url: base_url = "https://api.siliconflow.cn/v1" super().__init__(key, model_name, lang, base_url) class HunyuanCV(Base): + _FACTORY_NAME = "Tencent Hunyuan" + def __init__(self, key, model_name, lang="Chinese", base_url=None): from tencentcloud.common import credential from tencentcloud.hunyuan.v20230901 import hunyuan_client @@ -895,14 +931,13 @@ class HunyuanCV(Base): "Contents": [ { "Type": "image_url", - "ImageUrl": { - "Url": f"data:image/jpeg;base64,{b64}" - }, + "ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"}, }, { "Type": "text", - "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else - "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", + "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" + if self.lang.lower() == "chinese" + else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", }, ], } @@ -910,6 +945,8 @@ class HunyuanCV(Base): class AnthropicCV(Base): + _FACTORY_NAME = "Anthropic" + def __init__(self, key, model_name, base_url=None): import anthropic @@ -933,38 +970,29 @@ class AnthropicCV(Base): "data": b64, }, }, - { - "type": "text", - "text": prompt - } + {"type": "text", "text": prompt}, ], } ] def describe(self, image): b64 = self.image2base64(image) - prompt = self.prompt(b64, - "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else - "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." - ) - - response = self.client.messages.create( - model=self.model_name, - max_tokens=self.max_tokens, - messages=prompt + prompt = self.prompt( + b64, + "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" + if self.lang.lower() == "chinese" + else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", ) - return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"] + + response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt) + return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"] def describe_with_prompt(self, image, prompt=None): b64 = self.image2base64(image) prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt()) - response = self.client.messages.create( - model=self.model_name, - max_tokens=self.max_tokens, - messages=prompt - ) - return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"] + response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt) + return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"] def chat(self, system, history, gen_conf): if "presence_penalty" in gen_conf: @@ -984,11 +1012,7 @@ class AnthropicCV(Base): ).to_dict() ans = response["content"][0]["text"] if response["stop_reason"] == "max_tokens": - ans += ( - "...\nFor the content length reason, it stopped, continue?" - if is_english([ans]) - else "······\n由于长度的原因,回答被截断了,要继续吗?" - ) + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ( ans, response["usage"]["input_tokens"] + response["usage"]["output_tokens"], @@ -1014,7 +1038,7 @@ class AnthropicCV(Base): **gen_conf, ) for res in response: - if res.type == 'content_block_delta': + if res.type == "content_block_delta": if res.delta.type == "thinking_delta" and res.delta.thinking: if ans.find("") < 0: ans += "" @@ -1030,7 +1054,10 @@ class AnthropicCV(Base): yield total_tokens + class GPUStackCV(GptV4): + _FACTORY_NAME = "GPUStack" + def __init__(self, key, model_name, lang="Chinese", base_url=""): if not base_url: raise ValueError("Local llm url cannot be None") @@ -1041,11 +1068,13 @@ class GPUStackCV(GptV4): class GoogleCV(Base): + _FACTORY_NAME = "Google Cloud" + def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs): import base64 from google.oauth2 import service_account - + key = json.loads(key) access_token = json.loads(base64.b64decode(key.get("google_service_account_key", ""))) project_id = key.get("google_project_id", "") @@ -1079,9 +1108,12 @@ class GoogleCV(Base): self.client = glm.GenerativeModel(model_name=self.model_name) def describe(self, image): - prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ - "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." - + prompt = ( + "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" + if self.lang.lower() == "chinese" + else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." + ) + if "claude" in self.model_name: b64 = self.image2base64(image) vision_prompt = [ @@ -1096,28 +1128,22 @@ class GoogleCV(Base): "data": b64, }, }, - { - "type": "text", - "text": prompt - } + {"type": "text", "text": prompt}, ], } ] response = self.client.messages.create( model=self.model_name, max_tokens=8192, - messages=vision_prompt + messages=vision_prompt, ) return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens else: import vertexai.generative_models as glm - + b64 = self.image2base64(image) # Create proper image part for Gemini - image_part = glm.Part.from_data( - data=base64.b64decode(b64), - mime_type="image/jpeg" - ) + image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg") input = [prompt, image_part] res = self.client.generate_content(input) return res.text, res.usage_metadata.total_token_count @@ -1137,29 +1163,19 @@ class GoogleCV(Base): "data": b64, }, }, - { - "type": "text", - "text": prompt if prompt else vision_llm_describe_prompt() - } + {"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()}, ], } ] - response = self.client.messages.create( - model=self.model_name, - max_tokens=8192, - messages=vision_prompt - ) + response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt) return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens else: import vertexai.generative_models as glm - + b64 = self.image2base64(image) vision_prompt = prompt if prompt else vision_llm_describe_prompt() # Create proper image part for Gemini - image_part = glm.Part.from_data( - data=base64.b64decode(b64), - mime_type="image/jpeg" - ) + image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg") input = [vision_prompt, image_part] res = self.client.generate_content(input) return res.text, res.usage_metadata.total_token_count @@ -1180,25 +1196,17 @@ class GoogleCV(Base): "data": image, }, }, - { - "type": "text", - "text": his["content"] - } + {"type": "text", "text": his["content"]}, ] - response = self.client.messages.create( - model=self.model_name, - max_tokens=8192, - messages=history, - temperature=gen_conf.get("temperature", 0.3), - top_p=gen_conf.get("top_p", 0.7) - ) + response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)) return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens except Exception as e: return "**ERROR**: " + str(e), 0 else: import vertexai.generative_models as glm from transformers import GenerationConfig + if system: history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] try: @@ -1210,20 +1218,15 @@ class GoogleCV(Base): if his["role"] == "user": his["parts"] = [his["content"]] his.pop("content") - + # Create proper image part for Gemini img_bytes = base64.b64decode(image) - image_part = glm.Part.from_data( - data=img_bytes, - mime_type="image/jpeg" - ) + image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg") history[-1]["parts"].append(image_part) - response = self.client.generate_content(history, generation_config=GenerationConfig( - temperature=gen_conf.get("temperature", 0.3), - top_p=gen_conf.get("top_p", 0.7))) + response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))) ans = response.text return ans, response.usage_metadata.total_token_count except Exception as e: - return "**ERROR**: " + str(e), 0 \ No newline at end of file + return "**ERROR**: " + str(e), 0 diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index f53d492b9..fbfb7468b 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -13,28 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import logging +import os import re import threading +from abc import ABC from urllib.parse import urljoin +import dashscope +import google.generativeai as genai +import numpy as np import requests from huggingface_hub import snapshot_download -from zhipuai import ZhipuAI -import os -from abc import ABC from ollama import Client -import dashscope from openai import OpenAI -import numpy as np -import asyncio +from zhipuai import ZhipuAI from api import settings from api.utils.file_utils import get_home_cache_dir from api.utils.log_utils import log_exception from rag.utils import num_tokens_from_string, truncate -import google.generativeai as genai -import json class Base(ABC): @@ -60,7 +59,8 @@ class Base(ABC): class DefaultEmbedding(Base): - os.environ['CUDA_VISIBLE_DEVICES'] = '0' + _FACTORY_NAME = "BAAI" + os.environ["CUDA_VISIBLE_DEVICES"] = "0" _model = None _model_name = "" _model_lock = threading.Lock() @@ -79,21 +79,22 @@ class DefaultEmbedding(Base): """ if not settings.LIGHTEN: with DefaultEmbedding._model_lock: - from FlagEmbedding import FlagModel import torch + from FlagEmbedding import FlagModel + if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: try: - DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), - query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", - use_fp16=torch.cuda.is_available()) + DefaultEmbedding._model = FlagModel( + os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), + query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", + use_fp16=torch.cuda.is_available(), + ) DefaultEmbedding._model_name = model_name except Exception: - model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", - local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), - local_dir_use_symlinks=False) - DefaultEmbedding._model = FlagModel(model_dir, - query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", - use_fp16=torch.cuda.is_available()) + model_dir = snapshot_download( + repo_id="BAAI/bge-large-zh-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False + ) + DefaultEmbedding._model = FlagModel(model_dir, query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=torch.cuda.is_available()) self._model = DefaultEmbedding._model self._model_name = DefaultEmbedding._model_name @@ -105,7 +106,7 @@ class DefaultEmbedding(Base): token_count += num_tokens_from_string(t) ress = [] for i in range(0, len(texts), batch_size): - ress.extend(self._model.encode(texts[i:i + batch_size]).tolist()) + ress.extend(self._model.encode(texts[i : i + batch_size]).tolist()) return np.array(ress), token_count def encode_queries(self, text: str): @@ -114,8 +115,9 @@ class DefaultEmbedding(Base): class OpenAIEmbed(Base): - def __init__(self, key, model_name="text-embedding-ada-002", - base_url="https://api.openai.com/v1"): + _FACTORY_NAME = "OpenAI" + + def __init__(self, key, model_name="text-embedding-ada-002", base_url="https://api.openai.com/v1"): if not base_url: base_url = "https://api.openai.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) @@ -128,8 +130,7 @@ class OpenAIEmbed(Base): ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): - res = self.client.embeddings.create(input=texts[i:i + batch_size], - model=self.model_name) + res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) try: ress.extend([d.embedding for d in res.data]) total_tokens += self.total_token_count(res) @@ -138,12 +139,13 @@ class OpenAIEmbed(Base): return np.array(ress), total_tokens def encode_queries(self, text): - res = self.client.embeddings.create(input=[truncate(text, 8191)], - model=self.model_name) + res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name) return np.array(res.data[0].embedding), self.total_token_count(res) class LocalAIEmbed(Base): + _FACTORY_NAME = "LocalAI" + def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("Local embedding model url cannot be None") @@ -155,7 +157,7 @@ class LocalAIEmbed(Base): batch_size = 16 ress = [] for i in range(0, len(texts), batch_size): - res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) + res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) try: ress.extend([d.embedding for d in res.data]) except Exception as _e: @@ -169,41 +171,42 @@ class LocalAIEmbed(Base): class AzureEmbed(OpenAIEmbed): + _FACTORY_NAME = "Azure-OpenAI" + def __init__(self, key, model_name, **kwargs): from openai.lib.azure import AzureOpenAI - api_key = json.loads(key).get('api_key', '') - api_version = json.loads(key).get('api_version', '2024-02-01') + + api_key = json.loads(key).get("api_key", "") + api_version = json.loads(key).get("api_version", "2024-02-01") self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name class BaiChuanEmbed(OpenAIEmbed): - def __init__(self, key, - model_name='Baichuan-Text-Embedding', - base_url='https://api.baichuan-ai.com/v1'): + _FACTORY_NAME = "BaiChuan" + + def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https://api.baichuan-ai.com/v1"): if not base_url: base_url = "https://api.baichuan-ai.com/v1" super().__init__(key, model_name, base_url) class QWenEmbed(Base): + _FACTORY_NAME = "Tongyi-Qianwen" + def __init__(self, key, model_name="text_embedding_v2", **kwargs): self.key = key self.model_name = model_name def encode(self, texts: list): import dashscope + batch_size = 4 res = [] token_count = 0 texts = [truncate(t, 2048) for t in texts] for i in range(0, len(texts), batch_size): - resp = dashscope.TextEmbedding.call( - model=self.model_name, - input=texts[i:i + batch_size], - api_key=self.key, - text_type="document" - ) + resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") try: embds = [[] for _ in range(len(resp["output"]["embeddings"]))] for e in resp["output"]["embeddings"]: @@ -216,20 +219,16 @@ class QWenEmbed(Base): return np.array(res), token_count def encode_queries(self, text): - resp = dashscope.TextEmbedding.call( - model=self.model_name, - input=text[:2048], - api_key=self.key, - text_type="query" - ) + resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query") try: - return np.array(resp["output"]["embeddings"][0] - ["embedding"]), self.total_token_count(resp) + return np.array(resp["output"]["embeddings"][0]["embedding"]), self.total_token_count(resp) except Exception as _e: log_exception(_e, resp) class ZhipuEmbed(Base): + _FACTORY_NAME = "ZHIPU-AI" + def __init__(self, key, model_name="embedding-2", **kwargs): self.client = ZhipuAI(api_key=key) self.model_name = model_name @@ -246,8 +245,7 @@ class ZhipuEmbed(Base): texts = [truncate(t, MAX_LEN) for t in texts] for txt in texts: - res = self.client.embeddings.create(input=txt, - model=self.model_name) + res = self.client.embeddings.create(input=txt, model=self.model_name) try: arr.append(res.data[0].embedding) tks_num += self.total_token_count(res) @@ -256,8 +254,7 @@ class ZhipuEmbed(Base): return np.array(arr), tks_num def encode_queries(self, text): - res = self.client.embeddings.create(input=text, - model=self.model_name) + res = self.client.embeddings.create(input=text, model=self.model_name) try: return np.array(res.data[0].embedding), self.total_token_count(res) except Exception as _e: @@ -265,18 +262,17 @@ class ZhipuEmbed(Base): class OllamaEmbed(Base): + _FACTORY_NAME = "Ollama" + def __init__(self, key, model_name, **kwargs): - self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \ - Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) + self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) self.model_name = model_name def encode(self, texts: list): arr = [] tks_num = 0 for txt in texts: - res = self.client.embeddings(prompt=txt, - model=self.model_name, - options={"use_mmap": True}) + res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}) try: arr.append(res["embedding"]) except Exception as _e: @@ -285,9 +281,7 @@ class OllamaEmbed(Base): return np.array(arr), tks_num def encode_queries(self, text): - res = self.client.embeddings(prompt=text, - model=self.model_name, - options={"use_mmap": True}) + res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}) try: return np.array(res["embedding"]), 128 except Exception as _e: @@ -295,27 +289,28 @@ class OllamaEmbed(Base): class FastEmbed(DefaultEmbedding): - + _FACTORY_NAME = "FastEmbed" + def __init__( - self, - key: str | None = None, - model_name: str = "BAAI/bge-small-en-v1.5", - cache_dir: str | None = None, - threads: int | None = None, - **kwargs, + self, + key: str | None = None, + model_name: str = "BAAI/bge-small-en-v1.5", + cache_dir: str | None = None, + threads: int | None = None, + **kwargs, ): if not settings.LIGHTEN: with FastEmbed._model_lock: from fastembed import TextEmbedding + if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name: try: DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) DefaultEmbedding._model_name = model_name except Exception: - cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5", - local_dir=os.path.join(get_home_cache_dir(), - re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), - local_dir_use_symlinks=False) + cache_dir = snapshot_download( + repo_id="BAAI/bge-small-en-v1.5", local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False + ) DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) self._model = DefaultEmbedding._model self._model_name = model_name @@ -340,6 +335,8 @@ class FastEmbed(DefaultEmbedding): class XinferenceEmbed(Base): + _FACTORY_NAME = "Xinference" + def __init__(self, key, model_name="", base_url=""): base_url = urljoin(base_url, "v1") self.client = OpenAI(api_key=key, base_url=base_url) @@ -350,7 +347,7 @@ class XinferenceEmbed(Base): ress = [] total_tokens = 0 for i in range(0, len(texts), batch_size): - res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) + res = self.client.embeddings.create(input=texts[i : i + batch_size], model=self.model_name) try: ress.extend([d.embedding for d in res.data]) total_tokens += self.total_token_count(res) @@ -359,8 +356,7 @@ class XinferenceEmbed(Base): return np.array(ress), total_tokens def encode_queries(self, text): - res = self.client.embeddings.create(input=[text], - model=self.model_name) + res = self.client.embeddings.create(input=[text], model=self.model_name) try: return np.array(res.data[0].embedding), self.total_token_count(res) except Exception as _e: @@ -368,20 +364,18 @@ class XinferenceEmbed(Base): class YoudaoEmbed(Base): + _FACTORY_NAME = "Youdao" _client = None def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): if not settings.LIGHTEN and not YoudaoEmbed._client: from BCEmbedding import EmbeddingModel as qanthing + try: logging.info("LOADING BCE...") - YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join( - get_home_cache_dir(), - "bce-embedding-base_v1")) + YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(get_home_cache_dir(), "bce-embedding-base_v1")) except Exception: - YoudaoEmbed._client = qanthing( - model_name_or_path=model_name.replace( - "maidalun1020", "InfiniFlow")) + YoudaoEmbed._client = qanthing(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow")) def encode(self, texts: list): batch_size = 10 @@ -390,7 +384,7 @@ class YoudaoEmbed(Base): for t in texts: token_count += num_tokens_from_string(t) for i in range(0, len(texts), batch_size): - embds = YoudaoEmbed._client.encode(texts[i:i + batch_size]) + embds = YoudaoEmbed._client.encode(texts[i : i + batch_size]) res.extend(embds) return np.array(res), token_count @@ -400,14 +394,11 @@ class YoudaoEmbed(Base): class JinaEmbed(Base): - def __init__(self, key, model_name="jina-embeddings-v3", - base_url="https://api.jina.ai/v1/embeddings"): + _FACTORY_NAME = "Jina" + def __init__(self, key, model_name="jina-embeddings-v3", base_url="https://api.jina.ai/v1/embeddings"): self.base_url = "https://api.jina.ai/v1/embeddings" - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {key}" - } + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name def encode(self, texts: list): @@ -416,11 +407,7 @@ class JinaEmbed(Base): ress = [] token_count = 0 for i in range(0, len(texts), batch_size): - data = { - "model": self.model_name, - "input": texts[i:i + batch_size], - 'encoding_type': 'float' - } + data = {"model": self.model_name, "input": texts[i : i + batch_size], "encoding_type": "float"} response = requests.post(self.base_url, headers=self.headers, json=data) try: res = response.json() @@ -435,50 +422,12 @@ class JinaEmbed(Base): return np.array(embds[0]), cnt -class InfinityEmbed(Base): - _model = None - - def __init__( - self, - model_names: list[str] = ("BAAI/bge-small-en-v1.5",), - engine_kwargs: dict = {}, - key = None, - ): - - from infinity_emb import EngineArgs - from infinity_emb.engine import AsyncEngineArray - - self._default_model = model_names[0] - self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names]) - - async def _embed(self, sentences: list[str], model_name: str = ""): - if not model_name: - model_name = self._default_model - engine = self.engine_array[model_name] - was_already_running = engine.is_running - if not was_already_running: - await engine.astart() - embeddings, usage = await engine.embed(sentences=sentences) - if not was_already_running: - await engine.astop() - return embeddings, usage - - def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]: - # Using the internal tokenizer to encode the texts and get the total - # number of tokens - embeddings, usage = asyncio.run(self._embed(texts, model_name)) - return np.array(embeddings), usage - - def encode_queries(self, text: str) -> tuple[np.ndarray, int]: - # Using the internal tokenizer to encode the texts and get the total - # number of tokens - return self.encode([text]) - - class MistralEmbed(Base): - def __init__(self, key, model_name="mistral-embed", - base_url=None): + _FACTORY_NAME = "Mistral" + + def __init__(self, key, model_name="mistral-embed", base_url=None): from mistralai.client import MistralClient + self.client = MistralClient(api_key=key) self.model_name = model_name @@ -488,8 +437,7 @@ class MistralEmbed(Base): ress = [] token_count = 0 for i in range(0, len(texts), batch_size): - res = self.client.embeddings(input=texts[i:i + batch_size], - model=self.model_name) + res = self.client.embeddings(input=texts[i : i + batch_size], model=self.model_name) try: ress.extend([d.embedding for d in res.data]) token_count += self.total_token_count(res) @@ -498,8 +446,7 @@ class MistralEmbed(Base): return np.array(ress), token_count def encode_queries(self, text): - res = self.client.embeddings(input=[truncate(text, 8196)], - model=self.model_name) + res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) try: return np.array(res.data[0].embedding), self.total_token_count(res) except Exception as _e: @@ -507,30 +454,31 @@ class MistralEmbed(Base): class BedrockEmbed(Base): - def __init__(self, key, model_name, - **kwargs): + _FACTORY_NAME = "Bedrock" + + def __init__(self, key, model_name, **kwargs): import boto3 - self.bedrock_ak = json.loads(key).get('bedrock_ak', '') - self.bedrock_sk = json.loads(key).get('bedrock_sk', '') - self.bedrock_region = json.loads(key).get('bedrock_region', '') + + self.bedrock_ak = json.loads(key).get("bedrock_ak", "") + self.bedrock_sk = json.loads(key).get("bedrock_sk", "") + self.bedrock_region = json.loads(key).get("bedrock_region", "") self.model_name = model_name - - if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': + + if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "": # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) - self.client = boto3.client('bedrock-runtime') + self.client = boto3.client("bedrock-runtime") else: - self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, - aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) + self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts] embeddings = [] token_count = 0 for text in texts: - if self.model_name.split('.')[0] == 'amazon': + if self.model_name.split(".")[0] == "amazon": body = {"inputText": text} - elif self.model_name.split('.')[0] == 'cohere': - body = {"texts": [text], "input_type": 'search_document'} + elif self.model_name.split(".")[0] == "cohere": + body = {"texts": [text], "input_type": "search_document"} response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) try: @@ -545,10 +493,10 @@ class BedrockEmbed(Base): def encode_queries(self, text): embeddings = [] token_count = num_tokens_from_string(text) - if self.model_name.split('.')[0] == 'amazon': + if self.model_name.split(".")[0] == "amazon": body = {"inputText": truncate(text, 8196)} - elif self.model_name.split('.')[0] == 'cohere': - body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'} + elif self.model_name.split(".")[0] == "cohere": + body = {"texts": [truncate(text, 8196)], "input_type": "search_query"} response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) try: @@ -561,11 +509,12 @@ class BedrockEmbed(Base): class GeminiEmbed(Base): - def __init__(self, key, model_name='models/text-embedding-004', - **kwargs): + _FACTORY_NAME = "Gemini" + + def __init__(self, key, model_name="models/text-embedding-004", **kwargs): self.key = key - self.model_name = 'models/' + model_name - + self.model_name = "models/" + model_name + def encode(self, texts: list): texts = [truncate(t, 2048) for t in texts] token_count = sum(num_tokens_from_string(text) for text in texts) @@ -573,35 +522,27 @@ class GeminiEmbed(Base): batch_size = 16 ress = [] for i in range(0, len(texts), batch_size): - result = genai.embed_content( - model=self.model_name, - content=texts[i: i + batch_size], - task_type="retrieval_document", - title="Embedding of single string") + result = genai.embed_content(model=self.model_name, content=texts[i : i + batch_size], task_type="retrieval_document", title="Embedding of single string") try: - ress.extend(result['embedding']) + ress.extend(result["embedding"]) except Exception as _e: log_exception(_e, result) - return np.array(ress),token_count - + return np.array(ress), token_count + def encode_queries(self, text): genai.configure(api_key=self.key) - result = genai.embed_content( - model=self.model_name, - content=truncate(text,2048), - task_type="retrieval_document", - title="Embedding of single string") + result = genai.embed_content(model=self.model_name, content=truncate(text, 2048), task_type="retrieval_document", title="Embedding of single string") token_count = num_tokens_from_string(text) try: - return np.array(result['embedding']), token_count + return np.array(result["embedding"]), token_count except Exception as _e: log_exception(_e, result) class NvidiaEmbed(Base): - def __init__( - self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings" - ): + _FACTORY_NAME = "NVIDIA" + + def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"): if not base_url: base_url = "https://integrate.api.nvidia.com/v1/embeddings" self.api_key = key @@ -645,6 +586,8 @@ class NvidiaEmbed(Base): class LmStudioEmbed(LocalAIEmbed): + _FACTORY_NAME = "LM-Studio" + def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("Local llm url cannot be None") @@ -654,6 +597,8 @@ class LmStudioEmbed(LocalAIEmbed): class OpenAI_APIEmbed(OpenAIEmbed): + _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] + def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") @@ -663,6 +608,8 @@ class OpenAI_APIEmbed(OpenAIEmbed): class CoHereEmbed(Base): + _FACTORY_NAME = "Cohere" + def __init__(self, key, model_name, base_url=None): from cohere import Client @@ -701,6 +648,8 @@ class CoHereEmbed(Base): class TogetherAIEmbed(OpenAIEmbed): + _FACTORY_NAME = "TogetherAI" + def __init__(self, key, model_name, base_url="https://api.together.xyz/v1"): if not base_url: base_url = "https://api.together.xyz/v1" @@ -708,6 +657,8 @@ class TogetherAIEmbed(OpenAIEmbed): class PerfXCloudEmbed(OpenAIEmbed): + _FACTORY_NAME = "PerfXCloud" + def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"): if not base_url: base_url = "https://cloud.perfxlab.cn/v1" @@ -715,6 +666,8 @@ class PerfXCloudEmbed(OpenAIEmbed): class UpstageEmbed(OpenAIEmbed): + _FACTORY_NAME = "Upstage" + def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar"): if not base_url: base_url = "https://api.upstage.ai/v1/solar" @@ -722,6 +675,8 @@ class UpstageEmbed(OpenAIEmbed): class SILICONFLOWEmbed(Base): + _FACTORY_NAME = "SILICONFLOW" + def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/embeddings"): if not base_url: base_url = "https://api.siliconflow.cn/v1/embeddings" @@ -769,6 +724,8 @@ class SILICONFLOWEmbed(Base): class ReplicateEmbed(Base): + _FACTORY_NAME = "Replicate" + def __init__(self, key, model_name, base_url=None): from replicate.client import Client @@ -790,6 +747,8 @@ class ReplicateEmbed(Base): class BaiduYiyanEmbed(Base): + _FACTORY_NAME = "BaiduYiyan" + def __init__(self, key, model_name, base_url=None): import qianfan @@ -821,6 +780,8 @@ class BaiduYiyanEmbed(Base): class VoyageEmbed(Base): + _FACTORY_NAME = "Voyage AI" + def __init__(self, key, model_name, base_url=None): import voyageai @@ -832,9 +793,7 @@ class VoyageEmbed(Base): ress = [] token_count = 0 for i in range(0, len(texts), batch_size): - res = self.client.embed( - texts=texts[i : i + batch_size], model=self.model_name, input_type="document" - ) + res = self.client.embed(texts=texts[i : i + batch_size], model=self.model_name, input_type="document") try: ress.extend(res.embeddings) token_count += res.total_tokens @@ -843,9 +802,7 @@ class VoyageEmbed(Base): return np.array(ress), token_count def encode_queries(self, text): - res = self.client.embed( - texts=text, model=self.model_name, input_type="query" - ) + res = self.client.embed(texts=text, model=self.model_name, input_type="query") try: return np.array(res.embeddings)[0], res.total_tokens except Exception as _e: @@ -853,6 +810,8 @@ class VoyageEmbed(Base): class HuggingFaceEmbed(Base): + _FACTORY_NAME = "HuggingFace" + def __init__(self, key, model_name, base_url=None): if not model_name: raise ValueError("Model name cannot be None") @@ -863,11 +822,7 @@ class HuggingFaceEmbed(Base): def encode(self, texts: list): embeddings = [] for text in texts: - response = requests.post( - f"{self.base_url}/embed", - json={"inputs": text}, - headers={'Content-Type': 'application/json'} - ) + response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}) if response.status_code == 200: embedding = response.json() embeddings.append(embedding[0]) @@ -876,11 +831,7 @@ class HuggingFaceEmbed(Base): return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) def encode_queries(self, text): - response = requests.post( - f"{self.base_url}/embed", - json={"inputs": text}, - headers={'Content-Type': 'application/json'} - ) + response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}) if response.status_code == 200: embedding = response.json() return np.array(embedding[0]), num_tokens_from_string(text) @@ -889,15 +840,19 @@ class HuggingFaceEmbed(Base): class VolcEngineEmbed(OpenAIEmbed): + _FACTORY_NAME = "VolcEngine" + def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): if not base_url: base_url = "https://ark.cn-beijing.volces.com/api/v3" - ark_api_key = json.loads(key).get('ark_api_key', '') - model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') - super().__init__(ark_api_key,model_name,base_url) + ark_api_key = json.loads(key).get("ark_api_key", "") + model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "") + super().__init__(ark_api_key, model_name, base_url) class GPUStackEmbed(OpenAIEmbed): + _FACTORY_NAME = "GPUStack" + def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") @@ -908,6 +863,8 @@ class GPUStackEmbed(OpenAIEmbed): class NovitaEmbed(SILICONFLOWEmbed): + _FACTORY_NAME = "NovitaAI" + def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/embeddings"): if not base_url: base_url = "https://api.novita.ai/v3/openai/embeddings" @@ -915,7 +872,9 @@ class NovitaEmbed(SILICONFLOWEmbed): class GiteeEmbed(SILICONFLOWEmbed): + _FACTORY_NAME = "GiteeAI" + def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/embeddings"): if not base_url: base_url = "https://ai.gitee.com/v1/embeddings" - super().__init__(key, model_name, base_url) \ No newline at end of file + super().__init__(key, model_name, base_url) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index fafab7ee0..88fda1478 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -13,24 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json +import os import re import threading +from abc import ABC from collections.abc import Iterable from urllib.parse import urljoin -import requests import httpx -from huggingface_hub import snapshot_download -import os -from abc import ABC import numpy as np +import requests +from huggingface_hub import snapshot_download from yarl import URL from api import settings from api.utils.file_utils import get_home_cache_dir from api.utils.log_utils import log_exception from rag.utils import num_tokens_from_string, truncate -import json def sigmoid(x): @@ -57,6 +57,7 @@ class Base(ABC): class DefaultRerank(Base): + _FACTORY_NAME = "BAAI" _model = None _model_lock = threading.Lock() @@ -75,17 +76,13 @@ class DefaultRerank(Base): if not settings.LIGHTEN and not DefaultRerank._model: import torch from FlagEmbedding import FlagReranker + with DefaultRerank._model_lock: if not DefaultRerank._model: try: - DefaultRerank._model = FlagReranker( - os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), - use_fp16=torch.cuda.is_available()) + DefaultRerank._model = FlagReranker(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), use_fp16=torch.cuda.is_available()) except Exception: - model_dir = snapshot_download(repo_id=model_name, - local_dir=os.path.join(get_home_cache_dir(), - re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), - local_dir_use_symlinks=False) + model_dir = snapshot_download(repo_id=model_name, local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)), local_dir_use_symlinks=False) DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available()) self._model = DefaultRerank._model self._dynamic_batch_size = 8 @@ -94,6 +91,7 @@ class DefaultRerank(Base): def torch_empty_cache(self): try: import torch + torch.cuda.empty_cache() except Exception as e: print(f"Error emptying cache: {e}") @@ -112,7 +110,7 @@ class DefaultRerank(Base): while retry_count < max_retries: try: # call subclass implemented batch processing calculation - batch_scores = self._compute_batch_scores(pairs[i:i + current_batch]) + batch_scores = self._compute_batch_scores(pairs[i : i + current_batch]) res.extend(batch_scores) i += current_batch self._dynamic_batch_size = min(self._dynamic_batch_size * 2, 8) @@ -152,23 +150,16 @@ class DefaultRerank(Base): class JinaRerank(Base): - def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", - base_url="https://api.jina.ai/v1/rerank"): + _FACTORY_NAME = "Jina" + + def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"): self.base_url = "https://api.jina.ai/v1/rerank" - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {key}" - } + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name def similarity(self, query: str, texts: list): texts = [truncate(t, 8196) for t in texts] - data = { - "model": self.model_name, - "query": query, - "documents": texts, - "top_n": len(texts) - } + data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)} res = requests.post(self.base_url, headers=self.headers, json=data).json() rank = np.zeros(len(texts), dtype=float) try: @@ -180,22 +171,20 @@ class JinaRerank(Base): class YoudaoRerank(DefaultRerank): + _FACTORY_NAME = "Youdao" _model = None _model_lock = threading.Lock() def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): if not settings.LIGHTEN and not YoudaoRerank._model: from BCEmbedding import RerankerModel + with YoudaoRerank._model_lock: if not YoudaoRerank._model: try: - YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( - get_home_cache_dir(), - re.sub(r"^[a-zA-Z0-9]+/", "", model_name))) + YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name))) except Exception: - YoudaoRerank._model = RerankerModel( - model_name_or_path=model_name.replace( - "maidalun1020", "InfiniFlow")) + YoudaoRerank._model = RerankerModel(model_name_or_path=model_name.replace("maidalun1020", "InfiniFlow")) self._model = YoudaoRerank._model self._dynamic_batch_size = 8 @@ -212,6 +201,8 @@ class YoudaoRerank(DefaultRerank): class XInferenceRerank(Base): + _FACTORY_NAME = "Xinference" + def __init__(self, key="x", model_name="", base_url=""): if base_url.find("/v1") == -1: base_url = urljoin(base_url, "/v1/rerank") @@ -219,10 +210,7 @@ class XInferenceRerank(Base): base_url = urljoin(base_url, "/v1/rerank") self.model_name = model_name self.base_url = base_url - self.headers = { - "Content-Type": "application/json", - "accept": "application/json" - } + self.headers = {"Content-Type": "application/json", "accept": "application/json"} if key and key != "x": self.headers["Authorization"] = f"Bearer {key}" @@ -233,13 +221,7 @@ class XInferenceRerank(Base): token_count = 0 for _, t in pairs: token_count += num_tokens_from_string(t) - data = { - "model": self.model_name, - "query": query, - "return_documents": "true", - "return_len": "true", - "documents": texts - } + data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts} res = requests.post(self.base_url, headers=self.headers, json=data).json() rank = np.zeros(len(texts), dtype=float) try: @@ -251,15 +233,14 @@ class XInferenceRerank(Base): class LocalAIRerank(Base): + _FACTORY_NAME = "LocalAI" + def __init__(self, key, model_name, base_url): if base_url.find("/rerank") == -1: self.base_url = urljoin(base_url, "/rerank") else: self.base_url = base_url - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {key}" - } + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name.split("___")[0] def similarity(self, query: str, texts: list): @@ -296,16 +277,15 @@ class LocalAIRerank(Base): class NvidiaRerank(Base): - def __init__( - self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/" - ): + _FACTORY_NAME = "NVIDIA" + + def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"): if not base_url: base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/" self.model_name = model_name if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3": - self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking" - ) + self.base_url = urljoin(base_url, "nv-rerankqa-mistral-4b-v3/reranking") if self.model_name == "nvidia/rerank-qa-mistral-4b": self.base_url = urljoin(base_url, "reranking") @@ -318,9 +298,7 @@ class NvidiaRerank(Base): } def similarity(self, query: str, texts: list): - token_count = num_tokens_from_string(query) + sum( - [num_tokens_from_string(t) for t in texts] - ) + token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts]) data = { "model": self.model_name, "query": {"text": query}, @@ -339,6 +317,8 @@ class NvidiaRerank(Base): class LmStudioRerank(Base): + _FACTORY_NAME = "LM-Studio" + def __init__(self, key, model_name, base_url): pass @@ -347,15 +327,14 @@ class LmStudioRerank(Base): class OpenAI_APIRerank(Base): + _FACTORY_NAME = "OpenAI-API-Compatible" + def __init__(self, key, model_name, base_url): if base_url.find("/rerank") == -1: self.base_url = urljoin(base_url, "/rerank") else: self.base_url = base_url - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {key}" - } + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name.split("___")[0] def similarity(self, query: str, texts: list): @@ -392,6 +371,8 @@ class OpenAI_APIRerank(Base): class CoHereRerank(Base): + _FACTORY_NAME = ["Cohere", "VLLM"] + def __init__(self, key, model_name, base_url=None): from cohere import Client @@ -399,9 +380,7 @@ class CoHereRerank(Base): self.model_name = model_name.split("___")[0] def similarity(self, query: str, texts: list): - token_count = num_tokens_from_string(query) + sum( - [num_tokens_from_string(t) for t in texts] - ) + token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts]) res = self.client.rerank( model=self.model_name, query=query, @@ -419,6 +398,8 @@ class CoHereRerank(Base): class TogetherAIRerank(Base): + _FACTORY_NAME = "TogetherAI" + def __init__(self, key, model_name, base_url): pass @@ -427,9 +408,9 @@ class TogetherAIRerank(Base): class SILICONFLOWRerank(Base): - def __init__( - self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank" - ): + _FACTORY_NAME = "SILICONFLOW" + + def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"): if not base_url: base_url = "https://api.siliconflow.cn/v1/rerank" self.model_name = model_name @@ -450,9 +431,7 @@ class SILICONFLOWRerank(Base): "max_chunks_per_doc": 1024, "overlap_tokens": 80, } - response = requests.post( - self.base_url, json=payload, headers=self.headers - ).json() + response = requests.post(self.base_url, json=payload, headers=self.headers).json() rank = np.zeros(len(texts), dtype=float) try: for d in response["results"]: @@ -466,6 +445,8 @@ class SILICONFLOWRerank(Base): class BaiduYiyanRerank(Base): + _FACTORY_NAME = "BaiduYiyan" + def __init__(self, key, model_name, base_url=None): from qianfan.resources import Reranker @@ -492,6 +473,8 @@ class BaiduYiyanRerank(Base): class VoyageRerank(Base): + _FACTORY_NAME = "Voyage AI" + def __init__(self, key, model_name, base_url=None): import voyageai @@ -502,9 +485,7 @@ class VoyageRerank(Base): rank = np.zeros(len(texts), dtype=float) if not texts: return rank, 0 - res = self.client.rerank( - query=query, documents=texts, model=self.model_name, top_k=len(texts) - ) + res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts)) try: for r in res.results: rank[r.index] = r.relevance_score @@ -514,22 +495,20 @@ class VoyageRerank(Base): class QWenRerank(Base): - def __init__(self, key, model_name='gte-rerank', base_url=None, **kwargs): + _FACTORY_NAME = "Tongyi-Qianwen" + + def __init__(self, key, model_name="gte-rerank", base_url=None, **kwargs): import dashscope + self.api_key = key self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name def similarity(self, query: str, texts: list): - import dashscope from http import HTTPStatus - resp = dashscope.TextReRank.call( - api_key=self.api_key, - model=self.model_name, - query=query, - documents=texts, - top_n=len(texts), - return_documents=False - ) + + import dashscope + + resp = dashscope.TextReRank.call(api_key=self.api_key, model=self.model_name, query=query, documents=texts, top_n=len(texts), return_documents=False) rank = np.zeros(len(texts), dtype=float) if resp.status_code == HTTPStatus.OK: try: @@ -543,6 +522,8 @@ class QWenRerank(Base): class HuggingfaceRerank(DefaultRerank): + _FACTORY_NAME = "HuggingFace" + @staticmethod def post(query: str, texts: list, url="127.0.0.1"): exc = None @@ -550,9 +531,9 @@ class HuggingfaceRerank(DefaultRerank): batch_size = 8 for i in range(0, len(texts), batch_size): try: - res = requests.post(f"http://{url}/rerank", headers={"Content-Type": "application/json"}, - json={"query": query, "texts": texts[i: i + batch_size], - "raw_scores": False, "truncate": True}) + res = requests.post( + f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True} + ) for o in res.json(): scores[o["index"] + i] = o["score"] @@ -577,9 +558,9 @@ class HuggingfaceRerank(DefaultRerank): class GPUStackRerank(Base): - def __init__( - self, key, model_name, base_url - ): + _FACTORY_NAME = "GPUStack" + + def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") @@ -600,9 +581,7 @@ class GPUStackRerank(Base): } try: - response = requests.post( - self.base_url, json=payload, headers=self.headers - ) + response = requests.post(self.base_url, json=payload, headers=self.headers) response.raise_for_status() response_json = response.json() @@ -623,11 +602,12 @@ class GPUStackRerank(Base): ) except httpx.HTTPStatusError as e: - raise ValueError( - f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}") + raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}") class NovitaRerank(JinaRerank): + _FACTORY_NAME = "NovitaAI" + def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai/rerank"): if not base_url: base_url = "https://api.novita.ai/v3/openai/rerank" @@ -635,7 +615,9 @@ class NovitaRerank(JinaRerank): class GiteeRerank(JinaRerank): + _FACTORY_NAME = "GiteeAI" + def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/rerank"): if not base_url: base_url = "https://ai.gitee.com/v1/rerank" - super().__init__(key, model_name, base_url) \ No newline at end of file + super().__init__(key, model_name, base_url) diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index 193f0e14d..7d6c24b76 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os -import requests -from openai.lib.azure import AzureOpenAI -import io -from abc import ABC -from openai import OpenAI -import json -from rag.utils import num_tokens_from_string import base64 +import io +import json +import os import re +from abc import ABC + +import requests +from openai import OpenAI +from openai.lib.azure import AzureOpenAI + +from rag.utils import num_tokens_from_string class Base(ABC): @@ -30,11 +32,7 @@ class Base(ABC): pass def transcription(self, audio, **kwargs): - transcription = self.client.audio.transcriptions.create( - model=self.model_name, - file=audio, - response_format="text" - ) + transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text") return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) def audio2base64(self, audio): @@ -46,6 +44,8 @@ class Base(ABC): class GPTSeq2txt(Base): + _FACTORY_NAME = "OpenAI" + def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"): if not base_url: base_url = "https://api.openai.com/v1" @@ -54,31 +54,34 @@ class GPTSeq2txt(Base): class QWenSeq2txt(Base): + _FACTORY_NAME = "Tongyi-Qianwen" + def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs): import dashscope + dashscope.api_key = key self.model_name = model_name def transcription(self, audio, format): from http import HTTPStatus + from dashscope.audio.asr import Recognition - recognition = Recognition(model=self.model_name, - format=format, - sample_rate=16000, - callback=None) + recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None) result = recognition.call(audio) ans = "" if result.status_code == HTTPStatus.OK: for sentence in result.get_sentence(): - ans += sentence.text.decode('utf-8') + '\n' + ans += sentence.text.decode("utf-8") + "\n" return ans, num_tokens_from_string(ans) return "**ERROR**: " + result.message, 0 class AzureSeq2txt(Base): + _FACTORY_NAME = "Azure-OpenAI" + def __init__(self, key, model_name, lang="Chinese", **kwargs): self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") self.model_name = model_name @@ -86,43 +89,33 @@ class AzureSeq2txt(Base): class XinferenceSeq2txt(Base): + _FACTORY_NAME = "Xinference" + def __init__(self, key, model_name="whisper-small", **kwargs): - self.base_url = kwargs.get('base_url', None) + self.base_url = kwargs.get("base_url", None) self.model_name = model_name self.key = key def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): if isinstance(audio, str): - audio_file = open(audio, 'rb') + audio_file = open(audio, "rb") audio_data = audio_file.read() audio_file_name = audio.split("/")[-1] else: audio_data = audio audio_file_name = "audio.wav" - payload = { - "model": self.model_name, - "language": language, - "prompt": prompt, - "response_format": response_format, - "temperature": temperature - } + payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature} - files = { - "file": (audio_file_name, audio_data, 'audio/wav') - } + files = {"file": (audio_file_name, audio_data, "audio/wav")} try: - response = requests.post( - f"{self.base_url}/v1/audio/transcriptions", - files=files, - data=payload - ) + response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload) response.raise_for_status() result = response.json() - if 'text' in result: - transcription_text = result['text'].strip() + if "text" in result: + transcription_text = result["text"].strip() return transcription_text, num_tokens_from_string(transcription_text) else: return "**ERROR**: Failed to retrieve transcription.", 0 @@ -132,11 +125,11 @@ class XinferenceSeq2txt(Base): class TencentCloudSeq2txt(Base): - def __init__( - self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" - ): - from tencentcloud.common import credential + _FACTORY_NAME = "Tencent Cloud" + + def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"): from tencentcloud.asr.v20190614 import asr_client + from tencentcloud.common import credential key = json.loads(key) sid = key.get("tencent_cloud_sid", "") @@ -146,11 +139,12 @@ class TencentCloudSeq2txt(Base): self.model_name = model_name def transcription(self, audio, max_retries=60, retry_interval=5): + import time + + from tencentcloud.asr.v20190614 import models from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( TencentCloudSDKException, ) - from tencentcloud.asr.v20190614 import models - import time b64 = self.audio2base64(audio) try: @@ -174,9 +168,7 @@ class TencentCloudSeq2txt(Base): while retries < max_retries: resp = self.client.DescribeTaskStatus(req) if resp.Data.StatusStr == "success": - text = re.sub( - r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result - ).strip() + text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip() return text, num_tokens_from_string(text) elif resp.Data.StatusStr == "failed": return ( @@ -195,6 +187,8 @@ class TencentCloudSeq2txt(Base): class GPUStackSeq2txt(Base): + _FACTORY_NAME = "GPUStack" + def __init__(self, key, model_name, base_url): if not base_url: raise ValueError("url cannot be None") @@ -206,8 +200,11 @@ class GPUStackSeq2txt(Base): class GiteeSeq2txt(Base): + _FACTORY_NAME = "GiteeAI" + def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"): if not base_url: base_url = "https://ai.gitee.com/v1/" self.client = OpenAI(api_key=key, base_url=base_url) - self.model_name = model_name \ No newline at end of file + self.model_name = model_name + diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 7111b2f60..b2333b426 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -70,10 +70,12 @@ class Base(ABC): pass def normalize_text(self, text): - return re.sub(r'(\*\*|##\d+\$\$|#)', '', text) + return re.sub(r"(\*\*|##\d+\$\$|#)", "", text) class FishAudioTTS(Base): + _FACTORY_NAME = "Fish Audio" + def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"): if not base_url: base_url = "https://api.fish.audio/v1/tts" @@ -94,13 +96,11 @@ class FishAudioTTS(Base): with httpx.Client() as client: try: with client.stream( - method="POST", - url=self.base_url, - content=ormsgpack.packb( - request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC - ), - headers=self.headers, - timeout=None, + method="POST", + url=self.base_url, + content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), + headers=self.headers, + timeout=None, ) as response: if response.status_code == HTTPStatus.OK: for chunk in response.iter_bytes(): @@ -115,6 +115,8 @@ class FishAudioTTS(Base): class QwenTTS(Base): + _FACTORY_NAME = "Tongyi-Qianwen" + def __init__(self, key, model_name, base_url=""): import dashscope @@ -122,10 +124,11 @@ class QwenTTS(Base): dashscope.api_key = key def tts(self, text): - from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse - from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult from collections import deque + from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse + from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer + class Callback(ResultCallback): def __init__(self) -> None: self.dque = deque() @@ -159,10 +162,7 @@ class QwenTTS(Base): text = self.normalize_text(text) callback = Callback() - SpeechSynthesizer.call(model=self.model_name, - text=text, - callback=callback, - format="mp3") + SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3") try: for data in callback._run(): yield data @@ -173,24 +173,19 @@ class QwenTTS(Base): class OpenAITTS(Base): + _FACTORY_NAME = "OpenAI" + def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"): if not base_url: base_url = "https://api.openai.com/v1" self.api_key = key self.model_name = model_name self.base_url = base_url - self.headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } + self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} def tts(self, text, voice="alloy"): text = self.normalize_text(text) - payload = { - "model": self.model_name, - "voice": voice, - "input": text - } + payload = {"model": self.model_name, "voice": voice, "input": text} response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True) @@ -201,7 +196,8 @@ class OpenAITTS(Base): yield chunk -class SparkTTS: +class SparkTTS(Base): + _FACTORY_NAME = "XunFei Spark" STATUS_FIRST_FRAME = 0 STATUS_CONTINUE_FRAME = 1 STATUS_LAST_FRAME = 2 @@ -219,29 +215,23 @@ class SparkTTS: # 生成url def create_url(self): - url = 'wss://tts-api.xfyun.cn/v2/tts' + url = "wss://tts-api.xfyun.cn/v2/tts" now = datetime.now() date = format_date_time(mktime(now.timetuple())) signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" signature_origin += "date: " + date + "\n" signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" - signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') - authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( - self.APIKey, "hmac-sha256", "host date request-line", signature_sha) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - v = { - "authorization": authorization, - "date": date, - "host": "ws-api.xfyun.cn" - } - url = url + '?' + urlencode(v) + signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") + authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") + v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"} + url = url + "?" + urlencode(v) return url def tts(self, text): BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"} - Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')} + Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")} CommonArgs = {"app_id": self.APPID} audio_queue = self.audio_queue model_name = self.model_name @@ -273,9 +263,7 @@ class SparkTTS: def on_open(self, ws): def run(*args): - d = {"common": CommonArgs, - "business": BusinessArgs, - "data": Data} + d = {"common": CommonArgs, "business": BusinessArgs, "data": Data} ws.send(json.dumps(d)) thread.start_new_thread(run, ()) @@ -283,44 +271,32 @@ class SparkTTS: wsUrl = self.create_url() websocket.enableTrace(False) a = Callback() - ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, - on_message=a.on_message) + ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message) status_code = 0 ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) while True: audio_chunk = self.audio_queue.get() if audio_chunk is None: if status_code == 0: - raise Exception( - f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.") + raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.") else: break status_code = 1 yield audio_chunk -class XinferenceTTS: +class XinferenceTTS(Base): + _FACTORY_NAME = "Xinference" + def __init__(self, key, model_name, **kwargs): self.base_url = kwargs.get("base_url", None) self.model_name = model_name - self.headers = { - "accept": "application/json", - "Content-Type": "application/json" - } + self.headers = {"accept": "application/json", "Content-Type": "application/json"} def tts(self, text, voice="中文女", stream=True): - payload = { - "model": self.model_name, - "input": text, - "voice": voice - } + payload = {"model": self.model_name, "input": text, "voice": voice} - response = requests.post( - f"{self.base_url}/v1/audio/speech", - headers=self.headers, - json=payload, - stream=stream - ) + response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream) if response.status_code != 200: raise Exception(f"**Error**: {response.status_code}, {response.text}") @@ -332,22 +308,16 @@ class XinferenceTTS: class OllamaTTS(Base): def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"): - if not base_url: + if not base_url: base_url = "https://api.ollama.ai/v1" self.model_name = model_name self.base_url = base_url - self.headers = { - "Content-Type": "application/json" - } + self.headers = {"Content-Type": "application/json"} if key and key != "x": self.headers["Authorization"] = f"Bear {key}" def tts(self, text, voice="standard-voice"): - payload = { - "model": self.model_name, - "voice": voice, - "input": text - } + payload = {"model": self.model_name, "voice": voice, "input": text} response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True) @@ -359,30 +329,19 @@ class OllamaTTS(Base): yield chunk -class GPUStackTTS: +class GPUStackTTS(Base): + _FACTORY_NAME = "GPUStack" + def __init__(self, key, model_name, **kwargs): self.base_url = kwargs.get("base_url", None) self.api_key = key self.model_name = model_name - self.headers = { - "accept": "application/json", - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" - } + self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def tts(self, text, voice="Chinese Female", stream=True): - payload = { - "model": self.model_name, - "input": text, - "voice": voice - } + payload = {"model": self.model_name, "input": text, "voice": voice} - response = requests.post( - f"{self.base_url}/v1/audio/speech", - headers=self.headers, - json=payload, - stream=stream - ) + response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream) if response.status_code != 200: raise Exception(f"**Error**: {response.status_code}, {response.text}") @@ -393,16 +352,15 @@ class GPUStackTTS: class SILICONFLOWTTS(Base): + _FACTORY_NAME = "SILICONFLOW" + def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"): if not base_url: base_url = "https://api.siliconflow.cn/v1" self.api_key = key self.model_name = model_name self.base_url = base_url - self.headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } + self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} def tts(self, text, voice="anna"): text = self.normalize_text(text) @@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base): "sample_rate": 123, "stream": True, "speed": 1, - "gain": 0 + "gain": 0, } response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)