mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 20:16:49 +08:00
Fix: Asure-OpenAI resource not found (#11934)
### What problem does this PR solve? Asure-OpenAI resource not found. #11750 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -55,6 +55,7 @@ class SupportedLiteLLMProvider(StrEnum):
|
||||
DeerAPI = "DeerAPI"
|
||||
GPUStack = "GPUStack"
|
||||
OpenAI = "OpenAI"
|
||||
Azure_OpenAI = "Azure-OpenAI"
|
||||
|
||||
|
||||
FACTORY_DEFAULT_BASE_URL = {
|
||||
@ -116,7 +117,7 @@ LITELLM_PROVIDER_PREFIX = {
|
||||
SupportedLiteLLMProvider.DeerAPI: "openai/",
|
||||
SupportedLiteLLMProvider.GPUStack: "openai/",
|
||||
SupportedLiteLLMProvider.OpenAI: "openai/",
|
||||
|
||||
SupportedLiteLLMProvider.Azure_OpenAI: "azure/",
|
||||
}
|
||||
|
||||
ChatModel = globals().get("ChatModel", {})
|
||||
|
||||
@ -28,7 +28,6 @@ import json_repair
|
||||
import litellm
|
||||
import openai
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
|
||||
from strenum import StrEnum
|
||||
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
@ -191,7 +190,7 @@ class Base(ABC):
|
||||
except Exception as e:
|
||||
e = await self._exceptions_async(e, attempt)
|
||||
if e:
|
||||
yield e
|
||||
yield e
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
@ -517,26 +516,6 @@ class ModelScopeChat(Base):
|
||||
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
|
||||
|
||||
|
||||
class AzureChat(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
api_key = json.loads(key).get("api_key", "")
|
||||
api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
|
||||
self.async_client = AsyncAzureOpenAI(api_key=key, base_url=base_url, api_version=api_version)
|
||||
self.model_name = model_name
|
||||
|
||||
@property
|
||||
def _retryable_errors(self) -> set[str]:
|
||||
return {
|
||||
LLMErrorCode.ERROR_RATE_LIMIT,
|
||||
LLMErrorCode.ERROR_SERVER,
|
||||
LLMErrorCode.ERROR_QUOTA,
|
||||
}
|
||||
|
||||
|
||||
class BaiChuanChat(Base):
|
||||
_FACTORY_NAME = "BaiChuan"
|
||||
|
||||
@ -1219,6 +1198,7 @@ class LiteLLMBase(ABC):
|
||||
"DeerAPI",
|
||||
"GPUStack",
|
||||
"OpenAI",
|
||||
"Azure-OpenAI",
|
||||
]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
@ -1244,6 +1224,9 @@ class LiteLLMBase(ABC):
|
||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
self.api_key = json.loads(key).get("api_key", "")
|
||||
self.provider_order = json.loads(key).get("provider_order", "")
|
||||
elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI:
|
||||
self.api_key = json.loads(key).get("api_key", "")
|
||||
self.api_version = json.loads(key).get("api_version", "2024-02-01")
|
||||
|
||||
def _get_delay(self):
|
||||
return self.base_delay * random.uniform(10, 150)
|
||||
@ -1675,6 +1658,16 @@ class LiteLLMBase(ABC):
|
||||
"api_base": self.base_url,
|
||||
}
|
||||
)
|
||||
elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI:
|
||||
completion_args.pop("api_key", None)
|
||||
completion_args.pop("api_base", None)
|
||||
completion_args.update(
|
||||
{
|
||||
"api_key": self.api_key,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
}
|
||||
)
|
||||
|
||||
# Ollama deployments commonly sit behind a reverse proxy that enforces
|
||||
# Bearer auth. Ensure the Authorization header is set when an API key
|
||||
|
||||
Reference in New Issue
Block a user