Feat: optimize aws s3 connector (#12078)

### What problem does this PR solve?

Feat: optimize aws s3 connector #12008 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Magicbook1108
2025-12-22 19:06:01 +08:00
committed by GitHub
parent 6d3d3a40ab
commit 4cbc91f2fa
5 changed files with 313 additions and 59 deletions

View File

@ -64,15 +64,23 @@ class BlobStorageConnector(LoadConnector, PollConnector):
elif self.bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
if authentication_method == "access_key":
if not all(
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Amazon S3")
elif authentication_method == "iam_role":
if not credentials.get("aws_role_arn"):
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
elif authentication_method == "assume_role":
pass
else:
raise ConnectorMissingCredentialError("Unsupported S3 authentication method")
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
if not all(
@ -293,4 +301,4 @@ if __name__ == "__main__":
except ConnectorMissingCredentialError as e:
print(f"Error: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")
print(f"An unexpected error occurred: {e}")

View File

@ -254,18 +254,21 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
elif bucket_type == BlobType.S3:
authentication_method = credentials.get("authentication_method", "access_key")
region_name = credentials.get("region") or None
if authentication_method == "access_key":
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
region_name=region_name,
)
return session.client("s3")
return session.client("s3", region_name=region_name)
elif authentication_method == "iam_role":
role_arn = credentials["aws_role_arn"]
def _refresh_credentials() -> dict[str, str]:
sts_client = boto3.client("sts")
sts_client = boto3.client("sts", region_name=credentials.get("region") or None)
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName=f"onyx_blob_storage_{int(datetime.now().timestamp())}",
@ -285,11 +288,11 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
)
botocore_session = get_session()
botocore_session._credentials = refreshable
session = boto3.Session(botocore_session=botocore_session)
return session.client("s3")
session = boto3.Session(botocore_session=botocore_session, region_name=region_name)
return session.client("s3", region_name=region_name)
elif authentication_method == "assume_role":
return boto3.client("s3")
return boto3.client("s3", region_name=region_name)
else:
raise ValueError("Invalid authentication method for S3.")

View File

@ -0,0 +1,247 @@
import { useEffect, useMemo, useState } from 'react';
import { useFormContext } from 'react-hook-form';
import { SelectWithSearch } from '@/components/originui/select-with-search';
import { RAGFlowFormItem } from '@/components/ragflow-form';
import { Input } from '@/components/ui/input';
import { Segmented } from '@/components/ui/segmented';
import { t } from 'i18next';
// UI-only auth modes for S3
// access_key: Access Key ID + Secret
// iam_role: only Role ARN
// assume_role: no input fields (uses environment role)
type AuthMode = 'access_key' | 'iam_role' | 'assume_role';
type BlobMode = 's3' | 's3_compatible';
const modeOptions = [
{ label: 'S3', value: 's3' },
{ label: 'S3 Compatible', value: 's3_compatible' },
];
const authOptions = [
{ label: 'Access Key', value: 'access_key' },
{ label: 'IAM Role', value: 'iam_role' },
{ label: 'Assume Role', value: 'assume_role' },
];
const addressingOptions = [
{ label: 'Virtual Hosted Style', value: 'virtual' },
{ label: 'Path Style', value: 'path' },
];
const deriveInitialAuthMode = (credentials: any): AuthMode => {
const authMethod = credentials?.authentication_method;
if (authMethod === 'iam_role') return 'iam_role';
if (authMethod === 'assume_role') return 'assume_role';
if (credentials?.aws_role_arn) return 'iam_role';
if (credentials?.aws_access_key_id || credentials?.aws_secret_access_key)
return 'access_key';
return 'access_key';
};
const deriveInitialMode = (bucketType?: string): BlobMode =>
bucketType === 's3_compatible' ? 's3_compatible' : 's3';
const BlobTokenField = () => {
const form = useFormContext();
const credentials = form.watch('config.credentials');
const watchedBucketType = form.watch('config.bucket_type');
const [mode, setMode] = useState<BlobMode>(
deriveInitialMode(watchedBucketType),
);
const [authMode, setAuthMode] = useState<AuthMode>(() =>
deriveInitialAuthMode(credentials),
);
// Keep bucket_type in sync with UI mode
useEffect(() => {
const nextMode = deriveInitialMode(watchedBucketType);
setMode((prev) => (prev === nextMode ? prev : nextMode));
}, [watchedBucketType]);
useEffect(() => {
form.setValue('config.bucket_type', mode, { shouldDirty: true });
// Default addressing style for compatible mode
if (
mode === 's3_compatible' &&
!form.getValues('config.credentials.addressing_style')
) {
form.setValue('config.credentials.addressing_style', 'virtual', {
shouldDirty: false,
});
}
if (mode === 's3_compatible' && authMode !== 'access_key') {
setAuthMode('access_key');
}
// Persist authentication_method for backend
const nextAuthMethod: AuthMode =
mode === 's3_compatible' ? 'access_key' : authMode;
form.setValue('config.credentials.authentication_method', nextAuthMethod, {
shouldDirty: true,
});
// Clear errors for fields that are not relevant in the current mode/auth selection
const inactiveFields: string[] = [];
if (mode === 's3_compatible') {
inactiveFields.push('config.credentials.aws_role_arn');
} else {
if (authMode === 'iam_role') {
inactiveFields.push('config.credentials.aws_access_key_id');
inactiveFields.push('config.credentials.aws_secret_access_key');
}
if (authMode === 'assume_role') {
inactiveFields.push('config.credentials.aws_access_key_id');
inactiveFields.push('config.credentials.aws_secret_access_key');
inactiveFields.push('config.credentials.aws_role_arn');
}
}
if (inactiveFields.length) {
form.clearErrors(inactiveFields as any);
}
}, [form, mode, authMode]);
const isS3 = mode === 's3';
const requiresAccessKey =
authMode === 'access_key' || mode === 's3_compatible';
const requiresRoleArn = isS3 && authMode === 'iam_role';
// Help text for assume role (no inputs)
const assumeRoleNote = useMemo(
() => t('No credentials required. Uses the default environment role.'),
[t],
);
return (
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-2">
<div className="text-sm text-text-secondary">Mode</div>
<Segmented
options={modeOptions}
value={mode}
onChange={(val) => setMode(val as BlobMode)}
className="w-full"
itemClassName="flex-1 justify-center"
/>
</div>
{isS3 && (
<div className="flex flex-col gap-2">
<div className="text-sm text-text-secondary">Authentication</div>
<Segmented
options={authOptions}
value={authMode}
onChange={(val) => setAuthMode(val as AuthMode)}
className="w-full"
itemClassName="flex-1 justify-center"
/>
</div>
)}
{requiresAccessKey && (
<RAGFlowFormItem
name="config.credentials.aws_access_key_id"
label="AWS Access Key ID"
required={requiresAccessKey}
rules={{
validate: (val) =>
requiresAccessKey
? Boolean(val) || 'Access Key ID is required'
: true,
}}
>
{(field) => (
<Input {...field} placeholder="AKIA..." autoComplete="off" />
)}
</RAGFlowFormItem>
)}
{requiresAccessKey && (
<RAGFlowFormItem
name="config.credentials.aws_secret_access_key"
label="AWS Secret Access Key"
required={requiresAccessKey}
rules={{
validate: (val) =>
requiresAccessKey
? Boolean(val) || 'Secret Access Key is required'
: true,
}}
>
{(field) => (
<Input
{...field}
type="password"
placeholder="****************"
autoComplete="new-password"
/>
)}
</RAGFlowFormItem>
)}
{requiresRoleArn && (
<RAGFlowFormItem
name="config.credentials.aws_role_arn"
label="Role ARN"
required={requiresRoleArn}
tooltip="The role will be assumed by the runtime environment."
rules={{
validate: (val) =>
requiresRoleArn ? Boolean(val) || 'Role ARN is required' : true,
}}
>
{(field) => (
<Input
{...field}
placeholder="arn:aws:iam::123456789012:role/YourRole"
autoComplete="off"
/>
)}
</RAGFlowFormItem>
)}
{isS3 && authMode === 'assume_role' && (
<div className="text-sm text-text-secondary bg-bg-card border border-border-button rounded-md px-3 py-2">
{assumeRoleNote}
</div>
)}
{mode === 's3_compatible' && (
<div className="flex flex-col gap-4">
<RAGFlowFormItem
name="config.credentials.addressing_style"
label="Addressing Style"
tooltip={t('setting.S3CompatibleAddressingStyleTip')}
required={false}
>
{(field) => (
<SelectWithSearch
triggerClassName="!shrink"
options={addressingOptions}
value={field.value || 'virtual'}
onChange={(val) => field.onChange(val)}
/>
)}
</RAGFlowFormItem>
<RAGFlowFormItem
name="config.credentials.endpoint_url"
label="Endpoint URL"
required={false}
tooltip={t('setting.S3CompatibleEndpointUrlTip')}
>
{(field) => (
<Input
{...field}
placeholder="https://fsn1.your-objectstorage.com"
autoComplete="off"
/>
)}
</RAGFlowFormItem>
</div>
)}
</div>
);
};
export default BlobTokenField;

View File

@ -3,6 +3,8 @@ import SvgIcon from '@/components/svg-icon';
import { t, TFunction } from 'i18next';
import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { BedrockRegionList } from '../setting-model/constant';
import BlobTokenField from './component/blob-token-field';
import BoxTokenField from './component/box-token-field';
import { ConfluenceIndexingModeField } from './component/confluence-token-field';
import GmailTokenField from './component/gmail-token-field';
@ -105,6 +107,11 @@ export const generateDataSourceInfo = (t: TFunction) => {
};
};
const awsRegionOptions = BedrockRegionList.map((r) => ({
label: r,
value: r,
}));
export const useDataSourceInfo = () => {
const { t } = useTranslation();
const [dataSourceInfo, setDataSourceInfo] = useState<IDataSourceInfoMap>(
@ -222,18 +229,6 @@ export const DataSourceFormFields = {
},
],
[DataSourceKey.S3]: [
{
label: 'AWS Access Key ID',
name: 'config.credentials.aws_access_key_id',
type: FormFieldType.Text,
required: true,
},
{
label: 'AWS Secret Access Key',
name: 'config.credentials.aws_secret_access_key',
type: FormFieldType.Password,
required: true,
},
{
label: 'Bucket Name',
name: 'config.bucket_name',
@ -241,39 +236,21 @@ export const DataSourceFormFields = {
required: true,
},
{
label: 'Bucket Type',
name: 'config.bucket_type',
label: 'Region',
name: 'config.credentials.region',
type: FormFieldType.Select,
options: [
{ label: 'S3', value: 's3' },
{ label: 'S3 Compatible', value: 's3_compatible' },
],
required: true,
},
{
label: 'Addressing Style',
name: 'config.credentials.addressing_style',
type: FormFieldType.Select,
options: [
{ label: 'Virtual Hosted Style', value: 'virtual' },
{ label: 'Path Style', value: 'path' },
],
required: false,
placeholder: 'Virtual Hosted Style',
tooltip: t('setting.S3CompatibleAddressingStyleTip'),
shouldRender: (formValues: any) => {
return formValues?.config?.bucket_type === 's3_compatible';
},
},
{
label: 'Endpoint URL',
name: 'config.credentials.endpoint_url',
type: FormFieldType.Text,
required: false,
placeholder: 'https://fsn1.your-objectstorage.com',
tooltip: t('setting.S3CompatibleEndpointUrlTip'),
shouldRender: (formValues: any) => {
return formValues?.config?.bucket_type === 's3_compatible';
options: awsRegionOptions,
customValidate: (val: string, formValues: any) => {
const credentials = formValues?.config?.credentials || {};
const bucketType = formValues?.config?.bucket_type || 's3';
const hasAccessKey = Boolean(
credentials.aws_access_key_id || credentials.aws_secret_access_key,
);
if (bucketType === 's3' && hasAccessKey) {
return Boolean(val) || 'Region is required when using access key';
}
return true;
},
},
{
@ -283,6 +260,14 @@ export const DataSourceFormFields = {
required: false,
tooltip: t('setting.s3PrefixTip'),
},
{
label: 'Credentials',
name: 'config.credentials.__blob_token',
type: FormFieldType.Custom,
hideLabel: true,
required: false,
render: () => <BlobTokenField />,
},
],
[DataSourceKey.NOTION]: [
{
@ -700,6 +685,9 @@ export const DataSourceFormDefaultValues = {
credentials: {
aws_access_key_id: '',
aws_secret_access_key: '',
region: '',
authentication_method: 'access_key',
aws_role_arn: '',
endpoint_url: '',
addressing_style: 'virtual',
},

View File

@ -10,8 +10,7 @@ import { Input } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator';
import { RunningStatus } from '@/constants/knowledge';
import { t } from 'i18next';
import { debounce } from 'lodash';
import { CirclePause, Repeat } from 'lucide-react';
import { CirclePause, Loader2, Repeat } from 'lucide-react';
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { FieldValues } from 'react-hook-form';
import {
@ -120,11 +119,11 @@ const SourceDetailPage = () => {
];
}, [detail, runSchedule]);
const { handleAddOk } = useAddDataSource();
const { addLoading, handleAddOk } = useAddDataSource();
const onSubmit = useCallback(() => {
formRef?.current?.submit();
}, [formRef]);
}, []);
useEffect(() => {
if (detail) {
@ -140,9 +139,7 @@ const SourceDetailPage = () => {
return {
...field,
horizontal: true,
onChange: () => {
onSubmit();
},
onChange: undefined,
};
});
setFields(newFields);
@ -175,12 +172,23 @@ const SourceDetailPage = () => {
<DynamicForm.Root
ref={formRef}
fields={fields}
onSubmit={debounce((data) => {
handleAddOk(data);
}, 500)}
onSubmit={(data) => handleAddOk(data)}
defaultValues={defaultValues}
/>
</div>
<div className="max-w-[1200px] flex justify-end">
<button
type="button"
onClick={onSubmit}
disabled={addLoading}
className="flex items-center justify-center min-w-[100px] px-4 py-2 bg-primary text-white rounded-md disabled:opacity-60"
>
{addLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
{addLoading
? t('modal.loadingText', { defaultValue: 'Submitting...' })
: t('modal.okText', { defaultValue: 'Submit' })}
</button>
</div>
<section className="flex flex-col gap-2">
<div className="text-2xl text-text-primary mb-2">
{t('setting.log')}