diff --git a/agent/canvas.py b/agent/canvas.py
index c447b77b3..667caec29 100644
--- a/agent/canvas.py
+++ b/agent/canvas.py
@@ -16,6 +16,7 @@
import asyncio
import base64
import inspect
+import binascii
import json
import logging
import re
@@ -28,7 +29,9 @@ from typing import Any, Union, Tuple
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
+from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import has_canceled
+from common.constants import LLMType
from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format
@@ -356,8 +359,6 @@ class Canvas(Graph):
self.globals[k] = ""
else:
self.globals[k] = ""
- print(self.globals)
-
async def run(self, **kwargs):
st = time.perf_counter()
@@ -456,6 +457,7 @@ class Canvas(Graph):
self.error = ""
idx = len(self.path) - 1
partials = []
+ tts_mdl = None
while idx < len(self.path):
to = len(self.path)
for i in range(idx, to):
@@ -473,31 +475,51 @@ class Canvas(Graph):
cpn = self.get_component(self.path[i])
cpn_obj = self.get_component_obj(self.path[i])
if cpn_obj.component_name.lower() == "message":
+ if cpn_obj.get_param("auto_play"):
+ tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS)
if isinstance(cpn_obj.output("content"), partial):
_m = ""
+ buff_m = ""
stream = cpn_obj.output("content")()
+ async def _process_stream(m):
+ nonlocal buff_m, _m, tts_mdl
+ if not m:
+ return
+ if m == "":
+ return decorate("message", {"content": "", "start_to_think": True})
+
+ elif m == "":
+ return decorate("message", {"content": "", "end_to_think": True})
+
+ buff_m += m
+ _m += m
+
+ if len(buff_m) > 16:
+ ev = decorate(
+ "message",
+ {
+ "content": m,
+ "audio_binary": self.tts(tts_mdl, buff_m)
+ }
+ )
+ buff_m = ""
+ return ev
+
+ return decorate("message", {"content": m})
+
if inspect.isasyncgen(stream):
async for m in stream:
- if not m:
- continue
- if m == "":
- yield decorate("message", {"content": "", "start_to_think": True})
- elif m == "":
- yield decorate("message", {"content": "", "end_to_think": True})
- else:
- yield decorate("message", {"content": m})
- _m += m
+ ev= await _process_stream(m)
+ if ev:
+ yield ev
else:
for m in stream:
- if not m:
- continue
- if m == "":
- yield decorate("message", {"content": "", "start_to_think": True})
- elif m == "":
- yield decorate("message", {"content": "", "end_to_think": True})
- else:
- yield decorate("message", {"content": m})
- _m += m
+ ev= await _process_stream(m)
+ if ev:
+ yield ev
+ if buff_m:
+ yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)})
+ buff_m = ""
cpn_obj.set_output("content", _m)
cite = re.search(r"\[ID:[ 0-9]+\]", _m)
else:
@@ -618,6 +640,50 @@ class Canvas(Graph):
return False
return True
+
+ def tts(self,tts_mdl, text):
+ def clean_tts_text(text: str) -> str:
+ if not text:
+ return ""
+
+ text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
+
+ text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
+
+ emoji_pattern = re.compile(
+ "[\U0001F600-\U0001F64F"
+ "\U0001F300-\U0001F5FF"
+ "\U0001F680-\U0001F6FF"
+ "\U0001F1E0-\U0001F1FF"
+ "\U00002700-\U000027BF"
+ "\U0001F900-\U0001F9FF"
+ "\U0001FA70-\U0001FAFF"
+ "\U0001FAD0-\U0001FAFF]+",
+ flags=re.UNICODE
+ )
+ text = emoji_pattern.sub("", text)
+
+ text = re.sub(r"\s+", " ", text).strip()
+
+ MAX_LEN = 500
+ if len(text) > MAX_LEN:
+ text = text[:MAX_LEN]
+
+ return text
+ if not tts_mdl or not text:
+ return None
+ text = clean_tts_text(text)
+ if not text:
+ return None
+ bin = b""
+ try:
+ for chunk in tts_mdl.tts(text):
+ bin += chunk
+ except Exception as e:
+ logging.error(f"TTS failed: {e}, text={text!r}")
+ return None
+ return binascii.hexlify(bin).decode("utf-8")
+
def get_history(self, window_size):
convs = []
if window_size <= 0:
diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index f1b74ce82..4afdd1f3c 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -761,13 +761,48 @@ Please write the SQL, only SQL, without any other explanations or text.
"prompt": sys_prompt,
}
+def clean_tts_text(text: str) -> str:
+ if not text:
+ return ""
+
+ text = text.encode("utf-8", "ignore").decode("utf-8", "ignore")
+
+ text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text)
+
+ emoji_pattern = re.compile(
+ "[\U0001F600-\U0001F64F"
+ "\U0001F300-\U0001F5FF"
+ "\U0001F680-\U0001F6FF"
+ "\U0001F1E0-\U0001F1FF"
+ "\U00002700-\U000027BF"
+ "\U0001F900-\U0001F9FF"
+ "\U0001FA70-\U0001FAFF"
+ "\U0001FAD0-\U0001FAFF]+",
+ flags=re.UNICODE
+ )
+ text = emoji_pattern.sub("", text)
+
+ text = re.sub(r"\s+", " ", text).strip()
+
+ MAX_LEN = 500
+ if len(text) > MAX_LEN:
+ text = text[:MAX_LEN]
+
+ return text
def tts(tts_mdl, text):
if not tts_mdl or not text:
return None
+ text = clean_tts_text(text)
+ if not text:
+ return None
bin = b""
- for chunk in tts_mdl.tts(text):
- bin += chunk
+ try:
+ for chunk in tts_mdl.tts(text):
+ bin += chunk
+ except Exception as e:
+ logging.error(f"TTS failed: {e}, text={text!r}")
+ return None
return binascii.hexlify(bin).decode("utf-8")