diff --git a/backend/app/api/docs/stt_evaluation/update_sample.md b/backend/app/api/docs/stt_evaluation/update_sample.md new file mode 100644 index 000000000..4360c5695 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/update_sample.md @@ -0,0 +1,3 @@ +Update an STT sample's language and/or ground truth transcription. + +Only the provided fields will be updated. Fields set to `null` in the request will not modify the existing value. diff --git a/backend/app/api/docs/tts_evaluation/update_feedback.md b/backend/app/api/docs/tts_evaluation/update_feedback.md index 7701bb52a..910227f83 100644 --- a/backend/app/api/docs/tts_evaluation/update_feedback.md +++ b/backend/app/api/docs/tts_evaluation/update_feedback.md @@ -1,5 +1,20 @@ -Update human feedback on a TTS synthesis result. +Update human feedback and score on a TTS synthesis result. + +Only the provided fields will be updated. Fields omitted from the request will not modify the existing value. Sending a field as `null` will clear its value. Fields: - **is_correct**: Whether the synthesized audio quality is acceptable (null to clear) - **comment**: Optional feedback comment +- **score**: Evaluation metrics for the synthesized audio + +**Example request:** +```json +{ + "is_correct": true, + "comment": "string", + "score": { + "Speech Naturalness": "low | medium | high", + "Pronunciation Accuracy": "low | medium | high" + } +} +``` diff --git a/backend/app/api/routes/evaluations/dataset.py b/backend/app/api/routes/evaluations/dataset.py index a63eeba42..202774da2 100644 --- a/backend/app/api/routes/evaluations/dataset.py +++ b/backend/app/api/routes/evaluations/dataset.py @@ -42,6 +42,7 @@ def _dataset_to_response( return DatasetUploadResponse( dataset_id=dataset.id, dataset_name=dataset.name, + description=dataset.description, total_items=dataset.dataset_metadata.get("total_items_count", 0), original_items=dataset.dataset_metadata.get("original_items_count", 0), duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1), diff --git a/backend/app/api/routes/languages.py b/backend/app/api/routes/languages.py index f04896b94..9a184ea81 100644 --- a/backend/app/api/routes/languages.py +++ b/backend/app/api/routes/languages.py @@ -1,6 +1,6 @@ import logging -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter from app.api.deps import AuthContextDep, SessionDep from app.crud.language import get_language_by_id, get_languages @@ -37,8 +37,10 @@ def get_language(session: SessionDep, auth_context: AuthContextDep, language_id: """ Retrieve a language by ID. """ - language = get_language_by_id(session=session, language_id=language_id) - if language is None: - logger.error(f"[get_language] Language not found | language_id={language_id}") - raise HTTPException(status_code=404, detail="Language not found") + language = get_language_by_id( + session=session, + language_id=language_id, + status_code=404, + detail="Language not found", + ) return APIResponse.success_response(language) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index ea6ee9362..b42a4b17e 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -13,12 +13,14 @@ get_samples_by_dataset_id, get_stt_dataset_by_id, list_stt_datasets, + update_stt_sample, ) from app.models.stt_evaluation import ( STTDatasetCreate, STTDatasetPublic, STTDatasetWithSamples, STTSamplePublic, + STTSampleUpdate, ) from app.services.stt_evaluations.dataset import upload_stt_dataset from app.utils import APIResponse, load_description @@ -43,13 +45,7 @@ def create_dataset( """Create an STT evaluation dataset.""" # Validate language_id if dataset_create.language_id is not None: - language = get_language_by_id( - session=session, language_id=dataset_create.language_id - ) - if not language: - raise HTTPException( - status_code=400, detail="Invalid language_id: language not found" - ) + get_language_by_id(session=session, language_id=dataset_create.language_id) dataset, samples = upload_stt_dataset( session=session, @@ -165,7 +161,6 @@ def get_dataset( session=session, project_id=auth_context.project_.id ) - samples = [] for s in sample_records: signed_url = None if storage and s.file_id in file_map: @@ -209,3 +204,50 @@ def get_dataset( ), metadata={"samples_total": samples_total}, ) + + +@router.patch( + "/samples/{sample_id}", + response_model=APIResponse[STTSamplePublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Update STT sample", + description=load_description("stt_evaluation/update_sample.md"), +) +def update_sample( + session: SessionDep, + auth_context: AuthContextDep, + sample_id: int, + sample_update: STTSampleUpdate = Body(...), +) -> APIResponse[STTSamplePublic]: + """Update an STT sample's language and/or ground truth.""" + logger.info(f"[update_sample] Updating sample | " f"sample_id: {sample_id}") + + if sample_update.language_id is not None: + get_language_by_id(session=session, language_id=sample_update.language_id) + + sample = update_stt_sample( + session=session, + sample_id=sample_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + language_id=sample_update.language_id, + ground_truth=sample_update.ground_truth, + ) + + if not sample: + raise HTTPException(status_code=404, detail="Sample not found") + + return APIResponse.success_response( + data=STTSamplePublic( + id=sample.id, + file_id=sample.file_id, + language_id=sample.language_id, + ground_truth=sample.ground_truth, + sample_metadata=sample.sample_metadata, + dataset_id=sample.dataset_id, + organization_id=sample.organization_id, + project_id=sample.project_id, + inserted_at=sample.inserted_at, + updated_at=sample.updated_at, + ) + ) diff --git a/backend/app/api/routes/tts_evaluations/dataset.py b/backend/app/api/routes/tts_evaluations/dataset.py index c08d4e68e..115df8fb5 100644 --- a/backend/app/api/routes/tts_evaluations/dataset.py +++ b/backend/app/api/routes/tts_evaluations/dataset.py @@ -6,6 +6,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission +from app.core.cloud import get_cloud_storage from app.crud.language import get_language_by_id from app.crud.tts_evaluations import ( get_tts_dataset_by_id, @@ -38,13 +39,7 @@ def create_dataset( """Create a TTS evaluation dataset.""" # Validate language_id if provided if dataset_create.language_id is not None: - language = get_language_by_id( - session=session, language_id=dataset_create.language_id - ) - if not language: - raise HTTPException( - status_code=400, detail="Invalid language_id: language not found" - ) + get_language_by_id(session=session, language_id=dataset_create.language_id) dataset = upload_tts_dataset( session=session, @@ -71,6 +66,9 @@ def list_datasets( auth_context: AuthContextDep, limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), offset: int = Query(0, ge=0, description="Number of results to skip"), + include_signed_url: bool = Query( + False, description="Include signed URL for dataset files" + ), ) -> APIResponse[list[TTSDatasetPublic]]: """List TTS evaluation datasets.""" datasets, total = list_tts_datasets( @@ -81,8 +79,21 @@ def list_datasets( offset=offset, ) + storage = None + if include_signed_url: + storage = get_cloud_storage( + session=session, project_id=auth_context.project_.id + ) + + data = [] + for dataset in datasets: + signed_url = None + if storage and dataset.object_store_url: + signed_url = storage.get_signed_url(dataset.object_store_url) + data.append(TTSDatasetPublic.from_model(dataset, signed_url=signed_url)) + return APIResponse.success_response( - data=datasets, + data=data, metadata={"total": total, "limit": limit, "offset": offset}, ) @@ -98,6 +109,9 @@ def get_dataset( session: SessionDep, auth_context: AuthContextDep, dataset_id: int, + include_signed_url: bool = Query( + False, description="Include signed URL for dataset file" + ), ) -> APIResponse[TTSDatasetPublic]: """Get a TTS evaluation dataset.""" dataset = get_tts_dataset_by_id( @@ -110,8 +124,15 @@ def get_dataset( if not dataset: raise HTTPException(status_code=404, detail="Dataset not found") + signed_url = None + if include_signed_url and dataset.object_store_url: + storage = get_cloud_storage( + session=session, project_id=auth_context.project_.id + ) + signed_url = storage.get_signed_url(dataset.object_store_url) + return APIResponse.success_response( - data=TTSDatasetPublic.from_model(dataset), + data=TTSDatasetPublic.from_model(dataset, signed_url=signed_url), metadata={ "sample_count": (dataset.dataset_metadata or {}).get("sample_count", 0) }, diff --git a/backend/app/api/routes/tts_evaluations/result.py b/backend/app/api/routes/tts_evaluations/result.py index 2f1998fd0..92a82b002 100644 --- a/backend/app/api/routes/tts_evaluations/result.py +++ b/backend/app/api/routes/tts_evaluations/result.py @@ -70,6 +70,8 @@ def update_result_feedback( update_kwargs["is_correct"] = feedback.is_correct if "comment" in feedback.model_fields_set: update_kwargs["comment"] = feedback.comment + if "score" in feedback.model_fields_set: + update_kwargs["score"] = feedback.score result = update_tts_human_feedback( session=session, diff --git a/backend/app/crud/evaluations/langfuse.py b/backend/app/crud/evaluations/langfuse.py index 477de7e57..1dd0c519e 100644 --- a/backend/app/crud/evaluations/langfuse.py +++ b/backend/app/crud/evaluations/langfuse.py @@ -209,7 +209,7 @@ def update_traces_with_cosine_scores( try: langfuse.score( trace_id=trace_id, - name="cosine_similarity", + name="Cosine Similarity", value=cosine_score, comment=( "Cosine similarity between generated output and " diff --git a/backend/app/crud/evaluations/processing.py b/backend/app/crud/evaluations/processing.py index d33207302..1e36fd13b 100644 --- a/backend/app/crud/evaluations/processing.py +++ b/backend/app/crud/evaluations/processing.py @@ -392,7 +392,7 @@ async def process_completed_embedding_batch( eval_run.score = { "summary_scores": [ { - "name": "cosine_similarity", + "name": "Cosine Similarity", "avg": round(float(similarity_stats["cosine_similarity_avg"]), 2), "std": round(float(similarity_stats["cosine_similarity_std"]), 2), "total_pairs": similarity_stats["total_pairs"], diff --git a/backend/app/crud/language.py b/backend/app/crud/language.py index eb4f7d273..765e50542 100644 --- a/backend/app/crud/language.py +++ b/backend/app/crud/language.py @@ -1,6 +1,7 @@ import logging from typing import Optional +from fastapi import HTTPException from sqlmodel import Session, select from app.models import Language @@ -16,10 +17,19 @@ def get_languages(session: Session, skip: int = 0, limit: int = 100) -> list[Lan return list(session.exec(statement).all()) -def get_language_by_id(session: Session, language_id: int) -> Optional[Language]: - """Retrieve a language by its ID.""" +def get_language_by_id( + session: Session, + language_id: int, + *, + status_code: int = 400, + detail: str = "Invalid language_id: language not found", +) -> Language: + """Retrieve a language by its ID. Raises HTTPException if not found.""" statement = select(Language).where(Language.id == language_id) - return session.exec(statement).first() + language = session.exec(statement).first() + if not language: + raise HTTPException(status_code=status_code, detail=detail) + return language def get_language_by_locale(session: Session, locale: str) -> Optional[Language]: diff --git a/backend/app/crud/stt_evaluations/__init__.py b/backend/app/crud/stt_evaluations/__init__.py index 7cc235e65..071d168bf 100644 --- a/backend/app/crud/stt_evaluations/__init__.py +++ b/backend/app/crud/stt_evaluations/__init__.py @@ -6,8 +6,10 @@ create_stt_dataset, create_stt_samples, get_stt_dataset_by_id, + get_stt_sample_by_id, list_stt_datasets, get_samples_by_dataset_id, + update_stt_sample, ) from .run import ( create_stt_run, @@ -30,8 +32,10 @@ "create_stt_dataset", "create_stt_samples", "get_stt_dataset_by_id", + "get_stt_sample_by_id", "list_stt_datasets", "get_samples_by_dataset_id", + "update_stt_sample", # Run "create_stt_run", "get_stt_run_by_id", diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index c30d615d7..2c559fb10 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -204,6 +204,85 @@ def create_stt_samples( return created_samples +def get_stt_sample_by_id( + *, + session: Session, + sample_id: int, + org_id: int, + project_id: int, +) -> STTSample | None: + """Get an STT sample by ID. + + Args: + session: Database session + sample_id: Sample ID + org_id: Organization ID + project_id: Project ID + + Returns: + STTSample | None: Sample if found + """ + statement = select(STTSample).where( + STTSample.id == sample_id, + STTSample.organization_id == org_id, + STTSample.project_id == project_id, + ) + + return session.exec(statement).one_or_none() + + +def update_stt_sample( + *, + session: Session, + sample_id: int, + org_id: int, + project_id: int, + language_id: int | None = None, + ground_truth: str | None = None, +) -> STTSample | None: + """Update an STT sample's language and/or ground truth. + + Args: + session: Database session + sample_id: Sample ID + org_id: Organization ID + project_id: Project ID + language_id: Optional new language ID + ground_truth: Optional new ground truth transcription + + Returns: + STTSample | None: Updated sample, or None if not found + """ + sample = get_stt_sample_by_id( + session=session, + sample_id=sample_id, + org_id=org_id, + project_id=project_id, + ) + + if not sample: + return None + + if language_id is not None: + sample.language_id = language_id + + if ground_truth is not None: + sample.ground_truth = ground_truth + + sample.updated_at = now() + + session.add(sample) + session.flush() + + logger.info( + f"[update_stt_sample] Sample updated | " + f"sample_id: {sample_id}, language_id: {language_id}, " + f"ground_truth_updated: {ground_truth is not None}" + ) + + return sample + + def get_stt_dataset_by_id( *, session: Session, diff --git a/backend/app/crud/tts_evaluations/dataset.py b/backend/app/crud/tts_evaluations/dataset.py index e2956885d..c81007e92 100644 --- a/backend/app/crud/tts_evaluations/dataset.py +++ b/backend/app/crud/tts_evaluations/dataset.py @@ -10,7 +10,6 @@ from app.core.util import now from app.models import EvaluationDataset from app.models.stt_evaluation import EvaluationType -from app.models.tts_evaluation import TTSDatasetPublic logger = logging.getLogger(__name__) @@ -121,7 +120,7 @@ def list_tts_datasets( project_id: int, limit: int = 50, offset: int = 0, -) -> tuple[list[TTSDatasetPublic], int]: +) -> tuple[list[EvaluationDataset], int]: """List TTS datasets for a project. Args: @@ -132,7 +131,7 @@ def list_tts_datasets( offset: Number of results to skip Returns: - tuple[list[TTSDatasetPublic], int]: Datasets and total count + tuple[list[EvaluationDataset], int]: Datasets and total count """ base_filter = ( EvaluationDataset.organization_id == org_id, @@ -151,8 +150,6 @@ def list_tts_datasets( .limit(limit) ) - datasets = session.exec(statement).all() + datasets = list(session.exec(statement).all()) - result = [TTSDatasetPublic.from_model(dataset) for dataset in datasets] - - return result, total + return datasets, total diff --git a/backend/app/crud/tts_evaluations/result.py b/backend/app/crud/tts_evaluations/result.py index 2095f7def..b13b5a62e 100644 --- a/backend/app/crud/tts_evaluations/result.py +++ b/backend/app/crud/tts_evaluations/result.py @@ -237,6 +237,9 @@ def update_tts_human_feedback( if "comment" in kwargs: result.comment = kwargs["comment"] + if "score" in kwargs: + result.score = kwargs["score"] + result.updated_at = now() session.add(result) diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index 18e7749bf..d2d2beecc 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -28,6 +28,7 @@ class DatasetUploadResponse(BaseModel): dataset_id: int = Field(..., description="Database ID of the created dataset") dataset_name: str = Field(..., description="Name of the created dataset") + description: str | None = Field(None, description="Description of the dataset") total_items: int = Field( ..., description="Total number of items uploaded (after duplication)" ) diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index 7d0bd75f7..5e953e36d 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -268,6 +268,13 @@ class STTResultWithSample(STTResultPublic): sample: STTSamplePublic +class STTSampleUpdate(BaseModel): + """Request model for updating an STT sample's language and ground truth.""" + + language_id: int | None = Field(None, description="Language ID for this sample") + ground_truth: str | None = Field(None, description="Reference transcription") + + class STTFeedbackUpdate(BaseModel): """Request model for updating human feedback on a result.""" diff --git a/backend/app/models/tts_evaluation.py b/backend/app/models/tts_evaluation.py index 3ee796759..a0bbdb054 100644 --- a/backend/app/models/tts_evaluation.py +++ b/backend/app/models/tts_evaluation.py @@ -179,6 +179,7 @@ class TTSDatasetPublic(BaseModel): type: str language_id: int | None object_store_url: str | None + signed_url: str | None = None dataset_metadata: dict[str, Any] organization_id: int project_id: int @@ -186,7 +187,12 @@ class TTSDatasetPublic(BaseModel): updated_at: datetime @classmethod - def from_model(cls, dataset: EvaluationDataset) -> TTSDatasetPublic: + def from_model( + cls, + dataset: EvaluationDataset, + *, + signed_url: str | None = None, + ) -> TTSDatasetPublic: """Create from an EvaluationDataset model instance.""" return cls( id=dataset.id, @@ -195,6 +201,7 @@ def from_model(cls, dataset: EvaluationDataset) -> TTSDatasetPublic: type=dataset.type, language_id=dataset.language_id, object_store_url=dataset.object_store_url, + signed_url=signed_url, dataset_metadata=dataset.dataset_metadata, organization_id=dataset.organization_id, project_id=dataset.project_id, @@ -260,6 +267,16 @@ class TTSFeedbackUpdate(BaseModel): None, description="Is the synthesized audio correct?" ) comment: str | None = Field(None, description="Feedback comment") + score: dict[str, Any] | None = Field( + None, + description="Evaluation metrics", + json_schema_extra={ + "example": { + "Speech Naturalness": "low | medium | high", + "Pronunciation Accuracy": "low | medium | high", + } + }, + ) class TTSEvaluationRunCreate(BaseModel): diff --git a/backend/app/services/evaluations/validators.py b/backend/app/services/evaluations/validators.py index 92733a2f8..61d0c3b06 100644 --- a/backend/app/services/evaluations/validators.py +++ b/backend/app/services/evaluations/validators.py @@ -137,12 +137,19 @@ def parse_csv_items(csv_content: bytes) -> list[dict[str, str]]: field.strip().lower(): field for field in csv_reader.fieldnames } - # Validate required headers (case-insensitive) - if "question" not in clean_headers or "answer" not in clean_headers: + # Validate exactly 'question' and 'answer' columns (case-insensitive) + if set(clean_headers.keys()) != {"question", "answer"}: + extra = set(clean_headers.keys()) - {"question", "answer"} + missing = {"question", "answer"} - set(clean_headers.keys()) + parts = [] + if missing: + parts.append(f"Missing: {sorted(missing)}") + if extra: + parts.append(f"Unexpected: {sorted(extra)}") raise HTTPException( status_code=422, - detail=f"CSV must contain 'question' and 'answer' columns " - f"Found columns: {csv_reader.fieldnames}", + detail=f"CSV must contain exactly 'question' and 'answer' columns. " + f"{'. '.join(parts)}. Found columns: {csv_reader.fieldnames}", ) question_col = clean_headers["question"] diff --git a/backend/app/tests/crud/evaluations/test_langfuse.py b/backend/app/tests/crud/evaluations/test_langfuse.py index 128135f39..4cc1183ca 100644 --- a/backend/app/tests/crud/evaluations/test_langfuse.py +++ b/backend/app/tests/crud/evaluations/test_langfuse.py @@ -382,7 +382,7 @@ def test_update_traces_with_cosine_scores_success(self) -> None: calls = mock_langfuse.score.call_args_list assert calls[0].kwargs["trace_id"] == "trace_1" - assert calls[0].kwargs["name"] == "cosine_similarity" + assert calls[0].kwargs["name"] == "Cosine Similarity" assert calls[0].kwargs["value"] == 0.95 assert "cosine similarity" in calls[0].kwargs["comment"].lower() diff --git a/backend/app/tests/crud/evaluations/test_processing.py b/backend/app/tests/crud/evaluations/test_processing.py index 8d34900bc..29d62244d 100644 --- a/backend/app/tests/crud/evaluations/test_processing.py +++ b/backend/app/tests/crud/evaluations/test_processing.py @@ -580,7 +580,7 @@ async def test_process_completed_embedding_batch_success( assert "summary_scores" in result.score summary_scores = result.score["summary_scores"] cosine_score = next( - (s for s in summary_scores if s["name"] == "cosine_similarity"), None + (s for s in summary_scores if s["name"] == "Cosine Similarity"), None ) assert cosine_score is not None assert cosine_score["avg"] == 0.95