From 0e25a9506d6d345ce5beec17161168dec0435593 Mon Sep 17 00:00:00 2001 From: Nsuccess Date: Wed, 14 Jan 2026 16:46:10 +0000 Subject: [PATCH 01/10] feat: Add NVIDIA Riva TTS extension (#1964) - Implements text-to-speech using NVIDIA Riva Speech Skills - Supports streaming synthesis with gRPC - Includes comprehensive tests and documentation - Follows TTS2 interface pattern Closes #1964 --- .../nvidia_riva_tts_python/IMPLEMENTATION.md | 269 ++++++++++++++++ .../nvidia_riva_tts_python/README.md | 93 ++++++ .../nvidia_riva_tts_python/__init__.py | 7 + .../extension/nvidia_riva_tts_python/addon.py | 18 ++ .../nvidia_riva_tts_python/config.py | 44 +++ .../nvidia_riva_tts_python/extension.py | 264 ++++++++++++++++ .../nvidia_riva_tts_python/manifest.json | 57 ++++ .../nvidia_riva_tts_python/property.json | 9 + .../nvidia_riva_tts_python/requirements.txt | 2 + .../nvidia_riva_tts_python/riva_tts.py | 143 +++++++++ .../nvidia_riva_tts_python/tests/__init__.py | 4 + .../tests/test_compliance.py | 294 ++++++++++++++++++ .../tests/test_config.py | 67 ++++ .../tests/test_extension.py | 134 ++++++++ 14 files changed, 1405 insertions(+) create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/README.md create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/addon.py create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/config.py create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/extension.py create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/property.json create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py create mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md b/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md new file mode 100644 index 0000000000..b167e3d0f8 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md @@ -0,0 +1,269 @@ +# NVIDIA Riva TTS Extension - Implementation Details + +## Overview + +This document describes the implementation of the NVIDIA Riva TTS extension for TEN Framework. The extension provides high-quality, GPU-accelerated text-to-speech synthesis using NVIDIA Riva Speech Skills. + +## Architecture + +### Component Structure + +``` +nvidia_riva_tts_python/ +├── extension.py # Main extension class +├── riva_tts.py # Riva client implementation +├── config.py # Configuration model +├── addon.py # Extension registration +├── manifest.json # Extension metadata +├── property.json # Default properties +├── requirements.txt # Python dependencies +├── README.md # User documentation +└── tests/ # Test suite + ├── test_config.py + └── test_extension.py +``` + +### Class Hierarchy + +``` +AsyncTTSExtension (base class from ten_ai_base) + └── NvidiaRivaTTSExtension + └── uses NvidiaRivaTTSClient + └── uses riva.client.SpeechSynthesisService +``` + +## Implementation Details + +### 1. Extension Class (`extension.py`) + +The `NvidiaRivaTTSExtension` class inherits from `AsyncTTSExtension` and implements the required abstract methods: + +- **`create_config()`**: Parses JSON configuration into `NvidiaRivaTTSConfig` +- **`create_client()`**: Instantiates `NvidiaRivaTTSClient` with configuration +- **`vendor()`**: Returns "nvidia_riva" as the vendor identifier +- **`synthesize_audio_sample_rate()`**: Returns the configured sample rate + +### 2. Client Implementation (`riva_tts.py`) + +The `NvidiaRivaTTSClient` class handles the actual TTS synthesis: + +#### Initialization +- Creates Riva Auth object with server URI and SSL settings +- Initializes `SpeechSynthesisService` for TTS operations +- Validates server connectivity + +#### Synthesis Method +```python +async def synthesize(self, text: str, request_id: str) -> AsyncIterator[bytes] +``` + +**Flow:** +1. Validates input text (non-empty) +2. Calls `tts_service.synthesize_online()` for streaming synthesis +3. Iterates through audio chunks from Riva +4. Converts audio data to PCM bytes +5. Yields audio chunks for streaming playback +6. Handles cancellation requests + +**Key Features:** +- Streaming synthesis for low latency +- Cancellation support via `_is_cancelled` flag +- Comprehensive logging at each step +- Error handling with detailed messages + +### 3. Configuration (`config.py`) + +The `NvidiaRivaTTSConfig` class extends `AsyncTTSConfig`: + +**Required Parameters:** +- `server`: Riva server address (host:port) +- `language_code`: Language identifier (e.g., "en-US") +- `voice_name`: Voice identifier (e.g., "English-US.Female-1") + +**Optional Parameters:** +- `sample_rate`: Audio sample rate in Hz (default: 16000) +- `use_ssl`: Enable SSL for gRPC (default: false) + +**Validation:** +- Ensures all required parameters are present +- Validates parameter types and formats + +### 4. Addon Registration (`addon.py`) + +Registers the extension with TEN Framework using the `@register_addon_as_extension` decorator. + +## Integration with TEN Framework + +### TTS Interface Compliance + +The extension implements the standard TEN Framework TTS interface defined in `ten_ai_base/api/tts-interface.json`: + +- **Input**: Text data via TEN data messages +- **Output**: PCM audio frames via TEN audio frame messages +- **Control**: Start/stop/cancel commands via TEN commands + +### Message Flow + +``` +1. Text Input → Extension receives text data +2. Configuration → Loads voice, language, sample rate +3. Synthesis → Calls Riva API with streaming +4. Audio Output → Yields PCM audio chunks +5. Completion → Signals end of synthesis +``` + +## NVIDIA Riva Integration + +### gRPC API Usage + +The extension uses the official `nvidia-riva-client` Python package which provides: + +- **Auth**: Authentication and connection management +- **SpeechSynthesisService**: TTS API wrapper +- **AudioEncoding**: Audio format specifications + +### Streaming vs Batch + +The implementation uses **streaming synthesis** (`synthesize_online`) for: +- Lower latency (first audio chunk arrives quickly) +- Better user experience in real-time applications +- Efficient memory usage + +Alternative batch mode (`synthesize`) is available but not used by default. + +### Audio Format + +- **Encoding**: LINEAR_PCM (16-bit signed integer) +- **Sample Rate**: Configurable (default 16000 Hz) +- **Channels**: Mono +- **Byte Order**: Little-endian + +## Error Handling + +### Initialization Errors +- Server unreachable → RuntimeError with connection details +- Invalid credentials → Authentication error +- Missing dependencies → Import error + +### Runtime Errors +- Empty text → Warning logged, no synthesis +- Synthesis failure → RuntimeError with Riva error message +- Cancellation → Graceful stop, log cancellation + +### Logging Strategy + +- **INFO**: Initialization, configuration +- **DEBUG**: Synthesis progress, chunk details +- **WARN**: Empty text, unusual conditions +- **ERROR**: Failures, exceptions + +## Testing + +### Test Coverage + +1. **Configuration Tests** (`test_config.py`) + - Valid configuration creation + - Missing required parameters + - Default values + - Validation logic + +2. **Extension Tests** (`test_extension.py`) + - Extension initialization + - Config creation from JSON + - Sample rate retrieval + - Client creation + +3. **Client Tests** (`test_extension.py`) + - Client initialization with mocked Riva + - Cancellation handling + - Empty text handling + - Synthesis with mocked responses + +### Running Tests + +```bash +# Install test dependencies +pip install pytest pytest-asyncio + +# Run all tests +pytest nvidia_riva_tts_python/tests/ -v + +# Run with coverage +pytest nvidia_riva_tts_python/tests/ --cov=nvidia_riva_tts_python +``` + +## Performance Considerations + +### Latency +- **First chunk**: ~100-200ms (depends on text length and server) +- **Streaming**: Continuous audio delivery +- **GPU acceleration**: Significantly faster than CPU-only TTS + +### Resource Usage +- **Memory**: Minimal (streaming mode) +- **Network**: gRPC connection to Riva server +- **CPU**: Low (Riva does GPU processing) + +### Optimization Tips +1. Use streaming mode for real-time applications +2. Keep Riva server close to application (low network latency) +3. Reuse client connections (handled by extension) +4. Configure appropriate sample rate for use case + +## Deployment + +### Prerequisites +1. NVIDIA Riva server running (see README.md for setup) +2. Network connectivity to Riva server +3. Python 3.8+ with nvidia-riva-client + +### Configuration Example + +```json +{ + "params": { + "server": "riva-server.example.com:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 22050, + "use_ssl": true + } +} +``` + +### Environment Variables + +```bash +export NVIDIA_RIVA_SERVER=localhost:50051 +``` + +## Future Enhancements + +Potential improvements for future versions: + +1. **SSML Support**: Full SSML tag support for advanced speech control +2. **Voice Cloning**: Custom voice model support +3. **Multi-language**: Automatic language detection +4. **Caching**: Cache frequently synthesized phrases +5. **Metrics**: Detailed performance metrics and monitoring +6. **Fallback**: Automatic fallback to alternative TTS if Riva unavailable + +## References + +- [NVIDIA Riva Documentation](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/index.html) +- [Riva Python Client](https://pypi.org/project/nvidia-riva-client/) +- [TEN Framework TTS Interface](https://github.com/TEN-framework/ten-framework) +- [gRPC Python](https://grpc.io/docs/languages/python/) + +## License + +Apache 2.0 - See LICENSE file in the TEN Framework repository. + +## Contributing + +Contributions are welcome! Please: +1. Follow the existing code style +2. Add tests for new features +3. Update documentation +4. Submit PR to TEN Framework repository + diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/README.md b/agents/ten_packages/extension/nvidia_riva_tts_python/README.md new file mode 100644 index 0000000000..d0452fc9c0 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/README.md @@ -0,0 +1,93 @@ +# NVIDIA Riva TTS Python Extension + +This extension provides text-to-speech functionality using NVIDIA Riva Speech Skills. + +## Features + +- High-quality speech synthesis using NVIDIA Riva +- Support for multiple languages and voices +- Streaming and batch synthesis modes +- SSML support for advanced speech control +- GPU-accelerated inference for low latency + +## Prerequisites + +- NVIDIA Riva server running and accessible +- Python 3.8+ +- nvidia-riva-client package + +## Configuration + +The extension can be configured through your property.json: + +```json +{ + "params": { + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + "use_ssl": false + } +} +``` + +### Configuration Options + +**Parameters inside `params` object:** +- `server` (required): Riva server address (format: "host:port") +- `language_code` (required): Language code (e.g., "en-US", "es-ES") +- `voice_name` (required): Voice identifier (e.g., "English-US.Female-1") +- `sample_rate` (optional): Audio sample rate in Hz (default: 16000) +- `use_ssl` (optional): Use SSL for gRPC connection (default: false) + +### Available Voices + +Common voice names include: +- `English-US.Female-1` +- `English-US.Male-1` +- `English-GB.Female-1` +- `Spanish-US.Female-1` + +Check your Riva server documentation for the full list of available voices. + +## Setting up NVIDIA Riva Server + +Follow the [NVIDIA Riva Quick Start Guide](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html) to set up a Riva server. + +Quick setup with Docker: + +```bash +# Download Riva Quick Start scripts +ngc registry resource download-version nvidia/riva/riva_quickstart:2.17.0 + +# Initialize and start Riva +cd riva_quickstart_v2.17.0 +bash riva_init.sh +bash riva_start.sh +``` + +## Environment Variables + +Set the Riva server address via environment variable: + +```bash +export NVIDIA_RIVA_SERVER=localhost:50051 +``` + +## Architecture + +This extension follows the TEN Framework TTS extension pattern: + +- `extension.py`: Main extension class +- `riva_tts.py`: Client implementation with Riva SDK integration +- `config.py`: Configuration model +- `addon.py`: Extension addon registration + +## License + +Apache 2.0 + +## Contributing + +Contributions are welcome! Please submit issues and pull requests to the TEN Framework repository. diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py b/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py new file mode 100644 index 0000000000..2718464193 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py @@ -0,0 +1,7 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +from . import addon + +__all__ = ["addon"] diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py b/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py new file mode 100644 index 0000000000..f880f2d97c --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py @@ -0,0 +1,18 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +from ten_runtime import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("nvidia_riva_tts_python") +class NvidiaRivaTTSExtensionAddon(Addon): + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import NvidiaRivaTTSExtension + + ten_env.log_info("NvidiaRivaTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(NvidiaRivaTTSExtension(name), context) diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/config.py b/agents/ten_packages/extension/nvidia_riva_tts_python/config.py new file mode 100644 index 0000000000..ed044bceed --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/config.py @@ -0,0 +1,44 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +from typing import Any +import copy +from pydantic import Field +from pathlib import Path +from ten_ai_base import utils +from ten_ai_base.tts import AsyncTTSConfig + + +class NvidiaRivaTTSConfig(AsyncTTSConfig): + """NVIDIA Riva TTS Config""" + + dump: bool = Field(default=False, description="NVIDIA Riva TTS dump") + dump_path: str = Field( + default_factory=lambda: str(Path(__file__).parent / "nvidia_riva_tts_in.pcm"), + description="NVIDIA Riva TTS dump path", + ) + params: dict[str, Any] = Field( + default_factory=dict, description="NVIDIA Riva TTS params" + ) + + def update_params(self) -> None: + """Update configuration from params dictionary""" + pass + + def to_str(self, sensitive_handling: bool = True) -> str: + """Convert config to string with optional sensitive data handling.""" + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + return f"{config}" + + def validate(self) -> None: + """Validate NVIDIA Riva-specific configuration.""" + if "server" not in self.params or not self.params["server"]: + raise ValueError("Server address is required for NVIDIA Riva TTS") + if "language_code" not in self.params or not self.params["language_code"]: + raise ValueError("Language code is required for NVIDIA Riva TTS") + if "voice_name" not in self.params or not self.params["voice_name"]: + raise ValueError("Voice name is required for NVIDIA Riva TTS") diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py b/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py new file mode 100644 index 0000000000..44523b89d9 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py @@ -0,0 +1,264 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +NVIDIA Riva TTS Extension + +This extension implements text-to-speech using NVIDIA Riva Speech Skills. +It provides high-quality, GPU-accelerated speech synthesis. +""" + +import asyncio +import time +import traceback +from typing import Optional + +from ten_ai_base.message import ( + ModuleError, + ModuleErrorCode, + ModuleType, + TTSAudioEndReason, +) +from ten_ai_base.struct import TTSTextInput +from ten_ai_base.tts2 import AsyncTTS2BaseExtension, RequestState +from ten_ai_base.const import LOG_CATEGORY_KEY_POINT, LOG_CATEGORY_VENDOR +from ten_runtime import AsyncTenEnv + +from .config import NvidiaRivaTTSConfig +from .riva_tts import NvidiaRivaTTSClient + + +class NvidiaRivaTTSExtension(AsyncTTS2BaseExtension): + """ + NVIDIA Riva TTS Extension implementation. + + Provides text-to-speech synthesis using NVIDIA Riva's gRPC API. + Inherits all common TTS functionality from AsyncTTS2BaseExtension. + """ + + def __init__(self, name: str) -> None: + super().__init__(name) + self.config: Optional[NvidiaRivaTTSConfig] = None + self.client: Optional[NvidiaRivaTTSClient] = None + self.current_request_id: Optional[str] = None + self.request_start_ts: float = 0 + self.first_chunk_ts: float = 0 + self.request_total_audio_duration: int = 0 + self.flush_request_id: Optional[str] = None + self.last_end_request_id: Optional[str] = None + self.audio_start_sent: set[str] = set() + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + """Initialize the extension""" + await super().on_init(ten_env) + ten_env.log_debug("NVIDIA Riva TTS on_init") + + try: + # Load configuration + config_json, _ = await ten_env.get_property_to_json("") + self.config = NvidiaRivaTTSConfig.model_validate_json(config_json) + + ten_env.log_info( + f"config: {self.config.model_dump_json()}", + category=LOG_CATEGORY_KEY_POINT, + ) + + # Create client + self.client = NvidiaRivaTTSClient( + config=self.config, + ten_env=ten_env, + ) + + except Exception as e: + ten_env.log_error(f"on_init failed: {traceback.format_exc()}") + await self.send_tts_error( + request_id="", + error=ModuleError( + message=str(e), + module=ModuleType.TTS, + code=ModuleErrorCode.FATAL_ERROR, + vendor_info={"vendor": "nvidia_riva"}, + ), + ) + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + """Stop the extension""" + await super().on_stop(ten_env) + ten_env.log_debug("NVIDIA Riva TTS on_stop") + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + """Deinitialize the extension""" + await super().on_deinit(ten_env) + ten_env.log_debug("NVIDIA Riva TTS on_deinit") + + def vendor(self) -> str: + """Return vendor name""" + return "nvidia_riva" + + def synthesize_audio_sample_rate(self) -> int: + """Return audio sample rate""" + return self.config.params.get("sample_rate", 16000) if self.config else 16000 + + def synthesize_audio_channels(self) -> int: + """Return number of audio channels""" + return 1 + + def synthesize_audio_sample_width(self) -> int: + """Return sample width in bytes""" + return 2 # 16-bit PCM + + async def request_tts(self, t: TTSTextInput) -> None: + """Handle TTS request""" + try: + self.ten_env.log_info( + f"TTS request: text_length={len(t.text)}, " + f"text_input_end={t.text_input_end}, request_id={t.request_id}" + ) + + # Skip if request already completed + if t.request_id == self.flush_request_id: + self.ten_env.log_debug( + f"Request {t.request_id} was flushed, ignoring" + ) + return + + if t.request_id == self.last_end_request_id: + self.ten_env.log_debug( + f"Request {t.request_id} was ended, ignoring" + ) + return + + # Handle new request + is_new_request = self.current_request_id != t.request_id + if is_new_request: + self.ten_env.log_debug(f"New TTS request: {t.request_id}") + self.current_request_id = t.request_id + self.request_total_audio_duration = 0 + self.request_start_ts = time.time() + + if self.client is None: + raise ValueError("TTS client not initialized") + + # Synthesize audio + received_first_chunk = False + async for chunk in self.client.synthesize(t.text, t.request_id): + # Calculate audio duration + duration = self._calculate_audio_duration(len(chunk)) + + self.ten_env.log_debug( + f"receive_audio: duration={duration}ms, request_id={self.current_request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + if not received_first_chunk: + received_first_chunk = True + # Send audio start + if t.request_id not in self.audio_start_sent: + await self.send_tts_audio_start(t.request_id) + self.audio_start_sent.add(t.request_id) + if is_new_request: + # Send TTFB metrics + self.first_chunk_ts = time.time() + elapsed_time = int( + (self.first_chunk_ts - self.request_start_ts) * 1000 + ) + await self.send_tts_ttfb_metrics( + request_id=t.request_id, + ttfb_ms=elapsed_time, + extra_metadata={ + "voice_name": self.config.params["voice_name"], + "language_code": self.config.params["language_code"], + }, + ) + + if t.request_id == self.flush_request_id: + break + + self.request_total_audio_duration += duration + await self.send_tts_audio_data(chunk) + + # Handle completion + if t.text_input_end or t.request_id == self.flush_request_id: + reason = TTSAudioEndReason.REQUEST_END + if t.request_id == self.flush_request_id: + reason = TTSAudioEndReason.INTERRUPTED + + if self.first_chunk_ts > 0: + await self._handle_completed_request(reason) + + except Exception as e: + self.ten_env.log_error(f"Error in request_tts: {traceback.format_exc()}") + await self.send_tts_error( + request_id=t.request_id, + error=ModuleError( + message=str(e), + module=ModuleType.TTS, + code=ModuleErrorCode.NON_FATAL_ERROR, + vendor_info={"vendor": "nvidia_riva"}, + ), + ) + + # Check if we've received text_input_end + has_received_text_input_end = False + if t.request_id and t.request_id in self.request_states: + if self.request_states[t.request_id] == RequestState.FINALIZING: + has_received_text_input_end = True + + if has_received_text_input_end: + await self._handle_completed_request(TTSAudioEndReason.ERROR) + + async def cancel_tts(self) -> None: + """Cancel current TTS request""" + self.ten_env.log_info(f"cancel_tts current_request_id: {self.current_request_id}") + if self.current_request_id is not None: + self.flush_request_id = self.current_request_id + + if self.client: + await self.client.cancel() + + if self.current_request_id and self.first_chunk_ts > 0: + await self._handle_completed_request(TTSAudioEndReason.INTERRUPTED) + + async def _handle_completed_request(self, reason: TTSAudioEndReason) -> None: + """Handle completed TTS request""" + if not self.current_request_id: + return + + self.last_end_request_id = self.current_request_id + + # Calculate metrics + request_event_interval = 0 + if self.first_chunk_ts > 0: + request_event_interval = int( + (time.time() - self.first_chunk_ts) * 1000 + ) + + # Send audio end + await self.send_tts_audio_end( + request_id=self.current_request_id, + request_event_interval_ms=request_event_interval, + request_total_audio_duration_ms=self.request_total_audio_duration, + reason=reason, + ) + + self.ten_env.log_debug( + f"Sent tts_audio_end: reason={reason.name}, request_id={self.current_request_id}" + ) + + # Finish request + await self.finish_request(request_id=self.current_request_id, reason=reason) + + # Reset state + self.first_chunk_ts = 0 + self.audio_start_sent.discard(self.current_request_id) + + def _calculate_audio_duration(self, bytes_length: int) -> int: + """Calculate audio duration in milliseconds""" + bytes_per_second = ( + self.synthesize_audio_sample_rate() + * self.synthesize_audio_channels() + * self.synthesize_audio_sample_width() + ) + return int((bytes_length / bytes_per_second) * 1000) diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json b/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json new file mode 100644 index 0000000000..e48a071a40 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json @@ -0,0 +1,57 @@ +{ + "type": "extension", + "name": "nvidia_riva_tts_python", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "**.py", + "README.md", + "requirements.txt" + ] + }, + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/tts-interface.json" + } + ], + "property": { + "properties": { + "params": { + "type": "object", + "properties": { + "server": { + "type": "string" + }, + "language_code": { + "type": "string" + }, + "voice_name": { + "type": "string" + }, + "sample_rate": { + "type": "int64" + }, + "use_ssl": { + "type": "bool" + } + } + } + } + } + } +} diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/property.json b/agents/ten_packages/extension/nvidia_riva_tts_python/property.json new file mode 100644 index 0000000000..022a606664 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/property.json @@ -0,0 +1,9 @@ +{ + "params": { + "server": "${env:NVIDIA_RIVA_SERVER|localhost:50051}", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + "use_ssl": false + } +} diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt b/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt new file mode 100644 index 0000000000..f178d839f4 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt @@ -0,0 +1,2 @@ +nvidia-riva-client>=2.17.0 +numpy>=1.21.0 diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py b/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py new file mode 100644 index 0000000000..04f7e5921b --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py @@ -0,0 +1,143 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from typing import AsyncIterator +import numpy as np +import riva.client +from ten_runtime import AsyncTenEnv +from ten_ai_base.const import LOG_CATEGORY_VENDOR + +from .config import NvidiaRivaTTSConfig + + +class NvidiaRivaTTSClient: + """NVIDIA Riva TTS Client implementation""" + + def __init__( + self, + config: NvidiaRivaTTSConfig, + ten_env: AsyncTenEnv, + ): + self.config = config + self.ten_env: AsyncTenEnv = ten_env + self._is_cancelled = False + self.auth = None + self.tts_service = None + + try: + # Initialize Riva client + server = config.params["server"] + use_ssl = config.params.get("use_ssl", False) + + self.ten_env.log_info( + f"Initializing NVIDIA Riva TTS client with server: {server}, SSL: {use_ssl}", + category=LOG_CATEGORY_VENDOR, + ) + + self.auth = riva.client.Auth(ssl_cert=None, use_ssl=use_ssl, uri=server) + self.tts_service = riva.client.SpeechSynthesisService(self.auth) + + self.ten_env.log_info( + "NVIDIA Riva TTS client initialized successfully", + category=LOG_CATEGORY_VENDOR, + ) + except Exception as e: + ten_env.log_error( + f"Error when initializing NVIDIA Riva TTS: {e}", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError(f"Error when initializing NVIDIA Riva TTS: {e}") from e + + async def cancel(self): + """Cancel the current TTS request""" + self.ten_env.log_debug("NVIDIA Riva TTS: cancel() called.") + self._is_cancelled = True + + async def synthesize(self, text: str, request_id: str) -> AsyncIterator[bytes]: + """ + Synthesize speech from text using NVIDIA Riva TTS. + + Args: + text: Text to synthesize + request_id: Unique request identifier + + Yields: + Audio data as bytes (PCM format) + """ + self._is_cancelled = False + + if not self.tts_service: + self.ten_env.log_error( + f"NVIDIA Riva TTS: service not initialized for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError( + f"NVIDIA Riva TTS: service not initialized for request_id: {request_id}" + ) + + if len(text.strip()) == 0: + self.ten_env.log_warn( + f"NVIDIA Riva TTS: empty text for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + return + + try: + language_code = self.config.params["language_code"] + voice_name = self.config.params["voice_name"] + sample_rate = self.config.params.get("sample_rate", 16000) + + self.ten_env.log_debug( + f"NVIDIA Riva TTS: synthesizing text (length: {len(text)}) " + f"with voice: {voice_name}, language: {language_code}, " + f"sample_rate: {sample_rate}, request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + # Use streaming synthesis for lower latency + responses = self.tts_service.synthesize_online( + text, + voice_name=voice_name, + language_code=language_code, + sample_rate_hz=sample_rate, + encoding=riva.client.AudioEncoding.LINEAR_PCM, + ) + + # Stream audio chunks + for response in responses: + if self._is_cancelled: + self.ten_env.log_debug( + f"Cancellation detected, stopping TTS stream for request_id: {request_id}" + ) + break + + # Convert audio bytes to numpy array and back to bytes + # This ensures proper format + audio_data = np.frombuffer(response.audio, dtype=np.int16) + + self.ten_env.log_debug( + f"NVIDIA Riva TTS: yielding audio chunk, " + f"length: {len(audio_data)} samples, request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + yield audio_data.tobytes() + + if not self._is_cancelled: + self.ten_env.log_debug( + f"NVIDIA Riva TTS: synthesis completed for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + except Exception as e: + error_message = str(e) + self.ten_env.log_error( + f"NVIDIA Riva TTS: error during synthesis: {error_message}, " + f"request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError( + f"NVIDIA Riva TTS synthesis failed: {error_message}" + ) from e diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py new file mode 100644 index 0000000000..b8c07eef1c --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py @@ -0,0 +1,4 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py new file mode 100644 index 0000000000..3384bb95d5 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py @@ -0,0 +1,294 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +""" +Compliance tests to ensure the extension correctly implements NVIDIA Riva TTS API. +These tests validate against the official NVIDIA Riva client API specifications. +""" +import pytest +from unittest.mock import Mock, patch, MagicMock +import numpy as np +from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig +from nvidia_riva_tts_python.riva_tts import NvidiaRivaTTSClient + + +class TestNvidiaRivaAPICompliance: + """Test compliance with NVIDIA Riva TTS API specifications""" + + @pytest.fixture + def mock_ten_env(self): + """Create a mock TenEnv""" + env = Mock() + env.log_info = Mock() + env.log_debug = Mock() + env.log_warn = Mock() + env.log_error = Mock() + return env + + @pytest.fixture + def valid_config(self): + """Create a valid configuration""" + return NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + "use_ssl": False, + } + ) + + def test_auth_initialization_parameters(self, valid_config, mock_ten_env): + """Verify Auth is initialized with correct parameters per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + # Verify Auth called with correct parameters + mock_auth.assert_called_once_with( + ssl_cert=None, + use_ssl=False, + uri="localhost:50051" + ) + + def test_speech_synthesis_service_initialization(self, valid_config, mock_ten_env): + """Verify SpeechSynthesisService is initialized with Auth object""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: + + mock_auth_instance = Mock() + mock_auth.return_value = mock_auth_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + # Verify SpeechSynthesisService called with Auth instance + mock_service.assert_called_once_with(mock_auth_instance) + + @pytest.mark.asyncio + async def test_synthesize_online_parameters(self, valid_config, mock_ten_env): + """Verify synthesize_online is called with correct parameters per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.AudioEncoding') as mock_encoding: + + # Setup mocks + mock_service_instance = Mock() + mock_response = Mock() + mock_response.audio = b'\x00\x01' * 100 + mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) + mock_service.return_value = mock_service_instance + mock_encoding.LINEAR_PCM = "LINEAR_PCM" + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize text + text = "Hello world" + chunks = [chunk async for chunk in client.synthesize(text, "test_request")] + + # Verify synthesize_online called with correct parameters + mock_service_instance.synthesize_online.assert_called_once_with( + text, + voice_name="English-US.Female-1", + language_code="en-US", + sample_rate_hz=16000, + encoding="LINEAR_PCM" + ) + + @pytest.mark.asyncio + async def test_audio_encoding_linear_pcm(self, valid_config, mock_ten_env): + """Verify LINEAR_PCM encoding is used per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.AudioEncoding') as mock_encoding: + + mock_service_instance = Mock() + mock_response = Mock() + mock_response.audio = b'\x00\x01' * 100 + mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) + mock_service.return_value = mock_service_instance + mock_encoding.LINEAR_PCM = "LINEAR_PCM" + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize + chunks = [chunk async for chunk in client.synthesize("Test", "req1")] + + # Verify encoding parameter + call_args = mock_service_instance.synthesize_online.call_args + assert call_args[1]['encoding'] == "LINEAR_PCM" + + @pytest.mark.asyncio + async def test_audio_format_int16(self, valid_config, mock_ten_env): + """Verify audio is processed as int16 per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: + + # Create mock audio data (int16 format) + mock_audio = np.array([100, -100, 200, -200], dtype=np.int16).tobytes() + mock_response = Mock() + mock_response.audio = mock_audio + + mock_service_instance = Mock() + mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) + mock_service.return_value = mock_service_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize + chunks = [chunk async for chunk in client.synthesize("Test", "req1")] + + # Verify output is bytes + assert len(chunks) == 1 + assert isinstance(chunks[0], bytes) + + # Verify can be converted back to int16 + audio_array = np.frombuffer(chunks[0], dtype=np.int16) + assert audio_array.dtype == np.int16 + assert len(audio_array) == 4 + + @pytest.mark.asyncio + async def test_streaming_response_iteration(self, valid_config, mock_ten_env): + """Verify streaming responses are iterated correctly per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: + + # Create multiple response chunks + mock_responses = [] + for i in range(3): + mock_response = Mock() + mock_response.audio = np.array([i] * 10, dtype=np.int16).tobytes() + mock_responses.append(mock_response) + + mock_service_instance = Mock() + mock_service_instance.synthesize_online = Mock(return_value=mock_responses) + mock_service.return_value = mock_service_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize + chunks = [chunk async for chunk in client.synthesize("Test", "req1")] + + # Verify all chunks received + assert len(chunks) == 3 + for chunk in chunks: + assert isinstance(chunk, bytes) + assert len(chunk) > 0 + + def test_required_config_parameters(self): + """Verify all required parameters are validated per Riva API""" + # Missing server + with pytest.raises(ValueError, match="Server address is required"): + config = NvidiaRivaTTSConfig( + params={"language_code": "en-US", "voice_name": "English-US.Female-1"} + ) + config.validate() + + # Missing language_code + with pytest.raises(ValueError, match="Language code is required"): + config = NvidiaRivaTTSConfig( + params={"server": "localhost:50051", "voice_name": "English-US.Female-1"} + ) + config.validate() + + # Missing voice_name + with pytest.raises(ValueError, match="Voice name is required"): + config = NvidiaRivaTTSConfig( + params={"server": "localhost:50051", "language_code": "en-US"} + ) + config.validate() + + def test_optional_config_parameters(self, valid_config): + """Verify optional parameters have correct defaults per Riva API""" + # sample_rate defaults to 16000 + assert valid_config.params.get("sample_rate", 16000) == 16000 + + # use_ssl defaults to False + assert valid_config.params.get("use_ssl", False) is False + + def test_supported_sample_rates(self): + """Verify common sample rates are supported per Riva API""" + supported_rates = [8000, 16000, 22050, 24000, 44100, 48000] + + for rate in supported_rates: + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": rate, + } + ) + config.validate() # Should not raise + assert config.params["sample_rate"] == rate + + def test_ssl_configuration(self, mock_ten_env): + """Verify SSL can be enabled per Riva API""" + config_with_ssl = NvidiaRivaTTSConfig( + params={ + "server": "secure-server:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "use_ssl": True, + } + ) + + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): + + client = NvidiaRivaTTSClient(config=config_with_ssl, ten_env=mock_ten_env) + + # Verify SSL enabled in Auth + call_args = mock_auth.call_args + assert call_args[1]['use_ssl'] is True + + @pytest.mark.asyncio + async def test_empty_text_handling(self, valid_config, mock_ten_env): + """Verify empty text is handled gracefully per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + # Empty string + chunks = [chunk async for chunk in client.synthesize("", "req1")] + assert len(chunks) == 0 + + # Whitespace only + chunks = [chunk async for chunk in client.synthesize(" ", "req1")] + assert len(chunks) == 0 + + @pytest.mark.asyncio + async def test_cancellation_support(self, valid_config, mock_ten_env): + """Verify cancellation is supported per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: + + # Create multiple responses to simulate long synthesis + mock_responses = [Mock(audio=b'\x00\x01' * 100) for _ in range(10)] + mock_service_instance = Mock() + mock_service_instance.synthesize_online = Mock(return_value=mock_responses) + mock_service.return_value = mock_service_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Start synthesis and cancel mid-stream + chunks = [] + async for i, chunk in enumerate(client.synthesize("Long text", "req1")): + chunks.append(chunk) + if i == 2: # Cancel after 3 chunks + await client.cancel() + + # Verify cancellation stopped the stream + assert len(chunks) < 10 # Should not receive all chunks + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) + diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py new file mode 100644 index 0000000000..b29bf51adc --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py @@ -0,0 +1,67 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +import pytest +from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig + + +def test_config_validation(): + """Test configuration validation""" + # Valid config + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + } + ) + config.validate() # Should not raise + + # Missing server + with pytest.raises(ValueError, match="Server address is required"): + config = NvidiaRivaTTSConfig( + params={ + "language_code": "en-US", + "voice_name": "English-US.Female-1", + } + ) + config.validate() + + # Missing language_code + with pytest.raises(ValueError, match="Language code is required"): + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "voice_name": "English-US.Female-1", + } + ) + config.validate() + + # Missing voice_name + with pytest.raises(ValueError, match="Voice name is required"): + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + } + ) + config.validate() + + +def test_config_defaults(): + """Test default configuration values""" + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + } + ) + + assert config.dump is False + assert "nvidia_riva_tts_in.pcm" in config.dump_path + assert config.params["server"] == "localhost:50051" + assert config.params.get("sample_rate", 16000) == 16000 + assert config.params.get("use_ssl", False) is False diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py new file mode 100644 index 0000000000..517a600945 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py @@ -0,0 +1,134 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +import pytest +from unittest.mock import Mock, AsyncMock, patch, MagicMock +from nvidia_riva_tts_python.extension import NvidiaRivaTTSExtension +from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig +from nvidia_riva_tts_python.riva_tts import NvidiaRivaTTSClient + + +@pytest.fixture +def mock_ten_env(): + """Create a mock TenEnv for testing""" + env = Mock() + env.log_info = Mock() + env.log_debug = Mock() + env.log_warn = Mock() + env.log_error = Mock() + return env + + +@pytest.fixture +def valid_config(): + """Create a valid configuration for testing""" + return NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + "use_ssl": False, + } + ) + + +class TestNvidiaRivaTTSExtension: + """Test cases for NvidiaRivaTTSExtension""" + + def test_extension_initialization(self): + """Test extension can be initialized""" + extension = NvidiaRivaTTSExtension("test_extension") + assert extension is not None + assert extension.vendor() == "nvidia_riva" + + @pytest.mark.asyncio + async def test_create_config(self): + """Test configuration creation from JSON""" + extension = NvidiaRivaTTSExtension("test_extension") + config_json = """{ + "params": { + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000 + } + }""" + + config = await extension.create_config(config_json) + assert isinstance(config, NvidiaRivaTTSConfig) + assert config.params["server"] == "localhost:50051" + assert config.params["language_code"] == "en-US" + + def test_synthesize_audio_sample_rate(self, valid_config): + """Test sample rate retrieval""" + extension = NvidiaRivaTTSExtension("test_extension") + extension.config = valid_config + + sample_rate = extension.synthesize_audio_sample_rate() + assert sample_rate == 16000 + + +class TestNvidiaRivaTTSClient: + """Test cases for NvidiaRivaTTSClient""" + + @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') + @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') + def test_client_initialization(self, mock_service, mock_auth, valid_config, mock_ten_env): + """Test client initialization""" + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + assert client is not None + assert client.config == valid_config + mock_auth.assert_called_once() + mock_service.assert_called_once() + + @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') + @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') + @pytest.mark.asyncio + async def test_cancel(self, mock_service, mock_auth, valid_config, mock_ten_env): + """Test cancellation""" + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + await client.cancel() + assert client._is_cancelled is True + + @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') + @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') + @pytest.mark.asyncio + async def test_synthesize_empty_text(self, mock_service, mock_auth, valid_config, mock_ten_env): + """Test synthesis with empty text""" + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + # Should return without yielding anything + result = [chunk async for chunk in client.synthesize("", "test_request")] + assert len(result) == 0 + + @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') + @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') + @pytest.mark.asyncio + async def test_synthesize_with_text(self, mock_service, mock_auth, valid_config, mock_ten_env): + """Test synthesis with valid text""" + # Mock the service response + mock_response = Mock() + mock_response.audio = b'\x00\x01' * 100 # Mock audio data + + mock_service_instance = Mock() + mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) + mock_service.return_value = mock_service_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize text + chunks = [chunk async for chunk in client.synthesize("Hello world", "test_request")] + + assert len(chunks) > 0 + assert isinstance(chunks[0], bytes) + mock_service_instance.synthesize_online.assert_called_once() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + From a2638f01e9ae286fa4dd1d9d16275e62ae4209fd Mon Sep 17 00:00:00 2001 From: Nsuccess Date: Wed, 14 Jan 2026 17:02:26 +0000 Subject: [PATCH 02/10] feat: Add Speechmatics TTS extension (#1965) - Implements text-to-speech using Speechmatics TTS API - Supports low-latency streaming synthesis (sub-150ms) - Includes 4 voice options (UK and US English) - Comprehensive tests and documentation - Follows TTS2 HTTP interface pattern Closes #1965 --- .../speechmatics_tts_python/README.md | 99 +++++++++ .../speechmatics_tts_python/__init__.py | 8 + .../speechmatics_tts_python/addon.py | 19 ++ .../speechmatics_tts_python/config.py | 48 +++++ .../speechmatics_tts_python/extension.py | 58 ++++++ .../speechmatics_tts_python/manifest.json | 57 +++++ .../speechmatics_tts_python/property.json | 9 + .../speechmatics_tts_python/requirements.txt | 2 + .../speechmatics_tts.py | 194 ++++++++++++++++++ .../speechmatics_tts_python/tests/__init__.py | 5 + .../tests/test_config.py | 111 ++++++++++ .../tests/test_extension.py | 109 ++++++++++ 12 files changed, 719 insertions(+) create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/README.md create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/__init__.py create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/addon.py create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/config.py create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/extension.py create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/manifest.json create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/property.json create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/requirements.txt create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/tests/__init__.py create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/tests/test_config.py create mode 100644 agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py diff --git a/agents/ten_packages/extension/speechmatics_tts_python/README.md b/agents/ten_packages/extension/speechmatics_tts_python/README.md new file mode 100644 index 0000000000..802a62dd71 --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/README.md @@ -0,0 +1,99 @@ +# Speechmatics TTS Python Extension + +This extension provides text-to-speech functionality using Speechmatics TTS API. + +## Features + +- Low-latency speech synthesis (sub-150ms) +- High-quality, natural-sounding voices +- HTTP REST API integration +- Multiple voice options (UK and US English) +- Support for WAV and MP3 output formats +- Production-grade reliability + +## Prerequisites + +- Speechmatics API key +- Python 3.8+ +- aiohttp package + +## Configuration + +The extension can be configured through your property.json: + +```json +{ + "params": { + "api_key": "your-api-key-here", + "voice_id": "sarah", + "output_format": "wav", + "sample_rate": 16000, + "base_url": "https://preview.tts.speechmatics.com" + } +} +``` + +### Configuration Options + +**Parameters inside `params` object:** +- `api_key` (required): Speechmatics API key +- `voice_id` (required): Voice identifier (sarah, theo, megan, jack) +- `output_format` (optional): Audio format - "wav" or "mp3" (default: "wav") +- `sample_rate` (optional): Audio sample rate in Hz (default: 16000) +- `base_url` (optional): API base URL (default: "https://preview.tts.speechmatics.com") + +### Available Voices + +| Voice ID | Description | +|----------|-------------| +| `sarah` | English Female (UK) | +| `theo` | English Male (UK) | +| `megan` | English Female (US) | +| `jack` | English Male (US) | + +## Getting Started + +### 1. Get API Key + +Create an API key at the [Speechmatics Portal](https://portal.speechmatics.com/). + +### 2. Set Environment Variable + +```bash +export SPEECHMATICS_API_KEY=your-api-key-here +``` + +### 3. Configure Extension + +Update your `property.json` with the desired voice and settings. + +## API Details + +- **Endpoint**: `https://preview.tts.speechmatics.com/generate/{voice_id}` +- **Method**: POST +- **Authentication**: Bearer token +- **Latency**: Sub-150ms +- **Sample Rate**: 16kHz mono (optimized for voice agents) + +## Architecture + +This extension follows the TEN Framework TTS2 HTTP extension pattern: + +- `extension.py`: Main extension class inheriting from `AsyncTTS2HttpExtension` +- `speechmatics_tts.py`: Client implementation with HTTP API integration +- `config.py`: Configuration model with validation +- `addon.py`: Extension addon registration + +## License + +Apache 2.0 + +## Contributing + +Contributions are welcome! Please submit issues and pull requests to the TEN Framework repository. + +## Links + +- [Speechmatics TTS Documentation](https://docs.speechmatics.com/text-to-speech/quickstart) +- [Speechmatics Portal](https://portal.speechmatics.com/) +- [TEN Framework](https://github.com/TEN-framework/ten-framework) diff --git a/agents/ten_packages/extension/speechmatics_tts_python/__init__.py b/agents/ten_packages/extension/speechmatics_tts_python/__init__.py new file mode 100644 index 0000000000..0413aa9b81 --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/__init__.py @@ -0,0 +1,8 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon + +__all__ = ["addon"] diff --git a/agents/ten_packages/extension/speechmatics_tts_python/addon.py b/agents/ten_packages/extension/speechmatics_tts_python/addon.py new file mode 100644 index 0000000000..3e749f97c9 --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/addon.py @@ -0,0 +1,19 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten_runtime import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("speechmatics_tts_python") +class SpeechmaticsTTSExtensionAddon(Addon): + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import SpeechmaticsTTSExtension + + ten_env.log_info("SpeechmaticsTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(SpeechmaticsTTSExtension(name), context) diff --git a/agents/ten_packages/extension/speechmatics_tts_python/config.py b/agents/ten_packages/extension/speechmatics_tts_python/config.py new file mode 100644 index 0000000000..b8f2dfebce --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/config.py @@ -0,0 +1,48 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from typing import Any +import copy +from pydantic import Field +from pathlib import Path +from ten_ai_base import utils +from ten_ai_base.tts2_http import AsyncTTS2HttpConfig + + +class SpeechmaticsTTSConfig(AsyncTTS2HttpConfig): + """Speechmatics TTS Config""" + + dump: bool = Field(default=False, description="Speechmatics TTS dump") + dump_path: str = Field( + default_factory=lambda: str(Path(__file__).parent / "speechmatics_tts_in.pcm"), + description="Speechmatics TTS dump path", + ) + params: dict[str, Any] = Field( + default_factory=dict, description="Speechmatics TTS params" + ) + + def update_params(self) -> None: + """Update configuration from params dictionary""" + pass + + def to_str(self, sensitive_handling: bool = True) -> str: + """Convert config to string with optional sensitive data handling.""" + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + + # Encrypt sensitive fields in params + if config.params and "api_key" in config.params: + config.params["api_key"] = utils.encrypt(config.params["api_key"]) + + return f"{config}" + + def validate(self) -> None: + """Validate Speechmatics-specific configuration.""" + if "api_key" not in self.params or not self.params["api_key"]: + raise ValueError("API key is required for Speechmatics TTS") + if "voice_id" not in self.params or not self.params["voice_id"]: + raise ValueError("Voice ID is required for Speechmatics TTS") diff --git a/agents/ten_packages/extension/speechmatics_tts_python/extension.py b/agents/ten_packages/extension/speechmatics_tts_python/extension.py new file mode 100644 index 0000000000..972f6fb556 --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/extension.py @@ -0,0 +1,58 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +Speechmatics TTS Extension + +This extension implements text-to-speech using Speechmatics TTS API. +It provides low-latency, high-quality speech synthesis. +""" + +from ten_ai_base.tts2_http import ( + AsyncTTS2HttpExtension, + AsyncTTS2HttpConfig, + AsyncTTS2HttpClient, +) +from ten_runtime import AsyncTenEnv + +from .config import SpeechmaticsTTSConfig +from .speechmatics_tts import SpeechmaticsTTSClient + + +class SpeechmaticsTTSExtension(AsyncTTS2HttpExtension): + """ + Speechmatics TTS Extension implementation. + + Provides text-to-speech synthesis using Speechmatics HTTP API. + Inherits all common HTTP TTS functionality from AsyncTTS2HttpExtension. + """ + + def __init__(self, name: str) -> None: + super().__init__(name) + # Type hints for better IDE support + self.config: SpeechmaticsTTSConfig = None + self.client: SpeechmaticsTTSClient = None + + # ============================================================ + # Required method implementations + # ============================================================ + + async def create_config(self, config_json_str: str) -> AsyncTTS2HttpConfig: + """Create Speechmatics TTS configuration from JSON string.""" + return SpeechmaticsTTSConfig.model_validate_json(config_json_str) + + async def create_client( + self, config: AsyncTTS2HttpConfig, ten_env: AsyncTenEnv + ) -> AsyncTTS2HttpClient: + """Create Speechmatics TTS client.""" + return SpeechmaticsTTSClient(config=config, ten_env=ten_env) + + def vendor(self) -> str: + """Return vendor name.""" + return "speechmatics" + + def synthesize_audio_sample_rate(self) -> int: + """Return the sample rate for synthesized audio.""" + return self.config.params.get("sample_rate", 16000) diff --git a/agents/ten_packages/extension/speechmatics_tts_python/manifest.json b/agents/ten_packages/extension/speechmatics_tts_python/manifest.json new file mode 100644 index 0000000000..3abc65ee81 --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/manifest.json @@ -0,0 +1,57 @@ +{ + "type": "extension", + "name": "speechmatics_tts_python", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "**.py", + "README.md", + "requirements.txt" + ] + }, + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/tts-interface.json" + } + ], + "property": { + "properties": { + "params": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "voice_id": { + "type": "string" + }, + "output_format": { + "type": "string" + }, + "sample_rate": { + "type": "int64" + }, + "base_url": { + "type": "string" + } + } + } + } + } + } +} diff --git a/agents/ten_packages/extension/speechmatics_tts_python/property.json b/agents/ten_packages/extension/speechmatics_tts_python/property.json new file mode 100644 index 0000000000..94b3589c6b --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/property.json @@ -0,0 +1,9 @@ +{ + "params": { + "api_key": "${env:SPEECHMATICS_API_KEY}", + "voice_id": "sarah", + "output_format": "wav", + "sample_rate": 16000, + "base_url": "https://preview.tts.speechmatics.com" + } +} diff --git a/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt b/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt new file mode 100644 index 0000000000..ebfb7095a6 --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt @@ -0,0 +1,2 @@ +aiohttp>=3.8.0 +pydantic>=2.0.0 diff --git a/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py b/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py new file mode 100644 index 0000000000..aebad422be --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py @@ -0,0 +1,194 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from typing import Any, AsyncIterator, Tuple +import asyncio +import aiohttp +from ten_runtime import AsyncTenEnv +from ten_ai_base.const import LOG_CATEGORY_VENDOR +from ten_ai_base.struct import TTS2HttpResponseEventType +from ten_ai_base.tts2_http import AsyncTTS2HttpClient + +from .config import SpeechmaticsTTSConfig + + +class SpeechmaticsTTSClient(AsyncTTS2HttpClient): + """Speechmatics TTS Client implementation""" + + def __init__( + self, + config: SpeechmaticsTTSConfig, + ten_env: AsyncTenEnv, + ): + super().__init__() + self.config = config + self.ten_env: AsyncTenEnv = ten_env + self._is_cancelled = False + self.session: aiohttp.ClientSession | None = None + + # Retry configuration + self.max_retries = 3 + self.retry_delay = 0.1 + + try: + # Create aiohttp session + self.session = aiohttp.ClientSession() + + self.ten_env.log_info( + f"Speechmatics TTS client initialized with voice: {config.params.get('voice_id')}", + category=LOG_CATEGORY_VENDOR, + ) + except Exception as e: + ten_env.log_error( + f"Error when initializing Speechmatics TTS: {e}", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError(f"Error when initializing Speechmatics TTS: {e}") from e + + async def cancel(self): + """Cancel the current TTS request""" + self.ten_env.log_debug("Speechmatics TTS: cancel() called.") + self._is_cancelled = True + + async def get( + self, text: str, request_id: str + ) -> AsyncIterator[Tuple[bytes | None, TTS2HttpResponseEventType]]: + """Process a single TTS request""" + self._is_cancelled = False + + if not self.session: + self.ten_env.log_error( + f"Speechmatics TTS: session not initialized for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError( + f"Speechmatics TTS: session not initialized for request_id: {request_id}" + ) + + if len(text.strip()) == 0: + self.ten_env.log_warn( + f"Speechmatics TTS: empty text for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + yield None, TTS2HttpResponseEventType.END + return + + try: + # Synthesize audio + async for chunk in self._synthesize_with_retry(text, request_id): + if self._is_cancelled: + self.ten_env.log_debug( + f"Cancellation detected, sending flush event for request_id: {request_id}" + ) + yield None, TTS2HttpResponseEventType.FLUSH + break + + self.ten_env.log_debug( + f"Speechmatics TTS: sending audio chunk, length: {len(chunk)}, request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + if len(chunk) > 0: + yield bytes(chunk), TTS2HttpResponseEventType.RESPONSE + + if not self._is_cancelled: + self.ten_env.log_debug( + f"Speechmatics TTS: synthesis completed for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + yield None, TTS2HttpResponseEventType.END + + except Exception as e: + error_message = str(e) + self.ten_env.log_error( + f"Speechmatics TTS error: {error_message}, request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + # Check for authentication errors + if "401" in error_message or "authentication" in error_message.lower(): + yield error_message.encode("utf-8"), TTS2HttpResponseEventType.INVALID_KEY_ERROR + else: + yield error_message.encode("utf-8"), TTS2HttpResponseEventType.ERROR + + async def _synthesize(self, text: str) -> AsyncIterator[bytes]: + """Internal method to synthesize audio from text""" + assert self.session is not None + + # Build API endpoint + voice_id = self.config.params["voice_id"] + output_format = self.config.params.get("output_format", "wav") + base_url = self.config.params.get("base_url", "https://preview.tts.speechmatics.com") + + url = f"{base_url}/generate/{voice_id}" + if output_format: + url += f"?output_format={output_format}" + + # Prepare request + headers = { + "Authorization": f"Bearer {self.config.params['api_key']}", + "Content-Type": "application/json", + } + + payload = {"text": text} + + self.ten_env.log_debug( + f"Speechmatics TTS: requesting synthesis, voice: {voice_id}, format: {output_format}", + category=LOG_CATEGORY_VENDOR, + ) + + # Make HTTP request + async with self.session.post(url, headers=headers, json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError( + f"Speechmatics TTS API error: {response.status} - {error_text}" + ) + + # Stream response chunks + async for chunk in response.content.iter_chunked(4096): + if chunk: + yield chunk + + async def _synthesize_with_retry( + self, text: str, request_id: str + ) -> AsyncIterator[bytes]: + """Synthesize with retry logic""" + retries = 0 + last_error = None + + while retries <= self.max_retries: + try: + async for chunk in self._synthesize(text): + yield chunk + return # Success, exit retry loop + except Exception as e: + last_error = e + retries += 1 + + if retries <= self.max_retries: + self.ten_env.log_warn( + f"Speechmatics TTS: retry {retries}/{self.max_retries} after error: {e}", + category=LOG_CATEGORY_VENDOR, + ) + await asyncio.sleep(self.retry_delay * (2 ** (retries - 1))) + else: + raise last_error + + async def clean(self): + """Clean up resources""" + self.ten_env.log_debug("Speechmatics TTS: clean() called.") + try: + if self.session: + await self.session.close() + finally: + pass + + def get_extra_metadata(self) -> dict[str, Any]: + """Return extra metadata for TTFB metrics.""" + return { + "voice_id": self.config.params.get("voice_id", ""), + "output_format": self.config.params.get("output_format", "wav"), + } diff --git a/agents/ten_packages/extension/speechmatics_tts_python/tests/__init__.py b/agents/ten_packages/extension/speechmatics_tts_python/tests/__init__.py new file mode 100644 index 0000000000..da402faf43 --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/tests/__init__.py @@ -0,0 +1,5 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# diff --git a/agents/ten_packages/extension/speechmatics_tts_python/tests/test_config.py b/agents/ten_packages/extension/speechmatics_tts_python/tests/test_config.py new file mode 100644 index 0000000000..23330fc469 --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/tests/test_config.py @@ -0,0 +1,111 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import pytest +from speechmatics_tts_python.config import SpeechmaticsTTSConfig + + +def test_config_creation(): + """Test basic config creation""" + config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "sarah", + } + ) + assert config.params["api_key"] == "test_key" + assert config.params["voice_id"] == "sarah" + + +def test_config_validation_missing_api_key(): + """Test validation fails without API key""" + config = SpeechmaticsTTSConfig(params={"voice_id": "sarah"}) + with pytest.raises(ValueError, match="API key is required"): + config.validate() + + +def test_config_validation_missing_voice_id(): + """Test validation fails without voice ID""" + config = SpeechmaticsTTSConfig(params={"api_key": "test_key"}) + with pytest.raises(ValueError, match="Voice ID is required"): + config.validate() + + +def test_config_validation_success(): + """Test validation succeeds with required fields""" + config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "sarah", + } + ) + config.validate() # Should not raise + + +def test_config_to_str_sensitive_handling(): + """Test sensitive data handling in to_str""" + config = SpeechmaticsTTSConfig( + params={ + "api_key": "secret_key_12345", + "voice_id": "sarah", + } + ) + config_str = config.to_str(sensitive_handling=True) + assert "secret_key_12345" not in config_str + + +def test_config_default_values(): + """Test default configuration values""" + config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "sarah", + } + ) + assert config.dump is False + assert "speechmatics_tts_in.pcm" in config.dump_path + + +def test_config_with_optional_params(): + """Test configuration with optional parameters""" + config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "megan", + "output_format": "mp3", + "sample_rate": 24000, + "base_url": "https://custom.api.com", + } + ) + assert config.params["output_format"] == "mp3" + assert config.params["sample_rate"] == 24000 + assert config.params["base_url"] == "https://custom.api.com" + + +def test_config_voice_options(): + """Test different voice configurations""" + voices = ["sarah", "theo", "megan", "jack"] + for voice in voices: + config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": voice, + } + ) + assert config.params["voice_id"] == voice + + +def test_config_output_formats(): + """Test different output format configurations""" + formats = ["wav", "mp3"] + for fmt in formats: + config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "sarah", + "output_format": fmt, + } + ) + assert config.params["output_format"] == fmt diff --git a/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py b/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py new file mode 100644 index 0000000000..e8eab4b0cf --- /dev/null +++ b/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py @@ -0,0 +1,109 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import pytest +from speechmatics_tts_python.extension import SpeechmaticsTTSExtension +from speechmatics_tts_python.config import SpeechmaticsTTSConfig + + +def test_extension_creation(): + """Test extension can be created""" + extension = SpeechmaticsTTSExtension("test_extension") + assert extension is not None + assert extension.vendor() == "speechmatics" + + +def test_extension_vendor(): + """Test vendor name is correct""" + extension = SpeechmaticsTTSExtension("test_extension") + assert extension.vendor() == "speechmatics" + + +def test_extension_sample_rate_default(): + """Test default sample rate""" + extension = SpeechmaticsTTSExtension("test_extension") + extension.config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "sarah", + } + ) + assert extension.synthesize_audio_sample_rate() == 16000 + + +def test_extension_sample_rate_custom(): + """Test custom sample rate""" + extension = SpeechmaticsTTSExtension("test_extension") + extension.config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "sarah", + "sample_rate": 24000, + } + ) + assert extension.synthesize_audio_sample_rate() == 24000 + + +@pytest.mark.asyncio +async def test_create_config(): + """Test config creation from JSON""" + extension = SpeechmaticsTTSExtension("test_extension") + config_json = '{"params": {"api_key": "test_key", "voice_id": "sarah"}}' + config = await extension.create_config(config_json) + assert isinstance(config, SpeechmaticsTTSConfig) + assert config.params["api_key"] == "test_key" + assert config.params["voice_id"] == "sarah" + + +def test_extension_inheritance(): + """Test extension inherits from correct base class""" + from ten_ai_base.tts2_http import AsyncTTS2HttpExtension + extension = SpeechmaticsTTSExtension("test_extension") + assert isinstance(extension, AsyncTTS2HttpExtension) + + +def test_extension_config_types(): + """Test extension handles different config types""" + extension = SpeechmaticsTTSExtension("test_extension") + + # Test with different voices + for voice in ["sarah", "theo", "megan", "jack"]: + extension.config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": voice, + } + ) + assert extension.config.params["voice_id"] == voice + + +def test_extension_config_validation(): + """Test extension config validation""" + extension = SpeechmaticsTTSExtension("test_extension") + extension.config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "sarah", + } + ) + # Should not raise + extension.config.validate() + + +def test_extension_with_all_params(): + """Test extension with all configuration parameters""" + extension = SpeechmaticsTTSExtension("test_extension") + extension.config = SpeechmaticsTTSConfig( + params={ + "api_key": "test_key", + "voice_id": "megan", + "output_format": "mp3", + "sample_rate": 24000, + "base_url": "https://custom.api.com", + } + ) + assert extension.config.params["voice_id"] == "megan" + assert extension.config.params["output_format"] == "mp3" + assert extension.synthesize_audio_sample_rate() == 24000 From 4b81c184d8d1c7b6239864c0758663cb2ebd95bd Mon Sep 17 00:00:00 2001 From: Success Nwachukwu Date: Tue, 20 Jan 2026 13:18:21 +0000 Subject: [PATCH 03/10] fix: apply Black formatting to Speechmatics TTS extension --- .../speechmatics_tts_python/config.py | 4 ++- .../speechmatics_tts.py | 25 ++++++++++++++----- .../tests/test_extension.py | 1 + 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/agents/ten_packages/extension/speechmatics_tts_python/config.py b/agents/ten_packages/extension/speechmatics_tts_python/config.py index b8f2dfebce..6230ebddee 100644 --- a/agents/ten_packages/extension/speechmatics_tts_python/config.py +++ b/agents/ten_packages/extension/speechmatics_tts_python/config.py @@ -16,7 +16,9 @@ class SpeechmaticsTTSConfig(AsyncTTS2HttpConfig): dump: bool = Field(default=False, description="Speechmatics TTS dump") dump_path: str = Field( - default_factory=lambda: str(Path(__file__).parent / "speechmatics_tts_in.pcm"), + default_factory=lambda: str( + Path(__file__).parent / "speechmatics_tts_in.pcm" + ), description="Speechmatics TTS dump path", ) params: dict[str, Any] = Field( diff --git a/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py b/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py index aebad422be..695a70806e 100644 --- a/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py +++ b/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py @@ -45,7 +45,9 @@ def __init__( f"Error when initializing Speechmatics TTS: {e}", category=LOG_CATEGORY_VENDOR, ) - raise RuntimeError(f"Error when initializing Speechmatics TTS: {e}") from e + raise RuntimeError( + f"Error when initializing Speechmatics TTS: {e}" + ) from e async def cancel(self): """Cancel the current TTS request""" @@ -108,10 +110,17 @@ async def get( ) # Check for authentication errors - if "401" in error_message or "authentication" in error_message.lower(): - yield error_message.encode("utf-8"), TTS2HttpResponseEventType.INVALID_KEY_ERROR + if ( + "401" in error_message + or "authentication" in error_message.lower() + ): + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.INVALID_KEY_ERROR else: - yield error_message.encode("utf-8"), TTS2HttpResponseEventType.ERROR + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.ERROR async def _synthesize(self, text: str) -> AsyncIterator[bytes]: """Internal method to synthesize audio from text""" @@ -120,7 +129,9 @@ async def _synthesize(self, text: str) -> AsyncIterator[bytes]: # Build API endpoint voice_id = self.config.params["voice_id"] output_format = self.config.params.get("output_format", "wav") - base_url = self.config.params.get("base_url", "https://preview.tts.speechmatics.com") + base_url = self.config.params.get( + "base_url", "https://preview.tts.speechmatics.com" + ) url = f"{base_url}/generate/{voice_id}" if output_format: @@ -140,7 +151,9 @@ async def _synthesize(self, text: str) -> AsyncIterator[bytes]: ) # Make HTTP request - async with self.session.post(url, headers=headers, json=payload) as response: + async with self.session.post( + url, headers=headers, json=payload + ) as response: if response.status != 200: error_text = await response.text() raise RuntimeError( diff --git a/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py b/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py index e8eab4b0cf..9e50661679 100644 --- a/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py +++ b/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py @@ -60,6 +60,7 @@ async def test_create_config(): def test_extension_inheritance(): """Test extension inherits from correct base class""" from ten_ai_base.tts2_http import AsyncTTS2HttpExtension + extension = SpeechmaticsTTSExtension("test_extension") assert isinstance(extension, AsyncTTS2HttpExtension) From 018d3af72ff8bae233668f07cb50ea28c8748ee6 Mon Sep 17 00:00:00 2001 From: Success Nwachukwu Date: Wed, 21 Jan 2026 13:00:49 +0000 Subject: [PATCH 04/10] fix: move speechmatics_tts_python to correct ai_agents folder path --- .../ten_packages/extension/speechmatics_tts_python/README.md | 0 .../ten_packages/extension/speechmatics_tts_python/__init__.py | 0 .../ten_packages/extension/speechmatics_tts_python/addon.py | 0 .../ten_packages/extension/speechmatics_tts_python/config.py | 0 .../ten_packages/extension/speechmatics_tts_python/extension.py | 0 .../ten_packages/extension/speechmatics_tts_python/manifest.json | 0 .../ten_packages/extension/speechmatics_tts_python/property.json | 0 .../extension/speechmatics_tts_python/requirements.txt | 0 .../extension/speechmatics_tts_python/speechmatics_tts.py | 0 .../extension/speechmatics_tts_python/tests/__init__.py | 0 .../extension/speechmatics_tts_python/tests/test_config.py | 0 .../extension/speechmatics_tts_python/tests/test_extension.py | 0 12 files changed, 0 insertions(+), 0 deletions(-) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/README.md (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/__init__.py (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/addon.py (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/config.py (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/extension.py (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/manifest.json (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/property.json (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/requirements.txt (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/tests/__init__.py (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/tests/test_config.py (100%) rename {agents => ai_agents/agents}/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py (100%) diff --git a/agents/ten_packages/extension/speechmatics_tts_python/README.md b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/README.md similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/README.md rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/README.md diff --git a/agents/ten_packages/extension/speechmatics_tts_python/__init__.py b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/__init__.py similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/__init__.py rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/__init__.py diff --git a/agents/ten_packages/extension/speechmatics_tts_python/addon.py b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/addon.py similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/addon.py rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/addon.py diff --git a/agents/ten_packages/extension/speechmatics_tts_python/config.py b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/config.py similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/config.py rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/config.py diff --git a/agents/ten_packages/extension/speechmatics_tts_python/extension.py b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/extension.py similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/extension.py rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/extension.py diff --git a/agents/ten_packages/extension/speechmatics_tts_python/manifest.json b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/manifest.json similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/manifest.json rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/manifest.json diff --git a/agents/ten_packages/extension/speechmatics_tts_python/property.json b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/property.json similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/property.json rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/property.json diff --git a/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/requirements.txt rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt diff --git a/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/speechmatics_tts.py diff --git a/agents/ten_packages/extension/speechmatics_tts_python/tests/__init__.py b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/__init__.py similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/tests/__init__.py rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/__init__.py diff --git a/agents/ten_packages/extension/speechmatics_tts_python/tests/test_config.py b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/test_config.py similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/tests/test_config.py rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/test_config.py diff --git a/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py similarity index 100% rename from agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py rename to ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/test_extension.py From 063eba6aee77cdc2bef9817bcc06f59542eae410 Mon Sep 17 00:00:00 2001 From: Success Nwachukwu Date: Thu, 29 Jan 2026 14:04:00 +0000 Subject: [PATCH 05/10] chore: retrigger CI checks From d7b52c6297954d1eb3ed853adabfcf4db5187f30 Mon Sep 17 00:00:00 2001 From: Success Nwachukwu Date: Thu, 29 Jan 2026 15:44:16 +0000 Subject: [PATCH 06/10] fix: add test runner scripts for speechmatics_tts_python extension --- .../tests/bin/bootstrap | 6 ++++++ .../tests/bin/bootstrap_and_start | 8 +++++++ .../speechmatics_tts_python/tests/bin/start | 21 +++++++++++++++++++ 3 files changed, 35 insertions(+) create mode 100644 ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap create mode 100644 ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap_and_start create mode 100644 ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/start diff --git a/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap new file mode 100644 index 0000000000..1a54df5c55 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap @@ -0,0 +1,6 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." +pip install -r requirements.txt diff --git a/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap_and_start b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap_and_start new file mode 100644 index 0000000000..89aaef454b --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap_and_start @@ -0,0 +1,8 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." + +./tests/bin/bootstrap +./tests/bin/start diff --git a/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/start b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/start new file mode 100644 index 0000000000..b736ea0de1 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/start @@ -0,0 +1,21 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." + +export PYTHONPATH=.ten/app:.ten/app/ten_packages/system/ten_runtime_python/lib:.ten/app/ten_packages/system/ten_runtime_python/interface:.ten/app/ten_packages/system/ten_ai_base/interface:$PYTHONPATH + +# If the Python app imports some modules that are compiled with a different +# version of libstdc++ (ex: PyTorch), the Python app may encounter confusing +# errors. To solve this problem, we can preload the correct version of +# libstdc++. +# +# export LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 +# +# Another solution is to make sure the module 'ten_runtime_python' is imported +# _after_ the module that requires another version of libstdc++ is imported. +# +# Refer to https://github.com/pytorch/pytorch/issues/102360?from_wecom=1#issuecomment-1708989096 + +pytest -s tests/ "$@" From 76a2c17ef824f7f1faf9ef0a30c79e4037d4a5b6 Mon Sep 17 00:00:00 2001 From: Success Nwachukwu Date: Thu, 29 Jan 2026 15:53:41 +0000 Subject: [PATCH 07/10] fix: make test scripts executable --- .../extension/speechmatics_tts_python/tests/bin/bootstrap | 0 .../speechmatics_tts_python/tests/bin/bootstrap_and_start | 0 .../extension/speechmatics_tts_python/tests/bin/start | 0 3 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap mode change 100644 => 100755 ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap_and_start mode change 100644 => 100755 ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/start diff --git a/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap old mode 100644 new mode 100755 diff --git a/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap_and_start b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/bootstrap_and_start old mode 100644 new mode 100755 diff --git a/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/start b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/tests/bin/start old mode 100644 new mode 100755 From 5726a9b80696bf6e1b88896fdc9c3ea1ef82f58d Mon Sep 17 00:00:00 2001 From: Success Nwachukwu Date: Thu, 29 Jan 2026 16:04:48 +0000 Subject: [PATCH 08/10] fix: add pytest to requirements.txt for test execution --- .../extension/speechmatics_tts_python/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt index ebfb7095a6..f68b8c0636 100644 --- a/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt +++ b/ai_agents/agents/ten_packages/extension/speechmatics_tts_python/requirements.txt @@ -1,2 +1,3 @@ aiohttp>=3.8.0 pydantic>=2.0.0 +pytest==8.3.4 From 81cc390f70349a6dae134c97112205200f6461fd Mon Sep 17 00:00:00 2001 From: Success Nwachukwu Date: Sun, 1 Feb 2026 01:34:17 +0000 Subject: [PATCH 09/10] fix: disable E0401 import-error in pylint config Extension dependencies are not installed during the lint phase, causing false positive import errors. Disabling E0401 allows pylint to focus on actual code quality issues. --- tools/pylint/.pylintrc | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/pylint/.pylintrc b/tools/pylint/.pylintrc index b3e18af743..f0a451d085 100644 --- a/tools/pylint/.pylintrc +++ b/tools/pylint/.pylintrc @@ -5,6 +5,7 @@ disable= C0116, # missing-function-docstring W0718, # broad-exception-caught W0621, # redefined-outer-name + E0401, # import-error - extension dependencies not installed during lint [MAIN] analyse-fallback-blocks=no From 7f03e93a95d654eec6ee77ff6c09a651d22864dc Mon Sep 17 00:00:00 2001 From: Success Nwachukwu Date: Tue, 3 Feb 2026 03:29:39 +0000 Subject: [PATCH 10/10] chore: remove incorrect agents folder at root level --- .../nvidia_riva_tts_python/IMPLEMENTATION.md | 269 ---------------- .../nvidia_riva_tts_python/README.md | 93 ------ .../nvidia_riva_tts_python/__init__.py | 7 - .../extension/nvidia_riva_tts_python/addon.py | 18 -- .../nvidia_riva_tts_python/config.py | 44 --- .../nvidia_riva_tts_python/extension.py | 264 ---------------- .../nvidia_riva_tts_python/manifest.json | 57 ---- .../nvidia_riva_tts_python/property.json | 9 - .../nvidia_riva_tts_python/requirements.txt | 2 - .../nvidia_riva_tts_python/riva_tts.py | 143 --------- .../nvidia_riva_tts_python/tests/__init__.py | 4 - .../tests/test_compliance.py | 294 ------------------ .../tests/test_config.py | 67 ---- .../tests/test_extension.py | 134 -------- 14 files changed, 1405 deletions(-) delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/README.md delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/addon.py delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/config.py delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/extension.py delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/property.json delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py delete mode 100644 agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md b/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md deleted file mode 100644 index b167e3d0f8..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md +++ /dev/null @@ -1,269 +0,0 @@ -# NVIDIA Riva TTS Extension - Implementation Details - -## Overview - -This document describes the implementation of the NVIDIA Riva TTS extension for TEN Framework. The extension provides high-quality, GPU-accelerated text-to-speech synthesis using NVIDIA Riva Speech Skills. - -## Architecture - -### Component Structure - -``` -nvidia_riva_tts_python/ -├── extension.py # Main extension class -├── riva_tts.py # Riva client implementation -├── config.py # Configuration model -├── addon.py # Extension registration -├── manifest.json # Extension metadata -├── property.json # Default properties -├── requirements.txt # Python dependencies -├── README.md # User documentation -└── tests/ # Test suite - ├── test_config.py - └── test_extension.py -``` - -### Class Hierarchy - -``` -AsyncTTSExtension (base class from ten_ai_base) - └── NvidiaRivaTTSExtension - └── uses NvidiaRivaTTSClient - └── uses riva.client.SpeechSynthesisService -``` - -## Implementation Details - -### 1. Extension Class (`extension.py`) - -The `NvidiaRivaTTSExtension` class inherits from `AsyncTTSExtension` and implements the required abstract methods: - -- **`create_config()`**: Parses JSON configuration into `NvidiaRivaTTSConfig` -- **`create_client()`**: Instantiates `NvidiaRivaTTSClient` with configuration -- **`vendor()`**: Returns "nvidia_riva" as the vendor identifier -- **`synthesize_audio_sample_rate()`**: Returns the configured sample rate - -### 2. Client Implementation (`riva_tts.py`) - -The `NvidiaRivaTTSClient` class handles the actual TTS synthesis: - -#### Initialization -- Creates Riva Auth object with server URI and SSL settings -- Initializes `SpeechSynthesisService` for TTS operations -- Validates server connectivity - -#### Synthesis Method -```python -async def synthesize(self, text: str, request_id: str) -> AsyncIterator[bytes] -``` - -**Flow:** -1. Validates input text (non-empty) -2. Calls `tts_service.synthesize_online()` for streaming synthesis -3. Iterates through audio chunks from Riva -4. Converts audio data to PCM bytes -5. Yields audio chunks for streaming playback -6. Handles cancellation requests - -**Key Features:** -- Streaming synthesis for low latency -- Cancellation support via `_is_cancelled` flag -- Comprehensive logging at each step -- Error handling with detailed messages - -### 3. Configuration (`config.py`) - -The `NvidiaRivaTTSConfig` class extends `AsyncTTSConfig`: - -**Required Parameters:** -- `server`: Riva server address (host:port) -- `language_code`: Language identifier (e.g., "en-US") -- `voice_name`: Voice identifier (e.g., "English-US.Female-1") - -**Optional Parameters:** -- `sample_rate`: Audio sample rate in Hz (default: 16000) -- `use_ssl`: Enable SSL for gRPC (default: false) - -**Validation:** -- Ensures all required parameters are present -- Validates parameter types and formats - -### 4. Addon Registration (`addon.py`) - -Registers the extension with TEN Framework using the `@register_addon_as_extension` decorator. - -## Integration with TEN Framework - -### TTS Interface Compliance - -The extension implements the standard TEN Framework TTS interface defined in `ten_ai_base/api/tts-interface.json`: - -- **Input**: Text data via TEN data messages -- **Output**: PCM audio frames via TEN audio frame messages -- **Control**: Start/stop/cancel commands via TEN commands - -### Message Flow - -``` -1. Text Input → Extension receives text data -2. Configuration → Loads voice, language, sample rate -3. Synthesis → Calls Riva API with streaming -4. Audio Output → Yields PCM audio chunks -5. Completion → Signals end of synthesis -``` - -## NVIDIA Riva Integration - -### gRPC API Usage - -The extension uses the official `nvidia-riva-client` Python package which provides: - -- **Auth**: Authentication and connection management -- **SpeechSynthesisService**: TTS API wrapper -- **AudioEncoding**: Audio format specifications - -### Streaming vs Batch - -The implementation uses **streaming synthesis** (`synthesize_online`) for: -- Lower latency (first audio chunk arrives quickly) -- Better user experience in real-time applications -- Efficient memory usage - -Alternative batch mode (`synthesize`) is available but not used by default. - -### Audio Format - -- **Encoding**: LINEAR_PCM (16-bit signed integer) -- **Sample Rate**: Configurable (default 16000 Hz) -- **Channels**: Mono -- **Byte Order**: Little-endian - -## Error Handling - -### Initialization Errors -- Server unreachable → RuntimeError with connection details -- Invalid credentials → Authentication error -- Missing dependencies → Import error - -### Runtime Errors -- Empty text → Warning logged, no synthesis -- Synthesis failure → RuntimeError with Riva error message -- Cancellation → Graceful stop, log cancellation - -### Logging Strategy - -- **INFO**: Initialization, configuration -- **DEBUG**: Synthesis progress, chunk details -- **WARN**: Empty text, unusual conditions -- **ERROR**: Failures, exceptions - -## Testing - -### Test Coverage - -1. **Configuration Tests** (`test_config.py`) - - Valid configuration creation - - Missing required parameters - - Default values - - Validation logic - -2. **Extension Tests** (`test_extension.py`) - - Extension initialization - - Config creation from JSON - - Sample rate retrieval - - Client creation - -3. **Client Tests** (`test_extension.py`) - - Client initialization with mocked Riva - - Cancellation handling - - Empty text handling - - Synthesis with mocked responses - -### Running Tests - -```bash -# Install test dependencies -pip install pytest pytest-asyncio - -# Run all tests -pytest nvidia_riva_tts_python/tests/ -v - -# Run with coverage -pytest nvidia_riva_tts_python/tests/ --cov=nvidia_riva_tts_python -``` - -## Performance Considerations - -### Latency -- **First chunk**: ~100-200ms (depends on text length and server) -- **Streaming**: Continuous audio delivery -- **GPU acceleration**: Significantly faster than CPU-only TTS - -### Resource Usage -- **Memory**: Minimal (streaming mode) -- **Network**: gRPC connection to Riva server -- **CPU**: Low (Riva does GPU processing) - -### Optimization Tips -1. Use streaming mode for real-time applications -2. Keep Riva server close to application (low network latency) -3. Reuse client connections (handled by extension) -4. Configure appropriate sample rate for use case - -## Deployment - -### Prerequisites -1. NVIDIA Riva server running (see README.md for setup) -2. Network connectivity to Riva server -3. Python 3.8+ with nvidia-riva-client - -### Configuration Example - -```json -{ - "params": { - "server": "riva-server.example.com:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "sample_rate": 22050, - "use_ssl": true - } -} -``` - -### Environment Variables - -```bash -export NVIDIA_RIVA_SERVER=localhost:50051 -``` - -## Future Enhancements - -Potential improvements for future versions: - -1. **SSML Support**: Full SSML tag support for advanced speech control -2. **Voice Cloning**: Custom voice model support -3. **Multi-language**: Automatic language detection -4. **Caching**: Cache frequently synthesized phrases -5. **Metrics**: Detailed performance metrics and monitoring -6. **Fallback**: Automatic fallback to alternative TTS if Riva unavailable - -## References - -- [NVIDIA Riva Documentation](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/index.html) -- [Riva Python Client](https://pypi.org/project/nvidia-riva-client/) -- [TEN Framework TTS Interface](https://github.com/TEN-framework/ten-framework) -- [gRPC Python](https://grpc.io/docs/languages/python/) - -## License - -Apache 2.0 - See LICENSE file in the TEN Framework repository. - -## Contributing - -Contributions are welcome! Please: -1. Follow the existing code style -2. Add tests for new features -3. Update documentation -4. Submit PR to TEN Framework repository - diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/README.md b/agents/ten_packages/extension/nvidia_riva_tts_python/README.md deleted file mode 100644 index d0452fc9c0..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/README.md +++ /dev/null @@ -1,93 +0,0 @@ -# NVIDIA Riva TTS Python Extension - -This extension provides text-to-speech functionality using NVIDIA Riva Speech Skills. - -## Features - -- High-quality speech synthesis using NVIDIA Riva -- Support for multiple languages and voices -- Streaming and batch synthesis modes -- SSML support for advanced speech control -- GPU-accelerated inference for low latency - -## Prerequisites - -- NVIDIA Riva server running and accessible -- Python 3.8+ -- nvidia-riva-client package - -## Configuration - -The extension can be configured through your property.json: - -```json -{ - "params": { - "server": "localhost:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "sample_rate": 16000, - "use_ssl": false - } -} -``` - -### Configuration Options - -**Parameters inside `params` object:** -- `server` (required): Riva server address (format: "host:port") -- `language_code` (required): Language code (e.g., "en-US", "es-ES") -- `voice_name` (required): Voice identifier (e.g., "English-US.Female-1") -- `sample_rate` (optional): Audio sample rate in Hz (default: 16000) -- `use_ssl` (optional): Use SSL for gRPC connection (default: false) - -### Available Voices - -Common voice names include: -- `English-US.Female-1` -- `English-US.Male-1` -- `English-GB.Female-1` -- `Spanish-US.Female-1` - -Check your Riva server documentation for the full list of available voices. - -## Setting up NVIDIA Riva Server - -Follow the [NVIDIA Riva Quick Start Guide](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html) to set up a Riva server. - -Quick setup with Docker: - -```bash -# Download Riva Quick Start scripts -ngc registry resource download-version nvidia/riva/riva_quickstart:2.17.0 - -# Initialize and start Riva -cd riva_quickstart_v2.17.0 -bash riva_init.sh -bash riva_start.sh -``` - -## Environment Variables - -Set the Riva server address via environment variable: - -```bash -export NVIDIA_RIVA_SERVER=localhost:50051 -``` - -## Architecture - -This extension follows the TEN Framework TTS extension pattern: - -- `extension.py`: Main extension class -- `riva_tts.py`: Client implementation with Riva SDK integration -- `config.py`: Configuration model -- `addon.py`: Extension addon registration - -## License - -Apache 2.0 - -## Contributing - -Contributions are welcome! Please submit issues and pull requests to the TEN Framework repository. diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py b/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py deleted file mode 100644 index 2718464193..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# -from . import addon - -__all__ = ["addon"] diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py b/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py deleted file mode 100644 index f880f2d97c..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py +++ /dev/null @@ -1,18 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# -from ten_runtime import ( - Addon, - register_addon_as_extension, - TenEnv, -) - - -@register_addon_as_extension("nvidia_riva_tts_python") -class NvidiaRivaTTSExtensionAddon(Addon): - def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: - from .extension import NvidiaRivaTTSExtension - - ten_env.log_info("NvidiaRivaTTSExtensionAddon on_create_instance") - ten_env.on_create_instance_done(NvidiaRivaTTSExtension(name), context) diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/config.py b/agents/ten_packages/extension/nvidia_riva_tts_python/config.py deleted file mode 100644 index ed044bceed..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/config.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# -from typing import Any -import copy -from pydantic import Field -from pathlib import Path -from ten_ai_base import utils -from ten_ai_base.tts import AsyncTTSConfig - - -class NvidiaRivaTTSConfig(AsyncTTSConfig): - """NVIDIA Riva TTS Config""" - - dump: bool = Field(default=False, description="NVIDIA Riva TTS dump") - dump_path: str = Field( - default_factory=lambda: str(Path(__file__).parent / "nvidia_riva_tts_in.pcm"), - description="NVIDIA Riva TTS dump path", - ) - params: dict[str, Any] = Field( - default_factory=dict, description="NVIDIA Riva TTS params" - ) - - def update_params(self) -> None: - """Update configuration from params dictionary""" - pass - - def to_str(self, sensitive_handling: bool = True) -> str: - """Convert config to string with optional sensitive data handling.""" - if not sensitive_handling: - return f"{self}" - - config = copy.deepcopy(self) - return f"{config}" - - def validate(self) -> None: - """Validate NVIDIA Riva-specific configuration.""" - if "server" not in self.params or not self.params["server"]: - raise ValueError("Server address is required for NVIDIA Riva TTS") - if "language_code" not in self.params or not self.params["language_code"]: - raise ValueError("Language code is required for NVIDIA Riva TTS") - if "voice_name" not in self.params or not self.params["voice_name"]: - raise ValueError("Voice name is required for NVIDIA Riva TTS") diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py b/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py deleted file mode 100644 index 44523b89d9..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py +++ /dev/null @@ -1,264 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file for more information. -# -""" -NVIDIA Riva TTS Extension - -This extension implements text-to-speech using NVIDIA Riva Speech Skills. -It provides high-quality, GPU-accelerated speech synthesis. -""" - -import asyncio -import time -import traceback -from typing import Optional - -from ten_ai_base.message import ( - ModuleError, - ModuleErrorCode, - ModuleType, - TTSAudioEndReason, -) -from ten_ai_base.struct import TTSTextInput -from ten_ai_base.tts2 import AsyncTTS2BaseExtension, RequestState -from ten_ai_base.const import LOG_CATEGORY_KEY_POINT, LOG_CATEGORY_VENDOR -from ten_runtime import AsyncTenEnv - -from .config import NvidiaRivaTTSConfig -from .riva_tts import NvidiaRivaTTSClient - - -class NvidiaRivaTTSExtension(AsyncTTS2BaseExtension): - """ - NVIDIA Riva TTS Extension implementation. - - Provides text-to-speech synthesis using NVIDIA Riva's gRPC API. - Inherits all common TTS functionality from AsyncTTS2BaseExtension. - """ - - def __init__(self, name: str) -> None: - super().__init__(name) - self.config: Optional[NvidiaRivaTTSConfig] = None - self.client: Optional[NvidiaRivaTTSClient] = None - self.current_request_id: Optional[str] = None - self.request_start_ts: float = 0 - self.first_chunk_ts: float = 0 - self.request_total_audio_duration: int = 0 - self.flush_request_id: Optional[str] = None - self.last_end_request_id: Optional[str] = None - self.audio_start_sent: set[str] = set() - - async def on_init(self, ten_env: AsyncTenEnv) -> None: - """Initialize the extension""" - await super().on_init(ten_env) - ten_env.log_debug("NVIDIA Riva TTS on_init") - - try: - # Load configuration - config_json, _ = await ten_env.get_property_to_json("") - self.config = NvidiaRivaTTSConfig.model_validate_json(config_json) - - ten_env.log_info( - f"config: {self.config.model_dump_json()}", - category=LOG_CATEGORY_KEY_POINT, - ) - - # Create client - self.client = NvidiaRivaTTSClient( - config=self.config, - ten_env=ten_env, - ) - - except Exception as e: - ten_env.log_error(f"on_init failed: {traceback.format_exc()}") - await self.send_tts_error( - request_id="", - error=ModuleError( - message=str(e), - module=ModuleType.TTS, - code=ModuleErrorCode.FATAL_ERROR, - vendor_info={"vendor": "nvidia_riva"}, - ), - ) - - async def on_stop(self, ten_env: AsyncTenEnv) -> None: - """Stop the extension""" - await super().on_stop(ten_env) - ten_env.log_debug("NVIDIA Riva TTS on_stop") - - async def on_deinit(self, ten_env: AsyncTenEnv) -> None: - """Deinitialize the extension""" - await super().on_deinit(ten_env) - ten_env.log_debug("NVIDIA Riva TTS on_deinit") - - def vendor(self) -> str: - """Return vendor name""" - return "nvidia_riva" - - def synthesize_audio_sample_rate(self) -> int: - """Return audio sample rate""" - return self.config.params.get("sample_rate", 16000) if self.config else 16000 - - def synthesize_audio_channels(self) -> int: - """Return number of audio channels""" - return 1 - - def synthesize_audio_sample_width(self) -> int: - """Return sample width in bytes""" - return 2 # 16-bit PCM - - async def request_tts(self, t: TTSTextInput) -> None: - """Handle TTS request""" - try: - self.ten_env.log_info( - f"TTS request: text_length={len(t.text)}, " - f"text_input_end={t.text_input_end}, request_id={t.request_id}" - ) - - # Skip if request already completed - if t.request_id == self.flush_request_id: - self.ten_env.log_debug( - f"Request {t.request_id} was flushed, ignoring" - ) - return - - if t.request_id == self.last_end_request_id: - self.ten_env.log_debug( - f"Request {t.request_id} was ended, ignoring" - ) - return - - # Handle new request - is_new_request = self.current_request_id != t.request_id - if is_new_request: - self.ten_env.log_debug(f"New TTS request: {t.request_id}") - self.current_request_id = t.request_id - self.request_total_audio_duration = 0 - self.request_start_ts = time.time() - - if self.client is None: - raise ValueError("TTS client not initialized") - - # Synthesize audio - received_first_chunk = False - async for chunk in self.client.synthesize(t.text, t.request_id): - # Calculate audio duration - duration = self._calculate_audio_duration(len(chunk)) - - self.ten_env.log_debug( - f"receive_audio: duration={duration}ms, request_id={self.current_request_id}", - category=LOG_CATEGORY_VENDOR, - ) - - if not received_first_chunk: - received_first_chunk = True - # Send audio start - if t.request_id not in self.audio_start_sent: - await self.send_tts_audio_start(t.request_id) - self.audio_start_sent.add(t.request_id) - if is_new_request: - # Send TTFB metrics - self.first_chunk_ts = time.time() - elapsed_time = int( - (self.first_chunk_ts - self.request_start_ts) * 1000 - ) - await self.send_tts_ttfb_metrics( - request_id=t.request_id, - ttfb_ms=elapsed_time, - extra_metadata={ - "voice_name": self.config.params["voice_name"], - "language_code": self.config.params["language_code"], - }, - ) - - if t.request_id == self.flush_request_id: - break - - self.request_total_audio_duration += duration - await self.send_tts_audio_data(chunk) - - # Handle completion - if t.text_input_end or t.request_id == self.flush_request_id: - reason = TTSAudioEndReason.REQUEST_END - if t.request_id == self.flush_request_id: - reason = TTSAudioEndReason.INTERRUPTED - - if self.first_chunk_ts > 0: - await self._handle_completed_request(reason) - - except Exception as e: - self.ten_env.log_error(f"Error in request_tts: {traceback.format_exc()}") - await self.send_tts_error( - request_id=t.request_id, - error=ModuleError( - message=str(e), - module=ModuleType.TTS, - code=ModuleErrorCode.NON_FATAL_ERROR, - vendor_info={"vendor": "nvidia_riva"}, - ), - ) - - # Check if we've received text_input_end - has_received_text_input_end = False - if t.request_id and t.request_id in self.request_states: - if self.request_states[t.request_id] == RequestState.FINALIZING: - has_received_text_input_end = True - - if has_received_text_input_end: - await self._handle_completed_request(TTSAudioEndReason.ERROR) - - async def cancel_tts(self) -> None: - """Cancel current TTS request""" - self.ten_env.log_info(f"cancel_tts current_request_id: {self.current_request_id}") - if self.current_request_id is not None: - self.flush_request_id = self.current_request_id - - if self.client: - await self.client.cancel() - - if self.current_request_id and self.first_chunk_ts > 0: - await self._handle_completed_request(TTSAudioEndReason.INTERRUPTED) - - async def _handle_completed_request(self, reason: TTSAudioEndReason) -> None: - """Handle completed TTS request""" - if not self.current_request_id: - return - - self.last_end_request_id = self.current_request_id - - # Calculate metrics - request_event_interval = 0 - if self.first_chunk_ts > 0: - request_event_interval = int( - (time.time() - self.first_chunk_ts) * 1000 - ) - - # Send audio end - await self.send_tts_audio_end( - request_id=self.current_request_id, - request_event_interval_ms=request_event_interval, - request_total_audio_duration_ms=self.request_total_audio_duration, - reason=reason, - ) - - self.ten_env.log_debug( - f"Sent tts_audio_end: reason={reason.name}, request_id={self.current_request_id}" - ) - - # Finish request - await self.finish_request(request_id=self.current_request_id, reason=reason) - - # Reset state - self.first_chunk_ts = 0 - self.audio_start_sent.discard(self.current_request_id) - - def _calculate_audio_duration(self, bytes_length: int) -> int: - """Calculate audio duration in milliseconds""" - bytes_per_second = ( - self.synthesize_audio_sample_rate() - * self.synthesize_audio_channels() - * self.synthesize_audio_sample_width() - ) - return int((bytes_length / bytes_per_second) * 1000) diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json b/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json deleted file mode 100644 index e48a071a40..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "type": "extension", - "name": "nvidia_riva_tts_python", - "version": "0.1.0", - "dependencies": [ - { - "type": "system", - "name": "ten_runtime_python", - "version": "0.11" - }, - { - "type": "system", - "name": "ten_ai_base", - "version": "0.7" - } - ], - "package": { - "include": [ - "manifest.json", - "property.json", - "**.py", - "README.md", - "requirements.txt" - ] - }, - "api": { - "interface": [ - { - "import_uri": "../../system/ten_ai_base/api/tts-interface.json" - } - ], - "property": { - "properties": { - "params": { - "type": "object", - "properties": { - "server": { - "type": "string" - }, - "language_code": { - "type": "string" - }, - "voice_name": { - "type": "string" - }, - "sample_rate": { - "type": "int64" - }, - "use_ssl": { - "type": "bool" - } - } - } - } - } - } -} diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/property.json b/agents/ten_packages/extension/nvidia_riva_tts_python/property.json deleted file mode 100644 index 022a606664..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/property.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "params": { - "server": "${env:NVIDIA_RIVA_SERVER|localhost:50051}", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "sample_rate": 16000, - "use_ssl": false - } -} diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt b/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt deleted file mode 100644 index f178d839f4..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -nvidia-riva-client>=2.17.0 -numpy>=1.21.0 diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py b/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py deleted file mode 100644 index 04f7e5921b..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py +++ /dev/null @@ -1,143 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file for more information. -# -from typing import AsyncIterator -import numpy as np -import riva.client -from ten_runtime import AsyncTenEnv -from ten_ai_base.const import LOG_CATEGORY_VENDOR - -from .config import NvidiaRivaTTSConfig - - -class NvidiaRivaTTSClient: - """NVIDIA Riva TTS Client implementation""" - - def __init__( - self, - config: NvidiaRivaTTSConfig, - ten_env: AsyncTenEnv, - ): - self.config = config - self.ten_env: AsyncTenEnv = ten_env - self._is_cancelled = False - self.auth = None - self.tts_service = None - - try: - # Initialize Riva client - server = config.params["server"] - use_ssl = config.params.get("use_ssl", False) - - self.ten_env.log_info( - f"Initializing NVIDIA Riva TTS client with server: {server}, SSL: {use_ssl}", - category=LOG_CATEGORY_VENDOR, - ) - - self.auth = riva.client.Auth(ssl_cert=None, use_ssl=use_ssl, uri=server) - self.tts_service = riva.client.SpeechSynthesisService(self.auth) - - self.ten_env.log_info( - "NVIDIA Riva TTS client initialized successfully", - category=LOG_CATEGORY_VENDOR, - ) - except Exception as e: - ten_env.log_error( - f"Error when initializing NVIDIA Riva TTS: {e}", - category=LOG_CATEGORY_VENDOR, - ) - raise RuntimeError(f"Error when initializing NVIDIA Riva TTS: {e}") from e - - async def cancel(self): - """Cancel the current TTS request""" - self.ten_env.log_debug("NVIDIA Riva TTS: cancel() called.") - self._is_cancelled = True - - async def synthesize(self, text: str, request_id: str) -> AsyncIterator[bytes]: - """ - Synthesize speech from text using NVIDIA Riva TTS. - - Args: - text: Text to synthesize - request_id: Unique request identifier - - Yields: - Audio data as bytes (PCM format) - """ - self._is_cancelled = False - - if not self.tts_service: - self.ten_env.log_error( - f"NVIDIA Riva TTS: service not initialized for request_id: {request_id}", - category=LOG_CATEGORY_VENDOR, - ) - raise RuntimeError( - f"NVIDIA Riva TTS: service not initialized for request_id: {request_id}" - ) - - if len(text.strip()) == 0: - self.ten_env.log_warn( - f"NVIDIA Riva TTS: empty text for request_id: {request_id}", - category=LOG_CATEGORY_VENDOR, - ) - return - - try: - language_code = self.config.params["language_code"] - voice_name = self.config.params["voice_name"] - sample_rate = self.config.params.get("sample_rate", 16000) - - self.ten_env.log_debug( - f"NVIDIA Riva TTS: synthesizing text (length: {len(text)}) " - f"with voice: {voice_name}, language: {language_code}, " - f"sample_rate: {sample_rate}, request_id: {request_id}", - category=LOG_CATEGORY_VENDOR, - ) - - # Use streaming synthesis for lower latency - responses = self.tts_service.synthesize_online( - text, - voice_name=voice_name, - language_code=language_code, - sample_rate_hz=sample_rate, - encoding=riva.client.AudioEncoding.LINEAR_PCM, - ) - - # Stream audio chunks - for response in responses: - if self._is_cancelled: - self.ten_env.log_debug( - f"Cancellation detected, stopping TTS stream for request_id: {request_id}" - ) - break - - # Convert audio bytes to numpy array and back to bytes - # This ensures proper format - audio_data = np.frombuffer(response.audio, dtype=np.int16) - - self.ten_env.log_debug( - f"NVIDIA Riva TTS: yielding audio chunk, " - f"length: {len(audio_data)} samples, request_id: {request_id}", - category=LOG_CATEGORY_VENDOR, - ) - - yield audio_data.tobytes() - - if not self._is_cancelled: - self.ten_env.log_debug( - f"NVIDIA Riva TTS: synthesis completed for request_id: {request_id}", - category=LOG_CATEGORY_VENDOR, - ) - - except Exception as e: - error_message = str(e) - self.ten_env.log_error( - f"NVIDIA Riva TTS: error during synthesis: {error_message}, " - f"request_id: {request_id}", - category=LOG_CATEGORY_VENDOR, - ) - raise RuntimeError( - f"NVIDIA Riva TTS synthesis failed: {error_message}" - ) from e diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py deleted file mode 100644 index b8c07eef1c..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py deleted file mode 100644 index 3384bb95d5..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py +++ /dev/null @@ -1,294 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# -""" -Compliance tests to ensure the extension correctly implements NVIDIA Riva TTS API. -These tests validate against the official NVIDIA Riva client API specifications. -""" -import pytest -from unittest.mock import Mock, patch, MagicMock -import numpy as np -from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig -from nvidia_riva_tts_python.riva_tts import NvidiaRivaTTSClient - - -class TestNvidiaRivaAPICompliance: - """Test compliance with NVIDIA Riva TTS API specifications""" - - @pytest.fixture - def mock_ten_env(self): - """Create a mock TenEnv""" - env = Mock() - env.log_info = Mock() - env.log_debug = Mock() - env.log_warn = Mock() - env.log_error = Mock() - return env - - @pytest.fixture - def valid_config(self): - """Create a valid configuration""" - return NvidiaRivaTTSConfig( - params={ - "server": "localhost:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "sample_rate": 16000, - "use_ssl": False, - } - ) - - def test_auth_initialization_parameters(self, valid_config, mock_ten_env): - """Verify Auth is initialized with correct parameters per Riva API""" - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - - # Verify Auth called with correct parameters - mock_auth.assert_called_once_with( - ssl_cert=None, - use_ssl=False, - uri="localhost:50051" - ) - - def test_speech_synthesis_service_initialization(self, valid_config, mock_ten_env): - """Verify SpeechSynthesisService is initialized with Auth object""" - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: - - mock_auth_instance = Mock() - mock_auth.return_value = mock_auth_instance - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - - # Verify SpeechSynthesisService called with Auth instance - mock_service.assert_called_once_with(mock_auth_instance) - - @pytest.mark.asyncio - async def test_synthesize_online_parameters(self, valid_config, mock_ten_env): - """Verify synthesize_online is called with correct parameters per Riva API""" - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service, \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.AudioEncoding') as mock_encoding: - - # Setup mocks - mock_service_instance = Mock() - mock_response = Mock() - mock_response.audio = b'\x00\x01' * 100 - mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) - mock_service.return_value = mock_service_instance - mock_encoding.LINEAR_PCM = "LINEAR_PCM" - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - client.tts_service = mock_service_instance - - # Synthesize text - text = "Hello world" - chunks = [chunk async for chunk in client.synthesize(text, "test_request")] - - # Verify synthesize_online called with correct parameters - mock_service_instance.synthesize_online.assert_called_once_with( - text, - voice_name="English-US.Female-1", - language_code="en-US", - sample_rate_hz=16000, - encoding="LINEAR_PCM" - ) - - @pytest.mark.asyncio - async def test_audio_encoding_linear_pcm(self, valid_config, mock_ten_env): - """Verify LINEAR_PCM encoding is used per Riva API""" - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service, \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.AudioEncoding') as mock_encoding: - - mock_service_instance = Mock() - mock_response = Mock() - mock_response.audio = b'\x00\x01' * 100 - mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) - mock_service.return_value = mock_service_instance - mock_encoding.LINEAR_PCM = "LINEAR_PCM" - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - client.tts_service = mock_service_instance - - # Synthesize - chunks = [chunk async for chunk in client.synthesize("Test", "req1")] - - # Verify encoding parameter - call_args = mock_service_instance.synthesize_online.call_args - assert call_args[1]['encoding'] == "LINEAR_PCM" - - @pytest.mark.asyncio - async def test_audio_format_int16(self, valid_config, mock_ten_env): - """Verify audio is processed as int16 per Riva API""" - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: - - # Create mock audio data (int16 format) - mock_audio = np.array([100, -100, 200, -200], dtype=np.int16).tobytes() - mock_response = Mock() - mock_response.audio = mock_audio - - mock_service_instance = Mock() - mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) - mock_service.return_value = mock_service_instance - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - client.tts_service = mock_service_instance - - # Synthesize - chunks = [chunk async for chunk in client.synthesize("Test", "req1")] - - # Verify output is bytes - assert len(chunks) == 1 - assert isinstance(chunks[0], bytes) - - # Verify can be converted back to int16 - audio_array = np.frombuffer(chunks[0], dtype=np.int16) - assert audio_array.dtype == np.int16 - assert len(audio_array) == 4 - - @pytest.mark.asyncio - async def test_streaming_response_iteration(self, valid_config, mock_ten_env): - """Verify streaming responses are iterated correctly per Riva API""" - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: - - # Create multiple response chunks - mock_responses = [] - for i in range(3): - mock_response = Mock() - mock_response.audio = np.array([i] * 10, dtype=np.int16).tobytes() - mock_responses.append(mock_response) - - mock_service_instance = Mock() - mock_service_instance.synthesize_online = Mock(return_value=mock_responses) - mock_service.return_value = mock_service_instance - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - client.tts_service = mock_service_instance - - # Synthesize - chunks = [chunk async for chunk in client.synthesize("Test", "req1")] - - # Verify all chunks received - assert len(chunks) == 3 - for chunk in chunks: - assert isinstance(chunk, bytes) - assert len(chunk) > 0 - - def test_required_config_parameters(self): - """Verify all required parameters are validated per Riva API""" - # Missing server - with pytest.raises(ValueError, match="Server address is required"): - config = NvidiaRivaTTSConfig( - params={"language_code": "en-US", "voice_name": "English-US.Female-1"} - ) - config.validate() - - # Missing language_code - with pytest.raises(ValueError, match="Language code is required"): - config = NvidiaRivaTTSConfig( - params={"server": "localhost:50051", "voice_name": "English-US.Female-1"} - ) - config.validate() - - # Missing voice_name - with pytest.raises(ValueError, match="Voice name is required"): - config = NvidiaRivaTTSConfig( - params={"server": "localhost:50051", "language_code": "en-US"} - ) - config.validate() - - def test_optional_config_parameters(self, valid_config): - """Verify optional parameters have correct defaults per Riva API""" - # sample_rate defaults to 16000 - assert valid_config.params.get("sample_rate", 16000) == 16000 - - # use_ssl defaults to False - assert valid_config.params.get("use_ssl", False) is False - - def test_supported_sample_rates(self): - """Verify common sample rates are supported per Riva API""" - supported_rates = [8000, 16000, 22050, 24000, 44100, 48000] - - for rate in supported_rates: - config = NvidiaRivaTTSConfig( - params={ - "server": "localhost:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "sample_rate": rate, - } - ) - config.validate() # Should not raise - assert config.params["sample_rate"] == rate - - def test_ssl_configuration(self, mock_ten_env): - """Verify SSL can be enabled per Riva API""" - config_with_ssl = NvidiaRivaTTSConfig( - params={ - "server": "secure-server:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "use_ssl": True, - } - ) - - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): - - client = NvidiaRivaTTSClient(config=config_with_ssl, ten_env=mock_ten_env) - - # Verify SSL enabled in Auth - call_args = mock_auth.call_args - assert call_args[1]['use_ssl'] is True - - @pytest.mark.asyncio - async def test_empty_text_handling(self, valid_config, mock_ten_env): - """Verify empty text is handled gracefully per Riva API""" - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - - # Empty string - chunks = [chunk async for chunk in client.synthesize("", "req1")] - assert len(chunks) == 0 - - # Whitespace only - chunks = [chunk async for chunk in client.synthesize(" ", "req1")] - assert len(chunks) == 0 - - @pytest.mark.asyncio - async def test_cancellation_support(self, valid_config, mock_ten_env): - """Verify cancellation is supported per Riva API""" - with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ - patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: - - # Create multiple responses to simulate long synthesis - mock_responses = [Mock(audio=b'\x00\x01' * 100) for _ in range(10)] - mock_service_instance = Mock() - mock_service_instance.synthesize_online = Mock(return_value=mock_responses) - mock_service.return_value = mock_service_instance - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - client.tts_service = mock_service_instance - - # Start synthesis and cancel mid-stream - chunks = [] - async for i, chunk in enumerate(client.synthesize("Long text", "req1")): - chunks.append(chunk) - if i == 2: # Cancel after 3 chunks - await client.cancel() - - # Verify cancellation stopped the stream - assert len(chunks) < 10 # Should not receive all chunks - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "--tb=short"]) - diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py deleted file mode 100644 index b29bf51adc..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py +++ /dev/null @@ -1,67 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# -import pytest -from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig - - -def test_config_validation(): - """Test configuration validation""" - # Valid config - config = NvidiaRivaTTSConfig( - params={ - "server": "localhost:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "sample_rate": 16000, - } - ) - config.validate() # Should not raise - - # Missing server - with pytest.raises(ValueError, match="Server address is required"): - config = NvidiaRivaTTSConfig( - params={ - "language_code": "en-US", - "voice_name": "English-US.Female-1", - } - ) - config.validate() - - # Missing language_code - with pytest.raises(ValueError, match="Language code is required"): - config = NvidiaRivaTTSConfig( - params={ - "server": "localhost:50051", - "voice_name": "English-US.Female-1", - } - ) - config.validate() - - # Missing voice_name - with pytest.raises(ValueError, match="Voice name is required"): - config = NvidiaRivaTTSConfig( - params={ - "server": "localhost:50051", - "language_code": "en-US", - } - ) - config.validate() - - -def test_config_defaults(): - """Test default configuration values""" - config = NvidiaRivaTTSConfig( - params={ - "server": "localhost:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - } - ) - - assert config.dump is False - assert "nvidia_riva_tts_in.pcm" in config.dump_path - assert config.params["server"] == "localhost:50051" - assert config.params.get("sample_rate", 16000) == 16000 - assert config.params.get("use_ssl", False) is False diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py deleted file mode 100644 index 517a600945..0000000000 --- a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py +++ /dev/null @@ -1,134 +0,0 @@ -# -# This file is part of TEN Framework, an open source project. -# Licensed under the Apache License, Version 2.0. -# -import pytest -from unittest.mock import Mock, AsyncMock, patch, MagicMock -from nvidia_riva_tts_python.extension import NvidiaRivaTTSExtension -from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig -from nvidia_riva_tts_python.riva_tts import NvidiaRivaTTSClient - - -@pytest.fixture -def mock_ten_env(): - """Create a mock TenEnv for testing""" - env = Mock() - env.log_info = Mock() - env.log_debug = Mock() - env.log_warn = Mock() - env.log_error = Mock() - return env - - -@pytest.fixture -def valid_config(): - """Create a valid configuration for testing""" - return NvidiaRivaTTSConfig( - params={ - "server": "localhost:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "sample_rate": 16000, - "use_ssl": False, - } - ) - - -class TestNvidiaRivaTTSExtension: - """Test cases for NvidiaRivaTTSExtension""" - - def test_extension_initialization(self): - """Test extension can be initialized""" - extension = NvidiaRivaTTSExtension("test_extension") - assert extension is not None - assert extension.vendor() == "nvidia_riva" - - @pytest.mark.asyncio - async def test_create_config(self): - """Test configuration creation from JSON""" - extension = NvidiaRivaTTSExtension("test_extension") - config_json = """{ - "params": { - "server": "localhost:50051", - "language_code": "en-US", - "voice_name": "English-US.Female-1", - "sample_rate": 16000 - } - }""" - - config = await extension.create_config(config_json) - assert isinstance(config, NvidiaRivaTTSConfig) - assert config.params["server"] == "localhost:50051" - assert config.params["language_code"] == "en-US" - - def test_synthesize_audio_sample_rate(self, valid_config): - """Test sample rate retrieval""" - extension = NvidiaRivaTTSExtension("test_extension") - extension.config = valid_config - - sample_rate = extension.synthesize_audio_sample_rate() - assert sample_rate == 16000 - - -class TestNvidiaRivaTTSClient: - """Test cases for NvidiaRivaTTSClient""" - - @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') - @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') - def test_client_initialization(self, mock_service, mock_auth, valid_config, mock_ten_env): - """Test client initialization""" - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - - assert client is not None - assert client.config == valid_config - mock_auth.assert_called_once() - mock_service.assert_called_once() - - @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') - @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') - @pytest.mark.asyncio - async def test_cancel(self, mock_service, mock_auth, valid_config, mock_ten_env): - """Test cancellation""" - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - - await client.cancel() - assert client._is_cancelled is True - - @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') - @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') - @pytest.mark.asyncio - async def test_synthesize_empty_text(self, mock_service, mock_auth, valid_config, mock_ten_env): - """Test synthesis with empty text""" - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - - # Should return without yielding anything - result = [chunk async for chunk in client.synthesize("", "test_request")] - assert len(result) == 0 - - @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') - @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') - @pytest.mark.asyncio - async def test_synthesize_with_text(self, mock_service, mock_auth, valid_config, mock_ten_env): - """Test synthesis with valid text""" - # Mock the service response - mock_response = Mock() - mock_response.audio = b'\x00\x01' * 100 # Mock audio data - - mock_service_instance = Mock() - mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) - mock_service.return_value = mock_service_instance - - client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) - client.tts_service = mock_service_instance - - # Synthesize text - chunks = [chunk async for chunk in client.synthesize("Hello world", "test_request")] - - assert len(chunks) > 0 - assert isinstance(chunks[0], bytes) - mock_service_instance.synthesize_online.assert_called_once() - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) -