diff --git a/backend/app/api/docs/llm/llm_call.md b/backend/app/api/docs/llm/llm_call.md index fec4fbc49..8a594390c 100644 --- a/backend/app/api/docs/llm/llm_call.md +++ b/backend/app/api/docs/llm/llm_call.md @@ -6,7 +6,14 @@ for processing, and results are delivered via the callback URL when complete. ### Key Parameters **`query`** (required) - Query parameters for this LLM call: -- `input` (required, string, min 1 char): User question/prompt/query +- `input` (required): User input — accepts one of: + - A plain **string** e.g. `"input": "Hello"` (automatically normalized to a text input internally) + - A **structured input object** with `type` and `content` fields e.g. `"input": {"type": "text", "content": {"format": "text", "value": "Hello"}}` + - A **list of structured input objects** for multimodal inputs e.g. `"input": [{"type": "text", ...}, {"type": "image", ...}]` + - Supported input types: `text`, `audio`, `image`, `pdf` + - For `image` and `pdf` types, `content` accepts a single object or a list e.g. `"content": [{"format": "base64", "value": "..."}, ...]` + - Content `format` varies by type: `"text"` for text, `"base64"` for encoded data, `"url"` for image/pdf URLs + - Default MIME types when not specified: `image/png` for images, `application/pdf` for PDFs - `conversation` (optional, object): Conversation configuration - `id` (optional, string): Existing conversation ID to continue - `auto_create` (optional, boolean, default false): Create new conversation if no ID provided @@ -23,8 +30,9 @@ for processing, and results are delivered via the callback URL when complete. - **Mode 2: Ad-hoc Configuration** - `blob` (object): Complete configuration object - `completion` (required, object): Completion configuration - - `provider` (required, string): Provider type - either `"openai"` (Kaapi abstraction) or `"openai-native"` (pass-through) - - `params` (required, object): Parameters structure depends on provider type (see schema for detailed structure) + - `provider` (required, string): Provider type — `"openai"` or `"google"` (Kaapi abstraction), or `"openai-native"` or `"google-native"` (pass-through) + - `type` (required, string): Completion type — `"text"`, `"stt"`, `"tts"` for Kaapi providers; additionally `"image"`, `"pdf"`, `"multimodal"` for native providers + - `params` (required, object): Parameters structure depends on provider and type (see schema for detailed structure) - **Note** - When using ad-hoc configuration, do not include `id` and `version` fields - When using the Kaapi abstraction, parameters that are not supported by the selected provider or model are automatically suppressed. If any parameters are ignored, a list of warnings is included in the metadata.warnings. For example, the GPT-5 model does not support the temperature parameter, so Kaapi will neither throw an error nor pass this parameter to the model; instead, it will return a warning in the metadata.warnings response. diff --git a/backend/app/crud/llm.py b/backend/app/crud/llm.py index c1e01e7e7..360bab4f2 100644 --- a/backend/app/crud/llm.py +++ b/backend/app/crud/llm.py @@ -11,6 +11,8 @@ TextInput, AudioInput, QueryInput, + ImageInput, + PDFInput, ) logger = logging.getLogger(__name__) @@ -73,15 +75,26 @@ def create_llm_call( else getattr(completion_config.params, "type", "text") ) - input_type: Literal["text", "audio", "image"] + input_type: Literal["text", "audio", "image", "pdf", "multimodal"] output_type: Literal["text", "audio", "image"] | None + query_input = request.query.input + if completion_type == "stt": input_type = "audio" output_type = "text" elif completion_type == "tts": input_type = "text" output_type = "audio" + elif isinstance(query_input, ImageInput): + input_type = "image" + output_type = "text" + elif isinstance(query_input, PDFInput): + input_type = "pdf" + output_type = "text" + elif isinstance(query_input, list): + input_type = "multimodal" + output_type = "text" else: input_type = "text" output_type = "text" diff --git a/backend/app/models/llm/__init__.py b/backend/app/models/llm/__init__.py index b183543c4..67b288f39 100644 --- a/backend/app/models/llm/__init__.py +++ b/backend/app/models/llm/__init__.py @@ -9,6 +9,10 @@ LlmCall, AudioContent, TextContent, + ImageContent, + PDFContent, + ImageInput, + PDFInput, ) from app.models.llm.response import ( LLMCallResponse, diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 0991aeba8..57ccf2740 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -1,5 +1,5 @@ import sqlalchemy as sa -from typing import Annotated, Any, Literal, Union +from typing import Annotated, Any, List, Literal, Union from uuid import UUID, uuid4 from pydantic import model_validator, HttpUrl from datetime import datetime @@ -56,7 +56,11 @@ class TTSLLMParams(SQLModel): response_format: Literal["mp3", "wav", "ogg"] | None = "wav" -KaapiLLMParams = Union[TextLLMParams, STTLLMParams, TTSLLMParams] +KaapiLLMParams = Union[ + TextLLMParams, + STTLLMParams, + TTSLLMParams, +] # Input type models for discriminated union @@ -75,6 +79,28 @@ class AudioContent(SQLModel): ) +class ImageContent(SQLModel): + format: Literal["base64", "url"] = "base64" + value: str = Field( + ..., description="Base64 encoded image or Public URL to the image" + ) + # keeping the mime_type + mime_type: str | None = Field( + None, + description="MIME type of the image (e.g., image/png, image/jpeg)", + ) + + +class PDFContent(SQLModel): + format: Literal["base64", "url"] = "base64" + value: str = Field(..., description="Base64 encoded PDF or Public URL to the PDF") + # keeping the mime_type + mime_type: str | None = Field( + None, + description="MIME type of the PDF (e.g., application/pdf)", + ) + + class TextInput(SQLModel): type: Literal["text"] = "text" content: TextContent @@ -85,9 +111,19 @@ class AudioInput(SQLModel): content: AudioContent +class ImageInput(SQLModel): + type: Literal["image"] = "image" + content: ImageContent | list[ImageContent] + + +class PDFInput(SQLModel): + type: Literal["pdf"] = "pdf" + content: PDFContent | list[PDFContent] + + # Discriminated union for query input types QueryInput = Annotated[ - Union[TextInput, AudioInput], + Union[TextInput, AudioInput, ImageInput, PDFInput], Field(discriminator="type"), ] @@ -122,7 +158,7 @@ def validate_conversation_logic(self): class QueryParams(SQLModel): """Query-specific parameters for each LLM call.""" - input: str | QueryInput = Field( + input: str | QueryInput | list[QueryInput] = Field( ..., description=( "User input - either a plain string (text) or a structured input object. " @@ -389,12 +425,13 @@ class LlmCall(SQLModel, table=True): }, ) - input_type: Literal["text", "audio", "image"] = Field( + # NOTE: image, pdf, multimodal are internal labels stored in the table not user facing. + input_type: Literal["text", "audio", "image", "pdf", "multimodal"] = Field( ..., sa_column=sa.Column( sa.String, nullable=False, - comment="Input type: text, audio, image", + comment="Input type: text, audio, image, pdf, multimodal", ), ) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 33aff370a..5cdc0d32b 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -20,6 +20,8 @@ KaapiCompletionConfig, TextInput, AudioInput, + ImageInput, + PDFInput, ) from app.models.llm.response import TextOutput from app.services.llm.guardrails import ( @@ -102,13 +104,16 @@ def handle_job_error( @contextmanager -def resolved_input_context(query_input: TextInput | AudioInput): +def resolved_input_context( + query_input: TextInput | AudioInput | ImageInput | PDFInput | list, +): """Context manager for resolving and cleaning up input resources. Ensures temporary files (e.g., downloaded audio) are cleaned up even if errors occur during LLM execution. """ resolved_input, error = resolve_input(query_input) + if error: raise ValueError(error) diff --git a/backend/app/services/llm/mappers.py b/backend/app/services/llm/mappers.py index 8b0b895e3..d4efc2e9f 100644 --- a/backend/app/services/llm/mappers.py +++ b/backend/app/services/llm/mappers.py @@ -127,17 +127,18 @@ def map_kaapi_to_google_params(kaapi_params: dict) -> tuple[dict, list[str]]: response_format = kaapi_params.get("response_format") if response_format: google_params["response_format"] = response_format + + reasoning = kaapi_params.get("reasoning") + if reasoning: + google_params["reasoning"] = reasoning + # Warn about unsupported parameters if kaapi_params.get("knowledge_base_ids"): + # TODO: Will take up later, when we add google filesearch tool support warnings.append( "Parameter 'knowledge_base_ids' is not supported by Google AI and was ignored." ) - if kaapi_params.get("reasoning") is not None: - warnings.append( - "Parameter 'reasoning' is not applicable for Google AI and was ignored." - ) - return google_params, warnings diff --git a/backend/app/services/llm/providers/base.py b/backend/app/services/llm/providers/base.py index d8f7cafe7..f159f0f1c 100644 --- a/backend/app/services/llm/providers/base.py +++ b/backend/app/services/llm/providers/base.py @@ -7,7 +7,25 @@ from abc import ABC, abstractmethod from typing import Any +from pydantic import model_validator +from sqlmodel import SQLModel + from app.models.llm import NativeCompletionConfig, LLMCallResponse, QueryParams +from app.models.llm.request import TextContent, ImageContent, PDFContent + +ContentPart = TextContent | ImageContent | PDFContent + + +class MultiModalInput(SQLModel): + """Resolved multimodal input containing a list of content parts.""" + + parts: list[ContentPart] + + @model_validator(mode="after") + def validate_parts(self): + if not self.parts: + raise ValueError("MultiModalInput requires at least one content part") + return self class BaseProvider(ABC): @@ -44,7 +62,7 @@ def execute( self, completion_config: NativeCompletionConfig, query: QueryParams, - resolved_input: str, + resolved_input: str | list[ContentPart], include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: """Execute LLM API call. diff --git a/backend/app/services/llm/providers/gai.py b/backend/app/services/llm/providers/gai.py index ce9bf6ad4..05fa46fc1 100644 --- a/backend/app/services/llm/providers/gai.py +++ b/backend/app/services/llm/providers/gai.py @@ -20,9 +20,11 @@ Usage, TextOutput, TextContent, + ImageContent, + PDFContent, ) from app.models.llm.response import AudioOutput, AudioContent -from app.services.llm.providers.base import BaseProvider +from app.services.llm.providers.base import BaseProvider, ContentPart, MultiModalInput from app.core.audio_utils import convert_pcm_to_mp3, convert_pcm_to_ogg logger = logging.getLogger(__name__) @@ -44,6 +46,57 @@ def create_client(credentials: dict[str, Any]) -> Any: raise ValueError("API Key for Google Gemini Not Set") return genai.Client(api_key=credentials["api_key"]) + @staticmethod + def format_parts( + parts: list[ContentPart], + ) -> list[dict]: + items = [] + for part in parts: + if isinstance(part, TextContent): + items.append({"text": part.value}) + + elif isinstance(part, ImageContent): + if part.format == "base64": + items.append( + { + "inline_data": { + "data": part.value, + "mime_type": part.mime_type, + } + } + ) + else: + items.append( + { + "file_data": { + "file_uri": part.value, + "mime_type": part.mime_type, + "display_name": None, + } + } + ) + elif isinstance(part, PDFContent): + if part.format == "base64": + items.append( + { + "inline_data": { + "data": part.value, + "mime_type": part.mime_type, + } + } + ) + else: + items.append( + { + "file_data": { + "file_uri": part.value, + "mime_type": part.mime_type, + "display_name": None, + } + } + ) + return items + def _execute_stt( self, completion_config: NativeCompletionConfig, @@ -324,16 +377,92 @@ def _execute_tts( return llm_response, None + def _execute_text( + self, + completion_config: NativeCompletionConfig, + resolved_input: str | list[ContentPart] | MultiModalInput, + include_provider_raw_response: bool = False, + ) -> tuple[LLMCallResponse | None, str | None]: + model = completion_config.params.get("model") + if not model: + return None, "Missing 'model' in native params" + + if isinstance(resolved_input, MultiModalInput): + gemini_parts = self.format_parts(resolved_input.parts) + contents = [{"role": "user", "parts": gemini_parts}] + elif isinstance(resolved_input, list): + gemini_parts = self.format_parts(resolved_input) + contents = [{"role": "user", "parts": gemini_parts}] + else: + contents = [{"role": "user", "parts": [{"text": resolved_input}]}] + + instructions = completion_config.params.get("instructions", "") + temperature = completion_config.params.get("temperature", None) + thinking_level = completion_config.params.get("reasoning", None) + + generation_kwargs = {} + if instructions: + generation_kwargs["system_instruction"] = instructions + + if temperature is not None: + generation_kwargs["temperature"] = temperature + + if thinking_level is not None: + generation_kwargs["thinking_config"] = ThinkingConfig( + include_thoughts=False, thinking_level=thinking_level + ) + + response = self.client.models.generate_content( + model=model, + contents=contents, + config=GenerateContentConfig(**generation_kwargs), + ) + + if response.usage_metadata: + input_tokens = response.usage_metadata.prompt_token_count or 0 + output_tokens = response.usage_metadata.candidates_token_count or 0 + total_tokens = response.usage_metadata.total_token_count or 0 + reasoning_tokens = response.usage_metadata.thoughts_token_count or 0 + else: + logger.warning( + f"[GoogleAIProvider._execute_text] Response missing usage_metadata, using zeros" + ) + input_tokens = 0 + output_tokens = 0 + total_tokens = 0 + reasoning_tokens = 0 + + llm_response = LLMCallResponse( + response=LLMResponse( + provider_response_id=response.response_id, + model=response.model_version or model, + provider=completion_config.provider, + output=TextOutput(content=TextContent(value=response.text)), + ), + usage=Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + reasoning_tokens=reasoning_tokens, + ), + ) + if include_provider_raw_response: + llm_response.provider_raw_response = response.model_dump(mode="json") + + logger.info( + f"[GoogleAIProvider._execute_text] Successfully generated text response: {response.response_id}" + ) + return llm_response, None + def execute( self, completion_config: NativeCompletionConfig, - query: QueryParams, # Not used by Google AI provider (no conversation support yet) - resolved_input: str, + query: QueryParams, + resolved_input: str | list[ContentPart] | MultiModalInput, include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: try: completion_type = completion_config.type - if completion_type == "stt": return self._execute_stt( completion_config=completion_config, @@ -346,10 +475,12 @@ def execute( resolved_input=resolved_input, include_provider_raw_response=include_provider_raw_response, ) - else: - return ( - None, - f"Unsupported completion type '{completion_type}' for Google AI provider", + + elif completion_type == "text": + return self._execute_text( + completion_config=completion_config, + resolved_input=resolved_input, + include_provider_raw_response=include_provider_raw_response, ) except TypeError as e: diff --git a/backend/app/services/llm/providers/oai.py b/backend/app/services/llm/providers/oai.py index 83c0aa8d7..392487eea 100644 --- a/backend/app/services/llm/providers/oai.py +++ b/backend/app/services/llm/providers/oai.py @@ -13,9 +13,10 @@ Usage, TextOutput, TextContent, + ImageContent, + PDFContent, ) -from app.services.llm.providers.base import BaseProvider - +from app.services.llm.providers.base import BaseProvider, ContentPart, MultiModalInput logger = logging.getLogger(__name__) @@ -36,11 +37,36 @@ def create_client(credentials: dict[str, Any]) -> Any: raise ValueError("OpenAI credentials not configured for this project.") return OpenAI(api_key=credentials["api_key"]) + @staticmethod + def format_parts( + parts: list[ContentPart], + ) -> list[dict]: + items = [] + for part in parts: + if isinstance(part, TextContent): + items.append({"type": "input_text", "text": part.value}) + + elif isinstance(part, ImageContent): + if part.format == "base64": + url = f"data:{part.mime_type};base64,{part.value}" + else: + url = part.value + items.append({"type": "input_image", "image_url": url}) + + elif isinstance(part, PDFContent): + if part.format == "base64": + url = f"data:{part.mime_type};base64,{part.value}" + else: + url = part.value + items.append({"type": "input_file", "file_url": url}) + + return items + def execute( self, completion_config: NativeCompletionConfig, query: QueryParams, - resolved_input: str, + resolved_input: str | list[ImageContent] | list[PDFContent] | MultiModalInput, include_provider_raw_response: bool = False, ) -> tuple[LLMCallResponse | None, str | None]: response: Response | None = None @@ -50,7 +76,16 @@ def execute( params = { **completion_config.params, } - params["input"] = resolved_input + if isinstance(resolved_input, MultiModalInput): + params["input"] = [ + {"role": "user", "content": self.format_parts(resolved_input.parts)} + ] + elif isinstance(resolved_input, list): + params["input"] = [ + {"role": "user", "content": self.format_parts(resolved_input)} + ] + else: + params["input"] = resolved_input conversation_cfg = query.conversation diff --git a/backend/app/services/llm/providers/registry.py b/backend/app/services/llm/providers/registry.py index 15236b8d7..5eff4db19 100644 --- a/backend/app/services/llm/providers/registry.py +++ b/backend/app/services/llm/providers/registry.py @@ -3,7 +3,6 @@ import logging from sqlmodel import Session -from app.crud import get_provider_credential from app.services.llm.providers.base import BaseProvider from app.services.llm.providers.oai import OpenAIProvider from app.services.llm.providers.gai import GoogleAIProvider @@ -46,6 +45,8 @@ def supported_providers(cls) -> list[str]: def get_llm_provider( session: Session, provider_type: str, project_id: int, organization_id: int ) -> BaseProvider: + from app.crud.credentials import get_provider_credential + provider_class = LLMProvider.get_provider_class(provider_type) # e.g., "openai-native" → "openai", "claude-native" → "claude" diff --git a/backend/app/tests/services/llm/providers/test_registry.py b/backend/app/tests/services/llm/providers/test_registry.py index b3daa44c4..4349da107 100644 --- a/backend/app/tests/services/llm/providers/test_registry.py +++ b/backend/app/tests/services/llm/providers/test_registry.py @@ -40,9 +40,7 @@ def test_get_llm_provider_with_openai(self, db: Session): """Test getting OpenAI provider successfully.""" project = get_project(db) - with patch( - "app.services.llm.providers.registry.get_provider_credential" - ) as mock_get_creds: + with patch("app.crud.credentials.get_provider_credential") as mock_get_creds: mock_get_creds.return_value = {"api_key": "test-api-key"} provider = get_llm_provider( @@ -94,9 +92,7 @@ def test_get_llm_provider_with_missing_credentials(self, db: Session): """Test handling of errors when credentials are not found.""" project = get_project(db) - with patch( - "app.services.llm.providers.registry.get_provider_credential" - ) as mock_get_creds: + with patch("app.crud.credentials.get_provider_credential") as mock_get_creds: mock_get_creds.return_value = None with pytest.raises(ValueError) as exc_info: diff --git a/backend/app/tests/services/llm/test_mappers.py b/backend/app/tests/services/llm/test_mappers.py index 2ecbcd7b2..7a70cf46c 100644 --- a/backend/app/tests/services/llm/test_mappers.py +++ b/backend/app/tests/services/llm/test_mappers.py @@ -292,8 +292,7 @@ def test_knowledge_base_ids_warning(self): assert "knowledge_base_ids" in warnings[0].lower() assert "not supported" in warnings[0] - def test_reasoning_warning(self): - """Test that reasoning parameter is not supported and generates warning.""" + def test_reasoning_passed_through(self): kaapi_params = TextLLMParams( model="gemini-2.5-pro", reasoning="high", @@ -304,13 +303,10 @@ def test_reasoning_warning(self): ) assert result["model"] == "gemini-2.5-pro" - assert "reasoning" not in result - assert len(warnings) == 1 - assert "reasoning" in warnings[0].lower() - assert "not applicable" in warnings[0] + assert result["reasoning"] == "high" + assert len(warnings) == 0 - def test_multiple_unsupported_params(self): - """Test that multiple unsupported parameters generate multiple warnings.""" + def test_knowledge_base_ids_unsupported(self): kaapi_params = TextLLMParams( model="gemini-2.5-pro", reasoning="medium", @@ -322,13 +318,10 @@ def test_multiple_unsupported_params(self): ) assert result["model"] == "gemini-2.5-pro" - assert "reasoning" not in result + assert result["reasoning"] == "medium" assert "knowledge_base_ids" not in result - assert len(warnings) == 2 - # Check both warnings are present - warning_text = " ".join(warnings).lower() - assert "reasoning" in warning_text - assert "knowledge_base_ids" in warning_text + assert len(warnings) == 1 + assert "knowledge_base_ids" in warnings[0].lower() class TestTransformKaapiConfigToNative: @@ -476,7 +469,6 @@ def test_transform_google_config(self): assert warnings == [] def test_transform_google_with_unsupported_params(self): - """Test that Google transformation warns about unsupported parameters.""" kaapi_config = KaapiCompletionConfig( provider="google", type="text", @@ -491,6 +483,6 @@ def test_transform_google_with_unsupported_params(self): assert result.provider == "google-native" assert result.params["model"] == "gemini-2.5-pro" + assert result.params["reasoning"] == "high" assert "knowledge_base_ids" not in result.params - assert "reasoning" not in result.params - assert len(warnings) == 2 + assert len(warnings) == 1 diff --git a/backend/app/tests/services/llm/test_multimodal.py b/backend/app/tests/services/llm/test_multimodal.py new file mode 100644 index 000000000..bae09308a --- /dev/null +++ b/backend/app/tests/services/llm/test_multimodal.py @@ -0,0 +1,534 @@ +import pytest +from unittest.mock import MagicMock + +from app.models.llm.request import ( + TextInput, + AudioInput, + ImageInput, + PDFInput, + TextContent, + AudioContent, + ImageContent, + PDFContent, + NativeCompletionConfig, + QueryParams, +) +from app.services.llm.providers.base import ( + ContentPart, + MultiModalInput, +) +from app.services.llm.providers.oai import OpenAIProvider +from app.services.llm.providers.gai import GoogleAIProvider +from app.utils import ( + resolve_input, + resolve_image_content, + resolve_pdf_content, +) + + +class TestMultiModalInput: + def test_valid_parts(self): + mm = MultiModalInput( + parts=[ + TextContent(value="hello"), + ImageContent(format="base64", value="abc", mime_type="image/png"), + PDFContent(format="base64", value="abc", mime_type="application/pdf"), + ] + ) + assert len(mm.parts) == 3 + + def test_empty_parts_raises(self): + with pytest.raises(Exception): + MultiModalInput(parts=[]) + + def test_single_text_part(self): + mm = MultiModalInput(parts=[TextContent(value="only text")]) + assert len(mm.parts) == 1 + + +class TestResolveInputMultimodal: + def test_image_input_returns_image_content_list(self): + img = ImageInput( + content=ImageContent(format="base64", value="abc", mime_type="image/png") + ) + result, error = resolve_input(img) + assert error is None + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], ImageContent) + + def test_pdf_input_returns_pdf_content_list(self): + pdf = PDFInput( + content=PDFContent( + format="base64", value="abc", mime_type="application/pdf" + ) + ) + result, error = resolve_input(pdf) + assert error is None + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], PDFContent) + + def test_multimodal_list_returns_multimodal_input(self): + inputs = [ + TextInput(content=TextContent(value="describe")), + ImageInput( + content=ImageContent( + format="base64", value="abc", mime_type="image/png" + ) + ), + ] + result, error = resolve_input(inputs) + assert error is None + assert isinstance(result, MultiModalInput) + assert len(result.parts) == 2 + + def test_multimodal_list_with_pdf(self): + inputs = [ + TextInput(content=TextContent(value="analyze")), + PDFInput( + content=PDFContent( + format="base64", value="abc", mime_type="application/pdf" + ) + ), + ] + result, error = resolve_input(inputs) + assert error is None + assert isinstance(result, MultiModalInput) + assert len(result.parts) == 2 + + def test_multimodal_list_with_audio_rejected(self): + inputs = [ + TextInput(content=TextContent(value="hello")), + AudioInput(content=AudioContent(value="abc", mime_type="audio/wav")), + ] + result, error = resolve_input(inputs) + assert error is not None + assert "audio" in error.lower() + assert "stt" in error.lower() + + def test_image_input_default_mime_type(self): + img = ImageInput(content=ImageContent(format="base64", value="abc")) + result, error = resolve_input(img) + assert error is None + assert result[0].mime_type == "image/png" + + def test_pdf_input_default_mime_type(self): + pdf = PDFInput(content=PDFContent(format="base64", value="abc")) + result, error = resolve_input(pdf) + assert error is None + assert result[0].mime_type == "application/pdf" + + def test_image_input_multiple_contents(self): + img = ImageInput( + content=[ + ImageContent(format="base64", value="abc1", mime_type="image/png"), + ImageContent( + format="url", + value="https://example.com/img.jpg", + mime_type="image/jpeg", + ), + ] + ) + result, error = resolve_input(img) + assert error is None + assert len(result) == 2 + + def test_multimodal_mixed_types_in_parts(self): + inputs = [ + TextInput(content=TextContent(value="look at these")), + ImageInput( + content=ImageContent( + format="base64", value="img", mime_type="image/png" + ) + ), + PDFInput( + content=PDFContent( + format="base64", value="pdf", mime_type="application/pdf" + ) + ), + ] + result, error = resolve_input(inputs) + assert error is None + assert isinstance(result, MultiModalInput) + assert len(result.parts) == 3 + assert isinstance(result.parts[0], TextContent) + assert isinstance(result.parts[1], ImageContent) + assert isinstance(result.parts[2], PDFContent) + + +class TestOpenAIFormatParts: + def test_text_part(self): + parts = [TextContent(value="hello")] + result = OpenAIProvider.format_parts(parts) + assert result == [{"type": "input_text", "text": "hello"}] + + def test_image_base64_part(self): + parts = [ImageContent(format="base64", value="abc123", mime_type="image/png")] + result = OpenAIProvider.format_parts(parts) + assert len(result) == 1 + assert result[0]["type"] == "input_image" + assert result[0]["image_url"] == "data:image/png;base64,abc123" + + def test_image_url_part(self): + parts = [ + ImageContent( + format="url", + value="https://example.com/img.jpg", + mime_type="image/jpeg", + ) + ] + result = OpenAIProvider.format_parts(parts) + assert result[0]["type"] == "input_image" + assert result[0]["image_url"] == "https://example.com/img.jpg" + + def test_pdf_base64_part(self): + parts = [ + PDFContent(format="base64", value="pdf123", mime_type="application/pdf") + ] + result = OpenAIProvider.format_parts(parts) + assert len(result) == 1 + assert result[0]["type"] == "input_file" + assert result[0]["file_url"] == "data:application/pdf;base64,pdf123" + + def test_pdf_url_part(self): + parts = [ + PDFContent( + format="url", + value="https://example.com/doc.pdf", + mime_type="application/pdf", + ) + ] + result = OpenAIProvider.format_parts(parts) + assert result[0]["type"] == "input_file" + assert result[0]["file_url"] == "https://example.com/doc.pdf" + + def test_mixed_parts(self): + parts = [ + TextContent(value="describe"), + ImageContent(format="base64", value="img", mime_type="image/png"), + PDFContent( + format="url", + value="https://example.com/doc.pdf", + mime_type="application/pdf", + ), + ] + result = OpenAIProvider.format_parts(parts) + assert len(result) == 3 + assert result[0]["type"] == "input_text" + assert result[1]["type"] == "input_image" + assert result[2]["type"] == "input_file" + + +class TestGoogleAIFormatParts: + def test_text_part(self): + parts = [TextContent(value="hello")] + result = GoogleAIProvider.format_parts(parts) + assert result == [{"text": "hello"}] + + def test_image_base64_part(self): + parts = [ImageContent(format="base64", value="abc123", mime_type="image/png")] + result = GoogleAIProvider.format_parts(parts) + assert len(result) == 1 + assert result[0] == { + "inline_data": {"data": "abc123", "mime_type": "image/png"} + } + + def test_image_url_part(self): + parts = [ + ImageContent( + format="url", + value="https://example.com/img.jpg", + mime_type="image/jpeg", + ) + ] + result = GoogleAIProvider.format_parts(parts) + assert result[0] == { + "file_data": { + "file_uri": "https://example.com/img.jpg", + "mime_type": "image/jpeg", + "display_name": None, + } + } + + def test_pdf_base64_part(self): + parts = [ + PDFContent(format="base64", value="pdf123", mime_type="application/pdf") + ] + result = GoogleAIProvider.format_parts(parts) + assert result[0] == { + "inline_data": {"data": "pdf123", "mime_type": "application/pdf"} + } + + def test_pdf_url_part(self): + parts = [ + PDFContent( + format="url", + value="https://example.com/doc.pdf", + mime_type="application/pdf", + ) + ] + result = GoogleAIProvider.format_parts(parts) + assert result[0] == { + "file_data": { + "file_uri": "https://example.com/doc.pdf", + "mime_type": "application/pdf", + "display_name": None, + } + } + + def test_mixed_parts(self): + parts = [ + TextContent(value="analyze"), + ImageContent( + format="url", value="https://img.com/a.jpg", mime_type="image/jpeg" + ), + PDFContent(format="base64", value="pdf", mime_type="application/pdf"), + ] + result = GoogleAIProvider.format_parts(parts) + assert len(result) == 3 + assert "text" in result[0] + assert "file_data" in result[1] + assert "inline_data" in result[2] + + +class TestResolveImageContent: + def test_single_content(self): + img = ImageInput( + content=ImageContent(format="base64", value="abc", mime_type="image/png") + ) + result = resolve_image_content(img) + assert len(result) == 1 + assert result[0].mime_type == "image/png" + + def test_default_mime_type(self): + img = ImageInput(content=ImageContent(format="base64", value="abc")) + result = resolve_image_content(img) + assert result[0].mime_type == "image/png" + + def test_list_content(self): + img = ImageInput( + content=[ + ImageContent(format="base64", value="a", mime_type="image/png"), + ImageContent(format="base64", value="b", mime_type="image/jpeg"), + ] + ) + result = resolve_image_content(img) + assert len(result) == 2 + + +class TestResolvePdfContent: + def test_single_content(self): + pdf = PDFInput( + content=PDFContent( + format="base64", value="abc", mime_type="application/pdf" + ) + ) + result = resolve_pdf_content(pdf) + assert len(result) == 1 + assert result[0].mime_type == "application/pdf" + + def test_default_mime_type(self): + pdf = PDFInput(content=PDFContent(format="base64", value="abc")) + result = resolve_pdf_content(pdf) + assert result[0].mime_type == "application/pdf" + + def test_list_content(self): + pdf = PDFInput( + content=[ + PDFContent(format="base64", value="a", mime_type="application/pdf"), + PDFContent( + format="url", + value="https://example.com/doc.pdf", + mime_type="application/pdf", + ), + ] + ) + result = resolve_pdf_content(pdf) + assert len(result) == 2 + + +class TestResolveInputEdgeCases: + def test_unknown_input_type(self): + result, error = resolve_input(12345) + assert error is not None + assert "Unknown input type" in error + + def test_unsupported_type_in_multimodal_list(self): + result, error = resolve_input(["not_a_valid_input"]) + assert error is not None + assert "Unsupported input type" in error + + def test_text_input_resolves_string(self): + text = TextInput(content=TextContent(value="hello world")) + result, error = resolve_input(text) + assert error is None + assert result == "hello world" + + +class TestOpenAIExecuteInputRouting: + def _make_provider(self): + mock_client = MagicMock() + mock_resp = MagicMock() + mock_resp.id = "resp_123" + mock_resp.model = "gpt-4o-mini" + mock_resp.output_text = "result" + mock_resp.usage.input_tokens = 10 + mock_resp.usage.output_tokens = 5 + mock_resp.usage.total_tokens = 15 + mock_resp.conversation = None + mock_client.responses.create.return_value = mock_resp + return OpenAIProvider(client=mock_client), mock_client + + def _make_config(self): + return NativeCompletionConfig( + provider="openai-native", type="text", params={"model": "gpt-4o-mini"} + ) + + def _make_query(self): + return QueryParams(input="test") + + def test_multimodal_input(self): + provider, mock_client = self._make_provider() + mm = MultiModalInput( + parts=[ + TextContent(value="describe"), + ImageContent(format="base64", value="img", mime_type="image/png"), + ] + ) + response, error = provider.execute( + completion_config=self._make_config(), + query=self._make_query(), + resolved_input=mm, + ) + assert error is None + call_kwargs = mock_client.responses.create.call_args[1] + assert call_kwargs["input"][0]["role"] == "user" + assert len(call_kwargs["input"][0]["content"]) == 2 + + def test_list_input(self): + provider, mock_client = self._make_provider() + parts = [ImageContent(format="base64", value="img", mime_type="image/png")] + response, error = provider.execute( + completion_config=self._make_config(), + query=self._make_query(), + resolved_input=parts, + ) + assert error is None + call_kwargs = mock_client.responses.create.call_args[1] + assert call_kwargs["input"][0]["role"] == "user" + + def test_string_input(self): + provider, mock_client = self._make_provider() + response, error = provider.execute( + completion_config=self._make_config(), + query=self._make_query(), + resolved_input="hello", + ) + assert error is None + call_kwargs = mock_client.responses.create.call_args[1] + assert call_kwargs["input"] == "hello" + + +class TestGoogleAIExecuteTextRouting: + def _make_provider(self): + mock_client = MagicMock() + mock_resp = MagicMock() + mock_resp.response_id = "resp_gai_123" + mock_resp.model_version = "gemini-2.0-flash" + mock_resp.text = "response text" + mock_resp.usage_metadata.prompt_token_count = 10 + mock_resp.usage_metadata.candidates_token_count = 5 + mock_resp.usage_metadata.total_token_count = 15 + mock_resp.usage_metadata.thoughts_token_count = 0 + mock_client.models.generate_content.return_value = mock_resp + return GoogleAIProvider(client=mock_client), mock_client + + def _make_config(self, **extra_params): + params = {"model": "gemini-2.0-flash"} + params.update(extra_params) + return NativeCompletionConfig( + provider="google-native", type="text", params=params + ) + + def _make_query(self): + return QueryParams(input="test") + + def test_multimodal_input(self): + provider, mock_client = self._make_provider() + mm = MultiModalInput( + parts=[ + TextContent(value="describe"), + ImageContent(format="base64", value="img", mime_type="image/png"), + ] + ) + response, error = provider.execute( + completion_config=self._make_config(), + query=self._make_query(), + resolved_input=mm, + ) + assert error is None + call_kwargs = mock_client.models.generate_content.call_args[1] + assert call_kwargs["contents"][0]["role"] == "user" + assert len(call_kwargs["contents"][0]["parts"]) == 2 + + def test_list_input(self): + provider, mock_client = self._make_provider() + parts = [ImageContent(format="base64", value="img", mime_type="image/png")] + response, error = provider.execute( + completion_config=self._make_config(), + query=self._make_query(), + resolved_input=parts, + ) + assert error is None + call_kwargs = mock_client.models.generate_content.call_args[1] + assert call_kwargs["contents"][0]["role"] == "user" + + def test_string_input(self): + provider, mock_client = self._make_provider() + response, error = provider.execute( + completion_config=self._make_config(), + query=self._make_query(), + resolved_input="hello", + ) + assert error is None + call_kwargs = mock_client.models.generate_content.call_args[1] + assert call_kwargs["contents"][0]["parts"] == [{"text": "hello"}] + + def test_missing_model(self): + provider, _ = self._make_provider() + config = NativeCompletionConfig( + provider="google-native", type="text", params={} + ) + response, error = provider.execute( + completion_config=config, + query=self._make_query(), + resolved_input="hello", + ) + assert response is None + assert "Missing 'model'" in error + + def test_instructions_passed_to_config(self): + provider, mock_client = self._make_provider() + response, error = provider.execute( + completion_config=self._make_config(instructions="be helpful"), + query=self._make_query(), + resolved_input="hello", + ) + assert error is None + call_kwargs = mock_client.models.generate_content.call_args[1] + config = call_kwargs["config"] + assert config.system_instruction == "be helpful" + + def test_no_usage_metadata(self): + provider, mock_client = self._make_provider() + mock_resp = mock_client.models.generate_content.return_value + mock_resp.usage_metadata = None + response, error = provider.execute( + completion_config=self._make_config(), + query=self._make_query(), + resolved_input="hello", + ) + assert error is None + assert response.usage.input_tokens == 0 + assert response.usage.output_tokens == 0 diff --git a/backend/app/utils.py b/backend/app/utils.py index 37cd97053..9c1be2a11 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import functools as ft import ipaddress @@ -8,6 +10,7 @@ from pathlib import Path import requests import socket + from typing import Any, Dict, Generic, Optional, TypeVar from urllib.parse import urlparse @@ -25,6 +28,15 @@ from app.core import security from app.core.config import settings from app.crud.credentials import get_provider_credential +from app.models.llm.request import ( + TextInput, + AudioInput, + ImageInput, + PDFInput, + ImageContent, + PDFContent, +) +from app.services.llm.providers.base import ContentPart, MultiModalInput logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -444,27 +456,78 @@ def resolve_audio_base64(data: str, mime_type: str) -> tuple[str, str | None]: return "", f"Failed to write audio to temp file: {str(e)}" -def resolve_input(query_input) -> tuple[str, str | None]: - """Resolve discriminated union input to content string. +def resolve_image_content(image_input: ImageInput) -> list[ImageContent]: + contents = ( + image_input.content + if isinstance(image_input.content, list) + else [image_input.content] + ) + for c in contents: + if not c.mime_type: + c.mime_type = "image/png" + return contents - Args: - query_input: The input from QueryParams (TextInput or AudioInput) + +def resolve_pdf_content(pdf_input: PDFInput) -> list[PDFContent]: + contents = ( + pdf_input.content + if isinstance(pdf_input.content, list) + else [pdf_input.content] + ) + for c in contents: + if not c.mime_type: + c.mime_type = "application/pdf" + return contents + + +def resolve_input( + query_input, +) -> tuple[str | list[ImageContent] | list[PDFContent] | "MultiModalInput", str | None]: + """Resolve query input to provider-ready format. Returns: - (content_string, None) on success - for text returns content value, for audio returns temp file path - ("", error_message) on failure + - TextInput/AudioInput: (str, None) + - ImageInput: (list[ImageContent], None) + - PDFInput: (list[PDFContent], None) + - list[QueryInput]: (MultiModalInput, None) + - Error: ("", error_message) """ - from app.models.llm.request import TextInput, AudioInput try: if isinstance(query_input, TextInput): return query_input.content.value, None elif isinstance(query_input, AudioInput): - # AudioInput content is base64-encoded audio mime_type = query_input.content.mime_type or "audio/wav" return resolve_audio_base64(query_input.content.value, mime_type) + elif isinstance(query_input, ImageInput): + return resolve_image_content(query_input), None + + elif isinstance(query_input, PDFInput): + return resolve_pdf_content(query_input), None + + elif isinstance(query_input, list): + parts: list[ContentPart] = [] + for item in query_input: + if isinstance(item, TextInput): + parts.append(item.content) + elif isinstance(item, ImageInput): + parts.extend(resolve_image_content(item)) + elif isinstance(item, PDFInput): + parts.extend(resolve_pdf_content(item)) + elif isinstance(item, AudioInput): + return ( + "", + "Audio input is not supported in multimodal. Please use completion type 'stt' for audio processing.", + ) + else: + return ( + "", + "Unsupported input type in multimodal list. Multimodal only supports text, image, and pdf inputs.", + ) + return MultiModalInput(parts=parts), None + else: return "", f"Unknown input type: {type(query_input)}"