mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-23 03:26:53 +08:00
Feat: API adds audio to text and text to speech functions (#12764)
### What problem does this PR solve? API adds audio to text and text to speech functions ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -19,6 +19,10 @@ import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
|
||||
from quart import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
@ -35,7 +39,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.user_service import TenantService,UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
@ -1220,3 +1224,93 @@ async def mindmap():
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
||||
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def sequence2txt(tenant_id):
|
||||
req = await request.form
|
||||
stream_mode = req.get("stream", "false").lower() == "true"
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_error_data_result(message="Missing 'file' in multipart form-data")
|
||||
|
||||
uploaded = files["file"]
|
||||
|
||||
ALLOWED_EXTS = {
|
||||
".wav", ".mp3", ".m4a", ".aac",
|
||||
".flac", ".ogg", ".webm",
|
||||
".opus", ".wma"
|
||||
}
|
||||
|
||||
filename = uploaded.filename or ""
|
||||
suffix = os.path.splitext(filename)[-1].lower()
|
||||
if suffix not in ALLOWED_EXTS:
|
||||
return get_error_data_result(message=
|
||||
f"Unsupported audio format: {suffix}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
|
||||
)
|
||||
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
|
||||
os.close(fd)
|
||||
await uploaded.save(temp_audio_path)
|
||||
|
||||
tenants = TenantService.get_info_by(tenant_id)
|
||||
if not tenants:
|
||||
return get_error_data_result(message="Tenant not found!")
|
||||
|
||||
asr_id = tenants[0]["asr_id"]
|
||||
if not asr_id:
|
||||
return get_error_data_result(message="No default ASR model is set")
|
||||
|
||||
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
|
||||
if not stream_mode:
|
||||
text = asr_mdl.transcription(temp_audio_path)
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
return get_json_result(data={"text": text})
|
||||
async def event_stream():
|
||||
try:
|
||||
for evt in asr_mdl.stream_transcription(temp_audio_path):
|
||||
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
|
||||
return Response(event_stream(), content_type="text/event-stream")
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def tts(tenant_id):
|
||||
req = await get_request_json()
|
||||
text = req["text"]
|
||||
|
||||
tenants = TenantService.get_info_by(tenant_id)
|
||||
if not tenants:
|
||||
return get_error_data_result(message="Tenant not found!")
|
||||
|
||||
tts_id = tenants[0]["tts_id"]
|
||||
if not tts_id:
|
||||
return get_error_data_result(message="No default TTS model is set")
|
||||
|
||||
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
||||
|
||||
def stream_audio():
|
||||
try:
|
||||
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
|
||||
for chunk in tts_mdl.tts(txt):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
|
||||
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
|
||||
return resp
|
||||
Reference in New Issue
Block a user