mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-30 16:45:35 +08:00
Refa: automatic LLMs registration (#8651)
### What problem does this PR solve? Support automatic LLMs registration. ### Type of change - [x] Refactoring
This commit is contained in:
@ -15,289 +15,53 @@
|
||||
#
|
||||
# AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
|
||||
#
|
||||
from .embedding_model import (
|
||||
OllamaEmbed,
|
||||
LocalAIEmbed,
|
||||
OpenAIEmbed,
|
||||
AzureEmbed,
|
||||
XinferenceEmbed,
|
||||
QWenEmbed,
|
||||
ZhipuEmbed,
|
||||
FastEmbed,
|
||||
YoudaoEmbed,
|
||||
BaiChuanEmbed,
|
||||
JinaEmbed,
|
||||
DefaultEmbedding,
|
||||
MistralEmbed,
|
||||
BedrockEmbed,
|
||||
GeminiEmbed,
|
||||
NvidiaEmbed,
|
||||
LmStudioEmbed,
|
||||
OpenAI_APIEmbed,
|
||||
CoHereEmbed,
|
||||
TogetherAIEmbed,
|
||||
PerfXCloudEmbed,
|
||||
UpstageEmbed,
|
||||
SILICONFLOWEmbed,
|
||||
ReplicateEmbed,
|
||||
BaiduYiyanEmbed,
|
||||
VoyageEmbed,
|
||||
HuggingFaceEmbed,
|
||||
VolcEngineEmbed,
|
||||
GPUStackEmbed,
|
||||
NovitaEmbed,
|
||||
GiteeEmbed
|
||||
)
|
||||
from .chat_model import (
|
||||
GptTurbo,
|
||||
AzureChat,
|
||||
ZhipuChat,
|
||||
QWenChat,
|
||||
OllamaChat,
|
||||
LocalAIChat,
|
||||
XinferenceChat,
|
||||
MoonshotChat,
|
||||
DeepSeekChat,
|
||||
VolcEngineChat,
|
||||
BaiChuanChat,
|
||||
MiniMaxChat,
|
||||
MistralChat,
|
||||
GeminiChat,
|
||||
BedrockChat,
|
||||
GroqChat,
|
||||
OpenRouterChat,
|
||||
StepFunChat,
|
||||
NvidiaChat,
|
||||
LmStudioChat,
|
||||
OpenAI_APIChat,
|
||||
CoHereChat,
|
||||
LeptonAIChat,
|
||||
TogetherAIChat,
|
||||
PerfXCloudChat,
|
||||
UpstageChat,
|
||||
NovitaAIChat,
|
||||
SILICONFLOWChat,
|
||||
PPIOChat,
|
||||
YiChat,
|
||||
ReplicateChat,
|
||||
HunyuanChat,
|
||||
SparkChat,
|
||||
BaiduYiyanChat,
|
||||
AnthropicChat,
|
||||
GoogleChat,
|
||||
HuggingFaceChat,
|
||||
GPUStackChat,
|
||||
ModelScopeChat,
|
||||
GiteeChat
|
||||
)
|
||||
|
||||
from .cv_model import (
|
||||
GptV4,
|
||||
AzureGptV4,
|
||||
OllamaCV,
|
||||
XinferenceCV,
|
||||
QWenCV,
|
||||
Zhipu4V,
|
||||
LocalCV,
|
||||
GeminiCV,
|
||||
OpenRouterCV,
|
||||
LocalAICV,
|
||||
NvidiaCV,
|
||||
LmStudioCV,
|
||||
StepFunCV,
|
||||
OpenAI_APICV,
|
||||
TogetherAICV,
|
||||
YiCV,
|
||||
HunyuanCV,
|
||||
AnthropicCV,
|
||||
SILICONFLOWCV,
|
||||
GPUStackCV,
|
||||
GoogleCV,
|
||||
)
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from .rerank_model import (
|
||||
LocalAIRerank,
|
||||
DefaultRerank,
|
||||
JinaRerank,
|
||||
YoudaoRerank,
|
||||
XInferenceRerank,
|
||||
NvidiaRerank,
|
||||
LmStudioRerank,
|
||||
OpenAI_APIRerank,
|
||||
CoHereRerank,
|
||||
TogetherAIRerank,
|
||||
SILICONFLOWRerank,
|
||||
BaiduYiyanRerank,
|
||||
VoyageRerank,
|
||||
QWenRerank,
|
||||
GPUStackRerank,
|
||||
HuggingfaceRerank,
|
||||
NovitaRerank,
|
||||
GiteeRerank
|
||||
)
|
||||
ChatModel = globals().get("ChatModel", {})
|
||||
CvModel = globals().get("CvModel", {})
|
||||
EmbeddingModel = globals().get("EmbeddingModel", {})
|
||||
RerankModel = globals().get("RerankModel", {})
|
||||
Seq2txtModel = globals().get("Seq2txtModel", {})
|
||||
TTSModel = globals().get("TTSModel", {})
|
||||
|
||||
from .sequence2txt_model import (
|
||||
GPTSeq2txt,
|
||||
QWenSeq2txt,
|
||||
AzureSeq2txt,
|
||||
XinferenceSeq2txt,
|
||||
TencentCloudSeq2txt,
|
||||
GPUStackSeq2txt,
|
||||
GiteeSeq2txt
|
||||
)
|
||||
|
||||
from .tts_model import (
|
||||
FishAudioTTS,
|
||||
QwenTTS,
|
||||
OpenAITTS,
|
||||
SparkTTS,
|
||||
XinferenceTTS,
|
||||
GPUStackTTS,
|
||||
SILICONFLOWTTS,
|
||||
)
|
||||
|
||||
EmbeddingModel = {
|
||||
"Ollama": OllamaEmbed,
|
||||
"LocalAI": LocalAIEmbed,
|
||||
"OpenAI": OpenAIEmbed,
|
||||
"Azure-OpenAI": AzureEmbed,
|
||||
"Xinference": XinferenceEmbed,
|
||||
"Tongyi-Qianwen": QWenEmbed,
|
||||
"ZHIPU-AI": ZhipuEmbed,
|
||||
"FastEmbed": FastEmbed,
|
||||
"Youdao": YoudaoEmbed,
|
||||
"BaiChuan": BaiChuanEmbed,
|
||||
"Jina": JinaEmbed,
|
||||
"BAAI": DefaultEmbedding,
|
||||
"Mistral": MistralEmbed,
|
||||
"Bedrock": BedrockEmbed,
|
||||
"Gemini": GeminiEmbed,
|
||||
"NVIDIA": NvidiaEmbed,
|
||||
"LM-Studio": LmStudioEmbed,
|
||||
"OpenAI-API-Compatible": OpenAI_APIEmbed,
|
||||
"VLLM": OpenAI_APIEmbed,
|
||||
"Cohere": CoHereEmbed,
|
||||
"TogetherAI": TogetherAIEmbed,
|
||||
"PerfXCloud": PerfXCloudEmbed,
|
||||
"Upstage": UpstageEmbed,
|
||||
"SILICONFLOW": SILICONFLOWEmbed,
|
||||
"Replicate": ReplicateEmbed,
|
||||
"BaiduYiyan": BaiduYiyanEmbed,
|
||||
"Voyage AI": VoyageEmbed,
|
||||
"HuggingFace": HuggingFaceEmbed,
|
||||
"VolcEngine": VolcEngineEmbed,
|
||||
"GPUStack": GPUStackEmbed,
|
||||
"NovitaAI": NovitaEmbed,
|
||||
"GiteeAI": GiteeEmbed
|
||||
MODULE_MAPPING = {
|
||||
"chat_model": ChatModel,
|
||||
"cv_model": CvModel,
|
||||
"embedding_model": EmbeddingModel,
|
||||
"rerank_model": RerankModel,
|
||||
"sequence2txt_model": Seq2txtModel,
|
||||
"tts_model": TTSModel,
|
||||
}
|
||||
|
||||
CvModel = {
|
||||
"OpenAI": GptV4,
|
||||
"Azure-OpenAI": AzureGptV4,
|
||||
"Ollama": OllamaCV,
|
||||
"Xinference": XinferenceCV,
|
||||
"Tongyi-Qianwen": QWenCV,
|
||||
"ZHIPU-AI": Zhipu4V,
|
||||
"Moonshot": LocalCV,
|
||||
"Gemini": GeminiCV,
|
||||
"OpenRouter": OpenRouterCV,
|
||||
"LocalAI": LocalAICV,
|
||||
"NVIDIA": NvidiaCV,
|
||||
"LM-Studio": LmStudioCV,
|
||||
"StepFun": StepFunCV,
|
||||
"OpenAI-API-Compatible": OpenAI_APICV,
|
||||
"VLLM": OpenAI_APICV,
|
||||
"TogetherAI": TogetherAICV,
|
||||
"01.AI": YiCV,
|
||||
"Tencent Hunyuan": HunyuanCV,
|
||||
"Anthropic": AnthropicCV,
|
||||
"SILICONFLOW": SILICONFLOWCV,
|
||||
"GPUStack": GPUStackCV,
|
||||
"Google Cloud": GoogleCV
|
||||
}
|
||||
package_name = __name__
|
||||
|
||||
ChatModel = {
|
||||
"OpenAI": GptTurbo,
|
||||
"Azure-OpenAI": AzureChat,
|
||||
"ZHIPU-AI": ZhipuChat,
|
||||
"Tongyi-Qianwen": QWenChat,
|
||||
"Ollama": OllamaChat,
|
||||
"LocalAI": LocalAIChat,
|
||||
"Xinference": XinferenceChat,
|
||||
"Moonshot": MoonshotChat,
|
||||
"DeepSeek": DeepSeekChat,
|
||||
"VolcEngine": VolcEngineChat,
|
||||
"BaiChuan": BaiChuanChat,
|
||||
"MiniMax": MiniMaxChat,
|
||||
"Mistral": MistralChat,
|
||||
"Gemini": GeminiChat,
|
||||
"Bedrock": BedrockChat,
|
||||
"Groq": GroqChat,
|
||||
"OpenRouter": OpenRouterChat,
|
||||
"StepFun": StepFunChat,
|
||||
"NVIDIA": NvidiaChat,
|
||||
"LM-Studio": LmStudioChat,
|
||||
"OpenAI-API-Compatible": OpenAI_APIChat,
|
||||
"VLLM": OpenAI_APIChat,
|
||||
"Cohere": CoHereChat,
|
||||
"LeptonAI": LeptonAIChat,
|
||||
"TogetherAI": TogetherAIChat,
|
||||
"PerfXCloud": PerfXCloudChat,
|
||||
"Upstage": UpstageChat,
|
||||
"NovitaAI": NovitaAIChat,
|
||||
"SILICONFLOW": SILICONFLOWChat,
|
||||
"PPIO": PPIOChat,
|
||||
"01.AI": YiChat,
|
||||
"Replicate": ReplicateChat,
|
||||
"Tencent Hunyuan": HunyuanChat,
|
||||
"XunFei Spark": SparkChat,
|
||||
"BaiduYiyan": BaiduYiyanChat,
|
||||
"Anthropic": AnthropicChat,
|
||||
"Google Cloud": GoogleChat,
|
||||
"HuggingFace": HuggingFaceChat,
|
||||
"GPUStack": GPUStackChat,
|
||||
"ModelScope":ModelScopeChat,
|
||||
"GiteeAI": GiteeChat
|
||||
}
|
||||
for module_name, mapping_dict in MODULE_MAPPING.items():
|
||||
full_module_name = f"{package_name}.{module_name}"
|
||||
module = importlib.import_module(full_module_name)
|
||||
|
||||
RerankModel = {
|
||||
"LocalAI": LocalAIRerank,
|
||||
"BAAI": DefaultRerank,
|
||||
"Jina": JinaRerank,
|
||||
"Youdao": YoudaoRerank,
|
||||
"Xinference": XInferenceRerank,
|
||||
"NVIDIA": NvidiaRerank,
|
||||
"LM-Studio": LmStudioRerank,
|
||||
"OpenAI-API-Compatible": OpenAI_APIRerank,
|
||||
"VLLM": CoHereRerank,
|
||||
"Cohere": CoHereRerank,
|
||||
"TogetherAI": TogetherAIRerank,
|
||||
"SILICONFLOW": SILICONFLOWRerank,
|
||||
"BaiduYiyan": BaiduYiyanRerank,
|
||||
"Voyage AI": VoyageRerank,
|
||||
"Tongyi-Qianwen": QWenRerank,
|
||||
"GPUStack": GPUStackRerank,
|
||||
"HuggingFace": HuggingfaceRerank,
|
||||
"NovitaAI": NovitaRerank,
|
||||
"GiteeAI": GiteeRerank
|
||||
}
|
||||
base_class = None
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and name == "Base":
|
||||
base_class = obj
|
||||
break
|
||||
if base_class is None:
|
||||
continue
|
||||
|
||||
Seq2txtModel = {
|
||||
"OpenAI": GPTSeq2txt,
|
||||
"Tongyi-Qianwen": QWenSeq2txt,
|
||||
"Azure-OpenAI": AzureSeq2txt,
|
||||
"Xinference": XinferenceSeq2txt,
|
||||
"Tencent Cloud": TencentCloudSeq2txt,
|
||||
"GPUStack": GPUStackSeq2txt,
|
||||
"GiteeAI": GiteeSeq2txt
|
||||
}
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
|
||||
if isinstance(obj._FACTORY_NAME, list):
|
||||
for factory_name in obj._FACTORY_NAME:
|
||||
mapping_dict[factory_name] = obj
|
||||
else:
|
||||
mapping_dict[obj._FACTORY_NAME] = obj
|
||||
|
||||
TTSModel = {
|
||||
"Fish Audio": FishAudioTTS,
|
||||
"Tongyi-Qianwen": QwenTTS,
|
||||
"OpenAI": OpenAITTS,
|
||||
"XunFei Spark": SparkTTS,
|
||||
"Xinference": XinferenceTTS,
|
||||
"GPUStack": GPUStackTTS,
|
||||
"SILICONFLOW": SILICONFLOWTTS,
|
||||
}
|
||||
__all__ = [
|
||||
"ChatModel",
|
||||
"CvModel",
|
||||
"EmbeddingModel",
|
||||
"RerankModel",
|
||||
"Seq2txtModel",
|
||||
"TTSModel",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user