mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Support Specifying OpenRouter Model Provider (#10550)
### What problem does this PR solve? issue: [#5787](https://github.com/infiniflow/ragflow/issues/5787) change: Support Specifying OpenRouter Model Provider ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -194,6 +194,9 @@ def add_llm():
|
|||||||
elif factory == "Azure-OpenAI":
|
elif factory == "Azure-OpenAI":
|
||||||
api_key = apikey_json(["api_key", "api_version"])
|
api_key = apikey_json(["api_key", "api_version"])
|
||||||
|
|
||||||
|
elif factory == "OpenRouter":
|
||||||
|
api_key = apikey_json(["api_key", "provider_order"])
|
||||||
|
|
||||||
llm = {
|
llm = {
|
||||||
"tenant_id": current_user.id,
|
"tenant_id": current_user.id,
|
||||||
"llm_factory": factory,
|
"llm_factory": factory,
|
||||||
|
|||||||
@ -1425,6 +1425,9 @@ class LiteLLMBase(ABC):
|
|||||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||||
|
self.api_key = json.loads(key).get("api_key", "")
|
||||||
|
self.provider_order = json.loads(key).get("provider_order", "")
|
||||||
|
|
||||||
def _get_delay(self):
|
def _get_delay(self):
|
||||||
"""Calculate retry delay time"""
|
"""Calculate retry delay time"""
|
||||||
@ -1469,7 +1472,6 @@ class LiteLLMBase(ABC):
|
|||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||||
|
|
||||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||||
return "", 0
|
return "", 0
|
||||||
ans = response.choices[0].message.content.strip()
|
ans = response.choices[0].message.content.strip()
|
||||||
@ -1620,6 +1622,24 @@ class LiteLLMBase(ABC):
|
|||||||
"aws_region_name": self.bedrock_region,
|
"aws_region_name": self.bedrock_region,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||||
|
if self.provider_order:
|
||||||
|
def _to_order_list(x):
|
||||||
|
if x is None:
|
||||||
|
return []
|
||||||
|
if isinstance(x, str):
|
||||||
|
return [s.strip() for s in x.split(",") if s.strip()]
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
return [str(s).strip() for s in x if str(s).strip()]
|
||||||
|
return []
|
||||||
|
extra_body = {}
|
||||||
|
provider_cfg = {}
|
||||||
|
provider_order = _to_order_list(self.provider_order)
|
||||||
|
provider_cfg["order"] = provider_order
|
||||||
|
provider_cfg["allow_fallbacks"] = False
|
||||||
|
extra_body["provider"] = provider_cfg
|
||||||
|
completion_args.update({"extra_body": extra_body})
|
||||||
return completion_args
|
return completion_args
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
|
|||||||
@ -38,6 +38,7 @@ class Base(ABC):
|
|||||||
self.is_tools = False
|
self.is_tools = False
|
||||||
self.tools = []
|
self.tools = []
|
||||||
self.toolcall_sessions = {}
|
self.toolcall_sessions = {}
|
||||||
|
self.extra_body = None
|
||||||
|
|
||||||
def describe(self, image):
|
def describe(self, image):
|
||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
@ -77,7 +78,8 @@ class Base(ABC):
|
|||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self._form_history(system, history, images)
|
messages=self._form_history(system, history, images),
|
||||||
|
extra_body=self.extra_body,
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
return response.choices[0].message.content.strip(), response.usage.total_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -90,7 +92,8 @@ class Base(ABC):
|
|||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self._form_history(system, history, images),
|
messages=self._form_history(system, history, images),
|
||||||
stream=True
|
stream=True,
|
||||||
|
extra_body=self.extra_body,
|
||||||
)
|
)
|
||||||
for resp in response:
|
for resp in response:
|
||||||
if not resp.choices[0].delta.content:
|
if not resp.choices[0].delta.content:
|
||||||
@ -177,6 +180,7 @@ class GptV4(Base):
|
|||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self.prompt(b64),
|
messages=self.prompt(b64),
|
||||||
|
extra_body=self.extra_body,
|
||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||||
|
|
||||||
@ -185,6 +189,7 @@ class GptV4(Base):
|
|||||||
res = self.client.chat.completions.create(
|
res = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=self.vision_llm_prompt(b64, prompt),
|
messages=self.vision_llm_prompt(b64, prompt),
|
||||||
|
extra_body=self.extra_body,
|
||||||
)
|
)
|
||||||
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
|
||||||
|
|
||||||
@ -327,10 +332,27 @@ class OpenRouterCV(GptV4):
|
|||||||
):
|
):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://openrouter.ai/api/v1"
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
api_key = json.loads(key).get("api_key", "")
|
||||||
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
Base.__init__(self, **kwargs)
|
Base.__init__(self, **kwargs)
|
||||||
|
provider_order = json.loads(key).get("provider_order", "")
|
||||||
|
self.extra_body = {}
|
||||||
|
if provider_order:
|
||||||
|
def _to_order_list(x):
|
||||||
|
if x is None:
|
||||||
|
return []
|
||||||
|
if isinstance(x, str):
|
||||||
|
return [s.strip() for s in x.split(",") if s.strip()]
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
return [str(s).strip() for s in x if str(s).strip()]
|
||||||
|
return []
|
||||||
|
provider_cfg = {}
|
||||||
|
provider_order = _to_order_list(provider_order)
|
||||||
|
provider_cfg["order"] = provider_order
|
||||||
|
provider_cfg["allow_fallbacks"] = False
|
||||||
|
self.extra_body["provider"] = provider_cfg
|
||||||
|
|
||||||
|
|
||||||
class LocalAICV(GptV4):
|
class LocalAICV(GptV4):
|
||||||
|
|||||||
@ -15,7 +15,10 @@ import {
|
|||||||
import omit from 'lodash/omit';
|
import omit from 'lodash/omit';
|
||||||
import { useEffect } from 'react';
|
import { useEffect } from 'react';
|
||||||
|
|
||||||
type FieldType = IAddLlmRequestBody & { vision: boolean };
|
type FieldType = IAddLlmRequestBody & {
|
||||||
|
vision: boolean;
|
||||||
|
provider_order?: string;
|
||||||
|
};
|
||||||
|
|
||||||
const { Option } = Select;
|
const { Option } = Select;
|
||||||
|
|
||||||
@ -128,6 +131,10 @@ const OllamaModal = ({
|
|||||||
{ value: 'speech2text', label: 'sequence2text' },
|
{ value: 'speech2text', label: 'sequence2text' },
|
||||||
{ value: 'tts', label: 'tts' },
|
{ value: 'tts', label: 'tts' },
|
||||||
],
|
],
|
||||||
|
[LLMFactory.OpenRouter]: [
|
||||||
|
{ value: 'chat', label: 'chat' },
|
||||||
|
{ value: 'image2text', label: 'image2text' },
|
||||||
|
],
|
||||||
Default: [
|
Default: [
|
||||||
{ value: 'chat', label: 'chat' },
|
{ value: 'chat', label: 'chat' },
|
||||||
{ value: 'embedding', label: 'embedding' },
|
{ value: 'embedding', label: 'embedding' },
|
||||||
@ -233,6 +240,16 @@ const OllamaModal = ({
|
|||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
{llmFactory === LLMFactory.OpenRouter && (
|
||||||
|
<Form.Item<FieldType>
|
||||||
|
label="Provider Order"
|
||||||
|
name="provider_order"
|
||||||
|
tooltip="Comma-separated provider list, e.g. Groq,Fireworks"
|
||||||
|
rules={[]}
|
||||||
|
>
|
||||||
|
<Input placeholder="Groq,Fireworks" onKeyDown={handleKeyDown} />
|
||||||
|
</Form.Item>
|
||||||
|
)}
|
||||||
|
|
||||||
<Form.Item noStyle dependencies={['model_type']}>
|
<Form.Item noStyle dependencies={['model_type']}>
|
||||||
{({ getFieldValue }) =>
|
{({ getFieldValue }) =>
|
||||||
|
|||||||
Reference in New Issue
Block a user