diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index ee7f52cfd..4680bea28 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -13,6 +13,7 @@ class Provider(str, Enum): AWS = "aws" LANGFUSE = "langfuse" GOOGLE = "google" + SARVAMAI = "sarvamai" @dataclass @@ -32,6 +33,7 @@ class ProviderConfig: required_fields=["secret_key", "public_key", "host"] ), Provider.GOOGLE: ProviderConfig(required_fields=["api_key"]), + Provider.SARVAMAI: ProviderConfig(required_fields=["api_key"]), } diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 57ccf2740..22d75d18f 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -22,7 +22,7 @@ class TextLLMParams(SQLModel): description="Reasoning configuration or instructions", ) temperature: float | None = Field( - default=None, + default=0.1, ge=0.0, le=2.0, ) @@ -35,7 +35,7 @@ class TextLLMParams(SQLModel): class STTLLMParams(SQLModel): model: str - instructions: str + instructions: str | None = None input_language: str | None = None output_language: str | None = None response_format: Literal["text"] | None = Field( @@ -43,9 +43,10 @@ class STTLLMParams(SQLModel): description="Currently supports text type", ) temperature: float | None = Field( - default=0.2, + default=None, ge=0.0, le=2.0, + description="Temperature parameter (not supported by all STT providers)", ) @@ -190,7 +191,7 @@ class NativeCompletionConfig(SQLModel): Supports any LLM provider's native API format. """ - provider: Literal["openai-native", "google-native"] = Field( + provider: Literal["openai-native", "google-native", "sarvamai-native"] = Field( ..., description="Native provider type (e.g., openai-native)", ) @@ -210,8 +211,8 @@ class KaapiCompletionConfig(SQLModel): Supports multiple providers: OpenAI, Claude, Gemini, etc. """ - provider: Literal["openai", "google"] = Field( - ..., description="LLM provider (openai)" + provider: Literal["openai", "google", "sarvamai"] = Field( + ..., description="LLM provider (openai, google, sarvamai)" ) type: Literal["text", "stt", "tts"] = Field( diff --git a/backend/app/services/llm/mappers.py b/backend/app/services/llm/mappers.py index d4efc2e9f..b9e1ebae4 100644 --- a/backend/app/services/llm/mappers.py +++ b/backend/app/services/llm/mappers.py @@ -142,12 +142,98 @@ def map_kaapi_to_google_params(kaapi_params: dict) -> tuple[dict, list[str]]: return google_params, warnings +def map_kaapi_to_sarvam_params(kaapi_params: dict) -> tuple[dict, list[str]]: + """Map Kaapi-abstracted parameters to SarvamAI API parameters. + + Handles both STTLLMParams and TTSLLMParams. + + STTLLMParams: model, instructions, input_language, output_language, response_format, temperature + TTSLLMParams: model, voice, language, response_format + + Args: + kaapi_params: Dictionary with standardized Kaapi parameters + + Returns: + Tuple of: + - Dictionary of SarvamAI API parameters + - List of warnings for unsupported parameters + """ + sarvam_params = {} + warnings = [] + + # Model is required for all completion types + model = kaapi_params.get("model") + if not model: + return {}, ["Missing required 'model' parameter"] + sarvam_params["model"] = model + + # Determine if STT or TTS based on presence of specific params + voice = kaapi_params.get("voice") + input_language = kaapi_params.get("input_language") + + if voice is not None: + # TTS mode - map TTSLLMParams + sarvam_params["speaker"] = voice + + language = kaapi_params.get("language") + if not language: + return {}, ["Missing required 'language' parameter for TTS"] + sarvam_params["target_language_code"] = language + + response_format = kaapi_params.get("response_format") + if response_format: + # Map audio format to SarvamAI codec + format_mapping = {"mp3": "mp3", "wav": "wav", "ogg": "ogg"} + sarvam_params["output_audio_codec"] = format_mapping.get( + response_format, "wav" + ) + + elif input_language is not None or kaapi_params.get("output_language") is not None: + # STT mode - map STTLLMParams + output_language = kaapi_params.get("output_language") + transcription_mode = "transcribe" + + if input_language == "auto": + sarvam_params["language_code"] = "unknown" + elif input_language: + sarvam_params["language_code"] = input_language + + if output_language is None: + output_language = input_language + + if output_language == "en-IN" and input_language != output_language: + transcription_mode = "translate" + + sarvam_params["mode"] = transcription_mode + + # Warn about unsupported STT parameters + instructions = kaapi_params.get("instructions") + if instructions: + warnings.append( + "Parameter 'instructions' is not supported by SarvamAI STT and was ignored" + ) + + temperature = kaapi_params.get("temperature") + if temperature is not None: + warnings.append( + "Parameter 'temperature' is not supported by SarvamAI STT and was ignored" + ) + + response_format = kaapi_params.get("response_format") + if response_format: + warnings.append( + "Parameter 'response_format' is not supported by SarvamAI STT and was ignored" + ) + + return sarvam_params, warnings + + def transform_kaapi_config_to_native( kaapi_config: KaapiCompletionConfig, ) -> tuple[NativeCompletionConfig, list[str]]: """Transform Kaapi completion config to native provider config with mapped parameters. - Supports OpenAI and Google AI providers. + Supports OpenAI,Google AI and Sarvam AI providers. Args: kaapi_config: KaapiCompletionConfig with abstracted parameters @@ -175,4 +261,13 @@ def transform_kaapi_config_to_native( warnings, ) + if kaapi_config.provider == "sarvamai": + mapped_params, warnings = map_kaapi_to_sarvam_params(kaapi_config.params) + return ( + NativeCompletionConfig( + provider="sarvamai-native", params=mapped_params, type=kaapi_config.type + ), + warnings, + ) + raise ValueError(f"Unsupported provider: {kaapi_config.provider}") diff --git a/backend/app/services/llm/providers/registry.py b/backend/app/services/llm/providers/registry.py index 5eff4db19..7085b35cf 100644 --- a/backend/app/services/llm/providers/registry.py +++ b/backend/app/services/llm/providers/registry.py @@ -6,6 +6,7 @@ from app.services.llm.providers.base import BaseProvider from app.services.llm.providers.oai import OpenAIProvider from app.services.llm.providers.gai import GoogleAIProvider +from app.services.llm.providers.sai import SarvamAIProvider logger = logging.getLogger(__name__) @@ -16,6 +17,7 @@ class LLMProvider: # Future constants for native providers: # CLAUDE_NATIVE = "claude-native" GOOGLE_NATIVE = "google-native" + SARVAMAI_NATIVE = "sarvamai-native" _registry: dict[str, type[BaseProvider]] = { OPENAI_NATIVE: OpenAIProvider, @@ -23,6 +25,7 @@ class LLMProvider: # Future native providers: # CLAUDE_NATIVE: ClaudeProvider, GOOGLE_NATIVE: GoogleAIProvider, + SARVAMAI_NATIVE: SarvamAIProvider, } @classmethod diff --git a/backend/app/services/llm/providers/sai.py b/backend/app/services/llm/providers/sai.py new file mode 100644 index 000000000..c2984e6aa --- /dev/null +++ b/backend/app/services/llm/providers/sai.py @@ -0,0 +1,275 @@ +import logging +import os +from typing import Any + +from sarvamai import SarvamAI + + +from app.models.llm import ( + NativeCompletionConfig, + LLMCallResponse, + QueryParams, + TextOutput, + LLMResponse, + Usage, + TextContent, +) +from app.models.llm.response import AudioOutput +from app.models.llm.request import AudioContent +from app.services.llm.providers.base import BaseProvider + + +logger = logging.getLogger(__name__) + + +class SarvamAIProvider(BaseProvider): + def __init__(self, client: SarvamAI): + """Initialize SarvamAI provider with client. + + Args: + client: SarvamAI client instance + """ + super().__init__(client) + self.client = client + + @staticmethod + def create_client(credentials: dict[str, Any]) -> Any: + if "api_key" not in credentials: + raise ValueError("API Key for SarvamAI Not Set") + return SarvamAI(api_subscription_key=credentials["api_key"]) + + def _parse_input( + self, query_input: Any, completion_type: str, provider: str + ) -> str: + if completion_type == "stt": + if isinstance(query_input, str) and os.path.exists(query_input): + return query_input + else: + raise ValueError(f"{provider} STT requires a valid file path as input") + elif completion_type == "tts": + if isinstance(query_input, str): + return query_input + else: + raise ValueError(f"{provider} TTS requires a text string as input") + raise ValueError( + f"Unsupported completion type '{completion_type}' for {provider}" + ) + + def _execute_stt( + self, + completion_config: NativeCompletionConfig, + resolved_input: str, + include_provider_raw_response: bool = False, + ) -> tuple[LLMCallResponse | None, str | None]: + """Execute speech-to-text completion using SarvamAI. + + Args: + completion_config: Configuration for the completion request (with already-mapped params) + resolved_input: File path to the audio input + include_provider_raw_response: Whether to include raw provider response + + Returns: + Tuple of (response, error_message) + """ + provider_name = completion_config.provider + params = completion_config.params + + # Extract already-mapped parameters from the mapper + model = params.get("model") + if not model: + return None, "Missing 'model' in native params for SarvamAI STT" + + language_code = params.get("language_code") + mode = params.get("mode") + + # Parse and validate input + parsed_input_path = self._parse_input( + query_input=resolved_input, + completion_type="stt", + provider=provider_name, + ) + + try: + with open(parsed_input_path, "rb") as audio_file: + # Call SarvamAI transcribe with all mapped parameters + sarvam_response = self.client.speech_to_text.transcribe( + file=audio_file, + model=model, + language_code=language_code, + mode=mode, + ) + + # Estimate token usage (not directly provided by SarvamAI STT) + input_tokens_estimate = 0 + output_tokens_estimate = len(sarvam_response.transcript.split()) + total_tokens_estimate = input_tokens_estimate + output_tokens_estimate + + llm_response = LLMCallResponse( + response=LLMResponse( + provider_response_id=sarvam_response.request_id or "unknown", + conversation_id=None, + provider=provider_name, + model=model, + output=TextOutput( + content=TextContent(value=sarvam_response.transcript) + ), + ), + usage=Usage( + input_tokens=input_tokens_estimate, + output_tokens=output_tokens_estimate, + total_tokens=total_tokens_estimate, + reasoning_tokens=None, + ), + ) + + if include_provider_raw_response: + llm_response.provider_raw_response = sarvam_response.model_dump() + + logger.info( + f"[_execute_stt] Successfully transcribed audio | " + f"request_id={sarvam_response.request_id}, model={model}, mode={mode}" + ) + return llm_response, None + + except Exception as e: + error_message = f"SarvamAI STT transcription failed: {str(e)}" + logger.error(f"[_execute_stt] {error_message}", exc_info=True) + return None, error_message + + def _execute_tts( + self, + completion_config: NativeCompletionConfig, + resolved_input: str, + include_provider_raw_response: bool = False, + ) -> tuple[LLMCallResponse | None, str | None]: + """Execute text-to-speech completion using SarvamAI. + + Args: + completion_config: Configuration for the completion request (with already-mapped params) + resolved_input: Text string to convert to speech + include_provider_raw_response: Whether to include raw provider response + + Returns: + Tuple of (response, error_message) + """ + provider_name = completion_config.provider + params = completion_config.params + + # Extract already-mapped parameters from the mapper + model = params.get("model") + if not model: + return None, "Missing 'model' in native params for SarvamAI TTS" + + target_language_code = params.get("target_language_code") + if not target_language_code: + return ( + None, + "Missing 'target_language_code' in native params for SarvamAI TTS", + ) + + speaker = params.get("speaker") + output_audio_codec = params.get("output_audio_codec") + + # Parse and validate input + parsed_text = self._parse_input( + query_input=resolved_input, + completion_type="tts", + provider=provider_name, + ) + + try: + # Call SarvamAI TTS with all mapped parameters + sarvam_response = self.client.text_to_speech.convert( + text=parsed_text, + target_language_code=target_language_code, + model=model, + speaker=speaker, + output_audio_codec=output_audio_codec, + ) + + # SarvamAI returns a list of base64-encoded audio strings + # For single text input, take the first audio + if not sarvam_response.audios or len(sarvam_response.audios) == 0: + return None, "SarvamAI TTS returned no audio data" + + audio_base64 = sarvam_response.audios[0] + + # Estimate token usage (not directly provided by SarvamAI TTS) + input_tokens_estimate = len(parsed_text.split()) + output_tokens_estimate = 0 # Audio output, no tokens + total_tokens_estimate = input_tokens_estimate + + llm_response = LLMCallResponse( + response=LLMResponse( + provider_response_id=sarvam_response.request_id or "unknown", + conversation_id=None, + provider=provider_name, + model=model, + output=AudioOutput( + content=AudioContent( + format="base64", + value=audio_base64, + mime_type=f"audio/{output_audio_codec or 'wav'}", + ) + ), + ), + usage=Usage( + input_tokens=input_tokens_estimate, + output_tokens=output_tokens_estimate, + total_tokens=total_tokens_estimate, + reasoning_tokens=None, + ), + ) + + if include_provider_raw_response: + llm_response.provider_raw_response = sarvam_response.model_dump() + + logger.info( + f"[_execute_tts] Successfully converted text to speech | " + f"request_id={sarvam_response.request_id}, model={model}, speaker={speaker}" + ) + return llm_response, None + + except Exception as e: + error_message = f"SarvamAI TTS conversion failed: {str(e)}" + logger.error(f"[_execute_tts] {error_message}", exc_info=True) + return None, error_message + + def execute( + self, + completion_config: NativeCompletionConfig, + query: QueryParams, # noqa: ARG002 - Required by base class interface, unused for STT/TTS + resolved_input: str, + include_provider_raw_response: bool = False, + ) -> tuple[LLMCallResponse | None, str | None]: + try: + completion_type = completion_config.type + + if completion_type == "stt": + return self._execute_stt( + completion_config=completion_config, + resolved_input=resolved_input, + include_provider_raw_response=include_provider_raw_response, + ) + elif completion_type == "tts": + return self._execute_tts( + completion_config=completion_config, + resolved_input=resolved_input, + include_provider_raw_response=include_provider_raw_response, + ) + else: + return ( + None, + f"Unsupported completion type '{completion_type}' for SarvamAIProvider", + ) + + except ValueError as e: + error_message = f"Input validation error: {str(e)}" + logger.error(f"[SarvamAIProvider.execute] {error_message}", exc_info=True) + return None, error_message + except Exception as e: + error_message = "Unexpected error occurred during SarvamAI execution" + logger.error( + f"[SarvamAIProvider.execute] {error_message}: {str(e)}", exc_info=True + ) + return None, error_message diff --git a/backend/app/tests/services/llm/providers/test_sai.py b/backend/app/tests/services/llm/providers/test_sai.py new file mode 100644 index 000000000..474e4e09c --- /dev/null +++ b/backend/app/tests/services/llm/providers/test_sai.py @@ -0,0 +1,531 @@ +""" +Tests for the SarvamAI provider (STT and TTS). +""" + +import base64 +import pytest +from pathlib import Path +from unittest.mock import MagicMock, patch, mock_open +from types import SimpleNamespace + +from app.models.llm import ( + NativeCompletionConfig, + QueryParams, +) +from app.services.llm.providers.sai import SarvamAIProvider + + +def mock_sarvam_stt_response( + transcript: str = "नमस्ते", + request_id: str = "req_stt_123", +) -> SimpleNamespace: + """Create a mock SarvamAI STT response object.""" + response = SimpleNamespace( + transcript=transcript, + request_id=request_id, + model_dump=lambda: { + "transcript": transcript, + "request_id": request_id, + }, + ) + return response + + +def mock_sarvam_tts_response( + audio_base64: str = "YXVkaW9kYXRh", + request_id: str = "req_tts_456", +) -> SimpleNamespace: + """Create a mock SarvamAI TTS response object.""" + response = SimpleNamespace( + audios=[audio_base64], + request_id=request_id, + model_dump=lambda: { + "audios": [audio_base64], + "request_id": request_id, + }, + ) + return response + + +class TestSarvamAIProviderSTT: + """Test cases for SarvamAIProvider STT functionality.""" + + @pytest.fixture + def mock_client(self): + """Create a mock SarvamAI client.""" + client = MagicMock() + client.speech_to_text = MagicMock() + return client + + @pytest.fixture + def provider(self, mock_client): + """Create a SarvamAIProvider instance with mock client.""" + return SarvamAIProvider(client=mock_client) + + @pytest.fixture + def stt_config(self): + """Create a basic STT completion config.""" + return NativeCompletionConfig( + provider="sarvamai-native", + type="stt", + params={ + "model": "saarika:v1", + "language_code": "hi-IN", + "mode": "transcribe", + }, + ) + + @pytest.fixture + def query_params(self): + """Create basic query parameters.""" + return QueryParams(input="Test audio input") + + @pytest.fixture + def temp_audio_file(self, tmp_path): + """Create a temporary audio file for testing.""" + audio_file = tmp_path / "test_audio.wav" + audio_file.write_bytes(b"fake audio data") + return str(audio_file) + + def test_stt_success_basic_transcription( + self, provider, mock_client, stt_config, query_params, temp_audio_file + ): + """Test successful STT transcription.""" + mock_response = mock_sarvam_stt_response(transcript="नमस्ते दुनिया") + mock_client.speech_to_text.transcribe.return_value = mock_response + + result, error = provider.execute(stt_config, query_params, temp_audio_file) + + assert error is None + assert result is not None + assert result.response.output.content.value == "नमस्ते दुनिया" + assert result.response.model == "saarika:v1" + assert result.response.provider == "sarvamai-native" + assert result.response.provider_response_id == "req_stt_123" + assert result.usage.output_tokens == 2 # Number of words + + def test_stt_success_with_translate_mode( + self, provider, mock_client, query_params, temp_audio_file + ): + """Test STT with translate mode.""" + config = NativeCompletionConfig( + provider="sarvamai-native", + type="stt", + params={ + "model": "saarika:v1", + "language_code": "hi-IN", + "mode": "translate", + }, + ) + mock_response = mock_sarvam_stt_response(transcript="Hello world") + mock_client.speech_to_text.transcribe.return_value = mock_response + + result, error = provider.execute(config, query_params, temp_audio_file) + + assert error is None + assert result is not None + assert result.response.output.content.value == "Hello world" + # Verify translate mode was passed to API + call_args = mock_client.speech_to_text.transcribe.call_args + assert call_args.kwargs["mode"] == "translate" + + def test_stt_success_with_unknown_language( + self, provider, mock_client, query_params, temp_audio_file + ): + """Test STT with unknown/auto language detection.""" + config = NativeCompletionConfig( + provider="sarvamai-native", + type="stt", + params={ + "model": "saarika:v1", + "language_code": "unknown", + "mode": "transcribe", + }, + ) + mock_response = mock_sarvam_stt_response(transcript="Detected text") + mock_client.speech_to_text.transcribe.return_value = mock_response + + result, error = provider.execute(config, query_params, temp_audio_file) + + assert error is None + assert result is not None + call_args = mock_client.speech_to_text.transcribe.call_args + assert call_args.kwargs["language_code"] == "unknown" + + def test_stt_missing_model_param( + self, provider, mock_client, query_params, temp_audio_file + ): + """Test STT with missing model parameter.""" + config = NativeCompletionConfig( + provider="sarvamai-native", + type="stt", + params={ + "language_code": "hi-IN", + "mode": "transcribe", + }, + ) + + result, error = provider.execute(config, query_params, temp_audio_file) + + assert result is None + assert error is not None + assert "model" in error.lower() + + def test_stt_invalid_file_path( + self, provider, mock_client, stt_config, query_params + ): + """Test STT with non-existent file path.""" + result, error = provider.execute( + stt_config, query_params, "/nonexistent/path/audio.wav" + ) + + assert result is None + assert error is not None + + def test_stt_api_exception( + self, provider, mock_client, stt_config, query_params, temp_audio_file + ): + """Test STT when API raises exception.""" + mock_client.speech_to_text.transcribe.side_effect = Exception( + "API connection failed" + ) + + result, error = provider.execute(stt_config, query_params, temp_audio_file) + + assert result is None + assert error is not None + assert "API connection failed" in error + + def test_stt_include_provider_raw_response( + self, provider, mock_client, stt_config, query_params, temp_audio_file + ): + """Test STT with include_provider_raw_response flag.""" + mock_response = mock_sarvam_stt_response(transcript="Test") + mock_client.speech_to_text.transcribe.return_value = mock_response + + result, error = provider.execute( + stt_config, + query_params, + temp_audio_file, + include_provider_raw_response=True, + ) + + assert error is None + assert result is not None + assert result.provider_raw_response is not None + assert result.provider_raw_response["transcript"] == "Test" + assert result.provider_raw_response["request_id"] == "req_stt_123" + + +class TestSarvamAIProviderTTS: + """Test cases for SarvamAIProvider TTS functionality.""" + + @pytest.fixture + def mock_client(self): + """Create a mock SarvamAI client.""" + client = MagicMock() + client.text_to_speech = MagicMock() + return client + + @pytest.fixture + def provider(self, mock_client): + """Create a SarvamAIProvider instance with mock client.""" + return SarvamAIProvider(client=mock_client) + + @pytest.fixture + def tts_config(self): + """Create a basic TTS completion config.""" + return NativeCompletionConfig( + provider="sarvamai-native", + type="tts", + params={ + "model": "bulbul:v1", + "target_language_code": "hi-IN", + "speaker": "meera", + "output_audio_codec": "wav", + }, + ) + + @pytest.fixture + def query_params(self): + """Create basic query parameters.""" + return QueryParams(input="Test text input") + + def test_tts_success_basic_conversion( + self, provider, mock_client, tts_config, query_params + ): + """Test successful TTS conversion.""" + audio_data = base64.b64encode(b"fake audio binary data").decode("utf-8") + mock_response = mock_sarvam_tts_response(audio_base64=audio_data) + mock_client.text_to_speech.convert.return_value = mock_response + + result, error = provider.execute(tts_config, query_params, "नमस्ते दुनिया") + + assert error is None + assert result is not None + assert result.response.output.content.value == audio_data + assert result.response.output.content.format == "base64" + assert result.response.output.content.mime_type == "audio/wav" + assert result.response.model == "bulbul:v1" + assert result.response.provider == "sarvamai-native" + + def test_tts_with_mp3_codec(self, provider, mock_client, query_params): + """Test TTS with MP3 codec.""" + config = NativeCompletionConfig( + provider="sarvamai-native", + type="tts", + params={ + "model": "bulbul:v1", + "target_language_code": "en-IN", + "speaker": "arvind", + "output_audio_codec": "mp3", + }, + ) + audio_data = base64.b64encode(b"mp3 audio data").decode("utf-8") + mock_response = mock_sarvam_tts_response(audio_base64=audio_data) + mock_client.text_to_speech.convert.return_value = mock_response + + result, error = provider.execute(config, query_params, "Hello world") + + assert error is None + assert result is not None + assert result.response.output.content.mime_type == "audio/mp3" + call_args = mock_client.text_to_speech.convert.call_args + assert call_args.kwargs["output_audio_codec"] == "mp3" + + def test_tts_with_ogg_codec(self, provider, mock_client, query_params): + """Test TTS with OGG codec.""" + config = NativeCompletionConfig( + provider="sarvamai-native", + type="tts", + params={ + "model": "bulbul:v1", + "target_language_code": "hi-IN", + "speaker": "meera", + "output_audio_codec": "ogg", + }, + ) + audio_data = base64.b64encode(b"ogg audio data").decode("utf-8") + mock_response = mock_sarvam_tts_response(audio_base64=audio_data) + mock_client.text_to_speech.convert.return_value = mock_response + + result, error = provider.execute(config, query_params, "Test text") + + assert error is None + assert result is not None + assert result.response.output.content.mime_type == "audio/ogg" + + def test_tts_missing_model_param(self, provider, mock_client, query_params): + """Test TTS with missing model parameter.""" + config = NativeCompletionConfig( + provider="sarvamai-native", + type="tts", + params={ + "target_language_code": "hi-IN", + "speaker": "meera", + }, + ) + + result, error = provider.execute(config, query_params, "Test text") + + assert result is None + assert error is not None + assert "model" in error.lower() + + def test_tts_missing_target_language_code( + self, provider, mock_client, query_params + ): + """Test TTS with missing target_language_code.""" + config = NativeCompletionConfig( + provider="sarvamai-native", + type="tts", + params={ + "model": "bulbul:v1", + "speaker": "meera", + }, + ) + + result, error = provider.execute(config, query_params, "Test text") + + assert result is None + assert error is not None + assert "target_language_code" in error.lower() + + def test_tts_empty_audio_response( + self, provider, mock_client, tts_config, query_params + ): + """Test TTS when API returns empty audio list.""" + mock_response = SimpleNamespace( + audios=[], + request_id="req_123", + model_dump=lambda: {"audios": [], "request_id": "req_123"}, + ) + mock_client.text_to_speech.convert.return_value = mock_response + + result, error = provider.execute(tts_config, query_params, "Test text") + + assert result is None + assert error is not None + assert "no audio data" in error.lower() + + def test_tts_api_exception(self, provider, mock_client, tts_config, query_params): + """Test TTS when API raises exception.""" + mock_client.text_to_speech.convert.side_effect = Exception( + "TTS service unavailable" + ) + + result, error = provider.execute(tts_config, query_params, "Test text") + + assert result is None + assert error is not None + assert "TTS service unavailable" in error + + def test_tts_include_provider_raw_response( + self, provider, mock_client, tts_config, query_params + ): + """Test TTS with include_provider_raw_response flag.""" + audio_data = base64.b64encode(b"audio data").decode("utf-8") + mock_response = mock_sarvam_tts_response(audio_base64=audio_data) + mock_client.text_to_speech.convert.return_value = mock_response + + result, error = provider.execute( + tts_config, + query_params, + "Test text", + include_provider_raw_response=True, + ) + + assert error is None + assert result is not None + assert result.provider_raw_response is not None + assert result.provider_raw_response["audios"] == [audio_data] + + def test_tts_usage_estimates(self, provider, mock_client, tts_config, query_params): + """Test that TTS properly estimates token usage based on input text.""" + audio_data = base64.b64encode(b"audio").decode("utf-8") + mock_response = mock_sarvam_tts_response(audio_base64=audio_data) + mock_client.text_to_speech.convert.return_value = mock_response + + # Test with multi-word input + result, error = provider.execute( + tts_config, query_params, "Hello world how are you" + ) + + assert error is None + assert result.usage.input_tokens == 5 # 5 words + assert result.usage.output_tokens == 0 # Audio has no output tokens + assert result.usage.total_tokens == 5 + + +class TestSarvamAIProviderClientCreation: + """Test cases for SarvamAIProvider client creation.""" + + def test_create_client_with_valid_api_key(self): + """Test client creation with valid API key.""" + credentials = {"api_key": "test_api_key_123"} + + with patch("app.services.llm.providers.sai.SarvamAI") as mock_sarvam_class: + client = SarvamAIProvider.create_client(credentials) + + mock_sarvam_class.assert_called_once_with( + api_subscription_key="test_api_key_123" + ) + + def test_create_client_missing_api_key(self): + """Test client creation with missing API key.""" + credentials = {} + + with pytest.raises(ValueError) as exc_info: + SarvamAIProvider.create_client(credentials) + + assert "API Key for SarvamAI Not Set" in str(exc_info.value) + + def test_create_client_empty_credentials(self): + """Test client creation with empty credentials dict.""" + credentials = {"other_key": "value"} + + with pytest.raises(ValueError) as exc_info: + SarvamAIProvider.create_client(credentials) + + assert "API Key for SarvamAI Not Set" in str(exc_info.value) + + +class TestSarvamAIProviderInputParsing: + """Test cases for SarvamAIProvider input parsing.""" + + @pytest.fixture + def provider(self): + """Create a SarvamAIProvider instance.""" + mock_client = MagicMock() + return SarvamAIProvider(client=mock_client) + + @pytest.fixture + def temp_audio_file(self, tmp_path): + """Create a temporary audio file.""" + audio_file = tmp_path / "test.wav" + audio_file.write_bytes(b"audio data") + return str(audio_file) + + def test_parse_input_stt_valid_file(self, provider, temp_audio_file): + """Test parsing valid file path for STT.""" + result = provider._parse_input(temp_audio_file, "stt", "sarvamai") + assert result == temp_audio_file + + def test_parse_input_stt_invalid_file(self, provider): + """Test parsing invalid file path for STT.""" + with pytest.raises(ValueError) as exc_info: + provider._parse_input("/nonexistent/file.wav", "stt", "sarvamai") + + assert "valid file path" in str(exc_info.value) + + def test_parse_input_tts_valid_text(self, provider): + """Test parsing valid text for TTS.""" + result = provider._parse_input("Hello world", "tts", "sarvamai") + assert result == "Hello world" + + def test_parse_input_tts_invalid_type(self, provider): + """Test parsing invalid type for TTS.""" + with pytest.raises(ValueError) as exc_info: + provider._parse_input(12345, "tts", "sarvamai") + + assert "text string" in str(exc_info.value) + + def test_parse_input_unsupported_completion_type(self, provider): + """Test parsing with unsupported completion type.""" + with pytest.raises(ValueError) as exc_info: + provider._parse_input("input", "unsupported", "sarvamai") + + assert "Unsupported completion type" in str(exc_info.value) + + +class TestSarvamAIProviderExecute: + """Test cases for SarvamAIProvider execute method.""" + + @pytest.fixture + def mock_client(self): + """Create a mock SarvamAI client.""" + return MagicMock() + + @pytest.fixture + def provider(self, mock_client): + """Create a SarvamAIProvider instance.""" + return SarvamAIProvider(client=mock_client) + + @pytest.fixture + def query_params(self): + """Create basic query parameters.""" + return QueryParams(input="Test input") + + def test_execute_unsupported_completion_type(self, provider, query_params): + """Test execute with unsupported completion type.""" + config = NativeCompletionConfig( + provider="sarvamai-native", + type="text", # Unsupported for SarvamAI + params={"model": "test-model"}, + ) + + result, error = provider.execute(config, query_params, "input") + + assert result is None + assert error is not None + assert "Unsupported completion type" in error diff --git a/backend/app/tests/services/llm/test_mappers.py b/backend/app/tests/services/llm/test_mappers.py index 7a70cf46c..67e60cf3c 100644 --- a/backend/app/tests/services/llm/test_mappers.py +++ b/backend/app/tests/services/llm/test_mappers.py @@ -16,6 +16,7 @@ from app.services.llm.mappers import ( map_kaapi_to_openai_params, map_kaapi_to_google_params, + map_kaapi_to_sarvam_params, transform_kaapi_config_to_native, ) @@ -31,7 +32,8 @@ def test_basic_model_mapping(self): kaapi_params.model_dump(exclude_none=True) ) - assert result == {"model": "gpt-4o"} + # TextLLMParams has default temperature=0.1 + assert result == {"model": "gpt-4o", "temperature": 0.1} assert warnings == [] def test_instructions_mapping(self): @@ -91,7 +93,10 @@ def test_reasoning_mapping_for_reasoning_models(self): assert result["model"] == "o1" assert result["reasoning"] == {"effort": "high"} - assert warnings == [] + # Temperature is suppressed for reasoning models (even default value) + assert "temperature" not in result + assert len(warnings) == 1 + assert "temperature" in warnings[0].lower() def test_knowledge_base_ids_mapping(self): """Test knowledge_base_ids mapping to OpenAI tools format.""" @@ -211,7 +216,8 @@ def test_minimal_params(self): kaapi_params.model_dump(exclude_none=True) ) - assert result == {"model": "gpt-4"} + # TextLLMParams has default temperature=0.1 + assert result == {"model": "gpt-4", "temperature": 0.1} assert warnings == [] def test_only_knowledge_base_ids(self): @@ -242,7 +248,8 @@ def test_basic_model_mapping(self): kaapi_params.model_dump(exclude_none=True) ) - assert result == {"model": "gemini-2.5-pro"} + # TextLLMParams has default temperature=0.1 + assert result == {"model": "gemini-2.5-pro", "temperature": 0.1} assert warnings == [] def test_instructions_mapping(self): @@ -324,6 +331,249 @@ def test_knowledge_base_ids_unsupported(self): assert "knowledge_base_ids" in warnings[0].lower() +class TestMapKaapiToSarvamParams: + """Test cases for map_kaapi_to_sarvam_params function.""" + + def test_stt_basic_mapping(self): + """Test basic STT parameter mapping.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="hi-IN", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "saarika:v1" + assert result["language_code"] == "hi-IN" + assert result["mode"] == "transcribe" + assert warnings == [] + + def test_stt_auto_language_detection(self): + """Test STT with auto language detection.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="auto", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "saarika:v1" + assert result["language_code"] == "unknown" + assert result["mode"] == "transcribe" + assert warnings == [] + + def test_stt_translate_mode(self): + """Test STT with translation to English.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="hi-IN", + output_language="en-IN", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "saarika:v1" + assert result["language_code"] == "hi-IN" + assert result["mode"] == "translate" + assert warnings == [] + + def test_stt_same_input_output_language(self): + """Test STT when input and output languages are the same.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="hi-IN", + output_language="hi-IN", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["mode"] == "transcribe" + assert warnings == [] + + def test_stt_unsupported_instructions_warning(self): + """Test that instructions parameter generates warning for STT.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="hi-IN", + instructions="Please transcribe accurately", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "saarika:v1" + assert "instructions" not in result + assert len(warnings) == 1 + assert "instructions" in warnings[0].lower() + assert "not supported" in warnings[0] + + def test_stt_unsupported_temperature_warning(self): + """Test that temperature parameter generates warning for STT.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="hi-IN", + temperature=0.5, + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "saarika:v1" + assert "temperature" not in result + assert len(warnings) == 1 + assert "temperature" in warnings[0].lower() + assert "not supported" in warnings[0] + + def test_stt_unsupported_response_format_warning(self): + """Test that response_format parameter generates warning for STT.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="hi-IN", + response_format="text", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "saarika:v1" + assert "response_format" not in result + assert len(warnings) == 1 + assert "response_format" in warnings[0].lower() + + def test_stt_multiple_unsupported_params(self): + """Test STT with multiple unsupported parameters.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="hi-IN", + instructions="Transcribe", + temperature=0.5, + response_format="text", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "saarika:v1" + assert "instructions" not in result + assert "temperature" not in result + assert "response_format" not in result + assert len(warnings) == 3 + + def test_tts_basic_mapping(self): + """Test basic TTS parameter mapping.""" + kaapi_params = TTSLLMParams( + model="bulbul:v1", + voice="meera", + language="hi-IN", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "bulbul:v1" + assert result["speaker"] == "meera" + assert result["target_language_code"] == "hi-IN" + assert warnings == [] + + def test_tts_with_audio_format(self): + """Test TTS with custom audio format.""" + kaapi_params = TTSLLMParams( + model="bulbul:v1", + voice="meera", + language="hi-IN", + response_format="mp3", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["model"] == "bulbul:v1" + assert result["speaker"] == "meera" + assert result["target_language_code"] == "hi-IN" + assert result["output_audio_codec"] == "mp3" + assert warnings == [] + + def test_tts_default_wav_format(self): + """Test TTS with default WAV format.""" + kaapi_params = TTSLLMParams( + model="bulbul:v1", + voice="arvind", + language="en-IN", + response_format="wav", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["output_audio_codec"] == "wav" + assert warnings == [] + + def test_tts_ogg_format(self): + """Test TTS with OGG format.""" + kaapi_params = TTSLLMParams( + model="bulbul:v1", + voice="meera", + language="hi-IN", + response_format="ogg", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["output_audio_codec"] == "ogg" + assert warnings == [] + + def test_tts_missing_language(self): + """Test that missing language returns error for TTS.""" + kaapi_params = {"model": "bulbul:v1", "voice": "meera"} + + result, warnings = map_kaapi_to_sarvam_params(kaapi_params) + + assert result == {} + assert len(warnings) == 1 + assert "language" in warnings[0].lower() + + def test_missing_model(self): + """Test that missing model returns error.""" + kaapi_params = {"voice": "meera", "language": "hi-IN"} + + result, warnings = map_kaapi_to_sarvam_params(kaapi_params) + + assert result == {} + assert len(warnings) == 1 + assert "model" in warnings[0].lower() + + def test_stt_output_language_defaults_to_input(self): + """Test that output_language defaults to input_language when not provided.""" + kaapi_params = STTLLMParams( + model="saarika:v1", + input_language="hi-IN", + ) + + result, warnings = map_kaapi_to_sarvam_params( + kaapi_params.model_dump(exclude_none=True) + ) + + assert result["mode"] == "transcribe" + assert warnings == [] + + class TestTransformKaapiConfigToNative: """Test cases for transform_kaapi_config_to_native function.""" @@ -386,7 +636,10 @@ def test_transform_with_reasoning(self): assert result.provider == "openai-native" assert result.params["model"] == "o1" assert result.params["reasoning"] == {"effort": "medium"} - assert warnings == [] + # Temperature is suppressed for reasoning models (even default value) + assert "temperature" not in result.params + assert len(warnings) == 1 + assert "temperature" in warnings[0].lower() def test_transform_with_both_temperature_and_reasoning(self): """Test that transformation handles temperature + reasoning intelligently for reasoning models.""" @@ -486,3 +739,71 @@ def test_transform_google_with_unsupported_params(self): assert result.params["reasoning"] == "high" assert "knowledge_base_ids" not in result.params assert len(warnings) == 1 + + def test_transform_sarvamai_stt_config(self): + """Test transformation of Kaapi SarvamAI STT config to native format.""" + kaapi_config = KaapiCompletionConfig( + provider="sarvamai", + type="stt", + params={ + "model": "saarika:v1", + "input_language": "hi-IN", + }, + ) + + result, warnings = transform_kaapi_config_to_native(kaapi_config) + + assert isinstance(result, NativeCompletionConfig) + assert result.provider == "sarvamai-native" + assert result.type == "stt" + assert result.params["model"] == "saarika:v1" + assert result.params["language_code"] == "hi-IN" + assert result.params["mode"] == "transcribe" + assert warnings == [] + + def test_transform_sarvamai_tts_config(self): + """Test transformation of Kaapi SarvamAI TTS config to native format.""" + kaapi_config = KaapiCompletionConfig( + provider="sarvamai", + type="tts", + params={ + "model": "bulbul:v1", + "voice": "meera", + "language": "hi-IN", + "response_format": "mp3", + }, + ) + + result, warnings = transform_kaapi_config_to_native(kaapi_config) + + assert isinstance(result, NativeCompletionConfig) + assert result.provider == "sarvamai-native" + assert result.type == "tts" + assert result.params["model"] == "bulbul:v1" + assert result.params["speaker"] == "meera" + assert result.params["target_language_code"] == "hi-IN" + assert result.params["output_audio_codec"] == "mp3" + assert warnings == [] + + def test_transform_sarvamai_stt_with_unsupported_params(self): + """Test SarvamAI STT transformation with unsupported parameters.""" + kaapi_config = KaapiCompletionConfig( + provider="sarvamai", + type="stt", + params={ + "model": "saarika:v1", + "input_language": "hi-IN", + "instructions": "Transcribe carefully", + "temperature": 0.5, + }, + ) + + result, warnings = transform_kaapi_config_to_native(kaapi_config) + + assert result.provider == "sarvamai-native" + assert result.params["model"] == "saarika:v1" + assert "instructions" not in result.params + assert "temperature" not in result.params + assert len(warnings) == 2 + assert any("instructions" in w.lower() for w in warnings) + assert any("temperature" in w.lower() for w in warnings) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 0493677b4..d83df66d5 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "redis>=5.0.0,<6.0.0", "flower>=2.0.1", "google-genai>=1.59.0", + "sarvamai>=0.1.25", "pydub>=0.25.1", ] diff --git a/backend/uv.lock b/backend/uv.lock index 4ad522c4b..e79415bc4 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -234,6 +234,7 @@ dependencies = [ { name = "pytest" }, { name = "python-multipart" }, { name = "redis" }, + { name = "sarvamai" }, { name = "scikit-learn" }, { name = "sentry-sdk", extra = ["fastapi"] }, { name = "sqlmodel" }, @@ -282,6 +283,7 @@ requires-dist = [ { name = "pytest", specifier = ">=7.4.4" }, { name = "python-multipart", specifier = ">=0.0.22,<1.0.0" }, { name = "redis", specifier = ">=5.0.0,<6.0.0" }, + { name = "sarvamai", specifier = ">=0.1.25" }, { name = "scikit-learn", specifier = ">=1.7.1" }, { name = "sentry-sdk", extras = ["fastapi"], specifier = ">=2.20.0" }, { name = "sqlmodel", specifier = ">=0.0.21,<1.0.0" }, @@ -3146,6 +3148,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, ] +[[package]] +name = "sarvamai" +version = "0.1.25" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "pydantic" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/f7/f24106109458b01ae9317a885f991a7d91b3c09f3707365cccda5b6f860b/sarvamai-0.1.25.tar.gz", hash = "sha256:590c1b5d4337852529c26a3ecbb08acfd4692ce27089fd4ace3bc55b5f5b60f2", size = 107235, upload-time = "2026-02-10T13:52:25.647Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/78/f30a7cfab12fceeeaa7df0f40822c56e14544bb85a44c04f455aea792818/sarvamai-0.1.25-py3-none-any.whl", hash = "sha256:0daa7b8a48ad2696d323105e7f1fc06741068f4ad9c89688dfb81843b0892d17", size = 213774, upload-time = "2026-02-10T13:52:23.861Z" }, +] + [[package]] name = "scikit-learn" version = "1.8.0"