Feat: update stepfun list (#12991)

### What problem does this PR solve?

Update stepfun list.

Add TTS and Sequence2Text functionalities.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2026-02-05 12:47:04 +08:00
committed by GitHub
parent 803b480f9c
commit 6361fc4b33
3 changed files with 122 additions and 11 deletions

View File

@ -59,6 +59,15 @@ class GPTSeq2txt(Base):
self.model_name = model_name
class StepFunSeq2txt(GPTSeq2txt):
_FACTORY_NAME = "StepFun"
def __init__(self, key, model_name="step-asr", lang="Chinese", base_url="https://api.stepfun.com/v1", **kwargs):
if not base_url:
base_url = "https://api.stepfun.com/v1"
super().__init__(key, model_name=model_name, base_url=base_url, **kwargs)
class QWenSeq2txt(Base):
_FACTORY_NAME = "Tongyi-Qianwen"

View File

@ -19,6 +19,7 @@ import base64
import hashlib
import hmac
import json
import os
import queue
import re
import ssl
@ -36,6 +37,7 @@ import requests
import websocket
from pydantic import BaseModel, conint
from common.http_client import sync_request
from common.token_utils import num_tokens_from_string
@ -387,6 +389,7 @@ class SILICONFLOWTTS(Base):
if chunk:
yield chunk
class DeepInfraTTS(OpenAITTS):
_FACTORY_NAME = "DeepInfra"
@ -394,7 +397,8 @@ class DeepInfraTTS(OpenAITTS):
if not base_url:
base_url = "https://api.deepinfra.com/v1/openai"
super().__init__(key, model_name, base_url, **kwargs)
class CometAPITTS(OpenAITTS):
_FACTORY_NAME = "CometAPI"
@ -402,7 +406,8 @@ class CometAPITTS(OpenAITTS):
if not base_url:
base_url = "https://api.cometapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class DeerAPITTS(OpenAITTS):
_FACTORY_NAME = "DeerAPI"
@ -410,3 +415,37 @@ class DeerAPITTS(OpenAITTS):
if not base_url:
base_url = "https://api.deerapi.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class StepFunTTS(OpenAITTS):
_FACTORY_NAME = "StepFun"
_SUPPORTED_RESPONSE_FORMATS = {"wav", "mp3", "flac", "opus", "pcm"}
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
if not base_url:
base_url = "https://api.stepfun.com/v1"
self.default_voice = os.environ.get("STEPFUN_TTS_VOICE") or "cixingnansheng"
super().__init__(key, model_name, base_url, **kwargs)
def tts(self, text, voice=None, response_format: Literal["wav", "mp3", "flac", "opus", "pcm"] = "mp3"):
text = self.normalize_text(text)
if response_format not in self._SUPPORTED_RESPONSE_FORMATS:
raise ValueError(f"Unsupported response_format={response_format!r}. Supported: {sorted(self._SUPPORTED_RESPONSE_FORMATS)}")
payload = {
"model": self.model_name,
"voice": voice or self.default_voice,
"input": text,
"response_format": response_format,
}
response = sync_request("POST", f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_bytes():
if chunk:
yield chunk
yield num_tokens_from_string(text)