add support for mistral (#1153)

### What problem does this PR solve?

#433 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-06-14 11:32:58 +08:00
committed by GitHub
parent a25d32496c
commit 7dc39cbfa6
4 changed files with 141 additions and 3 deletions

View File

@ -472,3 +472,57 @@ class MiniMaxChat(Base):
if not base_url:
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"
super().__init__(key, model_name, base_url)
class MistralChat(Base):
def __init__(self, key, model_name, base_url=None):
from mistralai.client import MistralClient
self.client = MistralClient(api_key=key)
self.model_name = model_name
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
try:
response = self.client.chat(
model=self.model_name,
messages=history,
**gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
if k not in ["temperature", "top_p", "max_tokens"]:
del gen_conf[k]
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(
model=self.model_name,
messages=history,
**gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens