Add Support for AWS Bedrock (#1408)

### What problem does this PR solve?

#308 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: KevinHuSh <kevinhu.sh@gmail.com>
This commit is contained in:
H
2024-07-08 09:37:34 +08:00
committed by GitHub
parent b3ebc66b13
commit 6144a109ab
8 changed files with 325 additions and 7 deletions

View File

@ -533,3 +533,90 @@ class MistralChat(Base):
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class BedrockChat(Base):
def __init__(self, key, model_name, **kwargs):
import boto3
from botocore.exceptions import ClientError
self.bedrock_ak = eval(key).get('bedrock_ak', '')
self.bedrock_sk = eval(key).get('bedrock_sk', '')
self.bedrock_region = eval(key).get('bedrock_region', '')
self.model_name = model_name
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
if "max_tokens" in gen_conf:
gen_conf["maxTokens"] = gen_conf["max_tokens"]
_ = gen_conf.pop("max_tokens")
if "top_p" in gen_conf:
gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p")
try:
# Send the message to the model, using a basic inference configuration.
response = self.client.converse(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf
)
# Extract and print the response text.
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)
except (ClientError, Exception) as e:
return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
if "max_tokens" in gen_conf:
gen_conf["maxTokens"] = gen_conf["max_tokens"]
_ = gen_conf.pop("max_tokens")
if "top_p" in gen_conf:
gen_conf["topP"] = gen_conf["top_p"]
_ = gen_conf.pop("top_p")
if self.model_name.split('.')[0] == 'ai21':
try:
response = self.client.converse(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf
)
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)
except (ClientError, Exception) as e:
return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
ans = ""
try:
# Send the message to the model, using a basic inference configuration.
streaming_response = self.client.converse_stream(
modelId=self.model_name,
messages=history,
inferenceConfig=gen_conf
)
# Extract and print the streamed response text in real-time.
for resp in streaming_response["stream"]:
if "contentBlockDelta" in resp:
ans += resp["contentBlockDelta"]["delta"]["text"]
yield ans
except (ClientError, Exception) as e:
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
yield num_tokens_from_string(ans)