mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
support sequence2txt and tts model in Xinference (#2696)
### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c552a02e7f
commit
a3ab5ba9ac
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import requests
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from zhipuai import ZhipuAI
|
||||
import io
|
||||
@ -25,6 +26,7 @@ from rag.utils import num_tokens_from_string
|
||||
import base64
|
||||
import re
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
pass
|
||||
@ -36,8 +38,8 @@ class Base(ABC):
|
||||
response_format="text"
|
||||
)
|
||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||
|
||||
def audio2base64(self,audio):
|
||||
|
||||
def audio2base64(self, audio):
|
||||
if isinstance(audio, bytes):
|
||||
return base64.b64encode(audio).decode("utf-8")
|
||||
if isinstance(audio, io.BytesIO):
|
||||
@ -77,13 +79,6 @@ class QWenSeq2txt(Base):
|
||||
return "**ERROR**: " + result.message, 0
|
||||
|
||||
|
||||
class OllamaSeq2txt(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = Client(host=kwargs["base_url"])
|
||||
self.model_name = model_name
|
||||
self.lang = lang
|
||||
|
||||
|
||||
class AzureSeq2txt(Base):
|
||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||
@ -92,16 +87,53 @@ class AzureSeq2txt(Base):
|
||||
|
||||
|
||||
class XinferenceSeq2txt(Base):
|
||||
def __init__(self, key, model_name="", base_url=""):
|
||||
if base_url.split("/")[-1] != "v1":
|
||||
base_url = os.path.join(base_url, "v1")
|
||||
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
||||
def __init__(self,key,model_name="whisper-small",**kwargs):
|
||||
self.base_url = kwargs.get('base_url', None)
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||
if isinstance(audio, str):
|
||||
audio_file = open(audio, 'rb')
|
||||
audio_data = audio_file.read()
|
||||
audio_file_name = audio.split("/")[-1]
|
||||
else:
|
||||
audio_data = audio
|
||||
audio_file_name = "audio.wav"
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"language": language,
|
||||
"prompt": prompt,
|
||||
"response_format": response_format,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
files = {
|
||||
"file": (audio_file_name, audio_data, 'audio/wav')
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/audio/transcriptions",
|
||||
files=files,
|
||||
data=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if 'text' in result:
|
||||
transcription_text = result['text'].strip()
|
||||
return transcription_text, num_tokens_from_string(transcription_text)
|
||||
else:
|
||||
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"**ERROR**: {str(e)}", 0
|
||||
|
||||
|
||||
class TencentCloudSeq2txt(Base):
|
||||
def __init__(
|
||||
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
||||
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
||||
):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.asr.v20190614 import asr_client
|
||||
|
||||
Reference in New Issue
Block a user