mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add support for Tencent Hunyuan (#2015)
### What problem does this PR solve? #1853 ### Type of change - [X] New Feature (non-breaking change which adds functionality) Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
@ -1088,3 +1088,83 @@ class ReplicateChat(Base):
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield num_tokens_from_string(ans)
|
||||
|
||||
|
||||
class HunyuanChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||
|
||||
key = json.loads(key)
|
||||
sid = key.get("hunyuan_sid", "")
|
||||
sk = key.get("hunyuan_sk", "")
|
||||
cred = credential.Credential(sid, sk)
|
||||
self.model_name = model_name
|
||||
self.client = hunyuan_client.HunyuanClient(cred, "")
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
|
||||
_gen_conf = {}
|
||||
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
|
||||
if system:
|
||||
_history.insert(0, {"Role": "system", "Content": system})
|
||||
if "temperature" in gen_conf:
|
||||
_gen_conf["Temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
_gen_conf["TopP"] = gen_conf["top_p"]
|
||||
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {"Model": self.model_name, "Messages": _history, **_gen_conf}
|
||||
req.from_json_string(json.dumps(params))
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.ChatCompletions(req)
|
||||
ans = response.Choices[0].Message.Content
|
||||
return ans, response.Usage.TotalTokens
|
||||
except TencentCloudSDKException as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
|
||||
_gen_conf = {}
|
||||
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
|
||||
if system:
|
||||
_history.insert(0, {"Role": "system", "Content": system})
|
||||
|
||||
if "temperature" in gen_conf:
|
||||
_gen_conf["Temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
_gen_conf["TopP"] = gen_conf["top_p"]
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Model": self.model_name,
|
||||
"Messages": _history,
|
||||
"Stream": True,
|
||||
**_gen_conf,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.ChatCompletions(req)
|
||||
for resp in response:
|
||||
resp = json.loads(resp["data"])
|
||||
if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
|
||||
continue
|
||||
ans += resp["Choices"][0]["Delta"]["Content"]
|
||||
total_tokens += 1
|
||||
|
||||
yield ans
|
||||
|
||||
except TencentCloudSDKException as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
Reference in New Issue
Block a user