mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-21 05:16:54 +08:00
Feat: Further update Bedrock model configs (#12029)
### What problem does this PR solve? Feat: Further update Bedrock model configs #12020 #12008 <img width="700" alt="2b4f0f7fab803a2a2d5f345c756a2c69" src="https://github.com/user-attachments/assets/e1b9eaad-5c60-47bd-a6f4-88a104ce0c63" /> <img width="700" alt="afe88ec3c58f745f85c5c507b040c250" src="https://github.com/user-attachments/assets/9de39745-395d-4145-930b-96eb452ad6ef" /> <img width="700" alt="1a21bb2b7cd8003dce1e5207f27efc69" src="https://github.com/user-attachments/assets/ddba1682-6654-4954-aa71-41b8ebc04ac0" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -157,7 +157,7 @@ async def add_llm():
|
|||||||
elif factory == "Bedrock":
|
elif factory == "Bedrock":
|
||||||
# For Bedrock, due to its special authentication method
|
# For Bedrock, due to its special authentication method
|
||||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
|
||||||
|
|
||||||
elif factory == "LocalAI":
|
elif factory == "LocalAI":
|
||||||
llm_name += "___LocalAI"
|
llm_name += "___LocalAI"
|
||||||
|
|||||||
@ -1217,11 +1217,7 @@ class LiteLLMBase(ABC):
|
|||||||
self.toolcall_sessions = {}
|
self.toolcall_sessions = {}
|
||||||
|
|
||||||
# Factory specific fields
|
# Factory specific fields
|
||||||
if self.provider == SupportedLiteLLMProvider.Bedrock:
|
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
|
||||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
|
||||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
|
||||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
|
||||||
self.api_key = json.loads(key).get("api_key", "")
|
self.api_key = json.loads(key).get("api_key", "")
|
||||||
self.provider_order = json.loads(key).get("provider_order", "")
|
self.provider_order = json.loads(key).get("provider_order", "")
|
||||||
elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI:
|
elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI:
|
||||||
@ -1624,17 +1620,38 @@ class LiteLLMBase(ABC):
|
|||||||
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
||||||
completion_args.update({"api_base": self.base_url})
|
completion_args.update({"api_base": self.base_url})
|
||||||
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||||
|
import boto3
|
||||||
|
|
||||||
completion_args.pop("api_key", None)
|
completion_args.pop("api_key", None)
|
||||||
completion_args.pop("api_base", None)
|
completion_args.pop("api_base", None)
|
||||||
bedrock_credentials = { "aws_region_name": self.bedrock_region }
|
|
||||||
if self.bedrock_ak and self.bedrock_sk:
|
bedrock_key = json.loads(self.api_key)
|
||||||
bedrock_credentials["aws_access_key_id"] = self.bedrock_ak
|
mode = bedrock_key.get("auth_mode")
|
||||||
bedrock_credentials["aws_secret_access_key"] = self.bedrock_sk
|
if not mode:
|
||||||
|
logging.error("Bedrock auth_mode is not provided in the key")
|
||||||
|
raise ValueError("Bedrock auth_mode must be provided in the key")
|
||||||
|
|
||||||
|
bedrock_region = bedrock_key.get("bedrock_region")
|
||||||
|
bedrock_credentials = {"bedrock_region": bedrock_region}
|
||||||
|
|
||||||
|
if mode == "access_key_secret":
|
||||||
|
bedrock_credentials["aws_access_key_id"] = bedrock_key.get("bedrock_ak")
|
||||||
|
bedrock_credentials["aws_secret_access_key"] = bedrock_key.get("bedrock_sk")
|
||||||
|
elif mode == "iam_role":
|
||||||
|
aws_role_arn = bedrock_key.get("aws_role_arn")
|
||||||
|
sts_client = boto3.client("sts", region_name=bedrock_region)
|
||||||
|
resp = sts_client.assume_role(RoleArn=aws_role_arn, RoleSessionName="BedrockSession")
|
||||||
|
creds = resp["Credentials"]
|
||||||
|
bedrock_credentials["aws_access_key_id"] = creds["AccessKeyId"]
|
||||||
|
bedrock_credentials["aws_secret_access_key"] = creds["SecretAccessKey"]
|
||||||
|
bedrock_credentials["aws_session_token"] = creds["SessionToken"]
|
||||||
|
|
||||||
completion_args.update(
|
completion_args.update(
|
||||||
{
|
{
|
||||||
"bedrock_credentials": bedrock_credentials,
|
"bedrock_credentials": bedrock_credentials,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||||
if self.provider_order:
|
if self.provider_order:
|
||||||
|
|
||||||
|
|||||||
@ -463,20 +463,44 @@ class BedrockEmbed(Base):
|
|||||||
|
|
||||||
def __init__(self, key, model_name, **kwargs):
|
def __init__(self, key, model_name, **kwargs):
|
||||||
import boto3
|
import boto3
|
||||||
|
# `key` protocol (backend stores as JSON string in `api_key`):
|
||||||
|
# - Must decode into a dict.
|
||||||
|
# - Required: `auth_mode`, `bedrock_region`.
|
||||||
|
# - Supported auth modes:
|
||||||
|
# - "access_key_secret": requires `bedrock_ak` + `bedrock_sk`.
|
||||||
|
# - "iam_role": requires `aws_role_arn` and assumes role via STS.
|
||||||
|
# - else: treated as "assume_role" (default AWS credential chain).
|
||||||
|
key = json.loads(key)
|
||||||
|
mode = key.get("auth_mode")
|
||||||
|
if not mode:
|
||||||
|
logging.error("Bedrock auth_mode is not provided in the key")
|
||||||
|
raise ValueError("Bedrock auth_mode must be provided in the key")
|
||||||
|
|
||||||
|
self.bedrock_region = key.get("bedrock_region")
|
||||||
|
|
||||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
|
||||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
|
||||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||||
|
|
||||||
if self.bedrock_ak == "" or self.bedrock_sk == "":
|
if mode == "access_key_secret":
|
||||||
# Try to create a client using the default credentials if ak/sk are not provided.
|
self.bedrock_ak = key.get("bedrock_ak")
|
||||||
# Must provide a region.
|
self.bedrock_sk = key.get("bedrock_sk")
|
||||||
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
|
||||||
else:
|
|
||||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||||
|
elif mode == "iam_role":
|
||||||
|
self.aws_role_arn = key.get("aws_role_arn")
|
||||||
|
sts_client = boto3.client("sts", region_name=self.bedrock_region)
|
||||||
|
resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockSession")
|
||||||
|
creds = resp["Credentials"]
|
||||||
|
|
||||||
|
self.client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
aws_access_key_id=creds["AccessKeyId"],
|
||||||
|
aws_secret_access_key=creds["SecretAccessKey"],
|
||||||
|
aws_session_token=creds["SessionToken"],
|
||||||
|
)
|
||||||
|
else: # assume_role
|
||||||
|
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
||||||
|
|
||||||
|
|
||||||
def encode(self, texts: list):
|
def encode(self, texts: list):
|
||||||
texts = [truncate(t, 8196) for t in texts]
|
texts = [truncate(t, 8196) for t in texts]
|
||||||
|
|||||||
@ -787,6 +787,15 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
|||||||
deleteModel: 'Delete model',
|
deleteModel: 'Delete model',
|
||||||
bedrockCredentialsHint:
|
bedrockCredentialsHint:
|
||||||
'Tip: Leave Access Key / Secret Key blank to use AWS IAM authentication.',
|
'Tip: Leave Access Key / Secret Key blank to use AWS IAM authentication.',
|
||||||
|
awsAuthModeAccessKeySecret: 'Access Key',
|
||||||
|
awsAuthModeIamRole: 'IAM Role',
|
||||||
|
awsAuthModeAssumeRole: 'Assume Role',
|
||||||
|
awsAccessKeyId: 'AWS Access Key ID',
|
||||||
|
awsSecretAccessKey: 'AWS Secret Access Key',
|
||||||
|
awsRoleArn: 'AWS Role ARN',
|
||||||
|
awsRoleArnMessage: 'Please enter AWS Role ARN',
|
||||||
|
awsAssumeRoleTip:
|
||||||
|
'If you select this mode, the Amazon EC2 instance will assume its existing role to access AWS services. No additional credentials are required.',
|
||||||
modelEmptyTip:
|
modelEmptyTip:
|
||||||
'No models available. <br>Please add models from the panel on the right.',
|
'No models available. <br>Please add models from the panel on the right.',
|
||||||
sourceEmptyTip: 'No data sources added yet. Select one below to connect.',
|
sourceEmptyTip: 'No data sources added yet. Select one below to connect.',
|
||||||
|
|||||||
@ -546,6 +546,15 @@ export default {
|
|||||||
profileDescription: '在此更新您的照片和個人詳細信息。',
|
profileDescription: '在此更新您的照片和個人詳細信息。',
|
||||||
bedrockCredentialsHint:
|
bedrockCredentialsHint:
|
||||||
'提示:Access Key / Secret Key 可留空,以啟用 AWS IAM 自動驗證。',
|
'提示:Access Key / Secret Key 可留空,以啟用 AWS IAM 自動驗證。',
|
||||||
|
awsAuthModeAccessKeySecret: 'Access Key 和 Secret',
|
||||||
|
awsAuthModeIamRole: 'IAM Role',
|
||||||
|
awsAuthModeAssumeRole: 'Assume Role',
|
||||||
|
awsAccessKeyId: 'AWS Access Key ID',
|
||||||
|
awsSecretAccessKey: 'AWS Secret Access Key',
|
||||||
|
awsRoleArn: 'AWS Role ARN',
|
||||||
|
awsRoleArnMessage: '請輸入 AWS Role ARN',
|
||||||
|
awsAssumeRoleTip:
|
||||||
|
'選擇此模式後,EC2 執行個體將使用其既有的 IAM Role 存取 AWS 服務,無需額外憑證。',
|
||||||
maxTokens: '最大token數',
|
maxTokens: '最大token數',
|
||||||
maxTokensMessage: '最大token數是必填項',
|
maxTokensMessage: '最大token數是必填項',
|
||||||
maxTokensTip:
|
maxTokensTip:
|
||||||
|
|||||||
@ -876,6 +876,15 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
|||||||
bedrockAKMessage: '请输入 ACCESS KEY',
|
bedrockAKMessage: '请输入 ACCESS KEY',
|
||||||
addBedrockSK: 'SECRET KEY',
|
addBedrockSK: 'SECRET KEY',
|
||||||
bedrockSKMessage: '请输入 SECRET KEY',
|
bedrockSKMessage: '请输入 SECRET KEY',
|
||||||
|
awsAuthModeAccessKeySecret: 'Access Key 和 Secret',
|
||||||
|
awsAuthModeIamRole: 'IAM Role',
|
||||||
|
awsAuthModeAssumeRole: 'Assume Role',
|
||||||
|
awsAccessKeyId: 'AWS Access Key ID',
|
||||||
|
awsSecretAccessKey: 'AWS Secret Access Key',
|
||||||
|
awsRoleArn: 'AWS Role ARN',
|
||||||
|
awsRoleArnMessage: '请输入 AWS Role ARN',
|
||||||
|
awsAssumeRoleTip:
|
||||||
|
'选择此模式后,EC2 实例将使用其已有的 IAM Role 访问 AWS 服务,无需额外的凭证。',
|
||||||
bedrockRegion: 'AWS Region',
|
bedrockRegion: 'AWS Region',
|
||||||
bedrockRegionMessage: '请选择!',
|
bedrockRegionMessage: '请选择!',
|
||||||
'us-east-1': '美国东部 (弗吉尼亚北部)',
|
'us-east-1': '美国东部 (弗吉尼亚北部)',
|
||||||
|
|||||||
@ -1,15 +1,25 @@
|
|||||||
import { useTranslate } from '@/hooks/common-hooks';
|
import { useTranslate } from '@/hooks/common-hooks';
|
||||||
import { IModalProps } from '@/interfaces/common';
|
import { IModalProps } from '@/interfaces/common';
|
||||||
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
|
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
|
||||||
import { Form, Input, InputNumber, Modal, Select, Typography } from 'antd';
|
import {
|
||||||
import { useMemo } from 'react';
|
Form,
|
||||||
|
Input,
|
||||||
|
InputNumber,
|
||||||
|
Modal,
|
||||||
|
Segmented,
|
||||||
|
Select,
|
||||||
|
Typography,
|
||||||
|
} from 'antd';
|
||||||
|
import { useMemo, useState } from 'react';
|
||||||
import { LLMHeader } from '../../components/llm-header';
|
import { LLMHeader } from '../../components/llm-header';
|
||||||
import { BedrockRegionList } from '../../constant';
|
import { BedrockRegionList } from '../../constant';
|
||||||
|
|
||||||
type FieldType = IAddLlmRequestBody & {
|
type FieldType = IAddLlmRequestBody & {
|
||||||
|
auth_mode?: 'access_key_secret' | 'iam_role' | 'assume_role';
|
||||||
bedrock_ak: string;
|
bedrock_ak: string;
|
||||||
bedrock_sk: string;
|
bedrock_sk: string;
|
||||||
bedrock_region: string;
|
bedrock_region: string;
|
||||||
|
aws_role_arn?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const { Option } = Select;
|
const { Option } = Select;
|
||||||
@ -23,6 +33,8 @@ const BedrockModal = ({
|
|||||||
llmFactory,
|
llmFactory,
|
||||||
}: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
|
}: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
|
||||||
const [form] = Form.useForm<FieldType>();
|
const [form] = Form.useForm<FieldType>();
|
||||||
|
const [authMode, setAuthMode] =
|
||||||
|
useState<FieldType['auth_mode']>('access_key_secret');
|
||||||
|
|
||||||
const { t } = useTranslate('setting');
|
const { t } = useTranslate('setting');
|
||||||
const options = useMemo(
|
const options = useMemo(
|
||||||
@ -33,13 +45,32 @@ const BedrockModal = ({
|
|||||||
const handleOk = async () => {
|
const handleOk = async () => {
|
||||||
const values = await form.validateFields();
|
const values = await form.validateFields();
|
||||||
|
|
||||||
|
// Only submit fields related to the active auth mode.
|
||||||
|
const cleanedValues: Record<string, any> = { ...values };
|
||||||
|
|
||||||
|
const fieldsByMode: Record<string, string[]> = {
|
||||||
|
access_key_secret: ['bedrock_ak', 'bedrock_sk'],
|
||||||
|
iam_role: ['aws_role_arn'],
|
||||||
|
assume_role: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
cleanedValues.auth_mode = authMode;
|
||||||
|
|
||||||
|
Object.keys(fieldsByMode).forEach((mode) => {
|
||||||
|
if (mode !== authMode) {
|
||||||
|
fieldsByMode[mode].forEach((field) => {
|
||||||
|
delete cleanedValues[field];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
const data = {
|
const data = {
|
||||||
...values,
|
...cleanedValues,
|
||||||
llm_factory: llmFactory,
|
llm_factory: llmFactory,
|
||||||
max_tokens: values.max_tokens,
|
max_tokens: values.max_tokens,
|
||||||
};
|
};
|
||||||
|
|
||||||
onOk?.(data);
|
onOk?.(data as unknown as IAddLlmRequestBody);
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -47,9 +78,6 @@ const BedrockModal = ({
|
|||||||
title={
|
title={
|
||||||
<div>
|
<div>
|
||||||
<LLMHeader name={llmFactory} />
|
<LLMHeader name={llmFactory} />
|
||||||
<Text type="secondary" style={{ display: 'block', marginTop: 4 }}>
|
|
||||||
{t('bedrockCredentialsHint')}
|
|
||||||
</Text>
|
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
open={visible}
|
open={visible}
|
||||||
@ -82,20 +110,75 @@ const BedrockModal = ({
|
|||||||
>
|
>
|
||||||
<Input placeholder={t('bedrockModelNameMessage')} />
|
<Input placeholder={t('bedrockModelNameMessage')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item<FieldType>
|
|
||||||
label={t('addBedrockEngineAK')}
|
{/* AWS Credential Mode Switch (AK/SK section only) */}
|
||||||
name="bedrock_ak"
|
<Form.Item>
|
||||||
rules={[{ message: t('bedrockAKMessage') }]}
|
<Segmented
|
||||||
>
|
block
|
||||||
<Input placeholder={t('bedrockAKMessage')} />
|
value={authMode}
|
||||||
</Form.Item>
|
onChange={(v) => {
|
||||||
<Form.Item<FieldType>
|
const next = v as FieldType['auth_mode'];
|
||||||
label={t('addBedrockSK')}
|
setAuthMode(next);
|
||||||
name="bedrock_sk"
|
// Clear non-active fields so they won't be validated/submitted by accident.
|
||||||
rules={[{ message: t('bedrockSKMessage') }]}
|
if (next !== 'access_key_secret') {
|
||||||
>
|
form.setFieldsValue({ bedrock_ak: '', bedrock_sk: '' } as any);
|
||||||
<Input placeholder={t('bedrockSKMessage')} />
|
}
|
||||||
|
if (next !== 'iam_role') {
|
||||||
|
form.setFieldsValue({ aws_role_arn: '' } as any);
|
||||||
|
}
|
||||||
|
if (next !== 'assume_role') {
|
||||||
|
form.setFieldsValue({ role_arn: '' } as any);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
options={[
|
||||||
|
{
|
||||||
|
label: t('awsAuthModeAccessKeySecret'),
|
||||||
|
value: 'access_key_secret',
|
||||||
|
},
|
||||||
|
{ label: t('awsAuthModeIamRole'), value: 'iam_role' },
|
||||||
|
{ label: t('awsAuthModeAssumeRole'), value: 'assume_role' },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
|
{authMode === 'access_key_secret' && (
|
||||||
|
<>
|
||||||
|
<Form.Item<FieldType>
|
||||||
|
label={t('awsAccessKeyId')}
|
||||||
|
name="bedrock_ak"
|
||||||
|
rules={[{ required: true, message: t('bedrockAKMessage') }]}
|
||||||
|
>
|
||||||
|
<Input placeholder={t('bedrockAKMessage')} />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item<FieldType>
|
||||||
|
label={t('awsSecretAccessKey')}
|
||||||
|
name="bedrock_sk"
|
||||||
|
rules={[{ required: true, message: t('bedrockSKMessage') }]}
|
||||||
|
>
|
||||||
|
<Input placeholder={t('bedrockSKMessage')} />
|
||||||
|
</Form.Item>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{authMode === 'iam_role' && (
|
||||||
|
<Form.Item<FieldType>
|
||||||
|
label={t('awsRoleArn')}
|
||||||
|
name="aws_role_arn"
|
||||||
|
rules={[{ required: true, message: t('awsRoleArnMessage') }]}
|
||||||
|
>
|
||||||
|
<Input placeholder={t('awsRoleArnMessage')} />
|
||||||
|
</Form.Item>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{authMode === 'assume_role' && (
|
||||||
|
<Form.Item
|
||||||
|
style={{ marginTop: -8, marginBottom: 16 }}
|
||||||
|
// keep layout consistent with other modes
|
||||||
|
>
|
||||||
|
<Text type="secondary">{t('awsAssumeRoleTip')}</Text>
|
||||||
|
</Form.Item>
|
||||||
|
)}
|
||||||
|
|
||||||
<Form.Item<FieldType>
|
<Form.Item<FieldType>
|
||||||
label={t('bedrockRegion')}
|
label={t('bedrockRegion')}
|
||||||
name="bedrock_region"
|
name="bedrock_region"
|
||||||
|
|||||||
Reference in New Issue
Block a user