mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add support for Anthropic (#2148)
### What problem does this PR solve? #1853 add support for Anthropic ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -104,7 +104,8 @@ ChatModel = {
|
||||
"Replicate": ReplicateChat,
|
||||
"Tencent Hunyuan": HunyuanChat,
|
||||
"XunFei Spark": SparkChat,
|
||||
"BaiduYiyan": BaiduYiyanChat
|
||||
"BaiduYiyan": BaiduYiyanChat,
|
||||
"Anthropic": AnthropicChat
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1132,7 +1132,7 @@ class SparkChat(Base):
|
||||
class BaiduYiyanChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import qianfan
|
||||
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak","")
|
||||
sk = key.get("yiyan_sk","")
|
||||
@ -1149,7 +1149,7 @@ class BaiduYiyanChat(Base):
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
|
||||
|
||||
try:
|
||||
response = self.client.do(
|
||||
model=self.model_name,
|
||||
@ -1159,7 +1159,7 @@ class BaiduYiyanChat(Base):
|
||||
).body
|
||||
ans = response['result']
|
||||
return ans, response["usage"]["total_tokens"]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
@ -1173,7 +1173,7 @@ class BaiduYiyanChat(Base):
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
|
||||
|
||||
try:
|
||||
response = self.client.do(
|
||||
model=self.model_name,
|
||||
@ -1193,3 +1193,67 @@ class BaiduYiyanChat(Base):
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class AnthropicChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import anthropic
|
||||
|
||||
self.client = anthropic.Anthropic(api_key=key)
|
||||
self.model_name = model_name
|
||||
self.system = ""
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
if "max_tokens" not in gen_conf:
|
||||
gen_conf["max_tokens"] = 4096
|
||||
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
system=self.system,
|
||||
stream=False,
|
||||
**gen_conf,
|
||||
).json()
|
||||
ans = response["content"][0]["text"]
|
||||
if response["stop_reason"] == "max_tokens":
|
||||
ans += (
|
||||
"...\nFor the content length reason, it stopped, continue?"
|
||||
if is_english([ans])
|
||||
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
)
|
||||
return (
|
||||
ans,
|
||||
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
|
||||
)
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
if "max_tokens" not in gen_conf:
|
||||
gen_conf["max_tokens"] = 4096
|
||||
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
system=self.system,
|
||||
stream=True,
|
||||
**gen_conf,
|
||||
)
|
||||
for res in response.iter_lines():
|
||||
res = res.decode("utf-8")
|
||||
if "content_block_delta" in res and "data" in res:
|
||||
text = json.loads(res[6:])["delta"]["text"]
|
||||
ans += text
|
||||
total_tokens += num_tokens_from_string(text)
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
Reference in New Issue
Block a user