mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add support for Tencent Hunyuan (#2015)
### What problem does this PR solve? #1853 ### Type of change - [X] New Feature (non-breaking change which adds functionality) Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
@ -63,7 +63,8 @@ CvModel = {
|
||||
"StepFun":StepFunCV,
|
||||
"OpenAI-API-Compatible": OpenAI_APICV,
|
||||
"TogetherAI": TogetherAICV,
|
||||
"01.AI": YiCV
|
||||
"01.AI": YiCV,
|
||||
"Tencent Hunyuan": HunyuanCV
|
||||
}
|
||||
|
||||
|
||||
@ -98,7 +99,8 @@ ChatModel = {
|
||||
"novita.ai": NovitaAIChat,
|
||||
"SILICONFLOW": SILICONFLOWChat,
|
||||
"01.AI": YiChat,
|
||||
"Replicate": ReplicateChat
|
||||
"Replicate": ReplicateChat,
|
||||
"Tencent Hunyuan": HunyuanChat
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1088,3 +1088,83 @@ class ReplicateChat(Base):
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield num_tokens_from_string(ans)
|
||||
|
||||
|
||||
class HunyuanChat(Base):
|
||||
def __init__(self, key, model_name, base_url=None):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||
|
||||
key = json.loads(key)
|
||||
sid = key.get("hunyuan_sid", "")
|
||||
sk = key.get("hunyuan_sk", "")
|
||||
cred = credential.Credential(sid, sk)
|
||||
self.model_name = model_name
|
||||
self.client = hunyuan_client.HunyuanClient(cred, "")
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
|
||||
_gen_conf = {}
|
||||
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
|
||||
if system:
|
||||
_history.insert(0, {"Role": "system", "Content": system})
|
||||
if "temperature" in gen_conf:
|
||||
_gen_conf["Temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
_gen_conf["TopP"] = gen_conf["top_p"]
|
||||
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {"Model": self.model_name, "Messages": _history, **_gen_conf}
|
||||
req.from_json_string(json.dumps(params))
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.ChatCompletions(req)
|
||||
ans = response.Choices[0].Message.Content
|
||||
return ans, response.Usage.TotalTokens
|
||||
except TencentCloudSDKException as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
|
||||
_gen_conf = {}
|
||||
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
|
||||
if system:
|
||||
_history.insert(0, {"Role": "system", "Content": system})
|
||||
|
||||
if "temperature" in gen_conf:
|
||||
_gen_conf["Temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
_gen_conf["TopP"] = gen_conf["top_p"]
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Model": self.model_name,
|
||||
"Messages": _history,
|
||||
"Stream": True,
|
||||
**_gen_conf,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.ChatCompletions(req)
|
||||
for resp in response:
|
||||
resp = json.loads(resp["data"])
|
||||
if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
|
||||
continue
|
||||
ans += resp["Choices"][0]["Delta"]["Content"]
|
||||
total_tokens += 1
|
||||
|
||||
yield ans
|
||||
|
||||
except TencentCloudSDKException as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
@ -664,4 +664,56 @@ class YiCV(GptV4):
|
||||
def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
|
||||
if not base_url:
|
||||
base_url = "https://api.lingyiwanwu.com/v1"
|
||||
super().__init__(key, model_name,lang,base_url)
|
||||
super().__init__(key, model_name,lang,base_url)
|
||||
|
||||
|
||||
class HunyuanCV(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese",base_url=None):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||
|
||||
key = json.loads(key)
|
||||
sid = key.get("hunyuan_sid", "")
|
||||
sk = key.get("hunyuan_sk", "")
|
||||
cred = credential.Credential(sid, sk)
|
||||
self.model_name = model_name
|
||||
self.client = hunyuan_client.HunyuanClient(cred, "")
|
||||
self.lang = lang
|
||||
|
||||
def describe(self, image, max_tokens=4096):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
|
||||
b64 = self.image2base64(image)
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {"Model": self.model_name, "Messages": self.prompt(b64)}
|
||||
req.from_json_string(json.dumps(params))
|
||||
ans = ""
|
||||
try:
|
||||
response = self.client.ChatCompletions(req)
|
||||
ans = response.Choices[0].Message.Content
|
||||
return ans, response.Usage.TotalTokens
|
||||
except TencentCloudSDKException as e:
|
||||
return ans + "\n**ERROR**: " + str(e), 0
|
||||
|
||||
def prompt(self, b64):
|
||||
return [
|
||||
{
|
||||
"Role": "user",
|
||||
"Contents": [
|
||||
{
|
||||
"Type": "image_url",
|
||||
"ImageUrl": {
|
||||
"Url": f"data:image/jpeg;base64,{b64}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"Type": "text",
|
||||
"Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
|
||||
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
Reference in New Issue
Block a user