mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-21 13:32:49 +08:00
Feat: Further update Bedrock model configs (#12029)
### What problem does this PR solve? Feat: Further update Bedrock model configs #12020 #12008 <img width="700" alt="2b4f0f7fab803a2a2d5f345c756a2c69" src="https://github.com/user-attachments/assets/e1b9eaad-5c60-47bd-a6f4-88a104ce0c63" /> <img width="700" alt="afe88ec3c58f745f85c5c507b040c250" src="https://github.com/user-attachments/assets/9de39745-395d-4145-930b-96eb452ad6ef" /> <img width="700" alt="1a21bb2b7cd8003dce1e5207f27efc69" src="https://github.com/user-attachments/assets/ddba1682-6654-4954-aa71-41b8ebc04ac0" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -1217,11 +1217,7 @@ class LiteLLMBase(ABC):
|
||||
self.toolcall_sessions = {}
|
||||
|
||||
# Factory specific fields
|
||||
if self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
self.api_key = json.loads(key).get("api_key", "")
|
||||
self.provider_order = json.loads(key).get("provider_order", "")
|
||||
elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI:
|
||||
@ -1624,17 +1620,38 @@ class LiteLLMBase(ABC):
|
||||
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
||||
completion_args.update({"api_base": self.base_url})
|
||||
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||
import boto3
|
||||
|
||||
completion_args.pop("api_key", None)
|
||||
completion_args.pop("api_base", None)
|
||||
bedrock_credentials = { "aws_region_name": self.bedrock_region }
|
||||
if self.bedrock_ak and self.bedrock_sk:
|
||||
bedrock_credentials["aws_access_key_id"] = self.bedrock_ak
|
||||
bedrock_credentials["aws_secret_access_key"] = self.bedrock_sk
|
||||
|
||||
bedrock_key = json.loads(self.api_key)
|
||||
mode = bedrock_key.get("auth_mode")
|
||||
if not mode:
|
||||
logging.error("Bedrock auth_mode is not provided in the key")
|
||||
raise ValueError("Bedrock auth_mode must be provided in the key")
|
||||
|
||||
bedrock_region = bedrock_key.get("bedrock_region")
|
||||
bedrock_credentials = {"bedrock_region": bedrock_region}
|
||||
|
||||
if mode == "access_key_secret":
|
||||
bedrock_credentials["aws_access_key_id"] = bedrock_key.get("bedrock_ak")
|
||||
bedrock_credentials["aws_secret_access_key"] = bedrock_key.get("bedrock_sk")
|
||||
elif mode == "iam_role":
|
||||
aws_role_arn = bedrock_key.get("aws_role_arn")
|
||||
sts_client = boto3.client("sts", region_name=bedrock_region)
|
||||
resp = sts_client.assume_role(RoleArn=aws_role_arn, RoleSessionName="BedrockSession")
|
||||
creds = resp["Credentials"]
|
||||
bedrock_credentials["aws_access_key_id"] = creds["AccessKeyId"]
|
||||
bedrock_credentials["aws_secret_access_key"] = creds["SecretAccessKey"]
|
||||
bedrock_credentials["aws_session_token"] = creds["SessionToken"]
|
||||
|
||||
completion_args.update(
|
||||
{
|
||||
"bedrock_credentials": bedrock_credentials,
|
||||
}
|
||||
)
|
||||
|
||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
if self.provider_order:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user