add support for Google Cloud (#2175)

### What problem does this PR solve?

#1853 add support for Google Cloud

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-09-02 12:06:41 +08:00
committed by GitHub
parent def18308d0
commit 5decdde182
14 changed files with 352 additions and 3 deletions

View File

@ -701,9 +701,13 @@ class GeminiChat(Base):
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
def chat(self,system,history,gen_conf):
from google.generativeai.types import content_types
if system:
history.insert(0, {"role": "user", "parts": system})
self.model._system_instruction = content_types.to_content(system)
if 'max_tokens' in gen_conf:
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
for k in list(gen_conf.keys()):
@ -725,8 +729,10 @@ class GeminiChat(Base):
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
from google.generativeai.types import content_types
if system:
history.insert(0, {"role": "user", "parts": system})
self.model._system_instruction = content_types.to_content(system)
if 'max_tokens' in gen_conf:
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
for k in list(gen_conf.keys()):
@ -1257,3 +1263,154 @@ class AnthropicChat(Base):
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class GoogleChat(Base):
def __init__(self, key, model_name, base_url=None):
from google.oauth2 import service_account
import base64
key = json.load(key)
access_token = json.loads(
base64.b64decode(key.get("google_service_account_key", ""))
)
project_id = key.get("google_project_id", "")
region = key.get("google_region", "")
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
self.model_name = model_name
self.system = ""
if "claude" in self.model_name:
from anthropic import AnthropicVertex
from google.auth.transport.requests import Request
if access_token:
credits = service_account.Credentials.from_service_account_info(
access_token, scopes=scopes
)
request = Request()
credits.refresh(request)
token = credits.token
self.client = AnthropicVertex(
region=region, project_id=project_id, access_token=token
)
else:
self.client = AnthropicVertex(region=region, project_id=project_id)
else:
from google.cloud import aiplatform
import vertexai.generative_models as glm
if access_token:
credits = service_account.Credentials.from_service_account_info(
access_token
)
aiplatform.init(
credentials=credits, project=project_id, location=region
)
else:
aiplatform.init(project=project_id, location=region)
self.client = glm.GenerativeModel(model_name=self.model_name)
def chat(self, system, history, gen_conf):
if system:
self.system = system
if "claude" in self.model_name:
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
try:
response = self.client.messages.create(
model=self.model_name,
messages=history,
system=self.system,
stream=False,
**gen_conf,
).json()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
return (
ans,
response["usage"]["input_tokens"]
+ response["usage"]["output_tokens"],
)
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
else:
self.client._system_instruction = self.system
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_output_tokens"]:
del gen_conf[k]
for item in history:
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "content" in item:
item["parts"] = item.pop("content")
try:
response = self.client.generate_content(
history, generation_config=gen_conf
)
ans = response.text
return ans, response.usage_metadata.total_token_count
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
self.system = system
if "claude" in self.model_name:
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
ans = ""
total_tokens = 0
try:
response = self.client.messages.create(
model=self.model_name,
messages=history,
system=self.system,
stream=True,
**gen_conf,
)
for res in response.iter_lines():
res = res.decode("utf-8")
if "content_block_delta" in res and "data" in res:
text = json.loads(res[6:])["delta"]["text"]
ans += text
total_tokens += num_tokens_from_string(text)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
else:
self.client._system_instruction = self.system
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_output_tokens"]:
del gen_conf[k]
for item in history:
if "role" in item and item["role"] == "assistant":
item["role"] = "model"
if "content" in item:
item["parts"] = item.pop("content")
ans = ""
try:
response = self.model.generate_content(
history, generation_config=gen_conf, stream=True
)
for resp in response:
ans += resp.text
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield response._chunks[-1].usage_metadata.total_token_count