add support for XunFei Spark (#2017)

### What problem does this PR solve?

#1853  add support for XunFei Spark

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Zhedong Cen <cenzhedong2@126.com>
This commit is contained in:
黄腾
2024-08-20 16:56:42 +08:00
committed by GitHub
parent 02985fc905
commit be431449bd
12 changed files with 190 additions and 6 deletions

View File

@ -100,7 +100,8 @@ ChatModel = {
"SILICONFLOW": SILICONFLOWChat,
"01.AI": YiChat,
"Replicate": ReplicateChat,
"Tencent Hunyuan": HunyuanChat
"Tencent Hunyuan": HunyuanChat,
"XunFei Spark": SparkChat
}

View File

@ -1133,12 +1133,12 @@ class HunyuanChat(Base):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
_gen_conf = {}
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
if system:
_history.insert(0, {"Role": "system", "Content": system})
if "temperature" in gen_conf:
_gen_conf["Temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
@ -1168,3 +1168,20 @@ class HunyuanChat(Base):
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class SparkChat(Base):
def __init__(
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
):
if not base_url:
base_url = "https://spark-api-open.xf-yun.com/v1"
model2version = {
"Spark-Max": "generalv3.5",
"Spark-Lite": "general",
"Spark-Pro": "generalv3",
"Spark-Pro-128K": "pro-128k",
"Spark-4.0-Ultra": "4.0Ultra",
}
model_version = model2version[model_name]
super().__init__(key, model_version, base_url)