mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
add support for Tencent Cloud ASR (#2102)
### What problem does this PR solve? add support for Tencent Cloud ASR ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -128,7 +128,8 @@ Seq2txtModel = {
|
||||
"Tongyi-Qianwen": QWenSeq2txt,
|
||||
"Ollama": OllamaSeq2txt,
|
||||
"Azure-OpenAI": AzureSeq2txt,
|
||||
"Xinference": XinferenceSeq2txt
|
||||
"Xinference": XinferenceSeq2txt,
|
||||
"Tencent Cloud": TencentCloudSeq2txt
|
||||
}
|
||||
|
||||
TTSModel = {
|
||||
|
||||
@ -22,7 +22,8 @@ from openai import OpenAI
|
||||
import os
|
||||
import json
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
import base64
|
||||
import re
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name):
|
||||
@ -35,6 +36,13 @@ class Base(ABC):
|
||||
response_format="text"
|
||||
)
|
||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||
|
||||
def audio2base64(self,audio):
|
||||
if isinstance(audio, bytes):
|
||||
return base64.b64encode(audio).decode("utf-8")
|
||||
if isinstance(audio, io.BytesIO):
|
||||
return base64.b64encode(audio.getvalue()).decode("utf-8")
|
||||
raise TypeError("The input audio file should be in binary format.")
|
||||
|
||||
|
||||
class GPTSeq2txt(Base):
|
||||
@ -87,3 +95,66 @@ class XinferenceSeq2txt(Base):
|
||||
def __init__(self, key, model_name="", base_url=""):
|
||||
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
class TencentCloudSeq2txt(Base):
|
||||
def __init__(
|
||||
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
||||
):
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.asr.v20190614 import asr_client
|
||||
|
||||
key = json.loads(key)
|
||||
sid = key.get("tencent_cloud_sid", "")
|
||||
sk = key.get("tencent_cloud_sk", "")
|
||||
cred = credential.Credential(sid, sk)
|
||||
self.client = asr_client.AsrClient(cred, "")
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio, max_retries=60, retry_interval=5):
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
from tencentcloud.asr.v20190614 import models
|
||||
import time
|
||||
|
||||
b64 = self.audio2base64(audio)
|
||||
try:
|
||||
# dispatch disk
|
||||
req = models.CreateRecTaskRequest()
|
||||
params = {
|
||||
"EngineModelType": self.model_name,
|
||||
"ChannelNum": 1,
|
||||
"ResTextFormat": 0,
|
||||
"SourceType": 1,
|
||||
"Data": b64,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
resp = self.client.CreateRecTask(req)
|
||||
|
||||
# loop query
|
||||
req = models.DescribeTaskStatusRequest()
|
||||
params = {"TaskId": resp.Data.TaskId}
|
||||
req.from_json_string(json.dumps(params))
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
resp = self.client.DescribeTaskStatus(req)
|
||||
if resp.Data.StatusStr == "success":
|
||||
text = re.sub(
|
||||
r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
|
||||
).strip()
|
||||
return text, num_tokens_from_string(text)
|
||||
elif resp.Data.StatusStr == "failed":
|
||||
return (
|
||||
"**ERROR**: Failed to retrieve speech recognition results.",
|
||||
0,
|
||||
)
|
||||
else:
|
||||
time.sleep(retry_interval)
|
||||
retries += 1
|
||||
return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
|
||||
|
||||
except TencentCloudSDKException as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
Reference in New Issue
Block a user