diff --git a/ai_agents/agents/examples/voice-assistant/tenapp/manifest.json b/ai_agents/agents/examples/voice-assistant/tenapp/manifest.json index 020768c826..97f3d5c3a2 100644 --- a/ai_agents/agents/examples/voice-assistant/tenapp/manifest.json +++ b/ai_agents/agents/examples/voice-assistant/tenapp/manifest.json @@ -152,6 +152,9 @@ }, { "path": "../../../ten_packages/extension/oracle_tts_python" + }, + { + "path": "../../../ten_packages/extension/deepgram_tts" } ], "scripts": { diff --git a/ai_agents/agents/examples/voice-assistant/tenapp/property.json b/ai_agents/agents/examples/voice-assistant/tenapp/property.json index 270bfb77be..dcd0b8e214 100644 --- a/ai_agents/agents/examples/voice-assistant/tenapp/property.json +++ b/ai_agents/agents/examples/voice-assistant/tenapp/property.json @@ -185,6 +185,190 @@ ] } }, + { + "name": "voice_assistant_deepgram_tts", + "auto_start": false, + "graph": { + "nodes": [ + { + "type": "extension", + "name": "agora_rtc", + "addon": "agora_rtc", + "extension_group": "default", + "property": { + "app_id": "${env:AGORA_APP_ID}", + "app_certificate": "${env:AGORA_APP_CERTIFICATE|}", + "channel": "ten_agent_test", + "stream_id": 1234, + "remote_stream_id": 123, + "subscribe_audio": true, + "publish_audio": true, + "publish_data": true, + "enable_agora_asr": false + } + }, + { + "type": "extension", + "name": "stt", + "addon": "deepgram_asr_python", + "extension_group": "stt", + "property": { + "params": { + "api_key": "${env:DEEPGRAM_API_KEY}", + "language": "en-US", + "model": "nova-3" + } + } + }, + { + "type": "extension", + "name": "llm", + "addon": "openai_llm2_python", + "extension_group": "chatgpt", + "property": { + "base_url": "https://api.openai.com/v1", + "api_key": "${env:OPENAI_API_KEY}", + "frequency_penalty": 0.9, + "model": "${env:OPENAI_MODEL}", + "max_tokens": 512, + "prompt": "", + "proxy_url": "${env:OPENAI_PROXY_URL|}", + "greeting": "TEN Agent connected. How can I help you today?", + "max_memory_length": 10 + } + }, + { + "type": "extension", + "name": "tts", + "addon": "deepgram_tts", + "extension_group": "tts", + "property": { + "dump": false, + "dump_path": "/tmp", + "params": { + "api_key": "${env:DEEPGRAM_API_KEY}", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000 + } + } + }, + { + "type": "extension", + "name": "main_control", + "addon": "main_python", + "extension_group": "control", + "property": { + "greeting": "TEN Agent connected. How can I help you today?" + } + }, + { + "type": "extension", + "name": "message_collector", + "addon": "message_collector2", + "extension_group": "transcriber", + "property": {} + }, + { + "type": "extension", + "name": "weatherapi_tool_python", + "addon": "weatherapi_tool_python", + "extension_group": "default", + "property": { + "api_key": "${env:WEATHERAPI_API_KEY|}" + } + }, + { + "type": "extension", + "name": "streamid_adapter", + "addon": "streamid_adapter", + "property": {} + } + ], + "connections": [ + { + "extension": "main_control", + "cmd": [ + { + "names": [ + "on_user_joined", + "on_user_left" + ], + "source": [ + { + "extension": "agora_rtc" + } + ] + }, + { + "names": [ + "tool_register" + ], + "source": [ + { + "extension": "weatherapi_tool_python" + } + ] + } + ], + "data": [ + { + "name": "asr_result", + "source": [ + { + "extension": "stt" + } + ] + } + ] + }, + { + "extension": "agora_rtc", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension": "streamid_adapter" + } + ] + }, + { + "name": "pcm_frame", + "source": [ + { + "extension": "tts" + } + ] + } + ], + "data": [ + { + "name": "data", + "source": [ + { + "extension": "message_collector" + } + ] + } + ] + }, + { + "extension": "streamid_adapter", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension": "stt" + } + ] + } + ] + } + ] + } + }, { "name": "voice_assistant_oracle", "auto_start": false, diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/README.md b/ai_agents/agents/ten_packages/extension/deepgram_tts/README.md new file mode 100644 index 0000000000..ab18a5b30b --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/README.md @@ -0,0 +1,103 @@ +# Deepgram TTS Extension + +A TEN Framework extension that provides Text-to-Speech (TTS) capabilities using Deepgram's Aura streaming API. + +## Features + +- Real-time streaming TTS via WebSocket +- Multiple voice models (Aura-2 series) +- Configurable sample rates (8000, 16000, 24000, 48000 Hz) +- Linear16 PCM audio output +- TTFB (Time to First Byte) metrics reporting +- Audio dump capability for debugging + +## Configuration + +### Properties + +| Property | Type | Default | Description | +|----------|------|---------|-------------| +| `params.api_key` | string | Required | Deepgram API key | +| `params.model` | string | `aura-2-thalia-en` | Voice model to use | +| `params.encoding` | string | `linear16` | Audio encoding format | +| `params.sample_rate` | int | `24000` | Output sample rate in Hz | +| `params.base_url` | string | `wss://api.deepgram.com/v1/speak` | WebSocket endpoint | +| `params.` | scalar | Optional | Additional Deepgram websocket query parameters passed through to the vendor | +| `dump` | bool | `false` | Enable audio dumping | +| `dump_path` | string | `/tmp` | Path for audio dump files | + +### Example Configuration + +```json +{ + "params": { + "api_key": "${env:DEEPGRAM_API_KEY}", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + "container": "none" + }, + "dump": false, + "dump_path": "/tmp" +} +``` + +Known extension-owned keys such as `api_key`, `base_url`, `model`, `encoding`, +and `sample_rate` are normalized onto the config object. Any remaining scalar +keys under `params` are appended to the Deepgram websocket query string. + +## Available Voice Models + +Deepgram Aura-2 voices: +- `aura-2-thalia-en` - Female, English (default) +- `aura-2-luna-en` - Female, English +- `aura-2-stella-en` - Female, English +- `aura-2-athena-en` - Female, English +- `aura-2-hera-en` - Female, English +- `aura-2-orion-en` - Male, English +- `aura-2-arcas-en` - Male, English +- `aura-2-perseus-en` - Male, English +- `aura-2-angus-en` - Male, English +- `aura-2-orpheus-en` - Male, English +- `aura-2-helios-en` - Male, English +- `aura-2-zeus-en` - Male, English + +## Supported Sample Rates + +- 8000 Hz +- 16000 Hz +- 24000 Hz (recommended) +- 48000 Hz + +## API Interface + +This extension implements the standard TEN TTS interface: + +### Input Data +- `tts_text_input` - Text to synthesize +- `tts_flush` - Flush pending audio + +### Output Data +- `tts_audio_start` - Audio generation started +- `tts_audio_end` - Audio generation completed +- `metrics` - Performance metrics (TTFB, duration) +- `error` - Error information + +### Output Audio +- `pcm_frame` - PCM audio data (16-bit, mono) + +## Running Tests + +```bash +cd deepgram_tts +tman -y install --standalone +./tests/bin/start +``` + +## Environment Variables + +- `DEEPGRAM_API_KEY` - Your Deepgram API key + +## License + +Apache License, Version 2.0 diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/__init__.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/__init__.py new file mode 100644 index 0000000000..72593ab225 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/__init__.py @@ -0,0 +1,6 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/addon.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/addon.py new file mode 100644 index 0000000000..477d15e16d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/addon.py @@ -0,0 +1,20 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten_runtime import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("deepgram_tts") +class DeepgramTTSExtensionAddon(Addon): + + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import DeepgramTTSExtension + + ten_env.log_info("DeepgramTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(DeepgramTTSExtension(name), context) diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/config.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/config.py new file mode 100644 index 0000000000..cff5587242 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/config.py @@ -0,0 +1,70 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from __future__ import annotations + +from typing import Any +import copy + +from ten_ai_base import utils + +from pydantic import BaseModel, Field + + +class DeepgramTTSConfig(BaseModel): + api_key: str = "" + base_url: str = "wss://api.deepgram.com/v1/speak" + + model: str = "aura-2-thalia-en" + encoding: str = "linear16" + sample_rate: int = 24000 + + dump: bool = False + dump_path: str = "/tmp" + params: dict[str, Any] = Field(default_factory=dict) + + def update_params(self) -> None: + params = self._ensure_dict(self.params) + self.params = params + + if "api_key" in params: + self.api_key = params["api_key"] + del params["api_key"] + + if "base_url" in params: + self.base_url = params["base_url"] + del params["base_url"] + + if "model" in params: + self.model = params["model"] + del params["model"] + + if "encoding" in params: + self.encoding = params["encoding"] + del params["encoding"] + + if "sample_rate" in params: + self.sample_rate = params["sample_rate"] + del params["sample_rate"] + + def to_str(self, sensitive_handling: bool = True) -> str: + """ + Convert the configuration to a string representation. + """ + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + + if config.api_key: + config.api_key = utils.encrypt(config.api_key) + + return f"{config}" + + @staticmethod + def _ensure_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return value + return {} diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/deepgram_tts.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/deepgram_tts.py new file mode 100644 index 0000000000..57c69132f3 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/deepgram_tts.py @@ -0,0 +1,294 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import asyncio +import json +from datetime import datetime +from typing import AsyncIterator +from urllib.parse import urlencode + +import websockets +from websockets.asyncio.client import ClientConnection +from websockets.exceptions import InvalidStatus + +from .config import DeepgramTTSConfig +from ten_runtime import AsyncTenEnv +from ten_ai_base.const import LOG_CATEGORY_VENDOR + +# Event types communicated back to the extension. +# 4 is reserved (used by other TTS extensions for flush events). +EVENT_TTS_RESPONSE = 1 +EVENT_TTS_END = 2 +EVENT_TTS_ERROR = 3 +EVENT_TTS_TTFB_METRIC = 5 + +# Seconds to wait for a WebSocket response before timeout +WS_RECV_TIMEOUT = 8.0 + + +class DeepgramTTSConnectionException(Exception): + """Exception raised when Deepgram TTS connection fails""" + + def __init__(self, status_code: int, body: str): + self.status_code = status_code + self.body = body + super().__init__( + f"Deepgram TTS connection failed " f"(code: {status_code}): {body}" + ) + + +class DeepgramTTSClient: + """WebSocket client for Deepgram TTS. + + Each get() call sends Speak+Flush and streams audio + until Flushed. Connection is reused across calls but + reconnected when needed (cancel, error, new request). + """ + + def __init__( + self, + config: DeepgramTTSConfig, + ten_env: AsyncTenEnv, + ): + self.config = config + self.ten_env = ten_env + + self._ws: ClientConnection | None = None + self._is_cancelled = False + self._needs_reconnect = False + + # TTFB tracking + self._sent_ts: datetime | None = None + self._ttfb_sent: bool = False + + self._ws_url = self._build_ws_url() + + def _build_ws_url(self) -> str: + base = self.config.base_url + query_params: dict[str, str | int | float | bool] = { + "model": self.config.model, + "encoding": self.config.encoding, + "sample_rate": self.config.sample_rate, + } + + # Forward any additional Deepgram vendor params through the websocket + # query string while keeping auth and endpoint configuration out of it. + for key, value in self.config.params.items(): + if key in {"api_key", "base_url"} or value is None: + continue + query_params[key] = value + + return f"{base}?{urlencode(query_params, doseq=True)}" + + async def start(self) -> None: + """Preheat: establish initial connection.""" + try: + await self._connect() + except Exception as e: + self.ten_env.log_error(f"Deepgram TTS preheat failed: {e}") + + async def stop(self) -> None: + self._is_cancelled = True + if self._ws: + try: + await self._ws.send(json.dumps({"type": "Close"})) + except Exception: + pass + try: + await self._ws.close() + except Exception: + pass + self._ws = None + + async def cancel(self) -> None: + """Cancel current TTS. + + Sends Flush and drains until Flushed so the + connection is clean for the next request. + """ + self.ten_env.log_debug("Cancelling current TTS task.") + self._is_cancelled = True + self.reset_ttfb() + if self._ws: + try: + await self._ws.send(json.dumps({"type": "Flush"})) + # Drain until Flushed to leave connection clean + await asyncio.wait_for(self._drain_until_flushed(), timeout=3.0) + except Exception as e: + self.ten_env.log_warn( + f"Cancel drain failed: {e}, " + "will reconnect on next request" + ) + self._needs_reconnect = True + + async def _drain_until_flushed(self) -> None: + """Read and discard WS messages until Flushed.""" + while self._ws: + msg = await self._ws.recv() + if isinstance(msg, str): + try: + data = json.loads(msg) + if data.get("type") == "Flushed": + return + except json.JSONDecodeError: + pass + + def reset_ttfb(self) -> None: + self._sent_ts = None + self._ttfb_sent = False + + async def get( + self, text: str + ) -> AsyncIterator[tuple[bytes | int | None, int]]: + """Send text and yield audio events.""" + if len(text.strip()) == 0: + self.ten_env.log_warn("DeepgramTTS: empty text, returning END") + yield None, EVENT_TTS_END + return + + # Reconnect if needed (after error or cancel) + if self._needs_reconnect: + await self._reconnect() + self._needs_reconnect = False + + await self._ensure_connection() + + if not self._ttfb_sent: + self._sent_ts = datetime.now() + + # Clear cancel flag just before sending, not at + # method entry — avoids race with concurrent cancel() + self._is_cancelled = False + + # Send Speak + Flush + speak_msg = {"type": "Speak", "text": text} + await self._ws.send(json.dumps(speak_msg)) + await self._ws.send(json.dumps({"type": "Flush"})) + + # Receive audio until Flushed + try: + while True: + if self._is_cancelled: + self.ten_env.log_debug("Cancelled, stopping stream.") + break + + try: + message = await asyncio.wait_for( + self._ws.recv(), timeout=WS_RECV_TIMEOUT + ) + except asyncio.TimeoutError: + self.ten_env.log_error("Timeout waiting for Deepgram audio") + self._needs_reconnect = True + yield ( + b"Timeout waiting for Deepgram audio", + EVENT_TTS_ERROR, + ) + break + + if isinstance(message, bytes): + if self._is_cancelled: + self.ten_env.log_debug("Dropping audio (cancelled)") + break + + # TTFB on first audio chunk + if self._sent_ts and not self._ttfb_sent: + ttfb_ms = int( + (datetime.now() - self._sent_ts).total_seconds() + * 1000 + ) + yield ttfb_ms, EVENT_TTS_TTFB_METRIC + self._ttfb_sent = True + + self.ten_env.log_debug( + f"DeepgramTTS: audio chunk, " f"length: {len(message)}" + ) + yield message, EVENT_TTS_RESPONSE + else: + try: + data = json.loads(message) + msg_type = data.get("type", "") + + if msg_type == "Flushed": + self.ten_env.log_debug("DeepgramTTS: Flushed") + yield None, EVENT_TTS_END + break + + elif msg_type == "Warning": + self.ten_env.log_warn( + f"Deepgram warning: " + f"{data.get('warn_msg', '')}" + ) + + elif msg_type == "Error": + error_msg = data.get("err_msg", "Unknown error") + self.ten_env.log_error( + f"Deepgram error: {error_msg}" + ) + self._needs_reconnect = True + yield ( + error_msg.encode("utf-8"), + EVENT_TTS_ERROR, + ) + break + + except json.JSONDecodeError: + self.ten_env.log_warn(f"Failed to parse: {message}") + + if not self._is_cancelled: + self.ten_env.log_debug("DeepgramTTS: complete") + + except Exception as e: + self.ten_env.log_error( + f"vendor_error: {e}", + category=LOG_CATEGORY_VENDOR, + ) + self._needs_reconnect = True + yield ( + str(e).encode("utf-8"), + EVENT_TTS_ERROR, + ) + + async def _connect(self) -> None: + try: + extra_headers = { + "Authorization": f"Token {self.config.api_key}", + } + self._ws = await websockets.connect( + self._ws_url, + additional_headers=extra_headers, + ) + self.ten_env.log_debug( + "vendor_status: connected to deepgram tts", + category=LOG_CATEGORY_VENDOR, + ) + except InvalidStatus as e: + raise DeepgramTTSConnectionException( + status_code=e.response.status_code, + body=str(e), + ) from e + except Exception as e: + error_message = str(e) + # Fallback string match for non-websockets + # exceptions (e.g., mocked tests) + if "401" in error_message or "Unauthorized" in error_message: + raise DeepgramTTSConnectionException( + status_code=401, body=error_message + ) from e + self.ten_env.log_error(f"Deepgram TTS connection failed: {e}") + raise + + async def _ensure_connection(self) -> None: + if not self._ws: + await self._connect() + + async def _reconnect(self) -> None: + """Close and re-establish the connection.""" + if self._ws: + try: + await self._ws.close() + except Exception: + pass + self._ws = None + await self._connect() diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/extension.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/extension.py new file mode 100644 index 0000000000..aee0b9f1e8 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/extension.py @@ -0,0 +1,458 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from datetime import datetime +import os +import traceback + +from ten_ai_base.helper import PCMWriter +from ten_ai_base.message import ( + ModuleError, + ModuleErrorCode, + ModuleType, + ModuleErrorVendorInfo, + TTSAudioEndReason, +) +from ten_ai_base.struct import TTSTextInput +from ten_ai_base.tts2 import AsyncTTS2BaseExtension +from ten_ai_base.const import LOG_CATEGORY_VENDOR, LOG_CATEGORY_KEY_POINT +from .config import DeepgramTTSConfig + +from .deepgram_tts import ( + EVENT_TTS_END, + EVENT_TTS_RESPONSE, + EVENT_TTS_TTFB_METRIC, + EVENT_TTS_ERROR, + DeepgramTTSClient, + DeepgramTTSConnectionException, +) +from ten_runtime import AsyncTenEnv + + +class DeepgramTTSExtension(AsyncTTS2BaseExtension): + def __init__(self, name: str) -> None: + super().__init__(name) + self.config: DeepgramTTSConfig | None = None + self.client: DeepgramTTSClient | None = None + self.current_request_id: str | None = None + self.current_turn_id: int = -1 + self.sent_ts: datetime | None = None + self.current_request_finished: bool = False + self.total_audio_bytes: int = 0 + self._is_stopped: bool = False + self.recorder_map: dict[str, PCMWriter] = {} + self._audio_start_sent: bool = False + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + try: + await super().on_init(ten_env) + config_json_str, _ = await self.ten_env.get_property_to_json("") + + if not config_json_str or config_json_str.strip() == "{}": + raise ValueError( + "Configuration is empty. " + "Required parameter 'api_key' is missing." + ) + + self.config = DeepgramTTSConfig.model_validate_json(config_json_str) + self.config.update_params() + ten_env.log_info( + self.config.to_str(sensitive_handling=True), + category=LOG_CATEGORY_KEY_POINT, + ) + + if not self.config.api_key: + raise ValueError("API key is required") + + self.client = self._create_client(ten_env) + await self.client.start() + ten_env.log_debug("DeepgramTTS client initialized successfully") + except Exception as e: + ten_env.log_error(f"on_init failed: {traceback.format_exc()}") + await self.send_tts_error( + request_id="", + error=ModuleError( + message=f"Initialization failed: {e}", + module=ModuleType.TTS, + code=ModuleErrorCode.FATAL_ERROR, + vendor_info=ModuleErrorVendorInfo(vendor=self.vendor()), + ), + ) + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + self._is_stopped = True + ten_env.log_debug("Extension stopping, rejecting new requests") + + if self.client: + await self.client.stop() + self.client = None + + for request_id, recorder in list(self.recorder_map.items()): + try: + await recorder.flush() + ten_env.log_debug( + f"Flushed PCMWriter for request_id: " f"{request_id}" + ) + except Exception as e: + ten_env.log_error( + f"Error flushing PCMWriter for " + f"request_id {request_id}: {e}" + ) + + await super().on_stop(ten_env) + ten_env.log_debug("on_stop") + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) + ten_env.log_debug("on_deinit") + + async def cancel_tts(self) -> None: + self.current_request_finished = True + if self.current_request_id: + self.ten_env.log_debug( + f"Cancelling request {self.current_request_id}" + ) + if self.client: + await self.client.cancel() + await self._finalize_request(TTSAudioEndReason.INTERRUPTED) + else: + self.ten_env.log_warn("No current request, skipping cancel.") + + def vendor(self) -> str: + return "deepgram" + + def synthesize_audio_sample_rate(self) -> int: + if self.config is None: + return 24000 + return self.config.sample_rate + + def _create_client(self, ten_env: AsyncTenEnv) -> DeepgramTTSClient: + return DeepgramTTSClient( + config=self.config, + ten_env=ten_env, + ) + + async def _ensure_client(self) -> None: + """Ensure client is connected, reconnecting if needed.""" + if self.client is None: + self.ten_env.log_debug( + "TTS client is not initialized, reconnecting..." + ) + self.client = self._create_client(self.ten_env) + await self.client.start() + self.ten_env.log_debug("TTS client reconnected successfully.") + + async def _reconnect_client(self) -> None: + """Destroy current client and reconnect immediately.""" + if self.client: + await self.client.stop() + self.client = None + try: + self.client = self._create_client(self.ten_env) + await self.client.start() + self.ten_env.log_debug("Client reconnected after error.") + except Exception as e: + self.ten_env.log_error(f"Immediate reconnect failed: {e}") + self.client = None + + async def _finalize_request( + self, + reason: TTSAudioEndReason, + error: ModuleError | None = None, + ) -> None: + """Send audio end, flush recorder, finish request.""" + if not self._audio_start_sent: + await self.send_tts_audio_start( + request_id=self.current_request_id, + ) + self._audio_start_sent = True + + request_event_interval = self._current_request_interval_ms() + duration_ms = self._calculate_audio_duration_ms() + + await self.send_tts_audio_end( + request_id=self.current_request_id, + request_event_interval_ms=request_event_interval, + request_total_audio_duration_ms=duration_ms, + reason=reason, + ) + + if self.current_request_id in self.recorder_map: + await self.recorder_map[self.current_request_id].flush() + + await self.finish_request( + request_id=self.current_request_id, + reason=reason, + error=error, + ) + + self.sent_ts = None + self.ten_env.log_debug( + f"Finalized request, reason: {reason}, " + f"interval: {request_event_interval}ms, " + f"duration: {duration_ms}ms" + ) + + async def request_tts(self, t: TTSTextInput) -> None: + """Handle TTS requests.""" + try: + self.ten_env.log_info( + f"Requesting TTS for text: {t.text}, " + f"text_input_end: {t.text_input_end} " + f"request ID: {t.request_id}", + ) + + await self._ensure_client() + + if t.request_id != self.current_request_id: + self.ten_env.log_debug( + f"New TTS request with ID: {t.request_id}" + ) + if self.client: + self.client.reset_ttfb() + self.current_request_id = t.request_id + self.current_request_finished = False + self.total_audio_bytes = 0 + self.sent_ts = None + self._audio_start_sent = False + if t.metadata is not None: + self.session_id = t.metadata.get("session_id", "") + self.current_turn_id = t.metadata.get("turn_id", -1) + await self._setup_recorder(t.request_id) + elif self.current_request_finished: + self.ten_env.log_error( + f"Received a message for a finished " + f"request_id '{t.request_id}' with " + f"text_input_end=False." + ) + return + + if t.text_input_end: + self.ten_env.log_debug( + f"KEYPOINT finish session for " + f"request ID: {t.request_id}" + ) + self.current_request_finished = True + + prepared_text = t.text.strip() + + if self._is_stopped: + self.ten_env.log_debug( + f"TTS is stopped, skipping " f"request_id: {t.request_id}" + ) + return + + if prepared_text != "": + await self._process_tts_text(prepared_text, t) + elif t.text_input_end: + await self._finalize_request(TTSAudioEndReason.REQUEST_END) + + except DeepgramTTSConnectionException as e: + await self._handle_connection_error(e) + + except Exception as e: + self.ten_env.log_error( + f"Error in request_tts: " + f"{traceback.format_exc()}. text: {t.text}" + ) + error = ModuleError( + message=str(e), + module=ModuleType.TTS, + code=ModuleErrorCode.NON_FATAL_ERROR, + vendor_info=ModuleErrorVendorInfo(vendor=self.vendor()), + ) + await self._finalize_request(TTSAudioEndReason.ERROR, error=error) + if isinstance(e, ConnectionRefusedError): + await self._reconnect_client() + + async def _process_tts_text(self, text: str, t: TTSTextInput) -> None: + """Process non-empty text through the TTS pipeline.""" + self.ten_env.log_debug( + f"send_text_to_tts_server: {text} " + f"of request_id: {t.request_id}", + category=LOG_CATEGORY_VENDOR, + ) + data = self.client.get(text) + + chunk_count = 0 + if self.sent_ts is None: + self.sent_ts = datetime.now() + + async for data_msg, event_status in data: + self.ten_env.log_debug(f"Received event_status: {event_status}") + if event_status == EVENT_TTS_RESPONSE: + if ( + data_msg is not None + and isinstance(data_msg, bytes) + and len(data_msg) > 0 + ): + chunk_count += 1 + self.total_audio_bytes += len(data_msg) + self.ten_env.log_info( + f"Received audio chunk " + f"#{chunk_count}, " + f"size: {len(data_msg)} bytes" + ) + await self._write_dump(data_msg) + await self.send_tts_audio_data(data_msg) + else: + self.ten_env.log_debug("Empty payload, ignoring") + + elif event_status == EVENT_TTS_TTFB_METRIC: + if data_msg is not None and isinstance(data_msg, int): + # Overwrite sent_ts to audio-start time so that + # _current_request_interval_ms() measures streaming + # duration (first audio → last audio), not total + # request time. This matches the HTTP base class. + self.sent_ts = datetime.now() + ttfb = data_msg + await self.send_tts_audio_start( + request_id=self.current_request_id, + ) + self._audio_start_sent = True + await self.send_tts_ttfb_metrics( + request_id=self.current_request_id, + ttfb_ms=ttfb, + extra_metadata={ + "model": self.config.model, + }, + ) + self.ten_env.log_debug( + f"Sent TTS audio start and " f"TTFB metrics: {ttfb}ms" + ) + + elif event_status == EVENT_TTS_END: + if t.text_input_end: + self.ten_env.log_info( + f"Received final TTS_END event from Deepgram TTS " + f"for request_id: {t.request_id}" + ) + await self._finalize_request(TTSAudioEndReason.REQUEST_END) + else: + self.ten_env.log_debug( + f"Received intermediate TTS_END event from " + f"Deepgram TTS for request_id: {t.request_id}" + ) + break + + elif event_status == EVENT_TTS_ERROR: + error_msg = ( + data_msg.decode("utf-8") + if isinstance(data_msg, bytes) + else str(data_msg) + ) + self.ten_env.log_error(f"TTS_ERROR from Deepgram: {error_msg}") + error = ModuleError( + message=error_msg, + module=ModuleType.TTS, + code=ModuleErrorCode.NON_FATAL_ERROR, + vendor_info=ModuleErrorVendorInfo(vendor=self.vendor()), + ) + if t.text_input_end: + # Final chunk: surface error and + # finalize the request + await self._finalize_request( + TTSAudioEndReason.ERROR, + error=error, + ) + else: + # Non-final chunk: log only. The base + # class will send subsequent chunks for + # this request_id; errors on partial + # streaming are transient. + self.ten_env.log_warn( + f"Transient TTS error on non-final " + f"chunk for {t.request_id}: " + f"{error_msg}" + ) + break + + self.ten_env.log_debug( + f"TTS processing completed, " f"total chunks: {chunk_count}" + ) + + async def _handle_connection_error( + self, e: DeepgramTTSConnectionException + ) -> None: + """Handle Deepgram connection errors. + + Sends exactly one error event via _finalize_request. + """ + self.ten_env.log_error(f"DeepgramTTSConnectionException: {e.body}") + if e.status_code == 401: + code = ModuleErrorCode.FATAL_ERROR + else: + code = ModuleErrorCode.NON_FATAL_ERROR + + error = ModuleError( + message=e.body, + module=ModuleType.TTS, + code=code, + vendor_info=ModuleErrorVendorInfo( + vendor=self.vendor(), + code=str(e.status_code), + message=e.body, + ), + ) + await self._finalize_request(TTSAudioEndReason.ERROR, error=error) + + async def _setup_recorder(self, request_id: str) -> None: + """Set up PCMWriter for a new request.""" + if not (self.config and self.config.dump): + return + # Clean up old PCMWriters + for old_rid in [ + rid for rid in self.recorder_map.keys() if rid != request_id + ]: + try: + await self.recorder_map[old_rid].flush() + del self.recorder_map[old_rid] + self.ten_env.log_debug( + f"Cleaned up old PCMWriter for " f"request_id: {old_rid}" + ) + except Exception as e: + self.ten_env.log_error( + f"Error cleaning up PCMWriter for " + f"request_id {old_rid}: {e}" + ) + + if request_id not in self.recorder_map: + dump_file_path = os.path.join( + self.config.dump_path, + f"deepgram_dump_{request_id}.pcm", + ) + self.recorder_map[request_id] = PCMWriter(dump_file_path) + self.ten_env.log_debug( + f"Created PCMWriter for request_id: " + f"{request_id}, file: {dump_file_path}" + ) + + async def _write_dump(self, data: bytes) -> None: + """Write audio data to dump file if enabled.""" + if ( + self.config + and self.config.dump + and self.current_request_id + and self.current_request_id in self.recorder_map + ): + try: + await self.recorder_map[self.current_request_id].write(data) + except Exception as e: + self.ten_env.log_error(f"Dump write failed: {e}") + + def _current_request_interval_ms(self) -> int: + if not self.sent_ts: + return 0 + return int((datetime.now() - self.sent_ts).total_seconds() * 1000) + + def _calculate_audio_duration_ms(self) -> int: + if self.config is None: + return 0 + bytes_per_sample = 2 # 16-bit PCM + channels = 1 # Mono + duration_sec = self.total_audio_bytes / ( + self.synthesize_audio_sample_rate() * bytes_per_sample * channels + ) + return int(duration_sec * 1000) diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/manifest.json b/ai_agents/agents/ten_packages/extension/deepgram_tts/manifest.json new file mode 100644 index 0000000000..ffaceacaa3 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/manifest.json @@ -0,0 +1,65 @@ +{ + "type": "extension", + "name": "deepgram_tts", + "version": "0.1.1", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "BUILD.gn", + "**.tent", + "**.py", + "README.md", + "requirements.txt" + ] + }, + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/tts-interface.json" + } + ], + "property": { + "properties": { + "dump": { + "type": "bool" + }, + "dump_path": { + "type": "string" + }, + "params": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "model": { + "type": "string" + }, + "encoding": { + "type": "string" + }, + "sample_rate": { + "type": "int32" + } + } + } + } + } + } +} diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/property.json b/ai_agents/agents/ten_packages/extension/deepgram_tts/property.json new file mode 100644 index 0000000000..313cff84f4 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/property.json @@ -0,0 +1,11 @@ +{ + "dump": false, + "dump_path": "/tmp", + "params": { + "api_key": "${env:DEEPGRAM_API_KEY}", + "base_url": "wss://api.deepgram.com/v1/speak", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/requirements.txt b/ai_agents/agents/ten_packages/extension/deepgram_tts/requirements.txt new file mode 100644 index 0000000000..31b5e2f348 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/requirements.txt @@ -0,0 +1 @@ +websockets>=12.0 diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/__init__.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/__init__.py new file mode 100644 index 0000000000..da402faf43 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/__init__.py @@ -0,0 +1,5 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/bin/start b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/bin/start new file mode 100755 index 0000000000..41da3fdb45 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/bin/start @@ -0,0 +1,21 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." + +export PYTHONPATH=.ten/app:.ten/app/ten_packages/system/ten_runtime_python/lib:.ten/app/ten_packages/system/ten_runtime_python/interface:.ten/app/ten_packages/system/ten_ai_base/interface:$PYTHONPATH + +# If the Python app imports some modules that are compiled with a different +# version of libstdc++ (ex: PyTorch), the Python app may encounter confusing +# errors. To solve this problem, we can preload the correct version of +# libstdc++. +# +# export LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 +# +# Another solution is to make sure the module 'ten_runtime_python' is imported +# _after_ the module that requires another version of libstdc++ is imported. +# +# Refer to https://github.com/pytorch/pytorch/issues/102360?from_wecom=1#issuecomment-1708989096 + +pytest tests/ "$@" diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_basic_audio_setting1.json b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_basic_audio_setting1.json new file mode 100644 index 0000000000..ff0a081e87 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_basic_audio_setting1.json @@ -0,0 +1,10 @@ +{ + "dump": true, + "dump_path": "./tests/keep_dump_output/", + "params": { + "api_key": "${env:DEEPGRAM_API_KEY}", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_basic_audio_setting2.json b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_basic_audio_setting2.json new file mode 100644 index 0000000000..c753384856 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_basic_audio_setting2.json @@ -0,0 +1,10 @@ +{ + "dump": true, + "dump_path": "./tests/keep_dump_output/", + "params": { + "api_key": "${env:DEEPGRAM_API_KEY}", + "model": "aura-2-luna-en", + "encoding": "linear16", + "sample_rate": 16000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_dump.json b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_dump.json new file mode 100644 index 0000000000..4690fecb76 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_dump.json @@ -0,0 +1,10 @@ +{ + "dump": true, + "dump_path": "./dump/", + "params": { + "api_key": "${env:DEEPGRAM_API_KEY}", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_invalid.json b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_invalid.json new file mode 100644 index 0000000000..6233cf106a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_invalid.json @@ -0,0 +1,5 @@ +{ + "params": { + "api_key": "invalid" + } +} diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_miss_required.json b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_miss_required.json new file mode 100644 index 0000000000..df133e721a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/configs/property_miss_required.json @@ -0,0 +1,5 @@ +{ + "params": { + "api_key": "" + } +} diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/conftest.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/conftest.py new file mode 100644 index 0000000000..958647c64d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/conftest.py @@ -0,0 +1,107 @@ +import sys +from pathlib import Path + +# Add project root to sys.path for test imports +project_root = str(Path(__file__).resolve().parents[6]) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +import threading +from typing_extensions import override +import pytest +from ten_runtime import ( + App, + TenEnv, +) + + +class FakeApp(App): + def __init__(self): + super().__init__() + self.event: threading.Event | None = None + + # In the case of a fake app, we use `on_init` to allow the blocked testing + # fixture to continue execution, rather than using `on_configure`. The + # reason is that in the TEN runtime C core, the relationship between the + # addon manager and the (fake) app is bound after `on_configure_done` is + # called. So we only need to let the testing fixture continue execution + # after this action in the TEN runtime C core, and at the upper layer + # timing, the earliest point is within the `on_init()` function of the upper + # TEN app. Therefore, we release the testing fixture lock within the user + # layer's `on_init()` of the TEN app. + @override + def on_init(self, ten_env: TenEnv) -> None: + assert self.event + self.event.set() + + ten_env.on_init_done() + + @override + def on_configure(self, ten_env: TenEnv) -> None: + ten_env.init_property_from_json( + json.dumps( + { + "ten": { + "log": { + "handlers": [ + { + "matchers": [{"level": "debug"}], + "formatter": { + "type": "plain", + "colored": True, + }, + "emitter": { + "type": "console", + "config": {"stream": "stdout"}, + }, + } + ] + } + } + } + ), + ) + + ten_env.on_configure_done() + + +class FakeAppCtx: + def __init__(self, event: threading.Event): + self.fake_app: FakeApp | None = None + self.event = event + + +def run_fake_app(fake_app_ctx: FakeAppCtx): + app = FakeApp() + app.event = fake_app_ctx.event + fake_app_ctx.fake_app = app + app.run(False) + + +@pytest.fixture(scope="session", autouse=True) +def global_setup_and_teardown(): + event = threading.Event() + fake_app_ctx = FakeAppCtx(event) + + fake_app_thread = threading.Thread( + target=run_fake_app, args=(fake_app_ctx,) + ) + fake_app_thread.start() + + event.wait() + + assert fake_app_ctx.fake_app is not None + + # Yield control to the test; after the test execution is complete, continue + # with the teardown process. + yield + + # Teardown part. + fake_app_ctx.fake_app.close() + fake_app_thread.join() diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_basic.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_basic.py new file mode 100644 index 0000000000..2f001d17f3 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_basic.py @@ -0,0 +1,314 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from unittest.mock import patch, AsyncMock +import os +import asyncio +import filecmp +import shutil + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput, TTSFlush +from deepgram_tts.deepgram_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, +) + + +# ================ test dump file functionality ================ +class ExtensionTesterDump(ExtensionTester): + def __init__(self): + super().__init__() + self.dump_dir = "./dump/" + self.test_dump_file_path = os.path.join( + self.dump_dir, "test_manual_dump.pcm" + ) + self.audio_end_received = False + self.received_audio_chunks = [] + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Dump test started, sending TTS request.") + + tts_input = TTSTextInput( + request_id="tts_request_1", + text="hello word, hello agora", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end, stopping test.") + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + buf = audio_frame.lock_buf() + try: + copied_data = bytes(buf) + self.received_audio_chunks.append(copied_data) + finally: + audio_frame.unlock_buf(buf) + + def write_test_dump_file(self): + with open(self.test_dump_file_path, "wb") as f: + for chunk in self.received_audio_chunks: + f.write(chunk) + + def find_tts_dump_file(self) -> str | None: + if not os.path.exists(self.dump_dir): + return None + for filename in os.listdir(self.dump_dir): + if filename.endswith(".pcm") and filename != os.path.basename( + self.test_dump_file_path + ): + return os.path.join(self.dump_dir, filename) + return None + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_dump_functionality(MockDeepgramTTSClient): + """Tests that the dump file from the TTS extension matches the audio received.""" + print("Starting test_dump_functionality with mock...") + + DUMP_PATH = "./dump/" + + if os.path.exists(DUMP_PATH): + shutil.rmtree(DUMP_PATH) + os.makedirs(DUMP_PATH) + + mock_instance = MockDeepgramTTSClient.return_value + mock_instance.start = AsyncMock() + mock_instance.stop = AsyncMock() + mock_instance.cancel = AsyncMock() + mock_instance.reset_ttfb = lambda: None + + fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20 + fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20 + + async def mock_get_audio_stream(text: str): + yield (255, EVENT_TTS_TTFB_METRIC) + yield (fake_audio_chunk_1, EVENT_TTS_RESPONSE) + await asyncio.sleep(0.01) + yield (fake_audio_chunk_2, EVENT_TTS_RESPONSE) + await asyncio.sleep(0.01) + yield (None, EVENT_TTS_END) + + mock_instance.get.side_effect = mock_get_audio_stream + + tester = ExtensionTesterDump() + + dump_config = { + "dump": True, + "dump_path": DUMP_PATH, + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + + tester.set_test_mode_single("deepgram_tts", json.dumps(dump_config)) + + print("Running dump test...") + tester.run() + print("Dump test completed.") + + assert tester.audio_end_received, "Expected to receive tts_audio_end" + assert ( + len(tester.received_audio_chunks) > 0 + ), "Expected to receive audio chunks" + + tester.write_test_dump_file() + + tts_dump_file = tester.find_tts_dump_file() + assert ( + tts_dump_file is not None + ), f"Expected to find a TTS dump file in {DUMP_PATH}" + assert os.path.exists( + tts_dump_file + ), f"TTS dump file should exist: {tts_dump_file}" + + print( + f"Comparing test file {tester.test_dump_file_path} with TTS dump file {tts_dump_file}" + ) + assert filecmp.cmp( + tester.test_dump_file_path, tts_dump_file, shallow=False + ), "Test dump file and TTS dump file should have the same content" + + print( + f"Dump test passed: received {len(tester.received_audio_chunks)} audio chunks" + ) + + if os.path.exists(DUMP_PATH): + shutil.rmtree(DUMP_PATH) + + +# ================ test basic audio output ================ +class ExtensionTesterBasic(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_start_received = False + self.audio_end_received = False + self.audio_chunks_count = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Basic test started, sending TTS request.") + + tts_input = TTSTextInput( + request_id="tts_request_basic", + text="Hello, this is a test of the Deepgram TTS extension.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_start": + ten_env.log_info("Received tts_audio_start.") + self.audio_start_received = True + elif name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end, stopping test.") + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + self.audio_chunks_count += 1 + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_basic_audio(MockDeepgramTTSClient): + """Test basic TTS audio generation.""" + mock_instance = MockDeepgramTTSClient.return_value + mock_instance.start = AsyncMock() + mock_instance.stop = AsyncMock() + mock_instance.cancel = AsyncMock() + mock_instance.reset_ttfb = lambda: None + + fake_audio_chunk = b"\x00\x01\x02\x03" * 100 + + async def mock_get_audio_stream(text: str): + yield (150, EVENT_TTS_TTFB_METRIC) + yield (fake_audio_chunk, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock_instance.get.side_effect = mock_get_audio_stream + + tester = ExtensionTesterBasic() + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_start_received, "tts_audio_start was not received." + assert tester.audio_end_received, "tts_audio_end was not received." + assert tester.audio_chunks_count > 0, "No audio chunks received." + + +# ================ test flush functionality ================ +class ExtensionTesterFlush(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Flush test started.") + + tts_input = TTSTextInput( + request_id="tts_request_flush", + text="This is the first sentence.", + text_input_end=False, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + + flush = TTSFlush(flush_id="flush_1") + flush_data = Data.create("tts_flush") + flush_data.set_property_from_json(None, flush.model_dump_json()) + ten_env_tester.send_data(flush_data) + + tts_input2 = TTSTextInput( + request_id="tts_request_flush", + text="This is the final sentence.", + text_input_end=True, + ) + data2 = Data.create("tts_text_input") + data2.set_property_from_json(None, tts_input2.model_dump_json()) + ten_env_tester.send_data(data2) + + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end, stopping test.") + self.audio_end_received = True + ten_env.stop_test() + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_flush(MockDeepgramTTSClient): + """Test TTS flush functionality.""" + mock_instance = MockDeepgramTTSClient.return_value + mock_instance.start = AsyncMock() + mock_instance.stop = AsyncMock() + mock_instance.cancel = AsyncMock() + mock_instance.reset_ttfb = lambda: None + + fake_audio_chunk = b"\x00\x01\x02\x03" * 50 + + async def mock_get_audio_stream(text: str): + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio_chunk, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock_instance.get.side_effect = mock_get_audio_stream + + tester = ExtensionTesterFlush() + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert ( + tester.audio_end_received + ), "tts_audio_end was not received after flush." diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_error_msg.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_error_msg.py new file mode 100644 index 0000000000..f194ca34cc --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_error_msg.py @@ -0,0 +1,166 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from unittest.mock import patch, AsyncMock, MagicMock + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput + + +# ================ test empty params ================ +class ExtensionTesterEmptyParams(ExtensionTester): + def __init__(self): + super().__init__() + self.error_received = False + self.error_code = None + self.error_message = None + self.error_module = None + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + """Called when test starts""" + ten_env_tester.log_info("Test started") + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + ten_env.log_info(f"on_data name: {name}") + + if name == "error": + self.error_received = True + json_str, _ = data.get_property_to_json(None) + error_data = json.loads(json_str) + + self.error_code = error_data.get("code") + self.error_message = error_data.get("message", "") + self.error_module = error_data.get("module", "") + + ten_env.log_info( + f"Received error: code={self.error_code}, message={self.error_message}" + ) + ten_env.stop_test() + + +def test_empty_params_fatal_error(): + """Test that empty params raises FATAL ERROR with code -1000""" + print("Starting test_empty_params_fatal_error...") + + # Empty params configuration + empty_params_config = { + "params": { + "api_key": "", + } + } + + tester = ExtensionTesterEmptyParams() + tester.set_test_mode_single("deepgram_tts", json.dumps(empty_params_config)) + + print("Running test...") + tester.run() + print("Test completed.") + + # Verify FATAL ERROR was received + assert tester.error_received, "Expected to receive error message" + assert ( + tester.error_code == -1000 + ), f"Expected error code -1000 (FATAL_ERROR), got {tester.error_code}" + assert tester.error_message is not None, "Error message should not be None" + assert len(tester.error_message) > 0, "Error message should not be empty" + + print(f"Empty params test passed: code={tester.error_code}") + + +# ================ test invalid api key ================ +class ExtensionTesterInvalidApiKey(ExtensionTester): + def __init__(self): + super().__init__() + self.error_received = False + self.error_code = None + self.error_message = None + self.vendor_info = None + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + """Called when test starts, sends a TTS request to trigger the logic.""" + ten_env_tester.log_info( + "Invalid API key test started, sending TTS request" + ) + + tts_input = TTSTextInput( + request_id="test-request-invalid-key", + text="This text will trigger API key validation.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + ten_env.log_info(f"on_data name: {name}") + + if name == "error": + self.error_received = True + json_str, _ = data.get_property_to_json(None) + error_data = json.loads(json_str) + + self.error_code = error_data.get("code") + self.error_message = error_data.get("message", "") + self.vendor_info = error_data.get("vendor_info", {}) + + ten_env.log_info( + f"Received error: code={self.error_code}, message={self.error_message}" + ) + ten_env.stop_test() + elif name == "tts_audio_end": + ten_env.stop_test() + + +@patch("deepgram_tts.deepgram_tts.websockets.connect") +def test_invalid_api_key_error(mock_websocket_connect): + """Test that an invalid API key is handled correctly with a mock.""" + print("Starting test_invalid_api_key_error with mock...") + + # Mock websocket to raise 401 unauthorized error + mock_websocket_connect.side_effect = Exception( + "401 Unauthorized - Invalid API key" + ) + + # Config with invalid API key + invalid_key_config = { + "params": { + "api_key": "invalid_api_key_test", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + + tester = ExtensionTesterInvalidApiKey() + tester.set_test_mode_single("deepgram_tts", json.dumps(invalid_key_config)) + + print("Running test with mock...") + tester.run() + print("Test with mock completed.") + + # Verify FATAL ERROR was received for incorrect API key + assert tester.error_received, "Expected to receive error message" + assert ( + tester.error_code == -1000 + ), f"Expected error code -1000 (FATAL_ERROR), got {tester.error_code}" + + # Verify vendor_info + vendor_info = tester.vendor_info + assert vendor_info is not None, "Expected vendor_info to be present" + assert ( + vendor_info.get("vendor") == "deepgram" + ), f"Expected vendor 'deepgram', got {vendor_info.get('vendor')}" + + print(f"Invalid API key test passed: code={tester.error_code}") diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_metrics.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_metrics.py new file mode 100644 index 0000000000..60d7cdfe20 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_metrics.py @@ -0,0 +1,127 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from unittest.mock import patch, AsyncMock +import asyncio + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput +from deepgram_tts.deepgram_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, +) + + +# ================ test metrics ================ +class ExtensionTesterMetrics(ExtensionTester): + def __init__(self): + super().__init__() + self.ttfb_received = False + self.ttfb_value = -1 + self.audio_frame_received = False + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + """Called when test starts, sends a TTS request.""" + ten_env_tester.log_info("Metrics test started, sending TTS request.") + + tts_input = TTSTextInput( + request_id="tts_request_for_metrics", + text="hello, this is a metrics test.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + ten_env.log_info(f"on_data name: {name}") + if name == "metrics": + json_str, _ = data.get_property_to_json(None) + ten_env.log_info(f"Received metrics: {json_str}") + metrics_data = json.loads(json_str) + + # According to the structure, 'ttfb' is nested inside a 'metrics' object. + nested_metrics = metrics_data.get("metrics", {}) + if "ttfb" in nested_metrics: + self.ttfb_received = True + self.ttfb_value = nested_metrics.get("ttfb", -1) + ten_env.log_info( + f"Received TTFB metric with value: {self.ttfb_value}" + ) + + elif name == "tts_audio_end": + self.audio_end_received = True + # Stop the test only after both TTFB and audio end are received + if self.ttfb_received: + ten_env.log_info("Received tts_audio_end, stopping test.") + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + """Receives audio frames and confirms the stream is working.""" + if not self.audio_frame_received: + self.audio_frame_received = True + ten_env.log_info("First audio frame received.") + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_ttfb_metric_is_sent(MockDeepgramTTSClient): + """ + Tests that a TTFB (Time To First Byte) metric is correctly sent after + receiving the first audio chunk from the TTS service. + """ + print("Starting test_ttfb_metric_is_sent with mock...") + + # --- Mock Configuration --- + mock_instance = MockDeepgramTTSClient.return_value + mock_instance.start = AsyncMock() + mock_instance.stop = AsyncMock() + mock_instance.cancel = AsyncMock() + mock_instance.reset_ttfb = lambda: None + + # This async generator simulates the TTS client's get() method with a delay + async def mock_get_audio_with_delay(text: str): + await asyncio.sleep(0.2) + yield (255, EVENT_TTS_TTFB_METRIC) + yield (b"\x11\x22\x33", EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock_instance.get.side_effect = mock_get_audio_with_delay + + # --- Test Setup --- + metrics_config = { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + } + } + tester = ExtensionTesterMetrics() + tester.set_test_mode_single("deepgram_tts", json.dumps(metrics_config)) + + print("Running TTFB metrics test...") + tester.run() + print("TTFB metrics test completed.") + + # --- Assertions --- + assert tester.audio_frame_received, "Did not receive any audio frame." + assert tester.audio_end_received, "Did not receive the tts_audio_end event." + assert tester.ttfb_received, "TTFB metric was not received." + + # Check if the TTFB value matches what we sent + assert ( + tester.ttfb_value == 255 + ), f"Expected TTFB to be 255ms, but got {tester.ttfb_value}ms." + + print(f"TTFB metric test passed. Received TTFB: {tester.ttfb_value}ms.") diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_params.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_params.py new file mode 100644 index 0000000000..48ed8fe1b6 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_params.py @@ -0,0 +1,180 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from urllib.parse import parse_qs, urlparse +from unittest.mock import patch, AsyncMock, MagicMock + + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput +from deepgram_tts.deepgram_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, +) +from deepgram_tts.config import DeepgramTTSConfig +from deepgram_tts.deepgram_tts import DeepgramTTSClient + + +def create_mock_client(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + fake_audio = b"\x00\x01\x02\x03" * 100 + + async def mock_get(text): + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + +def test_params_passthrough(): + """Additional Deepgram params should be appended to the websocket URL.""" + config = DeepgramTTSConfig( + params={ + "api_key": "test_api_key", + "base_url": "wss://api.deepgram.com/v1/speak", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + "bit_rate": 64000, + "container": "none", + } + ) + config.update_params() + + client = DeepgramTTSClient(config=config, ten_env=MagicMock()) + parsed = urlparse(client._ws_url) + query = parse_qs(parsed.query) + + assert parsed.scheme == "wss" + assert parsed.netloc == "api.deepgram.com" + assert parsed.path == "/v1/speak" + assert query["model"] == ["aura-2-thalia-en"] + assert query["encoding"] == ["linear16"] + assert query["sample_rate"] == ["24000"] + assert query["bit_rate"] == ["64000"] + assert query["container"] == ["none"] + assert "api_key" not in query + assert "base_url" not in query + + +# ================ test different sample rates ================ +class ExtensionTesterSampleRate(ExtensionTester): + def __init__(self, sample_rate: int): + super().__init__() + self.sample_rate = sample_rate + self.audio_end_received = False + self.audio_chunks_count = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info(f"Sample rate test: {self.sample_rate}Hz") + + tts_input = TTSTextInput( + request_id="tts_request_sr", + text="Testing different sample rates.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + self.audio_chunks_count += 1 + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_sample_rate_16000(MockDeepgramTTSClient): + """Test with 16000 Hz sample rate.""" + MockDeepgramTTSClient.return_value = create_mock_client() + + tester = ExtensionTesterSampleRate(16000) + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 16000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_end_received, "tts_audio_end was not received." + assert tester.audio_chunks_count > 0, "No audio chunks received." + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_sample_rate_24000(MockDeepgramTTSClient): + """Test with 24000 Hz sample rate.""" + MockDeepgramTTSClient.return_value = create_mock_client() + + tester = ExtensionTesterSampleRate(24000) + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_end_received, "tts_audio_end was not received." + assert tester.audio_chunks_count > 0, "No audio chunks received." + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_sample_rate_48000(MockDeepgramTTSClient): + """Test with 48000 Hz sample rate.""" + MockDeepgramTTSClient.return_value = create_mock_client() + + tester = ExtensionTesterSampleRate(48000) + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 48000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_end_received, "tts_audio_end was not received." + assert tester.audio_chunks_count > 0, "No audio chunks received." diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_robustness.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_robustness.py new file mode 100644 index 0000000000..6191c8f14a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_robustness.py @@ -0,0 +1,267 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from unittest.mock import patch, AsyncMock + + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput +from deepgram_tts.deepgram_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, +) +from unittest.mock import MagicMock + + +def create_mock_client(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + fake_audio = b"\x00\x01\x02\x03" * 100 + + async def mock_get(text): + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + +# ================ test empty text ================ +class ExtensionTesterEmptyText(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Empty text test started.") + + tts_input = TTSTextInput( + request_id="tts_request_empty", + text="", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end for empty text.") + self.audio_end_received = True + ten_env.stop_test() + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_empty_text(MockDeepgramTTSClient): + """Test that empty text is handled gracefully.""" + MockDeepgramTTSClient.return_value = create_mock_client() + + tester = ExtensionTesterEmptyText() + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert ( + tester.audio_end_received + ), "tts_audio_end should be sent for empty text." + + +# ================ test whitespace only text ================ +class ExtensionTesterWhitespaceText(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Whitespace text test started.") + + tts_input = TTSTextInput( + request_id="tts_request_whitespace", + text=" \n\t ", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end for whitespace text.") + self.audio_end_received = True + ten_env.stop_test() + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_whitespace_text(MockDeepgramTTSClient): + """Test that whitespace-only text is handled gracefully.""" + MockDeepgramTTSClient.return_value = create_mock_client() + + tester = ExtensionTesterWhitespaceText() + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert ( + tester.audio_end_received + ), "tts_audio_end should be sent for whitespace text." + + +# ================ test long text ================ +class ExtensionTesterLongText(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + self.audio_chunks_count = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Long text test started.") + + long_text = "This is a longer piece of text. " * 20 + + tts_input = TTSTextInput( + request_id="tts_request_long", + text=long_text, + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end for long text.") + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + self.audio_chunks_count += 1 + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_long_text(MockDeepgramTTSClient): + """Test that long text is handled correctly.""" + MockDeepgramTTSClient.return_value = create_mock_client() + + tester = ExtensionTesterLongText() + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert ( + tester.audio_end_received + ), "tts_audio_end was not received for long text." + assert ( + tester.audio_chunks_count > 0 + ), "No audio chunks received for long text." + + +# ================ test special characters ================ +class ExtensionTesterSpecialChars(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + self.error_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Special characters test started.") + + tts_input = TTSTextInput( + request_id="tts_request_special", + text="Hello! How are you? I'm fine, thanks. $100 is 100%.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + self.audio_end_received = True + ten_env.stop_test() + elif name == "error": + self.error_received = True + ten_env.stop_test() + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_special_characters(MockDeepgramTTSClient): + """Test that special characters are handled correctly.""" + MockDeepgramTTSClient.return_value = create_mock_client() + + tester = ExtensionTesterSpecialChars() + tester.set_test_mode_single( + "deepgram_tts", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_end_received, "tts_audio_end was not received." + assert ( + not tester.error_received + ), "Error should not be received for special chars." diff --git a/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_state_machine.py b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_state_machine.py new file mode 100644 index 0000000000..12650f9d2c --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/deepgram_tts/tests/test_state_machine.py @@ -0,0 +1,426 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import asyncio +import json +from unittest.mock import patch, AsyncMock, MagicMock + + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput +from deepgram_tts.deepgram_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, + EVENT_TTS_ERROR, + DeepgramTTSClient, +) +from deepgram_tts.config import DeepgramTTSConfig + +MOCK_CONFIG = { + "params": { + "api_key": "test_api_key", + "model": "aura-2-thalia-en", + "encoding": "linear16", + "sample_rate": 24000, + }, +} + + +def create_mock_client(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + fake_audio = b"\x00\x01\x02\x03" * 100 + + async def mock_get(text): + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + +# ================ test sequential requests ================ +class SequentialRequestsTester(ExtensionTester): + """Send 3 requests with different IDs sequentially. + + Each request should produce tts_audio_start, audio + frames, and tts_audio_end with the correct request_id. + """ + + def __init__(self): + super().__init__() + self.completed_request_ids = [] + self.audio_start_ids = [] + self.expected_ids = [ + "seq_req_1", + "seq_req_2", + "seq_req_3", + ] + self.send_index = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Sequential requests test started.") + self._send_next(ten_env_tester) + ten_env_tester.on_start_done() + + def _send_next(self, ten_env_tester: TenEnvTester) -> None: + if self.send_index >= len(self.expected_ids): + return + req_id = self.expected_ids[self.send_index] + tts_input = TTSTextInput( + request_id=req_id, + text=f"Hello from request {self.send_index + 1}.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + self.send_index += 1 + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_start": + json_str, _ = data.get_property_to_json("") + d = json.loads(json_str) if json_str else {} + rid = d.get("request_id", "") + self.audio_start_ids.append(rid) + elif name == "tts_audio_end": + json_str, _ = data.get_property_to_json("") + d = json.loads(json_str) if json_str else {} + rid = d.get("request_id", "") + self.completed_request_ids.append(rid) + ten_env.log_info(f"Completed request: {rid}") + if len(self.completed_request_ids) < len(self.expected_ids): + self._send_next(ten_env) + else: + ten_env.stop_test() + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_sequential_requests(MockClient): + """Each sequential request should complete with its own + request_id in audio_start and audio_end.""" + MockClient.return_value = create_mock_client() + + tester = SequentialRequestsTester() + tester.set_test_mode_single("deepgram_tts", json.dumps(MOCK_CONFIG)) + tester.run() + + assert tester.completed_request_ids == [ + "seq_req_1", + "seq_req_2", + "seq_req_3", + ], ( + f"Expected 3 sequential completions, got " + f"{tester.completed_request_ids}" + ) + assert tester.audio_start_ids == [ + "seq_req_1", + "seq_req_2", + "seq_req_3", + ], f"audio_start ids mismatch: {tester.audio_start_ids}" + + +# ================ test reconnect after error ================ +class ReconnectAfterErrorTester(ExtensionTester): + """First request errors, second request should succeed. + + Validates that the client recovers after a mid-stream + failure. + """ + + def __init__(self): + super().__init__() + self.error_received = False + self.second_audio_end = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + # First request will trigger an error + tts_input = TTSTextInput( + request_id="err_req_1", + text="This will error.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + if not self.error_received: + # First request ended (with error) — send + # second request + self.error_received = True + tts_input = TTSTextInput( + request_id="ok_req_2", + text="This should work.", + text_input_end=True, + ) + data2 = Data.create("tts_text_input") + data2.set_property_from_json(None, tts_input.model_dump_json()) + ten_env.send_data(data2) + else: + self.second_audio_end = True + ten_env.stop_test() + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_reconnect_after_error(MockClient): + """After an error, subsequent requests should succeed.""" + call_count = 0 + + def create_mock(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + + fake_audio = b"\x00\x01" * 200 + + async def mock_get(text): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: error + yield ( + b"Simulated error", + EVENT_TTS_ERROR, + ) + else: + # Subsequent calls: success + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + MockClient.return_value = create_mock() + + tester = ReconnectAfterErrorTester() + tester.set_test_mode_single("deepgram_tts", json.dumps(MOCK_CONFIG)) + tester.run() + + assert ( + tester.second_audio_end + ), "Second request should complete after first errored." + + +# ================ test config redaction ================ +def test_config_redacts_api_key(): + """to_str(sensitive_handling=True) must not leak the + API key.""" + config = DeepgramTTSConfig( + params={ + "api_key": "super-secret-key-12345", + "model": "aura-2-thalia-en", + } + ) + config.update_params() + + safe_str = config.to_str(sensitive_handling=True) + + assert "super-secret-key-12345" not in safe_str + assert "aura-2-thalia-en" in safe_str + + +# ================ test empty text yields END ================ +def test_client_empty_text_yields_end(): + """get() with empty text should yield EVENT_TTS_END + immediately without connecting.""" + + async def _run(): + ten_env = MagicMock() + ten_env.log_warn = MagicMock() + config = DeepgramTTSConfig(api_key="test") + client = DeepgramTTSClient(config=config, ten_env=ten_env) + + events = [] + async for data, event in client.get(""): + events.append(event) + + assert events == [EVENT_TTS_END] + assert client._ws is None # no connection made + + asyncio.run(_run()) + + +def test_client_whitespace_text_yields_end(): + """get() with whitespace-only text should yield + EVENT_TTS_END.""" + + async def _run(): + ten_env = MagicMock() + ten_env.log_warn = MagicMock() + config = DeepgramTTSConfig(api_key="test") + client = DeepgramTTSClient(config=config, ten_env=ten_env) + + events = [] + async for data, event in client.get(" \n\t "): + events.append(event) + + assert events == [EVENT_TTS_END] + + asyncio.run(_run()) + + +# ================ test 401 emits exactly one error ================ +class AuthErrorTester(ExtensionTester): + """Validates that a 401 auth failure emits exactly one + error event and one terminal audio_end.""" + + def __init__(self): + super().__init__() + self.error_count = 0 + self.audio_end_count = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + tts_input = TTSTextInput( + request_id="auth_err_req", + text="This should fail with 401.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "error": + self.error_count += 1 + elif name == "tts_audio_end": + self.audio_end_count += 1 + ten_env.stop_test() + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_auth_error_single_emission(MockClient): + """401 should produce exactly 1 error event, not + duplicates.""" + from deepgram_tts.deepgram_tts import ( + DeepgramTTSConnectionException, + ) + + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + + async def mock_get_auth_fail(text): + raise DeepgramTTSConnectionException( + status_code=401, body="Unauthorized" + ) + yield # make it a generator # pragma: no cover + + mock.get.side_effect = mock_get_auth_fail + MockClient.return_value = mock + + tester = AuthErrorTester() + tester.set_test_mode_single("deepgram_tts", json.dumps(MOCK_CONFIG)) + tester.run() + + assert tester.error_count == 1, ( + f"Expected exactly 1 error event, got " f"{tester.error_count}" + ) + + +# ================ test non-final error contract ================ +class NonFinalErrorTester(ExtensionTester): + """Validates that an error on a non-final chunk does NOT + produce a public error event. Partial stream errors are + transient — only logged, not surfaced to callers.""" + + def __init__(self): + super().__init__() + self.error_count = 0 + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + # First chunk: non-final, will error + tts_input = TTSTextInput( + request_id="nonfinal_req", + text="First chunk errors.", + text_input_end=False, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + + # Second chunk: final, succeeds + tts_input2 = TTSTextInput( + request_id="nonfinal_req", + text="Second chunk works.", + text_input_end=True, + ) + data2 = Data.create("tts_text_input") + data2.set_property_from_json(None, tts_input2.model_dump_json()) + ten_env_tester.send_data(data2) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "error": + self.error_count += 1 + elif name == "tts_audio_end": + self.audio_end_received = True + ten_env.stop_test() + + +@patch("deepgram_tts.extension.DeepgramTTSClient") +def test_nonfinal_error_not_surfaced(MockClient): + """Error on non-final chunk should not emit public + error event. This is the intended contract: partial + stream errors are transient.""" + call_count = 0 + + def create_mock(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + + fake_audio = b"\x00\x01" * 200 + + async def mock_get(text): + nonlocal call_count + call_count += 1 + if call_count == 1: + yield (b"Transient error", EVENT_TTS_ERROR) + else: + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + MockClient.return_value = create_mock() + + tester = NonFinalErrorTester() + tester.set_test_mode_single("deepgram_tts", json.dumps(MOCK_CONFIG)) + tester.run() + + assert tester.error_count == 0, ( + f"Non-final error should not produce public error " + f"event, got {tester.error_count}" + ) + assert ( + tester.audio_end_received + ), "Request should still complete after non-final error"