mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
support sequence2txt and tts model in Xinference (#2696)
### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c552a02e7f
commit
a3ab5ba9ac
@ -47,10 +47,9 @@ EmbeddingModel = {
|
||||
"Replicate": ReplicateEmbed,
|
||||
"BaiduYiyan": BaiduYiyanEmbed,
|
||||
"Voyage AI": VoyageEmbed,
|
||||
"HuggingFace":HuggingFaceEmbed,
|
||||
"HuggingFace": HuggingFaceEmbed,
|
||||
}
|
||||
|
||||
|
||||
CvModel = {
|
||||
"OpenAI": GptV4,
|
||||
"Azure-OpenAI": AzureGptV4,
|
||||
@ -64,14 +63,13 @@ CvModel = {
|
||||
"LocalAI": LocalAICV,
|
||||
"NVIDIA": NvidiaCV,
|
||||
"LM-Studio": LmStudioCV,
|
||||
"StepFun":StepFunCV,
|
||||
"StepFun": StepFunCV,
|
||||
"OpenAI-API-Compatible": OpenAI_APICV,
|
||||
"TogetherAI": TogetherAICV,
|
||||
"01.AI": YiCV,
|
||||
"Tencent Hunyuan": HunyuanCV
|
||||
}
|
||||
|
||||
|
||||
ChatModel = {
|
||||
"OpenAI": GptTurbo,
|
||||
"Azure-OpenAI": AzureChat,
|
||||
@ -99,7 +97,7 @@ ChatModel = {
|
||||
"LeptonAI": LeptonAIChat,
|
||||
"TogetherAI": TogetherAIChat,
|
||||
"PerfXCloud": PerfXCloudChat,
|
||||
"Upstage":UpstageChat,
|
||||
"Upstage": UpstageChat,
|
||||
"novita.ai": NovitaAIChat,
|
||||
"SILICONFLOW": SILICONFLOWChat,
|
||||
"01.AI": YiChat,
|
||||
@ -111,7 +109,6 @@ ChatModel = {
|
||||
"Google Cloud": GoogleChat,
|
||||
}
|
||||
|
||||
|
||||
RerankModel = {
|
||||
"BAAI": DefaultRerank,
|
||||
"Jina": JinaRerank,
|
||||
@ -127,11 +124,9 @@ RerankModel = {
|
||||
"Voyage AI": VoyageRerank
|
||||
}
|
||||
|
||||
|
||||
Seq2txtModel = {
|
||||
"OpenAI": GPTSeq2txt,
|
||||
"Tongyi-Qianwen": QWenSeq2txt,
|
||||
"Ollama": OllamaSeq2txt,
|
||||
"Azure-OpenAI": AzureSeq2txt,
|
||||
"Xinference": XinferenceSeq2txt,
|
||||
"Tencent Cloud": TencentCloudSeq2txt
|
||||
@ -140,6 +135,7 @@ Seq2txtModel = {
|
||||
TTSModel = {
|
||||
"Fish Audio": FishAudioTTS,
|
||||
"Tongyi-Qianwen": QwenTTS,
|
||||
"OpenAI":OpenAITTS,
|
||||
"XunFei Spark":SparkTTS
|
||||
}
|
||||
"OpenAI": OpenAITTS,
|
||||
"XunFei Spark": SparkTTS,
|
||||
"Xinference": XinferenceTTS,
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import requests
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
import io
|
||||
@ -25,6 +26,7 @@ from rag.utils import num_tokens_from_string
|
||||
import base64
|
||||
import re
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
@ -36,8 +38,8 @@ class Base(ABC):
|
||||
response_format="text"
|
||||
)
|
||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||
|
||||
def audio2base64(self,audio):
|
||||
|
||||
def audio2base64(self, audio):
|
||||
if isinstance(audio, bytes):
|
||||
return base64.b64encode(audio).decode("utf-8")
|
||||
if isinstance(audio, io.BytesIO):
|
||||
@ -77,13 +79,6 @@ class QWenSeq2txt(Base):
|
||||
return "**ERROR**: " + result.message, 0
|
||||
|
||||
|
||||
class OllamaSeq2txt(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
|
||||
class AzureSeq2txt(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||
@ -92,16 +87,53 @@ class AzureSeq2txt(Base):
|
||||
|
||||
|
||||
class XinferenceSeq2txt(Base):
|
||||
def __init__(self, key, model_name="", base_url=""):
|
||||
if base_url.split("/")[-1] != "v1":
|
||||
base_url = os.path.join(base_url, "v1")
|
||||
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
||||
def __init__(self,key,model_name="whisper-small",**kwargs):
|
||||
self.base_url = kwargs.get('base_url', None)
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||
if isinstance(audio, str):
|
||||
audio_file = open(audio, 'rb')
|
||||
audio_data = audio_file.read()
|
||||
audio_file_name = audio.split("/")[-1]
|
||||
else:
|
||||
audio_data = audio
|
||||
audio_file_name = "audio.wav"
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"language": language,
|
||||
"prompt": prompt,
|
||||
"response_format": response_format,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
files = {
|
||||
"file": (audio_file_name, audio_data, 'audio/wav')
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if 'text' in result:
|
||||
transcription_text = result['text'].strip()
|
||||
return transcription_text, num_tokens_from_string(transcription_text)
|
||||
else:
|
||||
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"**ERROR**: {str(e)}", 0
|
||||
|
||||
|
||||
class TencentCloudSeq2txt(Base):
|
||||
def __init__(
|
||||
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
||||
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
||||
):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.asr.v20190614 import asr_client
|
||||
|
||||
@ -297,3 +297,36 @@ class SparkTTS:
|
||||
break
|
||||
status_code = 1
|
||||
yield audio_chunk
|
||||
|
||||
|
||||
|
||||
|
||||
class XinferenceTTS:
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
self.base_url = kwargs.get("base_url", None)
|
||||
self.model_name = model_name
|
||||
self.headers = {
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def tts(self, text, voice="中文女", stream=True):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"input": text,
|
||||
"voice": voice
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/audio/speech",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
Reference in New Issue
Block a user