diff --git a/ai_agents/.env.example b/ai_agents/.env.example index a367cf720f..844176e462 100644 --- a/ai_agents/.env.example +++ b/ai_agents/.env.example @@ -75,6 +75,9 @@ DEEPSEEK_API_KEY= STEPFUN_API_KEY= GLADIA_API_KEY= +# SiliconFlow unified API key (LLM / TTS) +SILICONFLOW_API_KEY= + # Extension: bedrock_llm # Extension: polly_tts AWS_ACCESS_KEY_ID= diff --git a/ai_agents/AGENTS.md b/ai_agents/AGENTS.md index 88ed9863d9..7c7abb27fc 100644 --- a/ai_agents/AGENTS.md +++ b/ai_agents/AGENTS.md @@ -509,7 +509,7 @@ Required `.env` variables depend on extensions used. Common ones: - `DEEPGRAM_API_KEY`, `AZURE_ASR_API_KEY`, `AZURE_ASR_REGION` **TTS:** -- `ELEVENLABS_TTS_KEY`, `AZURE_TTS_KEY`, `AZURE_TTS_REGION` +- `ELEVENLABS_TTS_KEY`, `AZURE_TTS_KEY`, `AZURE_TTS_REGION`, `SILICONFLOW_API_KEY` See `.env.example` for complete list. diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/AGENT.md b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/AGENT.md new file mode 100644 index 0000000000..50291f9a5a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/AGENT.md @@ -0,0 +1,18 @@ +# siliconflow_tts2_python/ +> L2 | 父级: /mnt/e/wsf-project/ai_agents/AGENTS.md + +成员清单 +`__init__.py`: 包声明,保持 Python 扩展目录可导入。 +`addon.py`: TEN addon 注册入口,暴露 `siliconflow_tts2_python` 扩展实例。 +`config.py`: SiliconFlow TTS 配置模型,归一化默认参数并校验采样率与响应格式。 +`extension.py`: HTTP TTS 基座适配层,负责创建配置/客户端并暴露采样率。 +`siliconflow_tts.py`: 供应商 HTTP 客户端,请求 `/audio/speech`,嗅探真实响应格式,并把 MPEG 解码成 PCM 数据块。 +`wav_stream_parser.py`: 流式 WAV 头解析器,仅在响应真实为 RIFF/WAV 时拆出 PCM 数据。 +`manifest.json`: 扩展元数据与属性模式,供 tman 和 TEN 运行时读取。 +`property.json`: 默认属性模板,约定 SiliconFlow 的环境变量和默认音色。 +`requirements.txt`: Python 依赖声明,包含 `httpx` 与 `miniaudio`。 +`README.md`: 扩展说明与最小配置示例。 + +法则: 成员完整·一行一文件·父级链接·技术词前置 + +[PROTOCOL]: 变更时更新此头部,然后检查 AGENT.md diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/CLAUDE.md b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/CLAUDE.md new file mode 100644 index 0000000000..90fbe27e66 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/CLAUDE.md @@ -0,0 +1,18 @@ +# siliconflow_tts2_python/ +> L2 | 父级: /mnt/e/wsf-project/ai_agents/CLAUDE.md + +成员清单 +`__init__.py`: 包声明,保持 Python 扩展目录可导入。 +`addon.py`: TEN addon 注册入口,暴露 `siliconflow_tts2_python` 扩展实例。 +`config.py`: SiliconFlow TTS 配置模型,归一化默认参数并校验采样率与响应格式。 +`extension.py`: HTTP TTS 基座适配层,负责创建配置/客户端并暴露采样率。 +`siliconflow_tts.py`: 供应商 HTTP 客户端,请求 `/audio/speech`,嗅探真实响应格式,并把 MPEG 解码成 PCM 数据块。 +`wav_stream_parser.py`: 流式 WAV 头解析器,仅在响应真实为 RIFF/WAV 时拆出 PCM 数据。 +`manifest.json`: 扩展元数据与属性模式,供 tman 和 TEN 运行时读取。 +`property.json`: 默认属性模板,约定 SiliconFlow 的环境变量和默认音色。 +`requirements.txt`: Python 依赖声明,包含 `httpx` 与 `miniaudio`。 +`README.md`: 扩展说明与最小配置示例。 + +法则: 成员完整·一行一文件·父级链接·技术词前置 + +[PROTOCOL]: 变更时更新此头部,然后检查 AGENT.md diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/README.md b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/README.md new file mode 100644 index 0000000000..b4622f29cd --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/README.md @@ -0,0 +1,30 @@ +# siliconflow_tts2_python + +SiliconFlow TTS extension built on TEN's `AsyncTTS2HttpExtension`. + +## Notes + +- Uses `POST /v1/audio/speech` +- Defaults to `response_format: "mp3"` because SiliconFlow currently returns `audio/mpeg` +- The extension decodes returned MP3 into mono 16-bit PCM before handing audio to TEN + +## Required Params + +```json +{ + "params": { + "api_key": "${env:SILICONFLOW_API_KEY}", + "base_url": "https://api.siliconflow.cn/v1", + "model": "IndexTeam/IndexTTS-2", + "voice": "IndexTeam/IndexTTS-2:anna" + } +} +``` + +## Optional Params + +- `sample_rate` +- `speed` +- `gain` +- `max_tokens` +- `response_format` (`mp3`, `wav` or `pcm`) diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/__init__.py b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/__init__.py new file mode 100644 index 0000000000..d6cb9a597a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/__init__.py @@ -0,0 +1,14 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +/* [INPUT]: 依赖 addon.py 的注册副作用 + * [OUTPUT]: 包导入时自动触发 siliconflow_tts2_python 的 addon 注册 + * [POS]: siliconflow_tts2_python 包入口,适配 Python addon loader 的导入约定 + * [PROTOCOL]: 变更时更新此头部,然后检查 AGENT.md + */ +""" + +from . import addon diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/addon.py b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/addon.py new file mode 100644 index 0000000000..204f6cb6c4 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/addon.py @@ -0,0 +1,26 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +/* [INPUT]: 依赖 ten_runtime 的 Addon/注册器,依赖 extension.py 的 SiliconFlowTTSExtension + * [OUTPUT]: 对外提供 siliconflow_tts2_python 扩展注册入口 + * [POS]: siliconflow_tts2_python 模块的 TEN 入口,被运行时按 addon 名称实例化 + * [PROTOCOL]: 变更时更新此头部,然后检查 AGENT.md + */ +""" + +from ten_runtime import Addon, TenEnv, register_addon_as_extension + +from .extension import SiliconFlowTTSExtension + + +@register_addon_as_extension("siliconflow_tts2_python") +class SiliconFlowTTSExtensionAddon(Addon): + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + ten_env.log_info("SiliconFlowTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done( + SiliconFlowTTSExtension(name), context + ) + diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/config.py b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/config.py new file mode 100644 index 0000000000..b59dff2933 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/config.py @@ -0,0 +1,81 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +/* [INPUT]: 依赖 pydantic 的 Field,依赖 ten_ai_base.tts2_http 的 AsyncTTS2HttpConfig + * [OUTPUT]: 对外提供 SiliconFlowTTSConfig 配置模型和参数校验能力 + * [POS]: siliconflow_tts2_python 的配置归一化层,给 extension/client 提供单一真相源 + * [PROTOCOL]: 变更时更新此头部,然后检查 AGENT.md + */ +""" + +from pathlib import Path +from typing import Any +import copy + +from pydantic import Field + +from ten_ai_base import utils +from ten_ai_base.tts2_http import AsyncTTS2HttpConfig + + +SUPPORTED_RESPONSE_FORMATS = {"wav", "pcm", "mp3"} + + +class SiliconFlowTTSConfig(AsyncTTS2HttpConfig): + dump: bool = Field(default=False, description="SiliconFlow TTS dump") + dump_path: str = Field( + default_factory=lambda: str( + Path(__file__).parent / "siliconflow_tts_in.pcm" + ), + description="SiliconFlow TTS dump path", + ) + params: dict[str, Any] = Field( + default_factory=dict, description="SiliconFlow TTS params" + ) + sample_rate: int = Field(default=32000, description="PCM sample rate") + + def update_params(self) -> None: + self.params.pop("input", None) + self.params["stream"] = True + self.params.setdefault("base_url", "https://api.siliconflow.cn/v1") + self.params.setdefault("model", "IndexTeam/IndexTTS-2") + self.params.setdefault("voice", "IndexTeam/IndexTTS-2:anna") + self.params.setdefault("max_tokens", 2048) + self.params.setdefault("speed", 1) + self.params.setdefault("gain", 0) + self.params.setdefault("response_format", "mp3") + + if "sample_rate" in self.params: + self.sample_rate = int(self.params["sample_rate"]) + else: + self.params["sample_rate"] = self.sample_rate + + def to_str(self, sensitive_handling: bool = True) -> str: + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + if config.params and "api_key" in config.params: + config.params["api_key"] = utils.encrypt(config.params["api_key"]) + return f"{config}" + + def validate(self) -> None: + if "api_key" not in self.params or not self.params["api_key"]: + raise ValueError("API key is required for SiliconFlow TTS") + if "model" not in self.params or not self.params["model"]: + raise ValueError("Model is required for SiliconFlow TTS") + if "voice" not in self.params or not self.params["voice"]: + raise ValueError("Voice is required for SiliconFlow TTS") + + response_format = str(self.params.get("response_format", "wav")).lower() + if response_format not in SUPPORTED_RESPONSE_FORMATS: + raise ValueError( + "SiliconFlow TTS in TEN only supports 'wav', 'pcm' or 'mp3' " + "response_format" + ) + + if self.sample_rate <= 0: + raise ValueError("sample_rate must be a positive integer") diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/extension.py b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/extension.py new file mode 100644 index 0000000000..d8cd1fd384 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/extension.py @@ -0,0 +1,46 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +/* [INPUT]: 依赖 ten_ai_base.tts2_http 的 HTTP TTS 基座,依赖 config.py 和 siliconflow_tts.py + * [OUTPUT]: 对外提供 SiliconFlowTTSExtension 扩展类 + * [POS]: siliconflow_tts2_python 的运行时适配层,负责把 TEN 生命周期接到 SiliconFlow 客户端 + * [PROTOCOL]: 变更时更新此头部,然后检查 AGENT.md + */ +""" + +from ten_ai_base.tts2_http import ( + AsyncTTS2HttpClient, + AsyncTTS2HttpConfig, + AsyncTTS2HttpExtension, +) +from ten_runtime import AsyncTenEnv + +from .config import SiliconFlowTTSConfig +from .siliconflow_tts import SiliconFlowTTSClient + + +class SiliconFlowTTSExtension(AsyncTTS2HttpExtension): + def __init__(self, name: str) -> None: + super().__init__(name) + self.config: SiliconFlowTTSConfig | None = None + self.client: SiliconFlowTTSClient | None = None + + async def create_config(self, config_json_str: str) -> AsyncTTS2HttpConfig: + return SiliconFlowTTSConfig.model_validate_json(config_json_str) + + async def create_client( + self, config: AsyncTTS2HttpConfig, ten_env: AsyncTenEnv + ) -> AsyncTTS2HttpClient: + return SiliconFlowTTSClient(config=config, ten_env=ten_env) + + def vendor(self) -> str: + return "siliconflow" + + def synthesize_audio_sample_rate(self) -> int: + if self.config is None: + return 32000 + return self.config.sample_rate + diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/manifest.json b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/manifest.json new file mode 100644 index 0000000000..18f801c5ec --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/manifest.json @@ -0,0 +1,82 @@ +{ + "type": "extension", + "name": "siliconflow_tts2_python", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "**.tent", + "**.py", + "README.md", + "requirements.txt", + "AGENT.md", + "CLAUDE.md" + ] + }, + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/tts-interface.json" + } + ], + "property": { + "properties": { + "params": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "model": { + "type": "string" + }, + "voice": { + "type": "string" + }, + "sample_rate": { + "type": "int64" + }, + "speed": { + "type": "float64" + }, + "gain": { + "type": "float64" + }, + "max_tokens": { + "type": "int64" + }, + "response_format": { + "type": "string" + }, + "stream": { + "type": "bool" + } + } + }, + "dump": { + "type": "bool" + }, + "dump_path": { + "type": "string" + } + } + } + } +} + diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/property.json b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/property.json new file mode 100644 index 0000000000..99ce02f302 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/property.json @@ -0,0 +1,16 @@ +{ + "dump": false, + "dump_path": "./", + "params": { + "api_key": "${env:SILICONFLOW_API_KEY|}", + "base_url": "https://api.siliconflow.cn/v1", + "model": "IndexTeam/IndexTTS-2", + "voice": "IndexTeam/IndexTTS-2:anna", + "sample_rate": 32000, + "speed": 1, + "gain": 0, + "max_tokens": 2048, + "response_format": "mp3", + "stream": true + } +} diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/requirements.txt b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/requirements.txt new file mode 100644 index 0000000000..bcdd7315a3 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/requirements.txt @@ -0,0 +1,2 @@ +httpx +miniaudio diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/siliconflow_tts.py b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/siliconflow_tts.py new file mode 100644 index 0000000000..dd478b879a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/siliconflow_tts.py @@ -0,0 +1,288 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +/* [INPUT]: 依赖 httpx 的流式 HTTP 客户端,依赖 config.py 与 wav_stream_parser.py + * [OUTPUT]: 对外提供 SiliconFlowTTSClient,向 TEN 输出 PCM 音频块和错误事件 + * [POS]: siliconflow_tts2_python 的核心供应商适配器,负责请求 SiliconFlow /audio/speech + * [PROTOCOL]: 变更时更新此头部,然后检查 AGENT.md + */ +""" + +from typing import Any, AsyncIterator, Tuple + +from httpx import AsyncClient, Limits, Timeout +import miniaudio + +from ten_ai_base.const import LOG_CATEGORY_VENDOR +from ten_ai_base.struct import TTS2HttpResponseEventType +from ten_ai_base.tts2_http import AsyncTTS2HttpClient +from ten_runtime import AsyncTenEnv + +from .config import SiliconFlowTTSConfig +from .wav_stream_parser import WavStreamParser + + +BYTES_PER_SAMPLE = 2 +NUMBER_OF_CHANNELS = 1 +PCM_CHUNK_SIZE = 4096 + + +class SiliconFlowTTSClient(AsyncTTS2HttpClient): + def __init__(self, config: SiliconFlowTTSConfig, ten_env: AsyncTenEnv): + super().__init__() + self.config = config + self.ten_env = ten_env + self._is_cancelled = False + base_url = str(self.config.params.get("base_url", "")).rstrip("/") + self.endpoint = f"{base_url}/audio/speech" + self.headers = { + "Authorization": f"Bearer {self.config.params['api_key']}", + "Content-Type": "application/json", + } + self.client = AsyncClient( + timeout=Timeout(timeout=60.0, connect=10.0), + limits=Limits( + max_connections=100, + max_keepalive_connections=20, + keepalive_expiry=600.0, + ), + http2=True, + ) + + async def cancel(self) -> None: + self.ten_env.log_debug("SiliconFlowTTS: cancel() called.") + self._is_cancelled = True + + async def get( + self, text: str, request_id: str + ) -> AsyncIterator[Tuple[bytes | None, TTS2HttpResponseEventType]]: + self._is_cancelled = False + + if len(text.strip()) == 0: + self.ten_env.log_warn( + f"SiliconFlowTTS: empty text for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield None, TTS2HttpResponseEventType.END + return + + payload = {**self.config.params} + payload.pop("api_key", None) + payload.pop("base_url", None) + payload["input"] = text + payload["stream"] = True + + try: + async with self.client.stream( + "POST", + self.endpoint, + headers=self.headers, + json=payload, + ) as response: + if response.status_code != 200: + error_message = ( + f"HTTP {response.status_code}: {await response.aread()}" + ) + self.ten_env.log_error( + f"vendor_error: {error_message} of request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + if response.status_code in (401, 403): + yield error_message.encode("utf-8"), ( + TTS2HttpResponseEventType.INVALID_KEY_ERROR + ) + else: + yield error_message.encode("utf-8"), ( + TTS2HttpResponseEventType.ERROR + ) + return + + content_type = response.headers.get("content-type", "").lower() + sniffed_format = await self._sniff_response_format( + response.aiter_bytes() + ) + self.ten_env.log_info( + "SiliconFlowTTS: " + f"request_id={request_id}, " + f"requested_format={payload.get('response_format', 'mp3')}, " + f"content_type={content_type}, " + f"sniffed_format={sniffed_format}", + category=LOG_CATEGORY_VENDOR, + ) + + if sniffed_format == "mpeg": + async for chunk in self._iter_mpeg_stream( + request_id=request_id + ): + if chunk is None: + yield None, TTS2HttpResponseEventType.FLUSH + return + yield chunk, TTS2HttpResponseEventType.RESPONSE + elif sniffed_format == "wav": + async for chunk in self._iter_wav_stream( + request_id=request_id + ): + if chunk is None: + yield None, TTS2HttpResponseEventType.FLUSH + return + yield chunk, TTS2HttpResponseEventType.RESPONSE + else: + async for chunk in self._iter_pcm_stream( + request_id=request_id + ): + if chunk is None: + yield None, TTS2HttpResponseEventType.FLUSH + return + yield chunk, TTS2HttpResponseEventType.RESPONSE + + if not self._is_cancelled: + yield None, TTS2HttpResponseEventType.END + + except Exception as exc: + error_message = str(exc) + self.ten_env.log_error( + f"vendor_error: {error_message} of request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + if "401" in error_message or "403" in error_message: + yield error_message.encode("utf-8"), ( + TTS2HttpResponseEventType.INVALID_KEY_ERROR + ) + else: + yield error_message.encode("utf-8"), ( + TTS2HttpResponseEventType.ERROR + ) + + async def _sniff_response_format( + self, byte_stream: AsyncIterator[bytes] + ) -> str: + self._stream_iterator = byte_stream + self._prefetched_chunk = b"" + + async for chunk in self._stream_iterator: + if chunk: + self._prefetched_chunk = chunk + break + + prefix = self._prefetched_chunk[:4] + if prefix == b"RIFF": + return "wav" + if prefix[:3] == b"ID3" or prefix[:2] == b"\xff\xfb": + return "mpeg" + return "pcm" + + async def _iter_wav_stream( + self, request_id: str + ) -> AsyncIterator[bytes | None]: + stream_parser = WavStreamParser(self._iter_prefetched_stream()) + format_info = await stream_parser.get_format_info() + self.config.sample_rate = int( + format_info.get("framerate", self.config.sample_rate) + ) + + channels = int(format_info.get("channels", NUMBER_OF_CHANNELS)) + sample_width = int( + format_info.get("sample_width_bytes", BYTES_PER_SAMPLE) + ) + if channels != NUMBER_OF_CHANNELS or sample_width != BYTES_PER_SAMPLE: + raise ValueError( + "SiliconFlow WAV stream must be mono 16-bit PCM compatible, " + f"got channels={channels}, sample_width={sample_width}" + ) + + async for chunk in stream_parser: + if self._is_cancelled: + self.ten_env.log_debug( + "Cancellation flag detected, stopping SiliconFlow WAV stream " + f"of request_id: {request_id}." + ) + yield None + return + + if len(chunk) > 0: + yield chunk + + async def _iter_mpeg_stream( + self, request_id: str + ) -> AsyncIterator[bytes | None]: + audio_bytes = bytearray() + async for chunk in self._iter_prefetched_stream(): + if self._is_cancelled: + self.ten_env.log_debug( + "Cancellation flag detected before MPEG decode, stopping " + f"SiliconFlow stream of request_id: {request_id}." + ) + yield None + return + + audio_bytes.extend(chunk) + + try: + decoded = miniaudio.decode( + bytes(audio_bytes), + output_format=miniaudio.SampleFormat.SIGNED16, + nchannels=NUMBER_OF_CHANNELS, + sample_rate=self.config.sample_rate, + ) + except Exception as exc: + raise ValueError(f"Failed to decode MPEG audio: {exc}") from exc + + pcm_bytes = decoded.samples.tobytes() + for offset in range(0, len(pcm_bytes), PCM_CHUNK_SIZE): + if self._is_cancelled: + self.ten_env.log_debug( + "Cancellation flag detected during MPEG playback, stopping " + f"SiliconFlow stream of request_id: {request_id}." + ) + yield None + return + + chunk = pcm_bytes[offset : offset + PCM_CHUNK_SIZE] + if chunk: + yield chunk + + async def _iter_pcm_stream( + self, request_id: str + ) -> AsyncIterator[bytes | None]: + cache_audio_bytes = bytearray() + async for chunk in self._iter_prefetched_stream(): + if self._is_cancelled: + self.ten_env.log_debug( + "Cancellation flag detected, stopping SiliconFlow PCM stream " + f"of request_id: {request_id}." + ) + yield None + return + + if len(cache_audio_bytes) > 0: + chunk = bytes(cache_audio_bytes) + chunk + cache_audio_bytes = bytearray() + + left_size = len(chunk) % (BYTES_PER_SAMPLE * NUMBER_OF_CHANNELS) + if left_size > 0: + cache_audio_bytes = bytearray(chunk[-left_size:]) + chunk = chunk[:-left_size] + + if len(chunk) > 0: + yield chunk + + async def _iter_prefetched_stream(self) -> AsyncIterator[bytes]: + if self._prefetched_chunk: + yield self._prefetched_chunk + self._prefetched_chunk = b"" + + async for chunk in self._stream_iterator: + yield chunk + + async def clean(self) -> None: + self.ten_env.log_debug("SiliconFlowTTS: clean() called.") + await self.client.aclose() + + def get_extra_metadata(self) -> dict[str, Any]: + return { + "model": self.config.params.get("model", ""), + "voice": self.config.params.get("voice", ""), + } diff --git a/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/wav_stream_parser.py b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/wav_stream_parser.py new file mode 100644 index 0000000000..470a24cf18 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/siliconflow_tts2_python/wav_stream_parser.py @@ -0,0 +1,76 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +/* [INPUT]: 依赖标准库 io/wave,依赖上游 HTTP 字节异步迭代器 + * [OUTPUT]: 对外提供 WavStreamParser,用于把流式 WAV 头剥离为 PCM 数据流 + * [POS]: siliconflow_tts2_python 的格式桥接层,避免在主客户端里塞额外分支 + * [PROTOCOL]: 变更时更新此头部,然后检查 AGENT.md + */ +""" + +import io +import wave +from typing import Any, AsyncGenerator, AsyncIterator + + +class WavStreamParser: + def __init__( + self, + aiter_bytes: AsyncGenerator[bytes, None] | AsyncIterator[bytes], + initial_buffer_size: int = 4096, + ) -> None: + self._stream_iterator = aiter_bytes + self._initial_buffer_size = initial_buffer_size + self._format_info: dict[str, Any] = {} + self._header_parsed = False + self._first_pcm_chunk: bytes | None = None + + async def _parse_header(self) -> None: + if self._header_parsed: + return + + header_buffer = bytearray() + async for chunk in self._stream_iterator: + header_buffer.extend(chunk) + if len(header_buffer) >= self._initial_buffer_size: + break + + with io.BytesIO(header_buffer) as in_memory_file: + try: + with wave.open(in_memory_file, "rb") as wav_reader: + self._format_info = { + "channels": wav_reader.getnchannels(), + "sample_width_bytes": wav_reader.getsampwidth(), + "framerate": wav_reader.getframerate(), + } + except wave.Error as exc: + raise ValueError( + f"Failed to parse WAV header: {exc}" + ) from exc + + data_chunk_start = header_buffer.find(b"data") + if data_chunk_start == -1: + raise ValueError("The 'data' chunk was not found in the stream") + + pcm_start_offset = data_chunk_start + 8 + self._first_pcm_chunk = bytes(header_buffer[pcm_start_offset:]) + self._header_parsed = True + + async def get_format_info(self) -> dict[str, Any]: + if not self._header_parsed: + await self._parse_header() + return self._format_info + + async def __aiter__(self) -> AsyncGenerator[bytes, None]: + if not self._header_parsed: + await self._parse_header() + + if self._first_pcm_chunk: + yield self._first_pcm_chunk + + async for chunk in self._stream_iterator: + yield chunk +