Added support for Baichuan LLM (#934)

### What problem does this PR solve?

- Added support for Baichuan LLM

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: 海贼宅 <stu_xyx@163.com>
This commit is contained in:
yungongzi
2024-05-28 09:09:37 +08:00
committed by GitHub
parent ec6ae744a1
commit 9ffd7ae321
6 changed files with 169 additions and 3 deletions

View File

@ -26,7 +26,8 @@ EmbeddingModel = {
"ZHIPU-AI": ZhipuEmbed,
"FastEmbed": FastEmbed,
"Youdao": YoudaoEmbed,
"DeepSeek": DefaultEmbedding
"DeepSeek": DefaultEmbedding,
"BaiChuan": BaiChuanEmbed
}
@ -47,6 +48,7 @@ ChatModel = {
"Ollama": OllamaChat,
"Xinference": XinferenceChat,
"Moonshot": MoonshotChat,
"DeepSeek": DeepSeekChat
"DeepSeek": DeepSeekChat,
"BaiChuan": BaiChuanChat
}

View File

@ -95,6 +95,84 @@ class DeepSeekChat(Base):
super().__init__(key, model_name, base_url)
class BaiChuanChat(Base):
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
@staticmethod
def _format_params(params):
return {
"temperature": params.get("temperature", 0.3),
"max_tokens": params.get("max_tokens", 2048),
"top_p": params.get("top_p", 0.85),
}
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
**self._format_params(gen_conf))
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
return ans, response.usage.total_tokens
except openai.APIError as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={
"tools": [{
"type": "web_search",
"web_search": {
"enable": True,
"search_mode": "performance_first"
}
}]
},
stream=True,
**self._format_params(gen_conf))
for resp in response:
if resp.choices[0].finish_reason == "stop":
if not resp.choices[0].delta.content:
continue
total_tokens = resp.usage.get('total_tokens', 0)
if not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class QWenChat(Base):
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
import dashscope

View File

@ -104,6 +104,15 @@ class OpenAIEmbed(Base):
return np.array(res.data[0].embedding), res.usage.total_tokens
class BaiChuanEmbed(OpenAIEmbed):
def __init__(self, key,
model_name='Baichuan-Text-Embedding',
base_url='https://api.baichuan-ai.com/v1'):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url)
class QWenEmbed(Base):
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
dashscope.api_key = key