mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add support for Baidu yiyan (#2049)
### What problem does this PR solve? add support for Baidu yiyan ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
@ -1185,3 +1185,69 @@ class SparkChat(Base):
|
||||
}
|
||||
model_version = model2version[model_name]
|
||||
super().__init__(key, model_version, base_url)
|
||||
|
||||
|
||||
class BaiduYiyanChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
import qianfan
|
||||
|
||||
key = json.loads(key)
|
||||
ak = key.get("yiyan_ak","")
|
||||
sk = key.get("yiyan_sk","")
|
||||
self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
|
||||
self.model_name = model_name.lower()
|
||||
self.system = ""
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
gen_conf["penalty_score"] = (
|
||||
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
|
||||
) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
|
||||
try:
|
||||
response = self.client.do(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
system=self.system,
|
||||
**gen_conf
|
||||
).body
|
||||
ans = response['result']
|
||||
return ans, response["usage"]["total_tokens"]
|
||||
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
gen_conf["penalty_score"] = (
|
||||
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
|
||||
) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
|
||||
try:
|
||||
response = self.client.do(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
system=self.system,
|
||||
stream=True,
|
||||
**gen_conf
|
||||
)
|
||||
for resp in response:
|
||||
resp = resp.body
|
||||
ans += resp['result']
|
||||
total_tokens = resp["usage"]["total_tokens"]
|
||||
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
yield total_tokens
|
||||
|
||||
Reference in New Issue
Block a user