diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 1ef1c3e71..4252456b1 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -194,6 +194,9 @@ def add_llm(): elif factory == "Azure-OpenAI": api_key = apikey_json(["api_key", "api_version"]) + elif factory == "OpenRouter": + api_key = apikey_json(["api_key", "provider_order"]) + llm = { "tenant_id": current_user.id, "llm_factory": factory, diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 5a552fa50..61a09d0df 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1425,6 +1425,9 @@ class LiteLLMBase(ABC): self.bedrock_ak = json.loads(key).get("bedrock_ak", "") self.bedrock_sk = json.loads(key).get("bedrock_sk", "") self.bedrock_region = json.loads(key).get("bedrock_region", "") + elif self.provider == SupportedLiteLLMProvider.OpenRouter: + self.api_key = json.loads(key).get("api_key", "") + self.provider_order = json.loads(key).get("provider_order", "") def _get_delay(self): """Calculate retry delay time""" @@ -1469,7 +1472,6 @@ class LiteLLMBase(ABC): timeout=self.timeout, ) # response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) - if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): return "", 0 ans = response.choices[0].message.content.strip() @@ -1620,6 +1622,24 @@ class LiteLLMBase(ABC): "aws_region_name": self.bedrock_region, } ) + + if self.provider == SupportedLiteLLMProvider.OpenRouter: + if self.provider_order: + def _to_order_list(x): + if x is None: + return [] + if isinstance(x, str): + return [s.strip() for s in x.split(",") if s.strip()] + if isinstance(x, (list, tuple)): + return [str(s).strip() for s in x if str(s).strip()] + return [] + extra_body = {} + provider_cfg = {} + provider_order = _to_order_list(self.provider_order) + provider_cfg["order"] = provider_order + provider_cfg["allow_fallbacks"] = False + extra_body["provider"] = provider_cfg + completion_args.update({"extra_body": extra_body}) return completion_args def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 55f01021a..b96d02d60 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -38,6 +38,7 @@ class Base(ABC): self.is_tools = False self.tools = [] self.toolcall_sessions = {} + self.extra_body = None def describe(self, image): raise NotImplementedError("Please implement encode method!") @@ -77,7 +78,8 @@ class Base(ABC): try: response = self.client.chat.completions.create( model=self.model_name, - messages=self._form_history(system, history, images) + messages=self._form_history(system, history, images), + extra_body=self.extra_body, ) return response.choices[0].message.content.strip(), response.usage.total_tokens except Exception as e: @@ -90,7 +92,8 @@ class Base(ABC): response = self.client.chat.completions.create( model=self.model_name, messages=self._form_history(system, history, images), - stream=True + stream=True, + extra_body=self.extra_body, ) for resp in response: if not resp.choices[0].delta.content: @@ -177,6 +180,7 @@ class GptV4(Base): res = self.client.chat.completions.create( model=self.model_name, messages=self.prompt(b64), + extra_body=self.extra_body, ) return res.choices[0].message.content.strip(), total_token_count_from_response(res) @@ -185,6 +189,7 @@ class GptV4(Base): res = self.client.chat.completions.create( model=self.model_name, messages=self.vision_llm_prompt(b64, prompt), + extra_body=self.extra_body, ) return res.choices[0].message.content.strip(),total_token_count_from_response(res) @@ -327,10 +332,27 @@ class OpenRouterCV(GptV4): ): if not base_url: base_url = "https://openrouter.ai/api/v1" - self.client = OpenAI(api_key=key, base_url=base_url) + api_key = json.loads(key).get("api_key", "") + self.client = OpenAI(api_key=api_key, base_url=base_url) self.model_name = model_name self.lang = lang Base.__init__(self, **kwargs) + provider_order = json.loads(key).get("provider_order", "") + self.extra_body = {} + if provider_order: + def _to_order_list(x): + if x is None: + return [] + if isinstance(x, str): + return [s.strip() for s in x.split(",") if s.strip()] + if isinstance(x, (list, tuple)): + return [str(s).strip() for s in x if str(s).strip()] + return [] + provider_cfg = {} + provider_order = _to_order_list(provider_order) + provider_cfg["order"] = provider_order + provider_cfg["allow_fallbacks"] = False + self.extra_body["provider"] = provider_cfg class LocalAICV(GptV4): diff --git a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx index 8ef4199d6..2dfa4601a 100644 --- a/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/ollama-modal/index.tsx @@ -15,7 +15,10 @@ import { import omit from 'lodash/omit'; import { useEffect } from 'react'; -type FieldType = IAddLlmRequestBody & { vision: boolean }; +type FieldType = IAddLlmRequestBody & { + vision: boolean; + provider_order?: string; +}; const { Option } = Select; @@ -128,6 +131,10 @@ const OllamaModal = ({ { value: 'speech2text', label: 'sequence2text' }, { value: 'tts', label: 'tts' }, ], + [LLMFactory.OpenRouter]: [ + { value: 'chat', label: 'chat' }, + { value: 'image2text', label: 'image2text' }, + ], Default: [ { value: 'chat', label: 'chat' }, { value: 'embedding', label: 'embedding' }, @@ -233,6 +240,16 @@ const OllamaModal = ({ onKeyDown={handleKeyDown} /> + {llmFactory === LLMFactory.OpenRouter && ( +