mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-21 13:32:49 +08:00
Feat: Further update Bedrock model configs (#12029)
### What problem does this PR solve? Feat: Further update Bedrock model configs #12020 #12008 <img width="700" alt="2b4f0f7fab803a2a2d5f345c756a2c69" src="https://github.com/user-attachments/assets/e1b9eaad-5c60-47bd-a6f4-88a104ce0c63" /> <img width="700" alt="afe88ec3c58f745f85c5c507b040c250" src="https://github.com/user-attachments/assets/9de39745-395d-4145-930b-96eb452ad6ef" /> <img width="700" alt="1a21bb2b7cd8003dce1e5207f27efc69" src="https://github.com/user-attachments/assets/ddba1682-6654-4954-aa71-41b8ebc04ac0" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -463,20 +463,44 @@ class BedrockEmbed(Base):
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
import boto3
|
||||
# `key` protocol (backend stores as JSON string in `api_key`):
|
||||
# - Must decode into a dict.
|
||||
# - Required: `auth_mode`, `bedrock_region`.
|
||||
# - Supported auth modes:
|
||||
# - "access_key_secret": requires `bedrock_ak` + `bedrock_sk`.
|
||||
# - "iam_role": requires `aws_role_arn` and assumes role via STS.
|
||||
# - else: treated as "assume_role" (default AWS credential chain).
|
||||
key = json.loads(key)
|
||||
mode = key.get("auth_mode")
|
||||
if not mode:
|
||||
logging.error("Bedrock auth_mode is not provided in the key")
|
||||
raise ValueError("Bedrock auth_mode must be provided in the key")
|
||||
|
||||
self.bedrock_region = key.get("bedrock_region")
|
||||
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
self.model_name = model_name
|
||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "":
|
||||
# Try to create a client using the default credentials if ak/sk are not provided.
|
||||
# Must provide a region.
|
||||
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
||||
else:
|
||||
|
||||
if mode == "access_key_secret":
|
||||
self.bedrock_ak = key.get("bedrock_ak")
|
||||
self.bedrock_sk = key.get("bedrock_sk")
|
||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
elif mode == "iam_role":
|
||||
self.aws_role_arn = key.get("aws_role_arn")
|
||||
sts_client = boto3.client("sts", region_name=self.bedrock_region)
|
||||
resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockSession")
|
||||
creds = resp["Credentials"]
|
||||
|
||||
self.client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=creds["AccessKeyId"],
|
||||
aws_secret_access_key=creds["SecretAccessKey"],
|
||||
aws_session_token=creds["SessionToken"],
|
||||
)
|
||||
else: # assume_role
|
||||
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
||||
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
|
||||
Reference in New Issue
Block a user