mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 08:35:08 +08:00
Refa: automatic LLMs registration (#8651)
### What problem does this PR solve? Support automatic LLMs registration. ### Type of change - [x] Refactoring
This commit is contained in:
@ -70,10 +70,12 @@ class Base(ABC):
|
||||
pass
|
||||
|
||||
def normalize_text(self, text):
|
||||
return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
|
||||
return re.sub(r"(\*\*|##\d+\$\$|#)", "", text)
|
||||
|
||||
|
||||
class FishAudioTTS(Base):
|
||||
_FACTORY_NAME = "Fish Audio"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.fish.audio/v1/tts"):
|
||||
if not base_url:
|
||||
base_url = "https://api.fish.audio/v1/tts"
|
||||
@ -94,13 +96,11 @@ class FishAudioTTS(Base):
|
||||
with httpx.Client() as client:
|
||||
try:
|
||||
with client.stream(
|
||||
method="POST",
|
||||
url=self.base_url,
|
||||
content=ormsgpack.packb(
|
||||
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
||||
),
|
||||
headers=self.headers,
|
||||
timeout=None,
|
||||
method="POST",
|
||||
url=self.base_url,
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
headers=self.headers,
|
||||
timeout=None,
|
||||
) as response:
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
for chunk in response.iter_bytes():
|
||||
@ -115,6 +115,8 @@ class FishAudioTTS(Base):
|
||||
|
||||
|
||||
class QwenTTS(Base):
|
||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||
|
||||
def __init__(self, key, model_name, base_url=""):
|
||||
import dashscope
|
||||
|
||||
@ -122,10 +124,11 @@ class QwenTTS(Base):
|
||||
dashscope.api_key = key
|
||||
|
||||
def tts(self, text):
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
||||
from collections import deque
|
||||
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult, SpeechSynthesizer
|
||||
|
||||
class Callback(ResultCallback):
|
||||
def __init__(self) -> None:
|
||||
self.dque = deque()
|
||||
@ -159,10 +162,7 @@ class QwenTTS(Base):
|
||||
|
||||
text = self.normalize_text(text)
|
||||
callback = Callback()
|
||||
SpeechSynthesizer.call(model=self.model_name,
|
||||
text=text,
|
||||
callback=callback,
|
||||
format="mp3")
|
||||
SpeechSynthesizer.call(model=self.model_name, text=text, callback=callback, format="mp3")
|
||||
try:
|
||||
for data in callback._run():
|
||||
yield data
|
||||
@ -173,24 +173,19 @@ class QwenTTS(Base):
|
||||
|
||||
|
||||
class OpenAITTS(Base):
|
||||
_FACTORY_NAME = "OpenAI"
|
||||
|
||||
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="alloy"):
|
||||
text = self.normalize_text(text)
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"voice": voice,
|
||||
"input": text
|
||||
}
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
@ -201,7 +196,8 @@ class OpenAITTS(Base):
|
||||
yield chunk
|
||||
|
||||
|
||||
class SparkTTS:
|
||||
class SparkTTS(Base):
|
||||
_FACTORY_NAME = "XunFei Spark"
|
||||
STATUS_FIRST_FRAME = 0
|
||||
STATUS_CONTINUE_FRAME = 1
|
||||
STATUS_LAST_FRAME = 2
|
||||
@ -219,29 +215,23 @@ class SparkTTS:
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = 'wss://tts-api.xfyun.cn/v2/tts'
|
||||
url = "wss://tts-api.xfyun.cn/v2/tts"
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
url = url + '?' + urlencode(v)
|
||||
signature_sha = hmac.new(self.APISecret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
|
||||
url = url + "?" + urlencode(v)
|
||||
return url
|
||||
|
||||
def tts(self, text):
|
||||
BusinessArgs = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.model_name, "tte": "utf8"}
|
||||
Data = {"status": 2, "text": base64.b64encode(text.encode('utf-8')).decode('utf-8')}
|
||||
Data = {"status": 2, "text": base64.b64encode(text.encode("utf-8")).decode("utf-8")}
|
||||
CommonArgs = {"app_id": self.APPID}
|
||||
audio_queue = self.audio_queue
|
||||
model_name = self.model_name
|
||||
@ -273,9 +263,7 @@ class SparkTTS:
|
||||
|
||||
def on_open(self, ws):
|
||||
def run(*args):
|
||||
d = {"common": CommonArgs,
|
||||
"business": BusinessArgs,
|
||||
"data": Data}
|
||||
d = {"common": CommonArgs, "business": BusinessArgs, "data": Data}
|
||||
ws.send(json.dumps(d))
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
@ -283,44 +271,32 @@ class SparkTTS:
|
||||
wsUrl = self.create_url()
|
||||
websocket.enableTrace(False)
|
||||
a = Callback()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close,
|
||||
on_message=a.on_message)
|
||||
ws = websocket.WebSocketApp(wsUrl, on_open=a.on_open, on_error=a.on_error, on_close=a.on_close, on_message=a.on_message)
|
||||
status_code = 0
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
while True:
|
||||
audio_chunk = self.audio_queue.get()
|
||||
if audio_chunk is None:
|
||||
if status_code == 0:
|
||||
raise Exception(
|
||||
f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||
raise Exception(f"Fail to access model({model_name}) using the provided credentials. **ERROR**: Invalid APPID, API Secret, or API Key.")
|
||||
else:
|
||||
break
|
||||
status_code = 1
|
||||
yield audio_chunk
|
||||
|
||||
|
||||
class XinferenceTTS:
|
||||
class XinferenceTTS(Base):
|
||||
_FACTORY_NAME = "Xinference"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="中文女", stream=True):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": voice
|
||||
}
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/audio/speech",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
stream=stream
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
@ -332,22 +308,16 @@ class XinferenceTTS:
|
||||
|
||||
class OllamaTTS(Base):
|
||||
def __init__(self, key, model_name="ollama-tts", base_url="https://api.ollama.ai/v1"):
|
||||
if not base_url:
|
||||
if not base_url:
|
||||
base_url = "https://api.ollama.ai/v1"
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
if key and key != "x":
|
||||
self.headers["Authorization"] = f"Bear {key}"
|
||||
|
||||
def tts(self, text, voice="standard-voice"):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"voice": voice,
|
||||
"input": text
|
||||
}
|
||||
payload = {"model": self.model_name, "voice": voice, "input": text}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/tts", headers=self.headers, json=payload, stream=True)
|
||||
|
||||
@ -359,30 +329,19 @@ class OllamaTTS(Base):
|
||||
yield chunk
|
||||
|
||||
|
||||
class GPUStackTTS:
|
||||
class GPUStackTTS(Base):
|
||||
_FACTORY_NAME = "GPUStack"
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
self.headers = {"accept": "application/json", "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def tts(self, text, voice="Chinese Female", stream=True):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": voice
|
||||
}
|
||||
payload = {"model": self.model_name, "input": text, "voice": voice}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/audio/speech",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
stream=stream
|
||||
)
|
||||
response = requests.post(f"{self.base_url}/v1/audio/speech", headers=self.headers, json=payload, stream=stream)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
@ -393,16 +352,15 @@ class GPUStackTTS:
|
||||
|
||||
|
||||
class SILICONFLOWTTS(Base):
|
||||
_FACTORY_NAME = "SILICONFLOW"
|
||||
|
||||
def __init__(self, key, model_name="FunAudioLLM/CosyVoice2-0.5B", base_url="https://api.siliconflow.cn/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.siliconflow.cn/v1"
|
||||
self.api_key = key
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
def tts(self, text, voice="anna"):
|
||||
text = self.normalize_text(text)
|
||||
@ -414,7 +372,7 @@ class SILICONFLOWTTS(Base):
|
||||
"sample_rate": 123,
|
||||
"stream": True,
|
||||
"speed": 1,
|
||||
"gain": 0
|
||||
"gain": 0,
|
||||
}
|
||||
|
||||
response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload)
|
||||
|
||||
Reference in New Issue
Block a user