Feat: Support Specifying OpenRouter Model Provider (#10550)

### What problem does this PR solve?
issue:
[#5787](https://github.com/infiniflow/ragflow/issues/5787)
change:
Support Specifying OpenRouter Model Provider

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
buua436
2025-10-16 09:39:59 +08:00
committed by GitHub
parent c99034f717
commit 4e86ee4ff9
4 changed files with 67 additions and 5 deletions

View File

@ -1425,6 +1425,9 @@ class LiteLLMBase(ABC):
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
self.api_key = json.loads(key).get("api_key", "")
self.provider_order = json.loads(key).get("provider_order", "")
def _get_delay(self):
"""Calculate retry delay time"""
@ -1469,7 +1472,6 @@ class LiteLLMBase(ABC):
timeout=self.timeout,
)
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
@ -1620,6 +1622,24 @@ class LiteLLMBase(ABC):
"aws_region_name": self.bedrock_region,
}
)
if self.provider == SupportedLiteLLMProvider.OpenRouter:
if self.provider_order:
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
extra_body = {}
provider_cfg = {}
provider_order = _to_order_list(self.provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
extra_body["provider"] = provider_cfg
completion_args.update({"extra_body": extra_body})
return completion_args
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):

View File

@ -38,6 +38,7 @@ class Base(ABC):
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
self.extra_body = None
def describe(self, image):
raise NotImplementedError("Please implement encode method!")
@ -77,7 +78,8 @@ class Base(ABC):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images)
messages=self._form_history(system, history, images),
extra_body=self.extra_body,
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
@ -90,7 +92,8 @@ class Base(ABC):
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True
stream=True,
extra_body=self.extra_body,
)
for resp in response:
if not resp.choices[0].delta.content:
@ -177,6 +180,7 @@ class GptV4(Base):
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
@ -185,6 +189,7 @@ class GptV4(Base):
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.vision_llm_prompt(b64, prompt),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
@ -327,10 +332,27 @@ class OpenRouterCV(GptV4):
):
if not base_url:
base_url = "https://openrouter.ai/api/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
api_key = json.loads(key).get("api_key", "")
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
provider_order = json.loads(key).get("provider_order", "")
self.extra_body = {}
if provider_order:
def _to_order_list(x):
if x is None:
return []
if isinstance(x, str):
return [s.strip() for s in x.split(",") if s.strip()]
if isinstance(x, (list, tuple)):
return [str(s).strip() for s in x if str(s).strip()]
return []
provider_cfg = {}
provider_order = _to_order_list(provider_order)
provider_cfg["order"] = provider_order
provider_cfg["allow_fallbacks"] = False
self.extra_body["provider"] = provider_cfg
class LocalAICV(GptV4):