mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat:support tts in agent (#11675)
### What problem does this PR solve? change: support tts in agent ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
106
agent/canvas.py
106
agent/canvas.py
@ -16,6 +16,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import inspect
|
import inspect
|
||||||
|
import binascii
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
|
|||||||
from agent.component import component_class
|
from agent.component import component_class
|
||||||
from agent.component.base import ComponentBase
|
from agent.component.base import ComponentBase
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import has_canceled
|
from api.db.services.task_service import has_canceled
|
||||||
|
from common.constants import LLMType
|
||||||
from common.misc_utils import get_uuid, hash_str2int
|
from common.misc_utils import get_uuid, hash_str2int
|
||||||
from common.exceptions import TaskCanceledException
|
from common.exceptions import TaskCanceledException
|
||||||
from rag.prompts.generator import chunks_format
|
from rag.prompts.generator import chunks_format
|
||||||
@ -356,8 +359,6 @@ class Canvas(Graph):
|
|||||||
self.globals[k] = ""
|
self.globals[k] = ""
|
||||||
else:
|
else:
|
||||||
self.globals[k] = ""
|
self.globals[k] = ""
|
||||||
print(self.globals)
|
|
||||||
|
|
||||||
|
|
||||||
async def run(self, **kwargs):
|
async def run(self, **kwargs):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
@ -456,6 +457,7 @@ class Canvas(Graph):
|
|||||||
self.error = ""
|
self.error = ""
|
||||||
idx = len(self.path) - 1
|
idx = len(self.path) - 1
|
||||||
partials = []
|
partials = []
|
||||||
|
tts_mdl = None
|
||||||
while idx < len(self.path):
|
while idx < len(self.path):
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
@ -473,31 +475,51 @@ class Canvas(Graph):
|
|||||||
cpn = self.get_component(self.path[i])
|
cpn = self.get_component(self.path[i])
|
||||||
cpn_obj = self.get_component_obj(self.path[i])
|
cpn_obj = self.get_component_obj(self.path[i])
|
||||||
if cpn_obj.component_name.lower() == "message":
|
if cpn_obj.component_name.lower() == "message":
|
||||||
|
if cpn_obj.get_param("auto_play"):
|
||||||
|
tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
|
||||||
if isinstance(cpn_obj.output("content"), partial):
|
if isinstance(cpn_obj.output("content"), partial):
|
||||||
_m = ""
|
_m = ""
|
||||||
|
buff_m = ""
|
||||||
stream = cpn_obj.output("content")()
|
stream = cpn_obj.output("content")()
|
||||||
|
async def _process_stream(m):
|
||||||
|
nonlocal buff_m, _m, tts_mdl
|
||||||
|
if not m:
|
||||||
|
return
|
||||||
|
if m == "<think>":
|
||||||
|
return decorate("message", {"content": "", "start_to_think": True})
|
||||||
|
|
||||||
|
elif m == "</think>":
|
||||||
|
return decorate("message", {"content": "", "end_to_think": True})
|
||||||
|
|
||||||
|
buff_m += m
|
||||||
|
_m += m
|
||||||
|
|
||||||
|
if len(buff_m) > 16:
|
||||||
|
ev = decorate(
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": m,
|
||||||
|
"audio_binary": self.tts(tts_mdl, buff_m)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
buff_m = ""
|
||||||
|
return ev
|
||||||
|
|
||||||
|
return decorate("message", {"content": m})
|
||||||
|
|
||||||
if inspect.isasyncgen(stream):
|
if inspect.isasyncgen(stream):
|
||||||
async for m in stream:
|
async for m in stream:
|
||||||
if not m:
|
ev= await _process_stream(m)
|
||||||
continue
|
if ev:
|
||||||
if m == "<think>":
|
yield ev
|
||||||
yield decorate("message", {"content": "", "start_to_think": True})
|
|
||||||
elif m == "</think>":
|
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
|
||||||
else:
|
|
||||||
yield decorate("message", {"content": m})
|
|
||||||
_m += m
|
|
||||||
else:
|
else:
|
||||||
for m in stream:
|
for m in stream:
|
||||||
if not m:
|
ev= await _process_stream(m)
|
||||||
continue
|
if ev:
|
||||||
if m == "<think>":
|
yield ev
|
||||||
yield decorate("message", {"content": "", "start_to_think": True})
|
if buff_m:
|
||||||
elif m == "</think>":
|
yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
|
||||||
yield decorate("message", {"content": "", "end_to_think": True})
|
buff_m = ""
|
||||||
else:
|
|
||||||
yield decorate("message", {"content": m})
|
|
||||||
_m += m
|
|
||||||
cpn_obj.set_output("content", _m)
|
cpn_obj.set_output("content", _m)
|
||||||
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
|
||||||
else:
|
else:
|
||||||
@ -618,6 +640,50 @@ class Canvas(Graph):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def tts(self,tts_mdl, text):
|
||||||
|
def clean_tts_text(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||||
|
|
||||||
|
emoji_pattern = re.compile(
|
||||||
|
"[\U0001F600-\U0001F64F"
|
||||||
|
"\U0001F300-\U0001F5FF"
|
||||||
|
"\U0001F680-\U0001F6FF"
|
||||||
|
"\U0001F1E0-\U0001F1FF"
|
||||||
|
"\U00002700-\U000027BF"
|
||||||
|
"\U0001F900-\U0001F9FF"
|
||||||
|
"\U0001FA70-\U0001FAFF"
|
||||||
|
"\U0001FAD0-\U0001FAFF]+",
|
||||||
|
flags=re.UNICODE
|
||||||
|
)
|
||||||
|
text = emoji_pattern.sub("", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
MAX_LEN = 500
|
||||||
|
if len(text) > MAX_LEN:
|
||||||
|
text = text[:MAX_LEN]
|
||||||
|
|
||||||
|
return text
|
||||||
|
if not tts_mdl or not text:
|
||||||
|
return None
|
||||||
|
text = clean_tts_text(text)
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
bin = b""
|
||||||
|
try:
|
||||||
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||||
|
return None
|
||||||
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
def get_history(self, window_size):
|
def get_history(self, window_size):
|
||||||
convs = []
|
convs = []
|
||||||
if window_size <= 0:
|
if window_size <= 0:
|
||||||
|
|||||||
@ -761,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
"prompt": sys_prompt,
|
"prompt": sys_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def clean_tts_text(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
|
||||||
|
|
||||||
|
text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
|
||||||
|
|
||||||
|
emoji_pattern = re.compile(
|
||||||
|
"[\U0001F600-\U0001F64F"
|
||||||
|
"\U0001F300-\U0001F5FF"
|
||||||
|
"\U0001F680-\U0001F6FF"
|
||||||
|
"\U0001F1E0-\U0001F1FF"
|
||||||
|
"\U00002700-\U000027BF"
|
||||||
|
"\U0001F900-\U0001F9FF"
|
||||||
|
"\U0001FA70-\U0001FAFF"
|
||||||
|
"\U0001FAD0-\U0001FAFF]+",
|
||||||
|
flags=re.UNICODE
|
||||||
|
)
|
||||||
|
text = emoji_pattern.sub("", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
MAX_LEN = 500
|
||||||
|
if len(text) > MAX_LEN:
|
||||||
|
text = text[:MAX_LEN]
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
def tts(tts_mdl, text):
|
def tts(tts_mdl, text):
|
||||||
if not tts_mdl or not text:
|
if not tts_mdl or not text:
|
||||||
return None
|
return None
|
||||||
|
text = clean_tts_text(text)
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
bin = b""
|
bin = b""
|
||||||
for chunk in tts_mdl.tts(text):
|
try:
|
||||||
bin += chunk
|
for chunk in tts_mdl.tts(text):
|
||||||
|
bin += chunk
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS failed: {e}, text={text!r}")
|
||||||
|
return None
|
||||||
return binascii.hexlify(bin).decode("utf-8")
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user