Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion backend/app/core/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Literal
from urllib.parse import unquote, urlparse
from uuid import UUID

from starlette.datastructures import Headers, UploadFile

from app.core.cloud.storage import CloudStorage, CloudStorageError
from typing import Literal

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -207,6 +208,46 @@ def load_json_from_object_store(storage: CloudStorage, url: str) -> list | dict
return None


_MIME_TO_EXT: dict[str, str] = {
"audio/mpeg": "mp3",
"audio/mp3": "mp3",
"audio/ogg": "ogg",
"audio/wav": "wav",
"audio/wave": "wav",
"audio/x-wav": "wav",
"audio/webm": "webm",
"audio/mp4": "mp4",
"audio/aac": "aac",
"audio/flac": "flac",
}


def upload_audio_bytes_to_s3(
storage: CloudStorage,
audio_bytes: bytes,
call_id: UUID,
mime_type: str | None,
prefix: str,
) -> str | None:
"""Upload decoded audio bytes to S3 and return the s3:// URI.

Args:
storage: CloudStorage instance
audio_bytes: Raw audio bytes
call_id: LLM call UUID used as the filename stem
mime_type: MIME type of the audio (determines file extension)
prefix: S3 subdirectory, e.g. "llm/tts/audio" or "llm/stt/audio"

Returns:
s3:// URI if successful, None on failure
"""
ext = _MIME_TO_EXT.get(mime_type or "", "wav")
filename = f"{call_id}.{ext}"
return upload_to_object_store(
storage, audio_bytes, filename, prefix, mime_type or "audio/wav"
)


def generate_timestamped_filename(base_name: str, extension: str = "csv") -> str:
"""
Generate a filename with timestamp.
Expand Down
39 changes: 35 additions & 4 deletions backend/app/crud/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ def serialize_input(query_input: QueryInput | str) -> str:
elif isinstance(query_input, TextInput):
return query_input.content.value
elif isinstance(query_input, AudioInput):
if query_input.content.format == "url":
return json.dumps(
{
"type": "audio",
"format": "url",
"mime_type": query_input.content.mime_type,
"url": query_input.content.value,
}
)
return json.dumps(
{
"type": "audio",
Expand Down Expand Up @@ -187,11 +196,12 @@ def update_llm_call_response(
db_llm_call.provider_response_id = provider_response_id

if content is not None:
# For audio outputs (AudioOutput model): calculate size metadata from base64 content
# AudioOutput serializes as: {"type": "audio", "content": {"format": "base64", "value": "...", "mime_type": "..."}}
# For audio outputs: calculate size only when content is still base64 (not a URI)
if content.get("type") == "audio":
audio_value = content.get("content", {}).get("value")
if audio_value:
audio_content = content.get("content", {})
audio_format = audio_content.get("format")
audio_value = audio_content.get("value")
if audio_value and audio_format == "base64":
try:
audio_data = base64.b64decode(audio_value)
content["audio_size_bytes"] = len(audio_data)
Expand All @@ -218,6 +228,27 @@ def update_llm_call_response(
return db_llm_call


def update_llm_call_input(
session: Session,
llm_call_id: UUID,
s3_uri: str,
) -> None:
"""Overwrite llm_call.input with an S3 URI after uploading STT audio."""
db_llm_call = session.get(LlmCall, llm_call_id)
if not db_llm_call:
logger.warning(
f"[update_llm_call_input] LLM call not found | llm_call_id={llm_call_id}"
)
return
db_llm_call.input = s3_uri
db_llm_call.updated_at = now()
session.add(db_llm_call)
session.commit()
logger.info(
f"[update_llm_call_input] Updated input URI | llm_call_id={llm_call_id}"
)


def get_llm_call_by_id(
session: Session,
llm_call_id: UUID,
Expand Down
10 changes: 8 additions & 2 deletions backend/app/models/llm/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,19 @@ class TextContent(SQLModel):


class AudioContent(SQLModel):
format: Literal["base64"] = "base64"
value: str = Field(..., description="Base64 encoded audio")
format: Literal["base64", "url"] = "base64"
value: str = Field(
..., description="Base64 encoded audio or public URL to download from"
)
# keeping the mime_type liberal here, since does not affect base64 encoding
mime_type: str | None = Field(
None,
description="MIME type of the audio (e.g., audio/wav, audio/mp3, audio/ogg)",
)
uri: str | None = Field(
None,
description="Presigned URL to the audio file in object storage (when available)",
)


class ImageContent(SQLModel):
Expand Down
33 changes: 32 additions & 1 deletion backend/app/services/llm/chain/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from sqlmodel import Session

from app.core.cloud.storage import get_cloud_storage
from app.core.db import engine
from app.crud.jobs import JobCrud
from app.crud.llm_chain import update_llm_chain_block_completed, update_llm_chain_status
Expand All @@ -10,7 +11,11 @@
ChainStatus,
LLMChainRequest,
)
from app.models.llm.response import IntermediateChainResponse, LLMChainResponse
from app.models.llm.response import (
AudioOutput,
IntermediateChainResponse,
LLMChainResponse,
)
from app.services.llm.chain.chain import ChainContext, LLMChain
from app.services.llm.chain.types import BlockResult
from app.utils import APIResponse, get_webhook_secret, send_callback
Expand Down Expand Up @@ -65,10 +70,33 @@ def _setup(self) -> None:
self._context.project_id, self._context.organization_id
)

def _resolve_presigned_url(self, output) -> None:
"""Swap the s3:// URI in content.uri for a presigned URL in-place.

Non-fatal: clears uri on failure so clients don't receive a raw s3:// address.
"""
if isinstance(output, AudioOutput) and output.content.uri:
try:
with Session(engine) as session:
storage = get_cloud_storage(session, self._context.project_id)
output.content.uri = storage.get_signed_url(
output.content.uri, expires_in=3600
)
except Exception as e:
logger.warning(
f"[_resolve_presigned_url] Failed to generate presigned URL: {e} | "
f"job_id={self._context.job_id}",
exc_info=True,
)
output.content.uri = None

def _teardown(self, result: BlockResult) -> dict:
"""Finalize chain record, send callback, and update job status."""

if result.success:
if result.response:
self._resolve_presigned_url(result.response.response.output)

final = LLMChainResponse(
response=result.response.response,
usage=result.usage,
Expand Down Expand Up @@ -159,6 +187,9 @@ def _send_intermediate_callback(
) -> None:
"""Send intermediate callback for a completed block."""
try:
if result.response:
self._resolve_presigned_url(result.response.response.output)

intermediate = IntermediateChainResponse(
block_index=block_index + 1,
total_blocks=self._context.total_blocks,
Expand Down
Loading
Loading