mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Added support for Baichuan LLM (#934)
### What problem does this PR solve? - Added support for Baichuan LLM ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: 海贼宅 <stu_xyx@163.com>
This commit is contained in:
@ -26,7 +26,8 @@ EmbeddingModel = {
|
||||
"ZHIPU-AI": ZhipuEmbed,
|
||||
"FastEmbed": FastEmbed,
|
||||
"Youdao": YoudaoEmbed,
|
||||
"DeepSeek": DefaultEmbedding
|
||||
"DeepSeek": DefaultEmbedding,
|
||||
"BaiChuan": BaiChuanEmbed
|
||||
}
|
||||
|
||||
|
||||
@ -47,6 +48,7 @@ ChatModel = {
|
||||
"Ollama": OllamaChat,
|
||||
"Xinference": XinferenceChat,
|
||||
"Moonshot": MoonshotChat,
|
||||
"DeepSeek": DeepSeekChat
|
||||
"DeepSeek": DeepSeekChat,
|
||||
"BaiChuan": BaiChuanChat
|
||||
}
|
||||
|
||||
|
||||
@ -95,6 +95,84 @@ class DeepSeekChat(Base):
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class BaiChuanChat(Base):
|
||||
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
@staticmethod
|
||||
def _format_params(params):
|
||||
return {
|
||||
"temperature": params.get("temperature", 0.3),
|
||||
"max_tokens": params.get("max_tokens", 2048),
|
||||
"top_p": params.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={
|
||||
"tools": [{
|
||||
"type": "web_search",
|
||||
"web_search": {
|
||||
"enable": True,
|
||||
"search_mode": "performance_first"
|
||||
}
|
||||
}]
|
||||
},
|
||||
**self._format_params(gen_conf))
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return ans, response.usage.total_tokens
|
||||
except openai.APIError as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={
|
||||
"tools": [{
|
||||
"type": "web_search",
|
||||
"web_search": {
|
||||
"enable": True,
|
||||
"search_mode": "performance_first"
|
||||
}
|
||||
}]
|
||||
},
|
||||
stream=True,
|
||||
**self._format_params(gen_conf))
|
||||
for resp in response:
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
total_tokens = resp.usage.get('total_tokens', 0)
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
ans += resp.choices[0].delta.content
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class QWenChat(Base):
|
||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
|
||||
import dashscope
|
||||
|
||||
@ -104,6 +104,15 @@ class OpenAIEmbed(Base):
|
||||
return np.array(res.data[0].embedding), res.usage.total_tokens
|
||||
|
||||
|
||||
class BaiChuanEmbed(OpenAIEmbed):
|
||||
def __init__(self, key,
|
||||
model_name='Baichuan-Text-Embedding',
|
||||
base_url='https://api.baichuan-ai.com/v1'):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class QWenEmbed(Base):
|
||||
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
||||
dashscope.api_key = key
|
||||
|
||||
Reference in New Issue
Block a user