mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: migrate OpenAI-compatible chats to LiteLLM (#10148)
### What problem does this PR solve? Migrate OpenAI-compatible chats to LiteLLM. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -895,25 +895,6 @@ class MistralChat(Base):
|
||||
yield total_tokens
|
||||
|
||||
|
||||
## openrouter
|
||||
class OpenRouterChat(Base):
|
||||
_FACTORY_NAME = "OpenRouter"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class StepFunChat(Base):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class LmStudioChat(Base):
|
||||
_FACTORY_NAME = "LM-Studio"
|
||||
|
||||
@ -936,15 +917,6 @@ class OpenAI_APIChat(Base):
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class PPIOChat(Base):
|
||||
_FACTORY_NAME = "PPIO"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.ppinfra.com/v3/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.ppinfra.com/v3/openai"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class LeptonAIChat(Base):
|
||||
_FACTORY_NAME = "LeptonAI"
|
||||
|
||||
@ -954,60 +926,6 @@ class LeptonAIChat(Base):
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class PerfXCloudChat(Base):
|
||||
_FACTORY_NAME = "PerfXCloud"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://cloud.perfxlab.cn/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class UpstageChat(Base):
|
||||
_FACTORY_NAME = "Upstage"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.upstage.ai/v1/solar", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.upstage.ai/v1/solar"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class NovitaAIChat(Base):
|
||||
_FACTORY_NAME = "NovitaAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.novita.ai/v3/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.novita.ai/v3/openai"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class SILICONFLOWChat(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class YiChat(Base):
|
||||
_FACTORY_NAME = "01.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.lingyiwanwu.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.lingyiwanwu.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class GiteeChat(Base):
|
||||
_FACTORY_NAME = "GiteeAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://ai.gitee.com/v1/", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://ai.gitee.com/v1/"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class ReplicateChat(Base):
|
||||
_FACTORY_NAME = "Replicate"
|
||||
|
||||
@ -1347,24 +1265,6 @@ class GPUStackChat(Base):
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class Ai302Chat(Base):
|
||||
_FACTORY_NAME = "302.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.302.ai/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.302.ai/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class CometChat(Base):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.cometapi.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class TokenPonyChat(Base):
|
||||
_FACTORY_NAME = "TokenPony"
|
||||
|
||||
@ -1372,18 +1272,39 @@ class TokenPonyChat(Base):
|
||||
if not base_url:
|
||||
base_url = "https://ragflow.vip-api.tokenpony.cn/v1"
|
||||
|
||||
|
||||
class MeituanChat(Base):
|
||||
_FACTORY_NAME = "Meituan"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.longcat.chat/openai", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.longcat.chat/openai"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class LiteLLMBase(ABC):
|
||||
_FACTORY_NAME = ["Tongyi-Qianwen", "Bedrock", "Moonshot", "xAI", "DeepInfra", "Groq", "Cohere", "Gemini", "DeepSeek", "NVIDIA", "TogetherAI", "Anthropic", "Ollama"]
|
||||
_FACTORY_NAME = [
|
||||
"Tongyi-Qianwen",
|
||||
"Bedrock",
|
||||
"Moonshot",
|
||||
"xAI",
|
||||
"DeepInfra",
|
||||
"Groq",
|
||||
"Cohere",
|
||||
"Gemini",
|
||||
"DeepSeek",
|
||||
"NVIDIA",
|
||||
"TogetherAI",
|
||||
"Anthropic",
|
||||
"Ollama",
|
||||
"Meituan",
|
||||
"CometAPI",
|
||||
"SILICONFLOW",
|
||||
"OpenRouter",
|
||||
"StepFun",
|
||||
"PPIO",
|
||||
"PerfXCloud",
|
||||
"Upstage",
|
||||
"NovitaAI",
|
||||
"01.AI",
|
||||
"GiteeAI",
|
||||
"302.AI",
|
||||
]
|
||||
|
||||
import litellm
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
self.timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||
@ -1391,7 +1312,7 @@ class LiteLLMBase(ABC):
|
||||
self.prefix = LITELLM_PROVIDER_PREFIX.get(self.provider, "")
|
||||
self.model_name = f"{self.prefix}{model_name}"
|
||||
self.api_key = key
|
||||
self.base_url = (base_url or FACTORY_DEFAULT_BASE_URL.get(self.provider, "")).rstrip('/')
|
||||
self.base_url = (base_url or FACTORY_DEFAULT_BASE_URL.get(self.provider, "")).rstrip("/")
|
||||
# Configure retry parameters
|
||||
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
|
||||
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
|
||||
|
||||
Reference in New Issue
Block a user