support sequence2txt and tts model in Xinference (#2696)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
JobSmithManipulation
2024-10-08 10:43:18 +08:00
committed by GitHub
parent c552a02e7f
commit a3ab5ba9ac
5 changed files with 112 additions and 38 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import requests
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
@ -25,6 +26,7 @@ from rag.utils import num_tokens_from_string
import base64
import re
class Base(ABC):
def __init__(self, key, model_name):
pass
@ -36,8 +38,8 @@ class Base(ABC):
response_format="text"
)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self,audio):
def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
@ -77,13 +79,6 @@ class QWenSeq2txt(Base):
return "**ERROR**: " + result.message, 0
class OllamaSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
class AzureSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
@ -92,16 +87,53 @@ class AzureSeq2txt(Base):
class XinferenceSeq2txt(Base):
def __init__(self, key, model_name="", base_url=""):
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="xxx", base_url=base_url)
def __init__(self,key,model_name="whisper-small",**kwargs):
self.base_url = kwargs.get('base_url', None)
self.model_name = model_name
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, 'rb')
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"
payload = {
"model": self.model_name,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature
}
files = {
"file": (audio_file_name, audio_data, 'audio/wav')
}
try:
response = requests.post(
f"{self.base_url}/v1/audio/transcriptions",
files=files,
data=payload
)
response.raise_for_status()
result = response.json()
if 'text' in result:
transcription_text = result['text'].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0
except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0
class TencentCloudSeq2txt(Base):
def __init__(
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
):
from tencentcloud.common import credential
from tencentcloud.asr.v20190614 import asr_client