mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-31 23:55:06 +08:00
Fix: Hunyuan cannot work properly (#12843)
### What problem does this PR solve? Hunyuan cannot work properly ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -57,6 +57,7 @@ class SupportedLiteLLMProvider(StrEnum):
|
||||
OpenAI = "OpenAI"
|
||||
Azure_OpenAI = "Azure-OpenAI"
|
||||
n1n = "n1n"
|
||||
HunYuan = "Tencent Hunyuan"
|
||||
|
||||
|
||||
FACTORY_DEFAULT_BASE_URL = {
|
||||
@ -83,6 +84,7 @@ FACTORY_DEFAULT_BASE_URL = {
|
||||
SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
|
||||
SupportedLiteLLMProvider.OpenAI: "https://api.openai.com/v1",
|
||||
SupportedLiteLLMProvider.n1n: "https://api.n1n.ai/v1",
|
||||
SupportedLiteLLMProvider.HunYuan: "https://api.hunyuan.cloud.tencent.com/v1",
|
||||
}
|
||||
|
||||
|
||||
@ -121,6 +123,7 @@ LITELLM_PROVIDER_PREFIX = {
|
||||
SupportedLiteLLMProvider.OpenAI: "openai/",
|
||||
SupportedLiteLLMProvider.Azure_OpenAI: "azure/",
|
||||
SupportedLiteLLMProvider.n1n: "openai/",
|
||||
SupportedLiteLLMProvider.HunYuan: "openai/",
|
||||
}
|
||||
|
||||
ChatModel = globals().get("ChatModel", {})
|
||||
|
||||
@ -34,8 +34,6 @@ from common.token_utils import num_tokens_from_string, total_token_count_from_re
|
||||
from rag.llm import FACTORY_DEFAULT_BASE_URL, LITELLM_PROVIDER_PREFIX, SupportedLiteLLMProvider
|
||||
from rag.nlp import is_chinese, is_english
|
||||
|
||||
# Error message constants
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
class LLMErrorCode(StrEnum):
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
@ -106,7 +104,7 @@ class Base(ABC):
|
||||
if "gpt-5" in model_name_lower:
|
||||
gen_conf = {}
|
||||
return gen_conf
|
||||
|
||||
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
@ -793,84 +791,6 @@ class ReplicateChat(Base):
|
||||
yield num_tokens_from_string(ans)
|
||||
|
||||
|
||||
class HunyuanChat(Base):
|
||||
_FACTORY_NAME = "Tencent Hunyuan"
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client
|
||||
|
||||
key = json.loads(key)
|
||||
sid = key.get("hunyuan_sid", "")
|
||||
sk = key.get("hunyuan_sk", "")
|
||||
cred = credential.Credential(sid, sk)
|
||||
self.model_name = model_name
|
||||
self.client = hunyuan_client.HunyuanClient(cred, "")
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
_gen_conf = {}
|
||||
if "temperature" in gen_conf:
|
||||
_gen_conf["Temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
_gen_conf["TopP"] = gen_conf["top_p"]
|
||||
return _gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
|
||||
hist = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {"Model": self.model_name, "Messages": hist, **gen_conf}
|
||||
req.from_json_string(json.dumps(params))
|
||||
response = self.client.ChatCompletions(req)
|
||||
ans = response.Choices[0].Message.Content
|
||||
return ans, response.Usage.TotalTokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
|
||||
_gen_conf = {}
|
||||
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||
if system and history and history[0].get("role") != "system":
|
||||
_history.insert(0, {"Role": "system", "Content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if "temperature" in gen_conf:
|
||||
_gen_conf["Temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
_gen_conf["TopP"] = gen_conf["top_p"]
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Model": self.model_name,
|
||||
"Messages": _history,
|
||||
"Stream": True,
|
||||
**_gen_conf,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.ChatCompletions(req)
|
||||
for resp in response:
|
||||
resp = json.loads(resp["data"])
|
||||
if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]:
|
||||
continue
|
||||
ans = resp["Choices"][0]["Delta"]["Content"]
|
||||
total_tokens += 1
|
||||
|
||||
yield ans
|
||||
|
||||
except TencentCloudSDKException as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class SparkChat(Base):
|
||||
_FACTORY_NAME = "XunFei Spark"
|
||||
|
||||
@ -1209,6 +1129,7 @@ class LiteLLMBase(ABC):
|
||||
"GPUStack",
|
||||
"OpenAI",
|
||||
"Azure-OpenAI",
|
||||
"Tencent Hunyuan",
|
||||
]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
@ -1259,6 +1180,11 @@ class LiteLLMBase(ABC):
|
||||
return LLMErrorCode.ERROR_GENERIC
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if self.provider == SupportedLiteLLMProvider.HunYuan:
|
||||
unsupported = ["presence_penalty", "frequency_penalty"]
|
||||
for key in unsupported:
|
||||
gen_conf.pop(key, None)
|
||||
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
return gen_conf
|
||||
@ -1704,3 +1630,4 @@ class LiteLLMBase(ABC):
|
||||
if extra_headers:
|
||||
completion_args["extra_headers"] = extra_headers
|
||||
return completion_args
|
||||
|
||||
|
||||
Reference in New Issue
Block a user