mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Synchronize with enterprise version (#4325)
### Type of change - [x] Refactoring
This commit is contained in:
@ -299,8 +299,6 @@ class SparkTTS:
|
||||
yield audio_chunk
|
||||
|
||||
|
||||
|
||||
|
||||
class XinferenceTTS:
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
@ -330,3 +328,30 @@ class XinferenceTTS:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OllamaTTS(Base):
|
||||
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.ollama.ai/v1"
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def tts(self, text, voice="standard-voice"):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"voice": voice,
|
||||
"input": text
|
||||
}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content():
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
Reference in New Issue
Block a user