Synchronize with enterprise version (#4325)

### Type of change

- [x] Refactoring
This commit is contained in:
Yingfeng
2025-01-02 13:44:44 +08:00
committed by GitHub
parent 564277736a
commit 50f209204e
6 changed files with 94 additions and 69 deletions

View File

@ -299,8 +299,6 @@ class SparkTTS:
yield audio_chunk
class XinferenceTTS:
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
@ -330,3 +328,30 @@ class XinferenceTTS:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk
class OllamaTTS(Base):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
if not base_url:
base_url = "https://api.ollama.ai/v1"
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Content-Type": "application/json"
}
def tts(self, text, voice="standard-voice"):
payload = {
"model": self.model_name,
"voice": voice,
"input": text
}
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content():
if chunk:
yield chunk