mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-06 10:35:06 +08:00
Feat: update stepfun list (#12991)
### What problem does this PR solve? Update stepfun list. Add TTS and Sequence2Text functionalities. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -59,6 +59,15 @@ class GPTSeq2txt(Base):
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class StepFunSeq2txt(GPTSeq2txt):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
|
||||
def __init__(self, key, model_name="step-asr", lang="Chinese", base_url="https://api.stepfun.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
super().__init__(key, model_name=model_name, base_url=base_url, **kwargs)
|
||||
|
||||
|
||||
class QWenSeq2txt(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import ssl
|
||||
@ -36,6 +37,7 @@ import requests
|
||||
import websocket
|
||||
from pydantic import BaseModel, conint
|
||||
|
||||
from common.http_client import sync_request
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
|
||||
@ -387,6 +389,7 @@ class SILICONFLOWTTS(Base):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class DeepInfraTTS(OpenAITTS):
|
||||
_FACTORY_NAME = "DeepInfra"
|
||||
|
||||
@ -394,7 +397,8 @@ class DeepInfraTTS(OpenAITTS):
|
||||
if not base_url:
|
||||
base_url = "https://api.deepinfra.com/v1/openai"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
|
||||
class CometAPITTS(OpenAITTS):
|
||||
_FACTORY_NAME = "CometAPI"
|
||||
|
||||
@ -402,7 +406,8 @@ class CometAPITTS(OpenAITTS):
|
||||
if not base_url:
|
||||
base_url = "https://api.cometapi.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
|
||||
class DeerAPITTS(OpenAITTS):
|
||||
_FACTORY_NAME = "DeerAPI"
|
||||
|
||||
@ -410,3 +415,37 @@ class DeerAPITTS(OpenAITTS):
|
||||
if not base_url:
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class StepFunTTS(OpenAITTS):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
_SUPPORTED_RESPONSE_FORMATS = {"wav", "mp3", "flac", "opus", "pcm"}
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1", **kwargs):
|
||||
if not base_url:
|
||||
base_url = "https://api.stepfun.com/v1"
|
||||
self.default_voice = os.environ.get("STEPFUN_TTS_VOICE") or "cixingnansheng"
|
||||
super().__init__(key, model_name, base_url, **kwargs)
|
||||
|
||||
def tts(self, text, voice=None, response_format: Literal["wav", "mp3", "flac", "opus", "pcm"] = "mp3"):
|
||||
text = self.normalize_text(text)
|
||||
if response_format not in self._SUPPORTED_RESPONSE_FORMATS:
|
||||
raise ValueError(f"Unsupported response_format={response_format!r}. Supported: {sorted(self._SUPPORTED_RESPONSE_FORMATS)}")
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"voice": voice or self.default_voice,
|
||||
"input": text,
|
||||
"response_format": response_format,
|
||||
}
|
||||
|
||||
response = sync_request("POST", f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_bytes():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
yield num_tokens_from_string(text)
|
||||
|
||||
Reference in New Issue
Block a user