diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/README.md b/ai_agents/agents/ten_packages/extension/generic_video_python/README.md index 7d7c80f1d6..308e2ef482 100644 --- a/ai_agents/agents/ten_packages/extension/generic_video_python/README.md +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/README.md @@ -1,29 +1,66 @@ # generic_video_python - +TEN avatar/video extension for ConvoAI-compatible providers using REST session +setup plus a WebSocket audio stream. ## Features - - -- xxx feature +- starts and stops remote avatar sessions using the current ConvoAI contract +- sends `init`, `voice`, `voice_end`, `voice_interrupt`, and `heartbeat` +- forwards the actual incoming TEN audio frame sample rate without resampling +- supports dynamic channel injection via the canonical `channel` property +- masks API keys in config logging ## API -Refer to `api` definition in [manifest.json] and default values in [property.json](property.json). +Refer to the `api.property.properties` schema in +[manifest.json](manifest.json) and defaults in [property.json](property.json). - +Canonical properties: +- `channel` +- `agora_avatar_uid` +- `generic_video_api_key` +- `avatar_id` +- `quality` +- `version` +- `video_encoding` +- `enable_string_uid` +- `activity_idle_timeout` +- `area` +- `start_endpoint` +- `stop_endpoint` +- `input_audio_sample_rate` +- `params` -## Development +Backward-compatible aliases still accepted by the loader: +- `agora_channel_name` -> `channel` +- `agora_video_uid` -> `agora_avatar_uid` + +Vendor passthrough: +- `params` can contain vendor-specific top-level fields +- known keys are normalized onto the named config fields first +- unknown keys are forwarded as top-level keys in both the session start body + and the WebSocket `init` message +- `api_key` inside `params` is accepted as an alias for + `generic_video_api_key` and is not forwarded downstream -### Build +## Protocol Notes - +- REST start payload includes `area` +- WebSocket `init` payload includes `area` +- stop requests send both `session_id` and `session_token` in the DELETE body +- WebSocket auth uses `Authorization: Bearer {session_token}` + +## Development ### Unit test - +```bash +tests/bin/start +``` ## Misc - +This package is validated against the checked-out `convoai_to_video` +reference contract, but the tests use local fixtures and mocks rather than +importing that repo directly. diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/__init__.py b/ai_agents/agents/ten_packages/extension/generic_video_python/__init__.py index 72593ab225..da402faf43 100644 --- a/ai_agents/agents/ten_packages/extension/generic_video_python/__init__.py +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/__init__.py @@ -3,4 +3,3 @@ # Licensed under the Apache License, Version 2.0. # See the LICENSE file for more information. # -from . import addon diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/config.py b/ai_agents/agents/ten_packages/extension/generic_video_python/config.py new file mode 100644 index 0000000000..0b7010acf0 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/config.py @@ -0,0 +1,190 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from __future__ import annotations + +import copy +from typing import Any + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, field_validator +from ten_ai_base import utils +from ten_runtime import AsyncTenEnv + + +VALID_QUALITIES = {"low", "medium", "high"} +VALID_VIDEO_ENCODINGS = {"H264", "VP8", "AV1"} +VALID_AREAS = { + "GLOBAL", + "NORTH_AMERICA", + "EUROPE", + "ASIA", + "INDIA", + "JAPAN", +} + +_PARAM_FIELD_ALIASES: dict[str, str] = { + "api_key": "generic_video_api_key", + "generic_video_api_key": "generic_video_api_key", + "agora_appid": "agora_appid", + "agora_appcert": "agora_appcert", + "channel": "channel", + "agora_channel_name": "channel", + "agora_avatar_uid": "agora_avatar_uid", + "agora_video_uid": "agora_avatar_uid", + "avatar_id": "avatar_id", + "quality": "quality", + "version": "version", + "video_encoding": "video_encoding", + "enable_string_uid": "enable_string_uid", + "activity_idle_timeout": "activity_idle_timeout", + "area": "area", + "start_endpoint": "start_endpoint", + "stop_endpoint": "stop_endpoint", + "input_audio_sample_rate": "input_audio_sample_rate", +} + +_PASSTHROUGH_EXCLUDED_KEYS = set(_PARAM_FIELD_ALIASES) + + +class GenericVideoConfig(BaseModel): + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + agora_appid: str = "" + agora_appcert: str = "" + channel: str = Field( + default="", + validation_alias=AliasChoices("channel", "agora_channel_name"), + ) + agora_avatar_uid: int = Field( + default=0, + validation_alias=AliasChoices( + "agora_avatar_uid", + "agora_video_uid", + ), + ) + generic_video_api_key: str = "" + avatar_id: str = "16cb73e7de08" + quality: str = "high" + version: str = "v1" + video_encoding: str = "H264" + enable_string_uid: bool = False + activity_idle_timeout: int = 120 + area: str = "GLOBAL" + start_endpoint: str = "https://api.example.com/v1/sessions/start" + stop_endpoint: str = "https://api.example.com/v1/sessions/stop" + input_audio_sample_rate: int = 48000 + params: dict[str, Any] = Field(default_factory=dict) + + @classmethod + async def create_async( + cls, ten_env: AsyncTenEnv + ) -> "GenericVideoConfig": + config_json, _ = await ten_env.get_property_to_json("") + config = cls.model_validate_json(config_json or "{}") + config.normalize_params() + config.validate_required() + return config + + def normalize_params(self) -> None: + params = self._ensure_dict(self.params) + normalized_params: dict[str, Any] = {} + + for key, value in params.items(): + if value is None: + continue + field_name = _PARAM_FIELD_ALIASES.get(key) + if field_name: + setattr(self, field_name, value) + continue + normalized_params[key] = value + + self.params = normalized_params + + def validate_required(self) -> None: + required_fields = { + "agora_appid": self.agora_appid, + "channel": self.channel, + "generic_video_api_key": self.generic_video_api_key, + "avatar_id": self.avatar_id, + "start_endpoint": self.start_endpoint, + "stop_endpoint": self.stop_endpoint, + } + + for field_name, value in required_fields.items(): + if not value or (isinstance(value, str) and value.strip() == ""): + raise ValueError( + f"Required field is missing or empty: {field_name}" + ) + + def to_str(self, sensitive_handling: bool = True) -> str: + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + if config.generic_video_api_key: + config.generic_video_api_key = utils.encrypt( + config.generic_video_api_key + ) + if "api_key" in config.params and config.params["api_key"]: + config.params["api_key"] = utils.encrypt(config.params["api_key"]) + return f"{config}" + + @property + def vendor_params(self) -> dict[str, Any]: + return { + key: value + for key, value in self.params.items() + if key not in _PASSTHROUGH_EXCLUDED_KEYS and value is not None + } + + @field_validator("quality") + @classmethod + def validate_quality(cls, value: str) -> str: + if value not in VALID_QUALITIES: + raise ValueError( + f"quality must be one of: {', '.join(sorted(VALID_QUALITIES))}" + ) + return value + + @field_validator("video_encoding") + @classmethod + def validate_video_encoding(cls, value: str) -> str: + if value not in VALID_VIDEO_ENCODINGS: + raise ValueError( + "video_encoding must be one of: " + f"{', '.join(sorted(VALID_VIDEO_ENCODINGS))}" + ) + return value + + @field_validator("area") + @classmethod + def validate_area(cls, value: str) -> str: + if value not in VALID_AREAS: + raise ValueError( + f"area must be one of: {', '.join(sorted(VALID_AREAS))}" + ) + return value + + @field_validator("activity_idle_timeout") + @classmethod + def validate_activity_idle_timeout(cls, value: int) -> int: + if value < 0: + raise ValueError("activity_idle_timeout must be >= 0") + return value + + @field_validator("input_audio_sample_rate") + @classmethod + def validate_input_audio_sample_rate(cls, value: int) -> int: + if value <= 0: + raise ValueError("input_audio_sample_rate must be > 0") + return value + + @staticmethod + def _ensure_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return dict(value) + if value is None: + return {} + return dict(value) diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/extension.py b/ai_agents/agents/ten_packages/extension/generic_video_python/extension.py index 97e5cfd3bd..3c20446293 100644 --- a/ai_agents/agents/ten_packages/extension/generic_video_python/extension.py +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/extension.py @@ -17,52 +17,17 @@ CmdResult, Data, ) -from ten_ai_base.config import BaseConfig from .generic import AgoraGenericRecorder -from dataclasses import dataclass +from .config import GenericVideoConfig # Error codes ERROR_CODE_CONFIG_VALIDATION_ERROR = -1012 -@dataclass -class GenericVideoConfig(BaseConfig): - agora_appid: str = "" - agora_appcert: str = "" - agora_channel_name: str = "" - agora_video_uid: int = 0 - generic_video_api_key: str = "" - avatar_id: str = "16cb73e7de08" - quality: str = "high" - version: str = "v1" - video_encoding: str = "H264" - enable_string_uid: bool = False - activity_idle_timeout: int = 60 - start_endpoint: str = "https://api.example.com/v1/sessions/start" - stop_endpoint: str = "https://api.example.com/v1/sessions/stop" - input_audio_sample_rate: int = 48000 - - def validate_params(self) -> None: - """Validate required configuration parameters.""" - required_fields = { - "agora_appid": self.agora_appid, - "generic_video_api_key": self.generic_video_api_key, - "avatar_id": self.avatar_id, - "start_endpoint": self.start_endpoint, - "stop_endpoint": self.stop_endpoint, - } - - for field_name, value in required_fields.items(): - if not value or (isinstance(value, str) and value.strip() == ""): - raise ValueError( - f"Required field is missing or empty: {field_name}" - ) - - class GenericVideoExtension(AsyncExtension): def __init__(self, name: str): super().__init__(name) - self.config = None + self.config: GenericVideoConfig | None = None self.input_audio_queue = asyncio.Queue() self.recorder: AgoraGenericRecorder = None self.ten_env: AsyncTenEnv = None @@ -71,6 +36,7 @@ def __init__(self, name: str): self._audio_task = None self._config_valid = False # Track configuration validation status self._connection_task = None + self._sample_rate_mismatch_warned = False async def on_init(self, ten_env: AsyncTenEnv) -> None: ten_env.log_debug("on_init") @@ -81,38 +47,38 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: try: self.config = await GenericVideoConfig.create_async(ten_env) - - # Validate configuration - self.config.validate_params() self._config_valid = True - # Log configuration summary ten_env.log_info( - f"[GENERIC-VIDEO] Config: avatar={self.config.avatar_id} " - f"quality={self.config.quality} sample_rate={self.config.input_audio_sample_rate}" + "[GENERIC-VIDEO] Config: " + f"{self.config.to_str(sensitive_handling=True)}" ) recorder = AgoraGenericRecorder( api_key=self.config.generic_video_api_key, app_id=self.config.agora_appid, app_cert=self.config.agora_appcert, - channel_name=self.config.agora_channel_name, - avatar_uid=self.config.agora_video_uid, + channel_name=self.config.channel, + avatar_uid=self.config.agora_avatar_uid, ten_env=ten_env, avatar_id=self.config.avatar_id, activity_idle_timeout=self.config.activity_idle_timeout, quality=self.config.quality, version=self.config.version, video_encoding=self.config.video_encoding, + area=self.config.area, enable_string_uid=self.config.enable_string_uid, start_endpoint=self.config.start_endpoint, stop_endpoint=self.config.stop_endpoint, + vendor_params=self.config.vendor_params, ) self.recorder = recorder self._audio_processing_enabled = True - asyncio.create_task(self._loop_input_audio_sender(ten_env)) + self._audio_task = asyncio.create_task( + self._loop_input_audio_sender(ten_env) + ) await self.recorder.connect() @@ -126,26 +92,36 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None: async def _loop_input_audio_sender(self, _: AsyncTenEnv): while self._audio_processing_enabled: - audio_frame = await self.input_audio_queue.get() + audio_frame, actual_sample_rate = await self.input_audio_queue.get() # Wait for recorder to be ready await self._wait_for_recorder_ready() if self.recorder is not None and self.recorder.ws_connected(): try: - original_rate = self.config.input_audio_sample_rate + expected_rate = self.config.input_audio_sample_rate if len(audio_frame) == 0: continue - # Send audio at original sample rate - let server handle any resampling + if ( + expected_rate != actual_sample_rate + and not self._sample_rate_mismatch_warned + ): + self._sample_rate_mismatch_warned = True + self.ten_env.log_warn( + "Audio frame sample rate does not match configured " + f"input_audio_sample_rate: actual={actual_sample_rate}, " + f"configured={expected_rate}. " + "Sending the actual frame rate without resampling." + ) + base64_audio_data = base64.b64encode(audio_frame).decode( "utf-8" ) - # Update the recorder to send with actual sample rate await self.recorder.send( - base64_audio_data, sample_rate=original_rate + base64_audio_data, sample_rate=actual_sample_rate ) except Exception as e: @@ -297,7 +273,9 @@ async def on_audio_frame( self.ten_env.log_warn("on_audio_frame: empty pcm_frame detected.") return - self.input_audio_queue.put_nowait(frame_buf) + self.input_audio_queue.put_nowait( + (frame_buf, audio_frame_sample_rate) + ) async def on_video_frame( self, ten_env: AsyncTenEnv, video_frame: VideoFrame diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/generic.py b/ai_agents/agents/ten_packages/extension/generic_video_python/generic.py index c52fea65a2..e7a90ceb27 100644 --- a/ai_agents/agents/ten_packages/extension/generic_video_python/generic.py +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/generic.py @@ -1,10 +1,14 @@ +import asyncio +import hashlib import json import os +import random +import tempfile import uuid -import asyncio -import requests +from typing import Any + +import httpx import websockets -import random from time import time from agora_token_builder import RtcTokenBuilder @@ -31,8 +35,6 @@ class AgoraGenericRecorder: - SESSION_CACHE_PATH = "/tmp/generic_session_id.txt" - def __init__( self, app_id: str, @@ -45,13 +47,17 @@ def __init__( quality: str, version: str, video_encoding: str, + area: str, enable_string_uid: bool, start_endpoint: str, stop_endpoint: str, activity_idle_timeout: int, + vendor_params: dict[str, Any] | None = None, + http_client: httpx.AsyncClient | None = None, + session_cache_path: str | None = None, ): # Validate required fields - self._validate_config(app_id, api_key, avatar_id) + self._validate_config(app_id, api_key, avatar_id, channel_name) self.app_id = app_id self.app_cert = app_cert @@ -63,36 +69,49 @@ def __init__( self.quality = quality self.version = version self.video_encoding = video_encoding + self.area = area self.enable_string_uid = enable_string_uid self.start_endpoint = start_endpoint self.stop_endpoint = stop_endpoint self.activity_idle_timeout = activity_idle_timeout + self.vendor_params = dict(vendor_params or {}) self.token_server = self._generate_token(self.uid_avatar, 1) + self.session_cache_path = session_cache_path or self._build_cache_path() self.headers = { "accept": "application/json", "content-type": "application/json", "x-api-key": self.api_key, } + self.http_client = http_client or httpx.AsyncClient(timeout=30.0) + self._owns_http_client = http_client is None self.session_id = None self.session_token = None self.realtime_endpoint = None self.websocket = None self.websocket_task = None self.heartbeat_task = None + self.listener_task = None self._should_reconnect = True self._connection_broken = False # Flag to trigger reconnection self._speak_end_timer_task: asyncio.Task | None = None self._speak_end_event = asyncio.Event() - def _validate_config(self, app_id: str, api_key: str, avatar_id: str): + def _validate_config( + self, + app_id: str, + api_key: str, + avatar_id: str, + channel_name: str, + ): """Validate required configuration parameters.""" required_fields = { "app_id": app_id, "api_key": api_key, "avatar_id": avatar_id, + "channel_name": channel_name, } for field_name, value in required_fields.items(): @@ -117,21 +136,124 @@ def _generate_token(self, uid, role): privilege_expired_ts, ) - def _load_cached_session_id(self): - if os.path.exists(self.SESSION_CACHE_PATH): - with open(self.SESSION_CACHE_PATH, "r", encoding="utf-8") as f: - return f.read().strip() + def _build_cache_path(self) -> str: + fingerprint = hashlib.sha1( + "|".join( + [ + self.start_endpoint, + self.stop_endpoint, + self.channel_name, + str(self.uid_avatar), + self.avatar_id, + ] + ).encode("utf-8") + ).hexdigest()[:12] + return os.path.join( + tempfile.gettempdir(), + f"generic_video_session_{fingerprint}.json", + ) + + def _load_cached_session(self) -> dict[str, str] | None: + if os.path.exists(self.session_cache_path): + with open(self.session_cache_path, "r", encoding="utf-8") as f: + raw = f.read().strip() + if not raw: + return None + try: + data = json.loads(raw) + if isinstance(data, dict): + return { + "session_id": str(data.get("session_id", "")), + "session_token": str(data.get("session_token", "")), + } + except json.JSONDecodeError: + return {"session_id": raw, "session_token": ""} return None - def _save_session_id(self, session_id: str): - with open(self.SESSION_CACHE_PATH, "w", encoding="utf-8") as f: - f.write(session_id) + def _save_session(self, session_id: str, session_token: str): + with open(self.session_cache_path, "w", encoding="utf-8") as f: + json.dump( + { + "session_id": session_id, + "session_token": session_token, + }, + f, + ) - def _clear_session_id_cache(self): - if os.path.exists(self.SESSION_CACHE_PATH): - os.remove(self.SESSION_CACHE_PATH) + def _clear_session_cache(self): + if os.path.exists(self.session_cache_path): + os.remove(self.session_cache_path) + + def _masked_headers( + self, headers: dict[str, str] | None = None + ) -> dict[str, str]: + safe_headers = dict(headers or self.headers) + if safe_headers.get("x-api-key"): + safe_headers["x-api-key"] = "***masked***" + if safe_headers.get("authorization"): + safe_headers["authorization"] = "Bearer ***masked***" + return safe_headers + + def _build_start_payload(self) -> dict[str, Any]: + payload = { + "avatar_id": self.avatar_id, + "quality": self.quality, + "version": self.version, + "video_encoding": self.video_encoding, + "activity_idle_timeout": self.activity_idle_timeout, + "area": self.area, + "agora_settings": { + "app_id": self.app_id, + "token": self.token_server, + "channel": self.channel_name, + "uid": str(self.uid_avatar), + "enable_string_uid": self.enable_string_uid, + }, + } + payload.update(self.vendor_params) + return payload - def get_connection_status(self) -> dict[str, any]: + def _build_init_payload(self) -> dict[str, Any]: + payload = { + "command": "init", + "session_id": self.session_id, + "avatar_id": self.avatar_id, + "quality": self.quality, + "version": self.version, + "video_encoding": self.video_encoding, + "activity_idle_timeout": self.activity_idle_timeout, + "area": self.area, + "agora_settings": { + "app_id": self.app_id, + "token": self.token_server, + "channel": self.channel_name, + "uid": str(self.uid_avatar), + "enable_string_uid": self.enable_string_uid, + }, + } + payload.update(self.vendor_params) + return payload + + def _build_stop_payload( + self, + session_id: str, + session_token: str | None = None, + ) -> dict[str, str]: + token = session_token or self.session_token or "" + if not token: + raise ValueError("session_token is required to stop a session") + return { + "session_id": session_id, + "session_token": token, + } + + def _websocket_headers(self) -> dict[str, str]: + headers: dict[str, str] = {} + if self.session_token: + headers["authorization"] = f"Bearer {self.session_token}" + return headers + + def get_connection_status(self) -> dict[str, Any]: """Get current connection status and information.""" return { "connected": self.websocket is not None, @@ -142,21 +264,32 @@ def get_connection_status(self) -> dict[str, any]: async def connect(self): # Check and stop old session if needed - old_session_id = self._load_cached_session_id() - if old_session_id: + old_session = self._load_cached_session() + if old_session and old_session.get("session_id"): try: - self.ten_env.log_info( - f"Found previous session id: {old_session_id}, attempting to stop it." - ) - await self._stop_session(old_session_id) - self.ten_env.log_info("Previous session stopped.") - self._clear_session_id_cache() + old_session_token = old_session.get("session_token", "") + if old_session_token: + self.ten_env.log_info( + "Found previous cached session, attempting to stop it." + ) + await self._stop_session( + old_session["session_id"], + session_token=old_session_token, + ) + self.ten_env.log_info("Previous session stopped.") + else: + self.ten_env.log_warn( + "Found legacy cached session without session_token. " + "Clearing cache because stop endpoint now requires both " + "session_id and session_token." + ) + self._clear_session_cache() except Exception as e: self.ten_env.log_error(f"Failed to stop old session: {e}") try: await self._create_session() - self._save_session_id(self.session_id) + self._save_session(self.session_id, self.session_token) # Start WebSocket connection self.websocket_task = asyncio.create_task( @@ -203,6 +336,17 @@ async def disconnect(self): f"Error while cancelling WebSocket task: {e}" ) + if self.listener_task: + self.listener_task.cancel() + try: + await self.listener_task + except asyncio.CancelledError: + pass + except Exception as e: + self.ten_env.log_error( + f"Error while cancelling listener task: {e}" + ) + # Stop session if self.session_id: try: @@ -213,32 +357,24 @@ async def disconnect(self): code=ERROR_CODE_FAILED_TO_STOP_SESSION, ) + if self._owns_http_client: + await self.http_client.aclose() + self.ten_env.log_info("Disconnection completed") async def _create_session(self): - payload = { - "avatar_id": self.avatar_id, - "quality": self.quality, - "version": self.version, - "video_encoding": self.video_encoding, - "activity_idle_timeout": self.activity_idle_timeout, - "agora_settings": { - "app_id": self.app_id, - "token": self.token_server, - "channel": self.channel_name, - "uid": str(self.uid_avatar), - "enable_string_uid": self.enable_string_uid, - }, - } + payload = self._build_start_payload() # Log the request details using existing logging mechanism self.ten_env.log_info("Creating new session with details:") self.ten_env.log_info(f"URL: {self.start_endpoint}") - self.ten_env.log_info(f"Headers: {json.dumps(self.headers, indent=2)}") + self.ten_env.log_info( + f"Headers: {json.dumps(self._masked_headers(), indent=2)}" + ) self.ten_env.log_info(f"Payload: {json.dumps(payload, indent=2)}") - response = requests.post( - self.start_endpoint, json=payload, headers=self.headers, timeout=30 + response = await self.http_client.post( + self.start_endpoint, json=payload, headers=self.headers ) await self._raise_for_status_verbose(response) data = response.json() @@ -258,7 +394,7 @@ async def _create_session(self): async def _raise_for_status_verbose(self, response): try: response.raise_for_status() - except requests.HTTPError as e: + except httpx.HTTPStatusError as e: # Try to parse JSON error response error_details = f"HTTP {response.status_code} Error: {e}" try: @@ -282,30 +418,32 @@ async def _raise_for_status_verbose(self, response): ) raise - async def _stop_session(self, session_id: str): + async def _stop_session( + self, + session_id: str, + session_token: str | None = None, + ): try: - # Payload contains only session_id - payload = {"session_id": session_id} - - # Add session token to headers for authentication - headers = self.headers.copy() - if self.session_token: - headers["authorization"] = f"Bearer {self.session_token}" + payload = self._build_stop_payload( + session_id, + session_token=session_token, + ) self.ten_env.log_info("_stop_session with details:") self.ten_env.log_info(f"URL: {self.stop_endpoint}") self.ten_env.log_info( - f"Headers: {json.dumps({k: v for k, v in headers.items() if k != 'authorization'}, indent=2)}" + f"Headers: {json.dumps(self._masked_headers(), indent=2)}" ) - self.ten_env.log_info("Authorization: Bearer ***masked***") self.ten_env.log_info(f"Payload: {json.dumps(payload, indent=2)}") - # Use DELETE method as specified in API documentation - response = requests.delete( - self.stop_endpoint, json=payload, headers=headers, timeout=30 + response = await self.http_client.request( + "DELETE", + self.stop_endpoint, + json=payload, + headers=self.headers, ) await self._raise_for_status_verbose(response) - self._clear_session_id_cache() + self._clear_session_cache() except Exception as e: self.ten_env.log_error(f"Failed to stop session: {e}") raise @@ -345,12 +483,8 @@ async def _connect_websocket_loop(self): f"Connecting to WebSocket at {self.realtime_endpoint} (attempt {attempt + 1})" ) - # Prepare WebSocket headers with session token (same as test scripts) - headers = {} - if self.session_token: - headers["authorization"] = f"Bearer {self.session_token}" + headers = self._websocket_headers() - # Use additional_headers for WebSocket authentication (websockets 10.4) async with websockets.connect( self.realtime_endpoint, additional_headers=headers ) as websocket: @@ -363,29 +497,15 @@ async def _connect_websocket_loop(self): "WebSocket connected successfully with headers" ) - # Send initial configuration payload with init command - initial_payload = { - "command": "init", - "session_id": self.session_id, - "avatar_id": self.avatar_id, - "quality": self.quality, - "version": self.version, - "video_encoding": self.video_encoding, - "activity_idle_timeout": self.activity_idle_timeout, - "agora_settings": { - "app_id": self.app_id, - "token": self.token_server, - "channel": self.channel_name, - "uid": str(self.uid_avatar), - "enable_string_uid": self.enable_string_uid, - }, - } + initial_payload = self._build_init_payload() await self.websocket.send(json.dumps(initial_payload)) self.ten_env.log_info("Sent initial configuration payload") # Start listening for messages - asyncio.create_task(self._listen_for_messages()) + self.listener_task = asyncio.create_task( + self._listen_for_messages() + ) # Wait for connection to be broken or cancelled while ( @@ -400,6 +520,13 @@ async def _connect_websocket_loop(self): attempt += 1 await self._handle_connection_error(e, attempt) finally: + if self.listener_task and not self.listener_task.done(): + self.listener_task.cancel() + try: + await self.listener_task + except asyncio.CancelledError: + pass + self.listener_task = None self.websocket = None async def _handle_connection_error( @@ -435,11 +562,11 @@ async def _handle_connection_error( ) try: # Clear old session cache - self._clear_session_id_cache() + self._clear_session_cache() # Create a new session await self._create_session() - self._save_session_id(self.session_id) + self._save_session(self.session_id, self.session_token) self.ten_env.log_info(f"New session created: {self.session_id}") # Continue with normal delay logic instead of immediate retry diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/manifest.json b/ai_agents/agents/ten_packages/extension/generic_video_python/manifest.json index 3c92436b42..3aec5780d9 100644 --- a/ai_agents/agents/ten_packages/extension/generic_video_python/manifest.json +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/manifest.json @@ -30,14 +30,45 @@ "agora_appcert": { "type": "string" }, - "agora_channel_name": { + "channel": { "type": "string" }, "agora_avatar_uid": { "type": "int64" }, + "avatar_id": { + "type": "string" + }, + "quality": { + "type": "string" + }, + "version": { + "type": "string" + }, + "video_encoding": { + "type": "string" + }, + "enable_string_uid": { + "type": "bool" + }, + "activity_idle_timeout": { + "type": "int64" + }, + "area": { + "type": "string" + }, + "start_endpoint": { + "type": "string" + }, + "stop_endpoint": { + "type": "string" + }, "input_audio_sample_rate": { "type": "int64" + }, + "params": { + "type": "object", + "properties": {} } } } @@ -45,4 +76,4 @@ "scripts": { "test": "tests/bin/start" } -} \ No newline at end of file +} diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/property.json b/ai_agents/agents/ten_packages/extension/generic_video_python/property.json index f3a65c64eb..825a3c13cb 100644 --- a/ai_agents/agents/ten_packages/extension/generic_video_python/property.json +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/property.json @@ -1,8 +1,18 @@ { "agora_appid": "${env:AGORA_APP_ID}", "agora_appcert": "${env:AGORA_APP_CERTIFICATE}", - "agora_channel_name": "ten_agent_test", + "channel": "ten_agent_test", "agora_avatar_uid": 12345, "generic_video_api_key": "${env:GENERIC_VIDEO_API_KEY}", - "input_audio_sample_rate": 16000 -} \ No newline at end of file + "avatar_id": "16cb73e7de08", + "quality": "high", + "version": "v1", + "video_encoding": "H264", + "enable_string_uid": false, + "activity_idle_timeout": 120, + "area": "GLOBAL", + "start_endpoint": "https://api.example.com/v1/sessions/start", + "stop_endpoint": "https://api.example.com/v1/sessions/stop", + "input_audio_sample_rate": 16000, + "params": {} +} diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/requirements.txt b/ai_agents/agents/ten_packages/extension/generic_video_python/requirements.txt index 611085e982..35d36446ef 100644 --- a/ai_agents/agents/ten_packages/extension/generic_video_python/requirements.txt +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/requirements.txt @@ -1,5 +1,4 @@ agora-token-builder>=1.0.0 -requests>=2.32.3 +httpx>=0.27.0 +pydantic>=2.0.0 websockets>=15.0.1 -scipy -numpy \ No newline at end of file diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/tests/__init__.py b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/__init__.py new file mode 100644 index 0000000000..da402faf43 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/__init__.py @@ -0,0 +1,5 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/tests/bin/start b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/bin/start new file mode 100755 index 0000000000..1407caea4d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/bin/start @@ -0,0 +1,9 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." + +export PYTHONPATH=.ten/app:.ten/app/ten_packages/system/ten_runtime_python/lib:.ten/app/ten_packages/system/ten_runtime_python/interface:.ten/app/ten_packages/system/ten_ai_base/interface:$PYTHONPATH + +python3 -m pytest tests/ "$@" diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/tests/conftest.py b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/conftest.py new file mode 100644 index 0000000000..645c4f3f3d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/conftest.py @@ -0,0 +1,127 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from __future__ import annotations + +import sys +import types + + +def _install_ten_runtime_stub() -> None: + if "ten_runtime" in sys.modules: + return + + module = types.ModuleType("ten_runtime") + + class AsyncExtension: + def __init__(self, _name: str): + pass + + class AsyncTenEnv: + pass + + class AudioFrame: + pass + + class VideoFrame: + pass + + class Cmd: + def __init__(self, name: str): + self._name = name + + @classmethod + def create(cls, name: str) -> "Cmd": + return cls(name) + + def get_name(self) -> str: + return self._name + + class StatusCode: + OK = 0 + + class CmdResult: + @classmethod + def create(cls, _status_code, _cmd): + return cls() + + class Data: + def __init__(self, name: str): + self._name = name + self._payload = "" + + @classmethod + def create(cls, name: str) -> "Data": + return cls(name) + + def get_name(self) -> str: + return self._name + + def set_property_from_json(self, _path, payload: str) -> None: + self._payload = payload + + def get_property_to_json(self, _path): + return self._payload, None + + class Addon: + pass + + class TenEnv: + def on_create_instance_done(self, *_args, **_kwargs): + pass + + def register_addon_as_extension(_name: str): + def decorator(cls): + return cls + + return decorator + + module.AsyncExtension = AsyncExtension + module.AsyncTenEnv = AsyncTenEnv + module.AudioFrame = AudioFrame + module.VideoFrame = VideoFrame + module.Cmd = Cmd + module.StatusCode = StatusCode + module.CmdResult = CmdResult + module.Data = Data + module.Addon = Addon + module.TenEnv = TenEnv + module.register_addon_as_extension = register_addon_as_extension + sys.modules["ten_runtime"] = module + + +def _install_ten_ai_base_stub() -> None: + if "ten_ai_base" in sys.modules: + return + + module = types.ModuleType("ten_ai_base") + utils_module = types.ModuleType("ten_ai_base.utils") + + def encrypt(value: str) -> str: + if not value: + return value + return "*" * min(len(value), 6) + + class ErrorMessage: + def __init__(self, module: str, message: str, code: int): + self.module = module + self.message = message + self.code = code + + def model_dump_json(self) -> str: + return ( + '{"module":"%s","message":"%s","code":%d}' + % (self.module, self.message, self.code) + ) + + utils_module.encrypt = encrypt + module.utils = utils_module + module.ErrorMessage = ErrorMessage + sys.modules["ten_ai_base"] = module + sys.modules["ten_ai_base.utils"] = utils_module + + +_install_ten_runtime_stub() +_install_ten_ai_base_stub() diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_config.py b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_config.py new file mode 100644 index 0000000000..7b900ed9f2 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_config.py @@ -0,0 +1,106 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import pytest + +from generic_video_python.config import GenericVideoConfig + + +def test_config_accepts_legacy_aliases(): + config = GenericVideoConfig.model_validate( + { + "agora_appid": "appid", + "agora_channel_name": "room-a", + "agora_video_uid": 42, + "generic_video_api_key": "secret", + "avatar_id": "avatar", + "start_endpoint": "https://example.test/start", + "stop_endpoint": "https://example.test/stop", + } + ) + + assert config.channel == "room-a" + assert config.agora_avatar_uid == 42 + + +def test_config_masks_sensitive_fields(): + config = GenericVideoConfig.model_validate( + { + "agora_appid": "appid", + "channel": "room-a", + "generic_video_api_key": "secret", + "avatar_id": "avatar", + "start_endpoint": "https://example.test/start", + "stop_endpoint": "https://example.test/stop", + } + ) + + config_str = config.to_str(sensitive_handling=True) + + assert "secret" not in config_str + assert "******" in config_str or "*" in config_str + + +def test_config_normalizes_known_params_and_keeps_vendor_passthrough(): + config = GenericVideoConfig.model_validate( + { + "agora_appid": "appid", + "channel": "room-a", + "generic_video_api_key": "secret", + "avatar_id": "avatar", + "start_endpoint": "https://example.test/start", + "stop_endpoint": "https://example.test/stop", + "params": { + "api_key": "secret-2", + "agora_channel_name": "room-b", + "agora_video_uid": 77, + "area": "JAPAN", + "model": "vendor-model-1", + "style": "cinematic", + }, + } + ) + + config.normalize_params() + + assert config.generic_video_api_key == "secret-2" + assert config.channel == "room-b" + assert config.agora_avatar_uid == 77 + assert config.area == "JAPAN" + assert config.vendor_params == { + "model": "vendor-model-1", + "style": "cinematic", + } + + +@pytest.mark.parametrize("quality", ["bad", "HIGH", ""]) +def test_config_rejects_invalid_quality(quality: str): + with pytest.raises(Exception): + GenericVideoConfig.model_validate( + { + "agora_appid": "appid", + "channel": "room-a", + "generic_video_api_key": "secret", + "avatar_id": "avatar", + "start_endpoint": "https://example.test/start", + "stop_endpoint": "https://example.test/stop", + "quality": quality, + } + ) + + +def test_config_rejects_invalid_area(): + with pytest.raises(Exception): + GenericVideoConfig.model_validate( + { + "agora_appid": "appid", + "channel": "room-a", + "generic_video_api_key": "secret", + "avatar_id": "avatar", + "start_endpoint": "https://example.test/start", + "stop_endpoint": "https://example.test/stop", + "area": "MARS", + } + ) diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_extension.py b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_extension.py new file mode 100644 index 0000000000..9e55561456 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_extension.py @@ -0,0 +1,148 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from __future__ import annotations + +import asyncio +import json + +from generic_video_python.extension import GenericVideoExtension + + +class FakeAudioFrame: + def __init__(self, payload: bytes, sample_rate: int): + self._payload = payload + self._sample_rate = sample_rate + + def get_name(self) -> str: + return "pcm" + + def get_sample_rate(self) -> int: + return self._sample_rate + + def get_buf(self) -> bytes: + return self._payload + + +class FakeData: + def __init__(self, name: str, payload: dict): + self._name = name + self._payload = payload + + def get_name(self) -> str: + return self._name + + def get_property_to_json(self, _path): + return json.dumps(self._payload), None + + +class FakeEnv: + def __init__(self): + self.warns: list[str] = [] + self.infos: list[str] = [] + self.debugs: list[str] = [] + self.errors: list[str] = [] + self.sent_cmds = [] + + def log_warn(self, msg: str, **_kwargs): + self.warns.append(msg) + + def log_info(self, msg: str, **_kwargs): + self.infos.append(msg) + + def log_debug(self, msg: str, **_kwargs): + self.debugs.append(msg) + + def log_error(self, msg: str, **_kwargs): + self.errors.append(msg) + + async def send_cmd(self, cmd): + self.sent_cmds.append(cmd.get_name()) + + async def return_result(self, _result): + return None + + +class FakeRecorder: + def __init__(self): + self.voice_end_count = 0 + self.sent_audio: list[tuple[str, int]] = [] + + def ws_connected(self) -> bool: + return True + + async def send_voice_end(self): + self.voice_end_count += 1 + + async def send(self, audio_base64: str, sample_rate: int): + self.sent_audio.append((audio_base64, sample_rate)) + + async def interrupt(self): + return True + + +def test_on_audio_frame_queues_actual_sample_rate(): + async def _run(): + extension = GenericVideoExtension("generic_video_python") + extension.ten_env = FakeEnv() + extension._audio_processing_enabled = True + + frame = FakeAudioFrame(b"\x01\x02", 44100) + await extension.on_audio_frame(extension.ten_env, frame) + + payload, sample_rate = extension.input_audio_queue.get_nowait() + assert payload == b"\x01\x02" + assert sample_rate == 44100 + + asyncio.run(_run()) + + +def test_audio_sender_uses_actual_sample_rate_and_warns_once(): + async def _run(): + extension = GenericVideoExtension("generic_video_python") + env = FakeEnv() + extension.ten_env = env + extension.config = type( + "Config", + (), + {"input_audio_sample_rate": 16000}, + )() + extension.recorder = FakeRecorder() + extension._audio_processing_enabled = True + + await extension.input_audio_queue.put((b"\x00\x01", 48000)) + + task = asyncio.create_task(extension._loop_input_audio_sender(env)) + await asyncio.sleep(0.05) + extension._audio_processing_enabled = False + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert extension.recorder.sent_audio[0][1] == 48000 + assert len(env.warns) == 1 + + asyncio.run(_run()) + + +def test_tts_audio_end_reason_one_triggers_voice_end(): + async def _run(): + extension = GenericVideoExtension("generic_video_python") + extension.recorder = FakeRecorder() + env = FakeEnv() + + await extension.on_data( + env, + FakeData( + "tts_audio_end", + {"reason": 1, "request_id": "req-1"}, + ), + ) + + assert extension.recorder.voice_end_count == 1 + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_protocol.py b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_protocol.py new file mode 100644 index 0000000000..9fc59bb741 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/generic_video_python/tests/test_protocol.py @@ -0,0 +1,227 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from __future__ import annotations + +import asyncio +import json + +import httpx +import pytest + +from generic_video_python.generic import AgoraGenericRecorder + + +class FakeTenEnv: + def __init__(self): + self.infos: list[str] = [] + self.errors: list[str] = [] + self.warns: list[str] = [] + self.debugs: list[str] = [] + + def log_info(self, msg: str, **_kwargs) -> None: + self.infos.append(msg) + + def log_error(self, msg: str, **_kwargs) -> None: + self.errors.append(msg) + + def log_warn(self, msg: str, **_kwargs) -> None: + self.warns.append(msg) + + def log_debug(self, msg: str, **_kwargs) -> None: + self.debugs.append(msg) + + async def send_data(self, _data) -> None: + return None + + +class RecordingWebSocket: + def __init__(self): + self.messages: list[dict] = [] + self.state = type("OpenState", (), {"name": "OPEN"})() + + async def send(self, payload: str) -> None: + self.messages.append(json.loads(payload)) + + +def create_recorder( + *, + http_client: httpx.AsyncClient | None = None, + session_cache_path: str | None = None, + vendor_params: dict | None = None, +) -> AgoraGenericRecorder: + return AgoraGenericRecorder( + app_id="appid", + app_cert="", + api_key="api-key", + channel_name="room-a", + avatar_uid=321, + ten_env=FakeTenEnv(), + avatar_id="avatar-1", + quality="high", + version="v1", + video_encoding="H264", + area="NORTH_AMERICA", + enable_string_uid=False, + start_endpoint="https://example.test/session/start", + stop_endpoint="https://example.test/session/stop", + activity_idle_timeout=120, + vendor_params=vendor_params, + http_client=http_client, + session_cache_path=session_cache_path, + ) + + +def test_start_and_init_payloads_match_contract(): + recorder = create_recorder(vendor_params={"model": "vendor-model-1"}) + recorder.session_id = "session-1" + + start_payload = recorder._build_start_payload() + init_payload = recorder._build_init_payload() + + assert start_payload["area"] == "NORTH_AMERICA" + assert start_payload["agora_settings"]["channel"] == "room-a" + assert start_payload["agora_settings"]["uid"] == "321" + assert start_payload["model"] == "vendor-model-1" + assert init_payload["command"] == "init" + assert init_payload["session_id"] == "session-1" + assert init_payload["area"] == "NORTH_AMERICA" + assert init_payload["model"] == "vendor-model-1" + + +def test_stop_payload_requires_session_token(): + recorder = create_recorder() + + with pytest.raises(ValueError): + recorder._build_stop_payload("session-1") + + +def test_stop_payload_includes_session_token(): + recorder = create_recorder() + payload = recorder._build_stop_payload( + "session-1", session_token="session-token" + ) + + assert payload == { + "session_id": "session-1", + "session_token": "session-token", + } + + +def test_cache_path_is_scoped_per_recorder(): + recorder_a = create_recorder() + recorder_b = AgoraGenericRecorder( + app_id="appid", + app_cert="", + api_key="api-key", + channel_name="room-b", + avatar_uid=321, + ten_env=FakeTenEnv(), + avatar_id="avatar-1", + quality="high", + version="v1", + video_encoding="H264", + area="GLOBAL", + enable_string_uid=False, + start_endpoint="https://example.test/session/start", + stop_endpoint="https://example.test/session/stop", + activity_idle_timeout=120, + http_client=httpx.AsyncClient( + transport=httpx.MockTransport( + lambda request: httpx.Response(200, json={}) + ) + ), + ) + + assert recorder_a.session_cache_path != recorder_b.session_cache_path + + asyncio.run(recorder_a.http_client.aclose()) + asyncio.run(recorder_b.http_client.aclose()) + + +def test_create_session_sends_area_and_parses_response(): + captured = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["json"] = json.loads(request.content.decode()) + captured["headers"] = dict(request.headers) + return httpx.Response( + 200, + json={ + "session_id": "session-1", + "websocket_address": "ws://example.test/ws", + "session_token": "token-1", + }, + ) + + async def _run(): + client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + recorder = create_recorder( + http_client=client, + vendor_params={"model": "vendor-model-1", "style": "studio"}, + ) + await recorder._create_session() + await client.aclose() + + assert captured["json"]["area"] == "NORTH_AMERICA" + assert captured["json"]["activity_idle_timeout"] == 120 + assert captured["json"]["model"] == "vendor-model-1" + assert captured["json"]["style"] == "studio" + assert captured["headers"]["x-api-key"] == "api-key" + assert recorder.session_id == "session-1" + assert recorder.realtime_endpoint == "ws://example.test/ws" + assert recorder.session_token == "token-1" + + asyncio.run(_run()) + + +def test_stop_session_sends_session_token_in_body(tmp_path): + captured = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["method"] = request.method + captured["json"] = json.loads(request.content.decode()) + return httpx.Response(200, json={"status": "success"}) + + async def _run(): + client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + recorder = create_recorder( + http_client=client, + session_cache_path=str(tmp_path / "session.json"), + ) + recorder.session_token = "token-1" + recorder._save_session("session-1", "token-1") + await recorder._stop_session("session-1") + await client.aclose() + + assert captured["method"] == "DELETE" + assert captured["json"] == { + "session_id": "session-1", + "session_token": "token-1", + } + assert not (tmp_path / "session.json").exists() + + asyncio.run(_run()) + + +def test_send_interrupt_voice_end_and_voice_messages(): + async def _run(): + recorder = create_recorder() + recorder.websocket = RecordingWebSocket() + + await recorder.send("YWJj", sample_rate=44100) + await recorder.interrupt() + await recorder.send_voice_end() + messages = recorder.websocket.messages + + assert messages[0]["command"] == "voice" + assert messages[0]["sampleRate"] == 44100 + assert messages[0]["encoding"] == "PCM16" + assert messages[1]["command"] == "voice_interrupt" + assert messages[2]["command"] == "voice_end" + + await recorder.http_client.aclose() + + asyncio.run(_run())