mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-24 07:26:47 +08:00
Feat: optimize aws s3 connector (#12078)
### What problem does this PR solve? Feat: optimize aws s3 connector #12008 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -64,15 +64,23 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
elif self.bucket_type == BlobType.S3:
|
||||
authentication_method = credentials.get("authentication_method", "access_key")
|
||||
|
||||
if authentication_method == "access_key":
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Amazon S3")
|
||||
|
||||
elif authentication_method == "iam_role":
|
||||
if not credentials.get("aws_role_arn"):
|
||||
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
|
||||
|
||||
elif authentication_method == "assume_role":
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ConnectorMissingCredentialError("Unsupported S3 authentication method")
|
||||
|
||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||
if not all(
|
||||
@ -293,4 +301,4 @@ if __name__ == "__main__":
|
||||
except ConnectorMissingCredentialError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
|
||||
@ -254,18 +254,21 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
|
||||
elif bucket_type == BlobType.S3:
|
||||
authentication_method = credentials.get("authentication_method", "access_key")
|
||||
|
||||
region_name = credentials.get("region") or None
|
||||
|
||||
if authentication_method == "access_key":
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
aws_secret_access_key=credentials["aws_secret_access_key"],
|
||||
region_name=region_name,
|
||||
)
|
||||
return session.client("s3")
|
||||
return session.client("s3", region_name=region_name)
|
||||
|
||||
elif authentication_method == "iam_role":
|
||||
role_arn = credentials["aws_role_arn"]
|
||||
|
||||
def _refresh_credentials() -> dict[str, str]:
|
||||
sts_client = boto3.client("sts")
|
||||
sts_client = boto3.client("sts", region_name=credentials.get("region") or None)
|
||||
assumed_role_object = sts_client.assume_role(
|
||||
RoleArn=role_arn,
|
||||
RoleSessionName=f"onyx_blob_storage_{int(datetime.now().timestamp())}",
|
||||
@ -285,11 +288,11 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
|
||||
)
|
||||
botocore_session = get_session()
|
||||
botocore_session._credentials = refreshable
|
||||
session = boto3.Session(botocore_session=botocore_session)
|
||||
return session.client("s3")
|
||||
session = boto3.Session(botocore_session=botocore_session, region_name=region_name)
|
||||
return session.client("s3", region_name=region_name)
|
||||
|
||||
elif authentication_method == "assume_role":
|
||||
return boto3.client("s3")
|
||||
return boto3.client("s3", region_name=region_name)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid authentication method for S3.")
|
||||
|
||||
Reference in New Issue
Block a user