diff --git a/agent/canvas.py b/agent/canvas.py index c447b77b3..667caec29 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -16,6 +16,7 @@ import asyncio import base64 import inspect +import binascii import json import logging import re @@ -28,7 +29,9 @@ from typing import Any, Union, Tuple from agent.component import component_class from agent.component.base import ComponentBase from api.db.services.file_service import FileService +from api.db.services.llm_service import LLMBundle from api.db.services.task_service import has_canceled +from common.constants import LLMType from common.misc_utils import get_uuid, hash_str2int from common.exceptions import TaskCanceledException from rag.prompts.generator import chunks_format @@ -356,8 +359,6 @@ class Canvas(Graph): self.globals[k] = "" else: self.globals[k] = "" - print(self.globals) - async def run(self, **kwargs): st = time.perf_counter() @@ -456,6 +457,7 @@ class Canvas(Graph): self.error = "" idx = len(self.path) - 1 partials = [] + tts_mdl = None while idx < len(self.path): to = len(self.path) for i in range(idx, to): @@ -473,31 +475,51 @@ class Canvas(Graph): cpn = self.get_component(self.path[i]) cpn_obj = self.get_component_obj(self.path[i]) if cpn_obj.component_name.lower() == "message": + if cpn_obj.get_param("auto_play"): + tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS) if isinstance(cpn_obj.output("content"), partial): _m = "" + buff_m = "" stream = cpn_obj.output("content")() + async def _process_stream(m): + nonlocal buff_m, _m, tts_mdl + if not m: + return + if m == "": + return decorate("message", {"content": "", "start_to_think": True}) + + elif m == "": + return decorate("message", {"content": "", "end_to_think": True}) + + buff_m += m + _m += m + + if len(buff_m) > 16: + ev = decorate( + "message", + { + "content": m, + "audio_binary": self.tts(tts_mdl, buff_m) + } + ) + buff_m = "" + return ev + + return decorate("message", {"content": m}) + if inspect.isasyncgen(stream): async for m in stream: - if not m: - continue - if m == "": - yield decorate("message", {"content": "", "start_to_think": True}) - elif m == "": - yield decorate("message", {"content": "", "end_to_think": True}) - else: - yield decorate("message", {"content": m}) - _m += m + ev= await _process_stream(m) + if ev: + yield ev else: for m in stream: - if not m: - continue - if m == "": - yield decorate("message", {"content": "", "start_to_think": True}) - elif m == "": - yield decorate("message", {"content": "", "end_to_think": True}) - else: - yield decorate("message", {"content": m}) - _m += m + ev= await _process_stream(m) + if ev: + yield ev + if buff_m: + yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)}) + buff_m = "" cpn_obj.set_output("content", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m) else: @@ -618,6 +640,50 @@ class Canvas(Graph): return False return True + + def tts(self,tts_mdl, text): + def clean_tts_text(text: str) -> str: + if not text: + return "" + + text = text.encode("utf-8", "ignore").decode("utf-8", "ignore") + + text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) + + emoji_pattern = re.compile( + "[\U0001F600-\U0001F64F" + "\U0001F300-\U0001F5FF" + "\U0001F680-\U0001F6FF" + "\U0001F1E0-\U0001F1FF" + "\U00002700-\U000027BF" + "\U0001F900-\U0001F9FF" + "\U0001FA70-\U0001FAFF" + "\U0001FAD0-\U0001FAFF]+", + flags=re.UNICODE + ) + text = emoji_pattern.sub("", text) + + text = re.sub(r"\s+", " ", text).strip() + + MAX_LEN = 500 + if len(text) > MAX_LEN: + text = text[:MAX_LEN] + + return text + if not tts_mdl or not text: + return None + text = clean_tts_text(text) + if not text: + return None + bin = b"" + try: + for chunk in tts_mdl.tts(text): + bin += chunk + except Exception as e: + logging.error(f"TTS failed: {e}, text={text!r}") + return None + return binascii.hexlify(bin).decode("utf-8") + def get_history(self, window_size): convs = [] if window_size <= 0: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index f1b74ce82..4afdd1f3c 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -761,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text. "prompt": sys_prompt, } +def clean_tts_text(text: str) -> str: + if not text: + return "" + + text = text.encode("utf-8", "ignore").decode("utf-8", "ignore") + + text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) + + emoji_pattern = re.compile( + "[\U0001F600-\U0001F64F" + "\U0001F300-\U0001F5FF" + "\U0001F680-\U0001F6FF" + "\U0001F1E0-\U0001F1FF" + "\U00002700-\U000027BF" + "\U0001F900-\U0001F9FF" + "\U0001FA70-\U0001FAFF" + "\U0001FAD0-\U0001FAFF]+", + flags=re.UNICODE + ) + text = emoji_pattern.sub("", text) + + text = re.sub(r"\s+", " ", text).strip() + + MAX_LEN = 500 + if len(text) > MAX_LEN: + text = text[:MAX_LEN] + + return text def tts(tts_mdl, text): if not tts_mdl or not text: return None + text = clean_tts_text(text) + if not text: + return None bin = b"" - for chunk in tts_mdl.tts(text): - bin += chunk + try: + for chunk in tts_mdl.tts(text): + bin += chunk + except Exception as e: + logging.error(f"TTS failed: {e}, text={text!r}") + return None return binascii.hexlify(bin).decode("utf-8")