diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index a9e3c1ab7..dc59e1fb8 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1632,25 +1632,22 @@ class LiteLLMBase(ABC): raise ValueError("Bedrock auth_mode must be provided in the key") bedrock_region = bedrock_key.get("bedrock_region") - bedrock_credentials = {"bedrock_region": bedrock_region} if mode == "access_key_secret": - bedrock_credentials["aws_access_key_id"] = bedrock_key.get("bedrock_ak") - bedrock_credentials["aws_secret_access_key"] = bedrock_key.get("bedrock_sk") + completion_args.update({"aws_region_name": bedrock_region}) + completion_args.update({"aws_access_key_id": bedrock_key.get("bedrock_ak")}) + completion_args.update({"aws_secret_access_key": bedrock_key.get("bedrock_sk")}) elif mode == "iam_role": aws_role_arn = bedrock_key.get("aws_role_arn") sts_client = boto3.client("sts", region_name=bedrock_region) resp = sts_client.assume_role(RoleArn=aws_role_arn, RoleSessionName="BedrockSession") creds = resp["Credentials"] - bedrock_credentials["aws_access_key_id"] = creds["AccessKeyId"] - bedrock_credentials["aws_secret_access_key"] = creds["SecretAccessKey"] - bedrock_credentials["aws_session_token"] = creds["SessionToken"] - - completion_args.update( - { - "bedrock_credentials": bedrock_credentials, - } - ) + completion_args.update({"aws_region_name": bedrock_region}) + completion_args.update({"aws_access_key_id": creds["AccessKeyId"]}) + completion_args.update({"aws_secret_access_key": creds["SecretAccessKey"]}) + completion_args.update({"aws_session_token": creds["SessionToken"]}) + else: # assume_role - use default credential chain (IRSA, instance profile, etc.) + completion_args.update({"aws_region_name": bedrock_region}) elif self.provider == SupportedLiteLLMProvider.OpenRouter: if self.provider_order: