From 4693c5382afee22688c9b11b15ae1dd12011e0bd Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Thu, 18 Sep 2025 17:16:59 +0800 Subject: [PATCH] Feat: migrate OpenAI-compatible chats to LiteLLM (#10148) ### What problem does this PR solve? Migrate OpenAI-compatible chats to LiteLLM. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- rag/llm/__init__.py | 36 +++++++++++ rag/llm/chat_model.py | 143 ++++++++++-------------------------------- 2 files changed, 68 insertions(+), 111 deletions(-) diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 14e9a8a19..d91f57736 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -37,6 +37,18 @@ class SupportedLiteLLMProvider(StrEnum): TogetherAI = "TogetherAI" Anthropic = "Anthropic" Ollama = "Ollama" + Meituan = "Meituan" + CometAPI = "CometAPI" + SILICONFLOW = "SILICONFLOW" + OpenRouter = "OpenRouter" + StepFun = "StepFun" + PPIO = "PPIO" + PerfXCloud = "PerfXCloud" + Upstage = "Upstage" + NovitaAI = "NovitaAI" + Lingyi_AI = "01.AI" + GiteeAI = "GiteeAI" + AI_302 = "302.AI" FACTORY_DEFAULT_BASE_URL = { @@ -44,6 +56,18 @@ FACTORY_DEFAULT_BASE_URL = { SupportedLiteLLMProvider.Dashscope: "https://dashscope.aliyuncs.com/compatible-mode/v1", SupportedLiteLLMProvider.Moonshot: "https://api.moonshot.cn/v1", SupportedLiteLLMProvider.Ollama: "", + SupportedLiteLLMProvider.Meituan: "https://api.longcat.chat/openai", + SupportedLiteLLMProvider.CometAPI: "https://api.cometapi.com/v1", + SupportedLiteLLMProvider.SILICONFLOW: "https://api.siliconflow.cn/v1", + SupportedLiteLLMProvider.OpenRouter: "https://openrouter.ai/api/v1", + SupportedLiteLLMProvider.StepFun: "https://api.stepfun.com/v1", + SupportedLiteLLMProvider.PPIO: "https://api.ppinfra.com/v3/openai", + SupportedLiteLLMProvider.PerfXCloud: "https://cloud.perfxlab.cn/v1", + SupportedLiteLLMProvider.Upstage: "https://api.upstage.ai/v1/solar", + SupportedLiteLLMProvider.NovitaAI: "https://api.novita.ai/v3/openai", + SupportedLiteLLMProvider.Lingyi_AI: "https://api.lingyiwanwu.com/v1", + SupportedLiteLLMProvider.GiteeAI: "https://ai.gitee.com/v1/", + SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1", } @@ -62,6 +86,18 @@ LITELLM_PROVIDER_PREFIX = { SupportedLiteLLMProvider.TogetherAI: "together_ai/", SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix SupportedLiteLLMProvider.Ollama: "ollama_chat/", + SupportedLiteLLMProvider.Meituan: "openai/", + SupportedLiteLLMProvider.CometAPI: "openai/", + SupportedLiteLLMProvider.SILICONFLOW: "openai/", + SupportedLiteLLMProvider.OpenRouter: "openai/", + SupportedLiteLLMProvider.StepFun: "openai/", + SupportedLiteLLMProvider.PPIO: "openai/", + SupportedLiteLLMProvider.PerfXCloud: "openai/", + SupportedLiteLLMProvider.Upstage: "openai/", + SupportedLiteLLMProvider.NovitaAI: "openai/", + SupportedLiteLLMProvider.Lingyi_AI: "openai/", + SupportedLiteLLMProvider.GiteeAI: "openai/", + SupportedLiteLLMProvider.AI_302: "openai/", } ChatModel = globals().get("ChatModel", {}) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index a2631920b..b43277fc0 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -895,25 +895,6 @@ class MistralChat(Base): yield total_tokens -## openrouter -class OpenRouterChat(Base): - _FACTORY_NAME = "OpenRouter" - - def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs): - if not base_url: - base_url = "https://openrouter.ai/api/v1" - super().__init__(key, model_name, base_url, **kwargs) - - -class StepFunChat(Base): - _FACTORY_NAME = "StepFun" - - def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs): - if not base_url: - base_url = "https://api.stepfun.com/v1" - super().__init__(key, model_name, base_url, **kwargs) - - class LmStudioChat(Base): _FACTORY_NAME = "LM-Studio" @@ -936,15 +917,6 @@ class OpenAI_APIChat(Base): super().__init__(key, model_name, base_url, **kwargs) -class PPIOChat(Base): - _FACTORY_NAME = "PPIO" - - def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs): - if not base_url: - base_url = "https://api.ppinfra.com/v3/openai" - super().__init__(key, model_name, base_url, **kwargs) - - class LeptonAIChat(Base): _FACTORY_NAME = "LeptonAI" @@ -954,60 +926,6 @@ class LeptonAIChat(Base): super().__init__(key, model_name, base_url, **kwargs) -class PerfXCloudChat(Base): - _FACTORY_NAME = "PerfXCloud" - - def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs): - if not base_url: - base_url = "https://cloud.perfxlab.cn/v1" - super().__init__(key, model_name, base_url, **kwargs) - - -class UpstageChat(Base): - _FACTORY_NAME = "Upstage" - - def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs): - if not base_url: - base_url = "https://api.upstage.ai/v1/solar" - super().__init__(key, model_name, base_url, **kwargs) - - -class NovitaAIChat(Base): - _FACTORY_NAME = "NovitaAI" - - def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs): - if not base_url: - base_url = "https://api.novita.ai/v3/openai" - super().__init__(key, model_name, base_url, **kwargs) - - -class SILICONFLOWChat(Base): - _FACTORY_NAME = "SILICONFLOW" - - def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs): - if not base_url: - base_url = "https://api.siliconflow.cn/v1" - super().__init__(key, model_name, base_url, **kwargs) - - -class YiChat(Base): - _FACTORY_NAME = "01.AI" - - def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs): - if not base_url: - base_url = "https://api.lingyiwanwu.com/v1" - super().__init__(key, model_name, base_url, **kwargs) - - -class GiteeChat(Base): - _FACTORY_NAME = "GiteeAI" - - def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs): - if not base_url: - base_url = "https://ai.gitee.com/v1/" - super().__init__(key, model_name, base_url, **kwargs) - - class ReplicateChat(Base): _FACTORY_NAME = "Replicate" @@ -1347,24 +1265,6 @@ class GPUStackChat(Base): super().__init__(key, model_name, base_url, **kwargs) -class Ai302Chat(Base): - _FACTORY_NAME = "302.AI" - - def __init__(self, key, model_name, base_url="https://api.302.ai/v1", **kwargs): - if not base_url: - base_url = "https://api.302.ai/v1" - super().__init__(key, model_name, base_url, **kwargs) - - -class CometChat(Base): - _FACTORY_NAME = "CometAPI" - - def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1", **kwargs): - if not base_url: - base_url = "https://api.cometapi.com/v1" - super().__init__(key, model_name, base_url, **kwargs) - - class TokenPonyChat(Base): _FACTORY_NAME = "TokenPony" @@ -1372,18 +1272,39 @@ class TokenPonyChat(Base): if not base_url: base_url = "https://ragflow.vip-api.tokenpony.cn/v1" - -class MeituanChat(Base): - _FACTORY_NAME = "Meituan" - - def __init__(self, key, model_name, base_url="https://api.longcat.chat/openai", **kwargs): - if not base_url: - base_url = "https://api.longcat.chat/openai" - super().__init__(key, model_name, base_url, **kwargs) - class LiteLLMBase(ABC): - _FACTORY_NAME = ["Tongyi-Qianwen", "Bedrock", "Moonshot", "xAI", "DeepInfra", "Groq", "Cohere", "Gemini", "DeepSeek", "NVIDIA", "TogetherAI", "Anthropic", "Ollama"] + _FACTORY_NAME = [ + "Tongyi-Qianwen", + "Bedrock", + "Moonshot", + "xAI", + "DeepInfra", + "Groq", + "Cohere", + "Gemini", + "DeepSeek", + "NVIDIA", + "TogetherAI", + "Anthropic", + "Ollama", + "Meituan", + "CometAPI", + "SILICONFLOW", + "OpenRouter", + "StepFun", + "PPIO", + "PerfXCloud", + "Upstage", + "NovitaAI", + "01.AI", + "GiteeAI", + "302.AI", + ] + + import litellm + + litellm._turn_on_debug() def __init__(self, key, model_name, base_url=None, **kwargs): self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) @@ -1391,7 +1312,7 @@ class LiteLLMBase(ABC): self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "") self.model_name = f"{self.prefix}{model_name}" self.api_key = key - self.base_url = (base_url or FACTORY_DEFAULT_BASE_URL.get(self.provider, "")).rstrip('/') + self.base_url = (base_url or FACTORY_DEFAULT_BASE_URL.get(self.provider, "")).rstrip("/") # Configure retry parameters self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))