Refa: automatic LLMs registration (#8651)

### What problem does this PR solve?

Support automatic LLMs registration.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-07-03 19:05:31 +08:00
committed by GitHub
parent 3234a15aae
commit f8a6987f1e
7 changed files with 619 additions and 876 deletions

View File

@ -70,10 +70,12 @@ class Base(ABC):
pass
def normalize_text(self, text):
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
class FishAudioTTS(Base):
_FACTORY_NAME = "Fish Audio"
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
if not base_url:
base_url = "https://api.fish.audio/v1/tts"
@ -94,13 +96,11 @@ class FishAudioTTS(Base):
with httpx.Client() as client:
try:
with client.stream(
method="POST",
url=self.base_url,
content=ormsgpack.packb(
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
),
headers=self.headers,
timeout=None,
method="POST",
url=self.base_url,
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
headers=self.headers,
timeout=None,
) as response:
if response.status_code == HTTPStatus.OK:
for chunk in response.iter_bytes():
@ -115,6 +115,8 @@ class FishAudioTTS(Base):
class QwenTTS(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name, base_url=""):
import dashscope
@ -122,10 +124,11 @@ class QwenTTS(Base):
dashscope.api_key = key
def tts(self, text):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
from collections import deque
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
@ -159,10 +162,7 @@ class QwenTTS(Base):
text = self.normalize_text(text)
callback = Callback()
SpeechSynthesizer.call(model=self.model_name,
text=text,
callback=callback,
format="mp3")
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
try:
for data in callback._run():
yield data
@ -173,24 +173,19 @@ class QwenTTS(Base):
class OpenAITTS(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def tts(self, text, voice="alloy"):
text = self.normalize_text(text)
payload = {
"model": self.model_name,
"voice": voice,
"input": text
}
payload = {"model": self.model_name, "voice": voice, "input": text}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
@ -201,7 +196,8 @@ class OpenAITTS(Base):
yield chunk
class SparkTTS:
class SparkTTS(Base):
_FACTORY_NAME = "XunFei Spark"
STATUS_FIRST_FRAME = 0
STATUS_CONTINUE_FRAME = 1
STATUS_LAST_FRAME = 2
@ -219,29 +215,23 @@ class SparkTTS:
# 生成url
def create_url(self):
url = 'wss://tts-api.xfyun.cn/v2/tts'
url = "wss://tts-api.xfyun.cn/v2/tts"
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": "ws-api.xfyun.cn"
}
url = url + '?' + urlencode(v)
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
url = url + "?" + urlencode(v)
return url
def tts(self, text):
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
CommonArgs = {"app_id": self.APPID}
audio_queue = self.audio_queue
model_name = self.model_name
@ -273,9 +263,7 @@ class SparkTTS:
def on_open(self, ws):
def run(*args):
d = {"common": CommonArgs,
"business": BusinessArgs,
"data": Data}
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
ws.send(json.dumps(d))
thread.start_new_thread(run, ())
@ -283,44 +271,32 @@ class SparkTTS:
wsUrl = self.create_url()
websocket.enableTrace(False)
a = Callback()
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
on_message=a.on_message)
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
status_code = 0
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
while True:
audio_chunk = self.audio_queue.get()
if audio_chunk is None:
if status_code == 0:
raise Exception(
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
else:
break
status_code = 1
yield audio_chunk
class XinferenceTTS:
class XinferenceTTS(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json"
}
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
def tts(self, text, voice="中文女", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}
payload = {"model": self.model_name, "input": text, "voice": voice}
response = requests.post(
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
@ -332,22 +308,16 @@ class XinferenceTTS:
class OllamaTTS(Base):
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
if not base_url:
if not base_url:
base_url = "https://api.ollama.ai/v1"
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Content-Type": "application/json"
}
self.headers = {"Content-Type": "application/json"}
if key and key != "x":
self.headers["Authorization"] = f"Bear {key}"
def tts(self, text, voice="standard-voice"):
payload = {
"model": self.model_name,
"voice": voice,
"input": text
}
payload = {"model": self.model_name, "voice": voice, "input": text}
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
@ -359,30 +329,19 @@ class OllamaTTS(Base):
yield chunk
class GPUStackTTS:
class GPUStackTTS(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.api_key = key
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def tts(self, text, voice="Chinese Female", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}
payload = {"model": self.model_name, "input": text, "voice": voice}
response = requests.post(
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
@ -393,16 +352,15 @@ class GPUStackTTS:
class SILICONFLOWTTS(Base):
_FACTORY_NAME = "SILICONFLOW"
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
if not base_url:
base_url = "https://api.siliconflow.cn/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
def tts(self, text, voice="anna"):
text = self.normalize_text(text)
@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base):
"sample_rate": 123,
"stream": True,
"speed": 1,
"gain": 0
"gain": 0,
}
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)