add support for Anthropic (#2148)

### What problem does this PR solve?

#1853  add support for Anthropic

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
黄腾
2024-08-29 13:30:06 +08:00
committed by GitHub
parent 0abc01311b
commit 06abef66ef
7 changed files with 124 additions and 5 deletions

View File

@ -104,7 +104,8 @@ ChatModel = {
"Replicate": ReplicateChat,
"Tencent Hunyuan": HunyuanChat,
"XunFei Spark": SparkChat,
"BaiduYiyan": BaiduYiyanChat
"BaiduYiyan": BaiduYiyanChat,
"Anthropic": AnthropicChat
}

View File

@ -1132,7 +1132,7 @@ class SparkChat(Base):
class BaiduYiyanChat(Base):
def __init__(self, key, model_name, base_url=None):
import qianfan
key = json.loads(key)
ak = key.get("yiyan_ak","")
sk = key.get("yiyan_sk","")
@ -1149,7 +1149,7 @@ class BaiduYiyanChat(Base):
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
ans = ""
try:
response = self.client.do(
model=self.model_name,
@ -1159,7 +1159,7 @@ class BaiduYiyanChat(Base):
).body
ans = response['result']
return ans, response["usage"]["total_tokens"]
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
@ -1173,7 +1173,7 @@ class BaiduYiyanChat(Base):
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
ans = ""
total_tokens = 0
try:
response = self.client.do(
model=self.model_name,
@ -1193,3 +1193,67 @@ class BaiduYiyanChat(Base):
return ans + "\n**ERROR**: " + str(e), 0
yield total_tokens
class AnthropicChat(Base):
def __init__(self, key, model_name, base_url=None):
import anthropic
self.client = anthropic.Anthropic(api_key=key)
self.model_name = model_name
self.system = ""
def chat(self, system, history, gen_conf):
if system:
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
try:
response = self.client.messages.create(
model=self.model_name,
messages=history,
system=self.system,
stream=False,
**gen_conf,
).json()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
return (
ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
)
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
ans = ""
total_tokens = 0
try:
response = self.client.messages.create(
model=self.model_name,
messages=history,
system=self.system,
stream=True,
**gen_conf,
)
for res in response.iter_lines():
res = res.decode("utf-8")
if "content_block_delta" in res and "data" in res:
text = json.loads(res[6:])["delta"]["text"]
ans += text
total_tokens += num_tokens_from_string(text)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens