Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions backend/app/api/routes/private.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import base64
import logging
from typing import Any

from fastapi import APIRouter
from pydantic import BaseModel
from sqlmodel import col, select

from app.api.deps import SessionDep
from app.core.cloud.storage import get_cloud_storage
from app.core.security import get_password_hash
from app.core.storage_utils import upload_audio_bytes_to_s3
from app.core.util import now
from app.models import (
LlmCall,
User,
UserPublic,
)

logger = logging.getLogger(__name__)

router = APIRouter(tags=["private"], prefix="/private")


Expand All @@ -20,6 +29,93 @@ class PrivateUserCreate(BaseModel):
is_verified: bool = False


@router.post("/migrate/tts-base64-to-s3", include_in_schema=False)
def migrate_tts_base64_to_s3(session: SessionDep) -> dict:
"""
One-shot migration: find all llm_call rows with input_type=text / output_type=audio
whose content still holds raw base64, upload the audio to S3, and replace with a URI.
"""
processed = skipped = failed = 0
errors: list[dict] = []

# Storage instances are cached per project_id to avoid redundant DB lookups.
storage_cache: dict[int, Any] = {}

statement = (
select(LlmCall)
.where(
LlmCall.input_type == "text",
LlmCall.output_type == "audio",
col(LlmCall.deleted_at).is_(None),
)
.order_by(col(LlmCall.created_at).desc())
.execution_options(yield_per=100)
)

for call in session.exec(statement):
content = call.content
if not content:
skipped += 1
continue

audio_content = content.get("content", {})
if audio_content.get("format") != "base64":
skipped += 1
continue

b64_value = audio_content.get("value")
if not b64_value:
skipped += 1
continue

try:
if call.project_id not in storage_cache:
storage_cache[call.project_id] = get_cloud_storage(
session, call.project_id
)
storage = storage_cache[call.project_id]

audio_bytes = base64.b64decode(b64_value)
s3_url = upload_audio_bytes_to_s3(
storage,
audio_bytes,
call.id,
audio_content.get("mime_type"),
"llm/tts/audio",
)

if not s3_url:
raise RuntimeError("upload returned None")

call.content = {
"type": "audio",
"content": {
"format": "uri",
"value": s3_url,
"mime_type": audio_content.get("mime_type"),
},
}
call.updated_at = now()
session.add(call)
processed += 1

except Exception as e:
failed += 1
errors.append({"call_id": str(call.id), "error": str(e)})
logger.warning(
f"[migrate_tts_base64_to_s3] Failed | call_id={call.id}, error={e}"
)

session.commit()

return {
"processed": processed,
"skipped": skipped,
"failed": failed,
"errors": errors[:50],
}


@router.post("/users", response_model=UserPublic, include_in_schema=False)
def create_user(user_in: PrivateUserCreate, session: SessionDep) -> Any:
"""
Expand Down
43 changes: 42 additions & 1 deletion backend/app/core/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Literal
from urllib.parse import unquote, urlparse
from uuid import UUID

from starlette.datastructures import Headers, UploadFile

from app.core.cloud.storage import CloudStorage, CloudStorageError
from typing import Literal

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -207,6 +208,46 @@ def load_json_from_object_store(storage: CloudStorage, url: str) -> list | dict
return None


_MIME_TO_EXT: dict[str, str] = {
"audio/mpeg": "mp3",
"audio/mp3": "mp3",
"audio/ogg": "ogg",
"audio/wav": "wav",
"audio/wave": "wav",
"audio/x-wav": "wav",
"audio/webm": "webm",
"audio/mp4": "mp4",
"audio/aac": "aac",
"audio/flac": "flac",
}


def upload_audio_bytes_to_s3(
storage: CloudStorage,
audio_bytes: bytes,
call_id: UUID,
mime_type: str | None,
prefix: str,
) -> str | None:
"""Upload decoded audio bytes to S3 and return the s3:// URI.

Args:
storage: CloudStorage instance
audio_bytes: Raw audio bytes
call_id: LLM call UUID used as the filename stem
mime_type: MIME type of the audio (determines file extension)
prefix: S3 subdirectory, e.g. "llm/tts/audio" or "llm/stt/audio"

Returns:
s3:// URI if successful, None on failure
"""
ext = _MIME_TO_EXT.get(mime_type or "", "wav")
filename = f"{call_id}.{ext}"
return upload_to_object_store(
storage, audio_bytes, filename, prefix, mime_type or "audio/wav"
)


def generate_timestamped_filename(base_name: str, extension: str = "csv") -> str:
"""
Generate a filename with timestamp.
Expand Down
Loading