From f8fd1ea7e177869e75a78f301f849bf7a906df89 Mon Sep 17 00:00:00 2001 From: Magicbook1108 Date: Fri, 19 Dec 2025 11:32:20 +0800 Subject: [PATCH] Feat: Further update Bedrock model configs (#12029) ### What problem does this PR solve? Feat: Further update Bedrock model configs #12020 #12008 2b4f0f7fab803a2a2d5f345c756a2c69 afe88ec3c58f745f85c5c507b040c250 1a21bb2b7cd8003dce1e5207f27efc69 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/llm_app.py | 2 +- rag/llm/chat_model.py | 35 +++-- rag/llm/embedding_model.py | 42 ++++-- web/src/locales/en.ts | 9 ++ web/src/locales/zh-traditional.ts | 9 ++ web/src/locales/zh.ts | 9 ++ .../modal/bedrock-modal/index.tsx | 123 +++++++++++++++--- 7 files changed, 190 insertions(+), 39 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 1f46a4098..e77b90506 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -157,7 +157,7 @@ async def add_llm(): elif factory == "Bedrock": # For Bedrock, due to its special authentication method # Assemble bedrock_ak, bedrock_sk, bedrock_region - api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"]) + api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"]) elif factory == "LocalAI": llm_name += "___LocalAI" diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 7c8e2476e..e1451c00d 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1217,11 +1217,7 @@ class LiteLLMBase(ABC): self.toolcall_sessions = {} # Factory specific fields - if self.provider == SupportedLiteLLMProvider.Bedrock: - self.bedrock_ak = json.loads(key).get("bedrock_ak", "") - self.bedrock_sk = json.loads(key).get("bedrock_sk", "") - self.bedrock_region = json.loads(key).get("bedrock_region", "") - elif self.provider == SupportedLiteLLMProvider.OpenRouter: + if self.provider == SupportedLiteLLMProvider.OpenRouter: self.api_key = json.loads(key).get("api_key", "") self.provider_order = json.loads(key).get("provider_order", "") elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI: @@ -1624,17 +1620,38 @@ class LiteLLMBase(ABC): if self.provider in FACTORY_DEFAULT_BASE_URL: completion_args.update({"api_base": self.base_url}) elif self.provider == SupportedLiteLLMProvider.Bedrock: + import boto3 + completion_args.pop("api_key", None) completion_args.pop("api_base", None) - bedrock_credentials = { "aws_region_name": self.bedrock_region } - if self.bedrock_ak and self.bedrock_sk: - bedrock_credentials["aws_access_key_id"] = self.bedrock_ak - bedrock_credentials["aws_secret_access_key"] = self.bedrock_sk + + bedrock_key = json.loads(self.api_key) + mode = bedrock_key.get("auth_mode") + if not mode: + logging.error("Bedrock auth_mode is not provided in the key") + raise ValueError("Bedrock auth_mode must be provided in the key") + + bedrock_region = bedrock_key.get("bedrock_region") + bedrock_credentials = {"bedrock_region": bedrock_region} + + if mode == "access_key_secret": + bedrock_credentials["aws_access_key_id"] = bedrock_key.get("bedrock_ak") + bedrock_credentials["aws_secret_access_key"] = bedrock_key.get("bedrock_sk") + elif mode == "iam_role": + aws_role_arn = bedrock_key.get("aws_role_arn") + sts_client = boto3.client("sts", region_name=bedrock_region) + resp = sts_client.assume_role(RoleArn=aws_role_arn, RoleSessionName="BedrockSession") + creds = resp["Credentials"] + bedrock_credentials["aws_access_key_id"] = creds["AccessKeyId"] + bedrock_credentials["aws_secret_access_key"] = creds["SecretAccessKey"] + bedrock_credentials["aws_session_token"] = creds["SessionToken"] + completion_args.update( { "bedrock_credentials": bedrock_credentials, } ) + elif self.provider == SupportedLiteLLMProvider.OpenRouter: if self.provider_order: diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index bf6559960..24b312df2 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -463,20 +463,44 @@ class BedrockEmbed(Base): def __init__(self, key, model_name, **kwargs): import boto3 + # `key` protocol (backend stores as JSON string in `api_key`): + # - Must decode into a dict. + # - Required: `auth_mode`, `bedrock_region`. + # - Supported auth modes: + # - "access_key_secret": requires `bedrock_ak` + `bedrock_sk`. + # - "iam_role": requires `aws_role_arn` and assumes role via STS. + # - else: treated as "assume_role" (default AWS credential chain). + key = json.loads(key) + mode = key.get("auth_mode") + if not mode: + logging.error("Bedrock auth_mode is not provided in the key") + raise ValueError("Bedrock auth_mode must be provided in the key") + + self.bedrock_region = key.get("bedrock_region") - self.bedrock_ak = json.loads(key).get("bedrock_ak", "") - self.bedrock_sk = json.loads(key).get("bedrock_sk", "") - self.bedrock_region = json.loads(key).get("bedrock_region", "") self.model_name = model_name self.is_amazon = self.model_name.split(".")[0] == "amazon" self.is_cohere = self.model_name.split(".")[0] == "cohere" - - if self.bedrock_ak == "" or self.bedrock_sk == "": - # Try to create a client using the default credentials if ak/sk are not provided. - # Must provide a region. - self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region) - else: + + if mode == "access_key_secret": + self.bedrock_ak = key.get("bedrock_ak") + self.bedrock_sk = key.get("bedrock_sk") self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) + elif mode == "iam_role": + self.aws_role_arn = key.get("aws_role_arn") + sts_client = boto3.client("sts", region_name=self.bedrock_region) + resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockSession") + creds = resp["Credentials"] + + self.client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=creds["AccessKeyId"], + aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], + ) + else: # assume_role + self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region) + def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts] diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 0d9ee1fa1..e3d5affc9 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -787,6 +787,15 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s deleteModel: 'Delete model', bedrockCredentialsHint: 'Tip: Leave Access Key / Secret Key blank to use AWS IAM authentication.', + awsAuthModeAccessKeySecret: 'Access Key', + awsAuthModeIamRole: 'IAM Role', + awsAuthModeAssumeRole: 'Assume Role', + awsAccessKeyId: 'AWS Access Key ID', + awsSecretAccessKey: 'AWS Secret Access Key', + awsRoleArn: 'AWS Role ARN', + awsRoleArnMessage: 'Please enter AWS Role ARN', + awsAssumeRoleTip: + 'If you select this mode, the Amazon EC2 instance will assume its existing role to access AWS services. No additional credentials are required.', modelEmptyTip: 'No models available.
Please add models from the panel on the right.', sourceEmptyTip: 'No data sources added yet. Select one below to connect.', diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 1c5f40b6f..0eaf5f436 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -546,6 +546,15 @@ export default { profileDescription: '在此更新您的照片和個人詳細信息。', bedrockCredentialsHint: '提示:Access Key / Secret Key 可留空,以啟用 AWS IAM 自動驗證。', + awsAuthModeAccessKeySecret: 'Access Key 和 Secret', + awsAuthModeIamRole: 'IAM Role', + awsAuthModeAssumeRole: 'Assume Role', + awsAccessKeyId: 'AWS Access Key ID', + awsSecretAccessKey: 'AWS Secret Access Key', + awsRoleArn: 'AWS Role ARN', + awsRoleArnMessage: '請輸入 AWS Role ARN', + awsAssumeRoleTip: + '選擇此模式後,EC2 執行個體將使用其既有的 IAM Role 存取 AWS 服務,無需額外憑證。', maxTokens: '最大token數', maxTokensMessage: '最大token數是必填項', maxTokensTip: diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index c1a5426e8..828dc5882 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -876,6 +876,15 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 bedrockAKMessage: '请输入 ACCESS KEY', addBedrockSK: 'SECRET KEY', bedrockSKMessage: '请输入 SECRET KEY', + awsAuthModeAccessKeySecret: 'Access Key 和 Secret', + awsAuthModeIamRole: 'IAM Role', + awsAuthModeAssumeRole: 'Assume Role', + awsAccessKeyId: 'AWS Access Key ID', + awsSecretAccessKey: 'AWS Secret Access Key', + awsRoleArn: 'AWS Role ARN', + awsRoleArnMessage: '请输入 AWS Role ARN', + awsAssumeRoleTip: + '选择此模式后,EC2 实例将使用其已有的 IAM Role 访问 AWS 服务,无需额外的凭证。', bedrockRegion: 'AWS Region', bedrockRegionMessage: '请选择!', 'us-east-1': '美国东部 (弗吉尼亚北部)', diff --git a/web/src/pages/user-setting/setting-model/modal/bedrock-modal/index.tsx b/web/src/pages/user-setting/setting-model/modal/bedrock-modal/index.tsx index 2127701a3..6a610d34a 100644 --- a/web/src/pages/user-setting/setting-model/modal/bedrock-modal/index.tsx +++ b/web/src/pages/user-setting/setting-model/modal/bedrock-modal/index.tsx @@ -1,15 +1,25 @@ import { useTranslate } from '@/hooks/common-hooks'; import { IModalProps } from '@/interfaces/common'; import { IAddLlmRequestBody } from '@/interfaces/request/llm'; -import { Form, Input, InputNumber, Modal, Select, Typography } from 'antd'; -import { useMemo } from 'react'; +import { + Form, + Input, + InputNumber, + Modal, + Segmented, + Select, + Typography, +} from 'antd'; +import { useMemo, useState } from 'react'; import { LLMHeader } from '../../components/llm-header'; import { BedrockRegionList } from '../../constant'; type FieldType = IAddLlmRequestBody & { + auth_mode?: 'access_key_secret' | 'iam_role' | 'assume_role'; bedrock_ak: string; bedrock_sk: string; bedrock_region: string; + aws_role_arn?: string; }; const { Option } = Select; @@ -23,6 +33,8 @@ const BedrockModal = ({ llmFactory, }: IModalProps & { llmFactory: string }) => { const [form] = Form.useForm(); + const [authMode, setAuthMode] = + useState('access_key_secret'); const { t } = useTranslate('setting'); const options = useMemo( @@ -33,13 +45,32 @@ const BedrockModal = ({ const handleOk = async () => { const values = await form.validateFields(); + // Only submit fields related to the active auth mode. + const cleanedValues: Record = { ...values }; + + const fieldsByMode: Record = { + access_key_secret: ['bedrock_ak', 'bedrock_sk'], + iam_role: ['aws_role_arn'], + assume_role: [], + }; + + cleanedValues.auth_mode = authMode; + + Object.keys(fieldsByMode).forEach((mode) => { + if (mode !== authMode) { + fieldsByMode[mode].forEach((field) => { + delete cleanedValues[field]; + }); + } + }); + const data = { - ...values, + ...cleanedValues, llm_factory: llmFactory, max_tokens: values.max_tokens, }; - onOk?.(data); + onOk?.(data as unknown as IAddLlmRequestBody); }; return ( @@ -47,9 +78,6 @@ const BedrockModal = ({ title={
- - {t('bedrockCredentialsHint')} -
} open={visible} @@ -82,20 +110,75 @@ const BedrockModal = ({ > - - label={t('addBedrockEngineAK')} - name="bedrock_ak" - rules={[{ message: t('bedrockAKMessage') }]} - > - - - - label={t('addBedrockSK')} - name="bedrock_sk" - rules={[{ message: t('bedrockSKMessage') }]} - > - + + {/* AWS Credential Mode Switch (AK/SK section only) */} + + { + const next = v as FieldType['auth_mode']; + setAuthMode(next); + // Clear non-active fields so they won't be validated/submitted by accident. + if (next !== 'access_key_secret') { + form.setFieldsValue({ bedrock_ak: '', bedrock_sk: '' } as any); + } + if (next !== 'iam_role') { + form.setFieldsValue({ aws_role_arn: '' } as any); + } + if (next !== 'assume_role') { + form.setFieldsValue({ role_arn: '' } as any); + } + }} + options={[ + { + label: t('awsAuthModeAccessKeySecret'), + value: 'access_key_secret', + }, + { label: t('awsAuthModeIamRole'), value: 'iam_role' }, + { label: t('awsAuthModeAssumeRole'), value: 'assume_role' }, + ]} + /> + + {authMode === 'access_key_secret' && ( + <> + + label={t('awsAccessKeyId')} + name="bedrock_ak" + rules={[{ required: true, message: t('bedrockAKMessage') }]} + > + + + + label={t('awsSecretAccessKey')} + name="bedrock_sk" + rules={[{ required: true, message: t('bedrockSKMessage') }]} + > + + + + )} + + {authMode === 'iam_role' && ( + + label={t('awsRoleArn')} + name="aws_role_arn" + rules={[{ required: true, message: t('awsRoleArnMessage') }]} + > + + + )} + + {authMode === 'assume_role' && ( + + {t('awsAssumeRoleTip')} + + )} + label={t('bedrockRegion')} name="bedrock_region"