From f03e860425ac2263a9feccd554ce598974dec25a Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 14 Mar 2026 12:52:11 +0530 Subject: [PATCH 1/8] added signed url for tts dataset --- .../app/api/routes/tts_evaluations/dataset.py | 31 +++++++++++++++++-- backend/app/crud/tts_evaluations/dataset.py | 11 +++---- backend/app/models/tts_evaluation.py | 9 +++++- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/backend/app/api/routes/tts_evaluations/dataset.py b/backend/app/api/routes/tts_evaluations/dataset.py index c08d4e68e..2005a4dc3 100644 --- a/backend/app/api/routes/tts_evaluations/dataset.py +++ b/backend/app/api/routes/tts_evaluations/dataset.py @@ -6,6 +6,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission +from app.core.cloud import get_cloud_storage from app.crud.language import get_language_by_id from app.crud.tts_evaluations import ( get_tts_dataset_by_id, @@ -71,6 +72,9 @@ def list_datasets( auth_context: AuthContextDep, limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), offset: int = Query(0, ge=0, description="Number of results to skip"), + include_signed_url: bool = Query( + False, description="Include signed URL for dataset files" + ), ) -> APIResponse[list[TTSDatasetPublic]]: """List TTS evaluation datasets.""" datasets, total = list_tts_datasets( @@ -81,8 +85,21 @@ def list_datasets( offset=offset, ) + storage = None + if include_signed_url: + storage = get_cloud_storage( + session=session, project_id=auth_context.project_.id + ) + + data = [] + for dataset in datasets: + signed_url = None + if storage and dataset.object_store_url: + signed_url = storage.get_signed_url(dataset.object_store_url) + data.append(TTSDatasetPublic.from_model(dataset, signed_url=signed_url)) + return APIResponse.success_response( - data=datasets, + data=data, metadata={"total": total, "limit": limit, "offset": offset}, ) @@ -98,6 +115,9 @@ def get_dataset( session: SessionDep, auth_context: AuthContextDep, dataset_id: int, + include_signed_url: bool = Query( + False, description="Include signed URL for dataset file" + ), ) -> APIResponse[TTSDatasetPublic]: """Get a TTS evaluation dataset.""" dataset = get_tts_dataset_by_id( @@ -110,8 +130,15 @@ def get_dataset( if not dataset: raise HTTPException(status_code=404, detail="Dataset not found") + signed_url = None + if include_signed_url and dataset.object_store_url: + storage = get_cloud_storage( + session=session, project_id=auth_context.project_.id + ) + signed_url = storage.get_signed_url(dataset.object_store_url) + return APIResponse.success_response( - data=TTSDatasetPublic.from_model(dataset), + data=TTSDatasetPublic.from_model(dataset, signed_url=signed_url), metadata={ "sample_count": (dataset.dataset_metadata or {}).get("sample_count", 0) }, diff --git a/backend/app/crud/tts_evaluations/dataset.py b/backend/app/crud/tts_evaluations/dataset.py index e2956885d..c81007e92 100644 --- a/backend/app/crud/tts_evaluations/dataset.py +++ b/backend/app/crud/tts_evaluations/dataset.py @@ -10,7 +10,6 @@ from app.core.util import now from app.models import EvaluationDataset from app.models.stt_evaluation import EvaluationType -from app.models.tts_evaluation import TTSDatasetPublic logger = logging.getLogger(__name__) @@ -121,7 +120,7 @@ def list_tts_datasets( project_id: int, limit: int = 50, offset: int = 0, -) -> tuple[list[TTSDatasetPublic], int]: +) -> tuple[list[EvaluationDataset], int]: """List TTS datasets for a project. Args: @@ -132,7 +131,7 @@ def list_tts_datasets( offset: Number of results to skip Returns: - tuple[list[TTSDatasetPublic], int]: Datasets and total count + tuple[list[EvaluationDataset], int]: Datasets and total count """ base_filter = ( EvaluationDataset.organization_id == org_id, @@ -151,8 +150,6 @@ def list_tts_datasets( .limit(limit) ) - datasets = session.exec(statement).all() + datasets = list(session.exec(statement).all()) - result = [TTSDatasetPublic.from_model(dataset) for dataset in datasets] - - return result, total + return datasets, total diff --git a/backend/app/models/tts_evaluation.py b/backend/app/models/tts_evaluation.py index 3ee796759..e1222ab97 100644 --- a/backend/app/models/tts_evaluation.py +++ b/backend/app/models/tts_evaluation.py @@ -179,6 +179,7 @@ class TTSDatasetPublic(BaseModel): type: str language_id: int | None object_store_url: str | None + signed_url: str | None = None dataset_metadata: dict[str, Any] organization_id: int project_id: int @@ -186,7 +187,12 @@ class TTSDatasetPublic(BaseModel): updated_at: datetime @classmethod - def from_model(cls, dataset: EvaluationDataset) -> TTSDatasetPublic: + def from_model( + cls, + dataset: EvaluationDataset, + *, + signed_url: str | None = None, + ) -> TTSDatasetPublic: """Create from an EvaluationDataset model instance.""" return cls( id=dataset.id, @@ -195,6 +201,7 @@ def from_model(cls, dataset: EvaluationDataset) -> TTSDatasetPublic: type=dataset.type, language_id=dataset.language_id, object_store_url=dataset.object_store_url, + signed_url=signed_url, dataset_metadata=dataset.dataset_metadata, organization_id=dataset.organization_id, project_id=dataset.project_id, From f9ecdb7e0117ed798c0f1481663f1e8fe9e065b9 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 14 Mar 2026 13:16:11 +0530 Subject: [PATCH 2/8] added API to update STT samples --- .../api/docs/stt_evaluation/update_sample.md | 3 + .../app/api/routes/stt_evaluations/dataset.py | 53 ++++++++++++ backend/app/crud/stt_evaluations/__init__.py | 4 + backend/app/crud/stt_evaluations/dataset.py | 83 +++++++++++++++++++ backend/app/models/stt_evaluation.py | 7 ++ 5 files changed, 150 insertions(+) create mode 100644 backend/app/api/docs/stt_evaluation/update_sample.md diff --git a/backend/app/api/docs/stt_evaluation/update_sample.md b/backend/app/api/docs/stt_evaluation/update_sample.md new file mode 100644 index 000000000..4360c5695 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/update_sample.md @@ -0,0 +1,3 @@ +Update an STT sample's language and/or ground truth transcription. + +Only the provided fields will be updated. Fields set to `null` in the request will not modify the existing value. diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index ea6ee9362..fb18f4843 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -12,13 +12,16 @@ from app.crud.stt_evaluations import ( get_samples_by_dataset_id, get_stt_dataset_by_id, + get_stt_sample_by_id, list_stt_datasets, + update_stt_sample, ) from app.models.stt_evaluation import ( STTDatasetCreate, STTDatasetPublic, STTDatasetWithSamples, STTSamplePublic, + STTSampleUpdate, ) from app.services.stt_evaluations.dataset import upload_stt_dataset from app.utils import APIResponse, load_description @@ -209,3 +212,53 @@ def get_dataset( ), metadata={"samples_total": samples_total}, ) + + +@router.patch( + "/samples/{sample_id}", + response_model=APIResponse[STTSamplePublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Update STT sample", + description=load_description("stt_evaluation/update_sample.md"), +) +def update_sample( + session: SessionDep, + auth_context: AuthContextDep, + sample_id: int, + sample_update: STTSampleUpdate = Body(...), +) -> APIResponse[STTSamplePublic]: + """Update an STT sample's language and/or ground truth.""" + logger.info(f"[update_sample] Updating sample | " f"sample_id: {sample_id}") + + if sample_update.language_id is not None: + language = get_language_by_id( + session=session, language_id=sample_update.language_id + ) + if not language: + raise HTTPException( + status_code=400, detail="Invalid language_id: language not found" + ) + + sample = update_stt_sample( + session=session, + sample_id=sample_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + language_id=sample_update.language_id, + ground_truth=sample_update.ground_truth, + ) + + return APIResponse.success_response( + data=STTSamplePublic( + id=sample.id, + file_id=sample.file_id, + language_id=sample.language_id, + ground_truth=sample.ground_truth, + sample_metadata=sample.sample_metadata, + dataset_id=sample.dataset_id, + organization_id=sample.organization_id, + project_id=sample.project_id, + inserted_at=sample.inserted_at, + updated_at=sample.updated_at, + ) + ) diff --git a/backend/app/crud/stt_evaluations/__init__.py b/backend/app/crud/stt_evaluations/__init__.py index 7cc235e65..071d168bf 100644 --- a/backend/app/crud/stt_evaluations/__init__.py +++ b/backend/app/crud/stt_evaluations/__init__.py @@ -6,8 +6,10 @@ create_stt_dataset, create_stt_samples, get_stt_dataset_by_id, + get_stt_sample_by_id, list_stt_datasets, get_samples_by_dataset_id, + update_stt_sample, ) from .run import ( create_stt_run, @@ -30,8 +32,10 @@ "create_stt_dataset", "create_stt_samples", "get_stt_dataset_by_id", + "get_stt_sample_by_id", "list_stt_datasets", "get_samples_by_dataset_id", + "update_stt_sample", # Run "create_stt_run", "get_stt_run_by_id", diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index c30d615d7..665c65239 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -204,6 +204,89 @@ def create_stt_samples( return created_samples +def get_stt_sample_by_id( + *, + session: Session, + sample_id: int, + org_id: int, + project_id: int, +) -> STTSample | None: + """Get an STT sample by ID. + + Args: + session: Database session + sample_id: Sample ID + org_id: Organization ID + project_id: Project ID + + Returns: + STTSample | None: Sample if found + """ + statement = select(STTSample).where( + STTSample.id == sample_id, + STTSample.organization_id == org_id, + STTSample.project_id == project_id, + ) + + return session.exec(statement).one_or_none() + + +def update_stt_sample( + *, + session: Session, + sample_id: int, + org_id: int, + project_id: int, + language_id: int | None = None, + ground_truth: str | None = None, +) -> STTSample: + """Update an STT sample's language and/or ground truth. + + Args: + session: Database session + sample_id: Sample ID + org_id: Organization ID + project_id: Project ID + language_id: Optional new language ID + ground_truth: Optional new ground truth transcription + + Returns: + STTSample: Updated sample + + Raises: + HTTPException: If sample not found + """ + sample = get_stt_sample_by_id( + session=session, + sample_id=sample_id, + org_id=org_id, + project_id=project_id, + ) + + if not sample: + raise HTTPException(status_code=404, detail="Sample not found") + + if language_id is not None: + sample.language_id = language_id + + if ground_truth is not None: + sample.ground_truth = ground_truth + + sample.updated_at = now() + + session.add(sample) + session.commit() + session.refresh(sample) + + logger.info( + f"[update_stt_sample] Sample updated | " + f"sample_id: {sample_id}, language_id: {language_id}, " + f"ground_truth_updated: {ground_truth is not None}" + ) + + return sample + + def get_stt_dataset_by_id( *, session: Session, diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index 7d0bd75f7..5e953e36d 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -268,6 +268,13 @@ class STTResultWithSample(STTResultPublic): sample: STTSamplePublic +class STTSampleUpdate(BaseModel): + """Request model for updating an STT sample's language and ground truth.""" + + language_id: int | None = Field(None, description="Language ID for this sample") + ground_truth: str | None = Field(None, description="Reference transcription") + + class STTFeedbackUpdate(BaseModel): """Request model for updating human feedback on a result.""" From 33db59e59ae06457a8725dd083f06fd0a02c45f8 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 14 Mar 2026 21:17:21 +0530 Subject: [PATCH 3/8] updated TTS feedback results --- .../api/docs/tts_evaluation/update_feedback.md | 17 ++++++++++++++++- .../app/api/routes/tts_evaluations/result.py | 2 ++ backend/app/crud/tts_evaluations/result.py | 3 +++ backend/app/models/tts_evaluation.py | 10 ++++++++++ 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/backend/app/api/docs/tts_evaluation/update_feedback.md b/backend/app/api/docs/tts_evaluation/update_feedback.md index 7701bb52a..910227f83 100644 --- a/backend/app/api/docs/tts_evaluation/update_feedback.md +++ b/backend/app/api/docs/tts_evaluation/update_feedback.md @@ -1,5 +1,20 @@ -Update human feedback on a TTS synthesis result. +Update human feedback and score on a TTS synthesis result. + +Only the provided fields will be updated. Fields omitted from the request will not modify the existing value. Sending a field as `null` will clear its value. Fields: - **is_correct**: Whether the synthesized audio quality is acceptable (null to clear) - **comment**: Optional feedback comment +- **score**: Evaluation metrics for the synthesized audio + +**Example request:** +```json +{ + "is_correct": true, + "comment": "string", + "score": { + "Speech Naturalness": "low | medium | high", + "Pronunciation Accuracy": "low | medium | high" + } +} +``` diff --git a/backend/app/api/routes/tts_evaluations/result.py b/backend/app/api/routes/tts_evaluations/result.py index 2f1998fd0..92a82b002 100644 --- a/backend/app/api/routes/tts_evaluations/result.py +++ b/backend/app/api/routes/tts_evaluations/result.py @@ -70,6 +70,8 @@ def update_result_feedback( update_kwargs["is_correct"] = feedback.is_correct if "comment" in feedback.model_fields_set: update_kwargs["comment"] = feedback.comment + if "score" in feedback.model_fields_set: + update_kwargs["score"] = feedback.score result = update_tts_human_feedback( session=session, diff --git a/backend/app/crud/tts_evaluations/result.py b/backend/app/crud/tts_evaluations/result.py index 2095f7def..b13b5a62e 100644 --- a/backend/app/crud/tts_evaluations/result.py +++ b/backend/app/crud/tts_evaluations/result.py @@ -237,6 +237,9 @@ def update_tts_human_feedback( if "comment" in kwargs: result.comment = kwargs["comment"] + if "score" in kwargs: + result.score = kwargs["score"] + result.updated_at = now() session.add(result) diff --git a/backend/app/models/tts_evaluation.py b/backend/app/models/tts_evaluation.py index e1222ab97..a0bbdb054 100644 --- a/backend/app/models/tts_evaluation.py +++ b/backend/app/models/tts_evaluation.py @@ -267,6 +267,16 @@ class TTSFeedbackUpdate(BaseModel): None, description="Is the synthesized audio correct?" ) comment: str | None = Field(None, description="Feedback comment") + score: dict[str, Any] | None = Field( + None, + description="Evaluation metrics", + json_schema_extra={ + "example": { + "Speech Naturalness": "low | medium | high", + "Pronunciation Accuracy": "low | medium | high", + } + }, + ) class TTSEvaluationRunCreate(BaseModel): From bdc8133edeb85e597fb81445268f53822e687b0b Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sun, 15 Mar 2026 23:24:41 +0530 Subject: [PATCH 4/8] added description and headers validation --- backend/app/api/routes/evaluations/dataset.py | 1 + backend/app/models/evaluation.py | 1 + backend/app/services/evaluations/validators.py | 15 +++++++++++---- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/backend/app/api/routes/evaluations/dataset.py b/backend/app/api/routes/evaluations/dataset.py index a63eeba42..202774da2 100644 --- a/backend/app/api/routes/evaluations/dataset.py +++ b/backend/app/api/routes/evaluations/dataset.py @@ -42,6 +42,7 @@ def _dataset_to_response( return DatasetUploadResponse( dataset_id=dataset.id, dataset_name=dataset.name, + description=dataset.description, total_items=dataset.dataset_metadata.get("total_items_count", 0), original_items=dataset.dataset_metadata.get("original_items_count", 0), duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1), diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index 18e7749bf..d2d2beecc 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -28,6 +28,7 @@ class DatasetUploadResponse(BaseModel): dataset_id: int = Field(..., description="Database ID of the created dataset") dataset_name: str = Field(..., description="Name of the created dataset") + description: str | None = Field(None, description="Description of the dataset") total_items: int = Field( ..., description="Total number of items uploaded (after duplication)" ) diff --git a/backend/app/services/evaluations/validators.py b/backend/app/services/evaluations/validators.py index 92733a2f8..61d0c3b06 100644 --- a/backend/app/services/evaluations/validators.py +++ b/backend/app/services/evaluations/validators.py @@ -137,12 +137,19 @@ def parse_csv_items(csv_content: bytes) -> list[dict[str, str]]: field.strip().lower(): field for field in csv_reader.fieldnames } - # Validate required headers (case-insensitive) - if "question" not in clean_headers or "answer" not in clean_headers: + # Validate exactly 'question' and 'answer' columns (case-insensitive) + if set(clean_headers.keys()) != {"question", "answer"}: + extra = set(clean_headers.keys()) - {"question", "answer"} + missing = {"question", "answer"} - set(clean_headers.keys()) + parts = [] + if missing: + parts.append(f"Missing: {sorted(missing)}") + if extra: + parts.append(f"Unexpected: {sorted(extra)}") raise HTTPException( status_code=422, - detail=f"CSV must contain 'question' and 'answer' columns " - f"Found columns: {csv_reader.fieldnames}", + detail=f"CSV must contain exactly 'question' and 'answer' columns. " + f"{'. '.join(parts)}. Found columns: {csv_reader.fieldnames}", ) question_col = clean_headers["question"] From 029bb487b75d116ce7bff2daab8aca56c6484d9b Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sun, 15 Mar 2026 23:29:33 +0530 Subject: [PATCH 5/8] using cosine similarity naming for consistency --- backend/app/crud/evaluations/langfuse.py | 2 +- backend/app/crud/evaluations/processing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/crud/evaluations/langfuse.py b/backend/app/crud/evaluations/langfuse.py index 477de7e57..1dd0c519e 100644 --- a/backend/app/crud/evaluations/langfuse.py +++ b/backend/app/crud/evaluations/langfuse.py @@ -209,7 +209,7 @@ def update_traces_with_cosine_scores( try: langfuse.score( trace_id=trace_id, - name="cosine_similarity", + name="Cosine Similarity", value=cosine_score, comment=( "Cosine similarity between generated output and " diff --git a/backend/app/crud/evaluations/processing.py b/backend/app/crud/evaluations/processing.py index d33207302..1e36fd13b 100644 --- a/backend/app/crud/evaluations/processing.py +++ b/backend/app/crud/evaluations/processing.py @@ -392,7 +392,7 @@ async def process_completed_embedding_batch( eval_run.score = { "summary_scores": [ { - "name": "cosine_similarity", + "name": "Cosine Similarity", "avg": round(float(similarity_stats["cosine_similarity_avg"]), 2), "std": round(float(similarity_stats["cosine_similarity_std"]), 2), "total_pairs": similarity_stats["total_pairs"], From 05ff2c77d6c7363cd0cf9b820bf86048550e5cad Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 17 Mar 2026 13:21:07 +0530 Subject: [PATCH 6/8] updating testcases --- backend/app/tests/crud/evaluations/test_langfuse.py | 2 +- backend/app/tests/crud/evaluations/test_processing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/tests/crud/evaluations/test_langfuse.py b/backend/app/tests/crud/evaluations/test_langfuse.py index 128135f39..4cc1183ca 100644 --- a/backend/app/tests/crud/evaluations/test_langfuse.py +++ b/backend/app/tests/crud/evaluations/test_langfuse.py @@ -382,7 +382,7 @@ def test_update_traces_with_cosine_scores_success(self) -> None: calls = mock_langfuse.score.call_args_list assert calls[0].kwargs["trace_id"] == "trace_1" - assert calls[0].kwargs["name"] == "cosine_similarity" + assert calls[0].kwargs["name"] == "Cosine Similarity" assert calls[0].kwargs["value"] == 0.95 assert "cosine similarity" in calls[0].kwargs["comment"].lower() diff --git a/backend/app/tests/crud/evaluations/test_processing.py b/backend/app/tests/crud/evaluations/test_processing.py index 8d34900bc..29d62244d 100644 --- a/backend/app/tests/crud/evaluations/test_processing.py +++ b/backend/app/tests/crud/evaluations/test_processing.py @@ -580,7 +580,7 @@ async def test_process_completed_embedding_batch_success( assert "summary_scores" in result.score summary_scores = result.score["summary_scores"] cosine_score = next( - (s for s in summary_scores if s["name"] == "cosine_similarity"), None + (s for s in summary_scores if s["name"] == "Cosine Similarity"), None ) assert cosine_score is not None assert cosine_score["avg"] == 0.95 From ad328d5aa179944212617e3186c66d9205d45016 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 18 Mar 2026 12:45:54 +0530 Subject: [PATCH 7/8] moved it language exception to language module --- backend/app/api/routes/languages.py | 5 +++-- .../app/api/routes/stt_evaluations/dataset.py | 16 ++-------------- .../app/api/routes/tts_evaluations/dataset.py | 8 +------- backend/app/crud/language.py | 12 +++++++++--- 4 files changed, 15 insertions(+), 26 deletions(-) diff --git a/backend/app/api/routes/languages.py b/backend/app/api/routes/languages.py index f04896b94..eb0d06042 100644 --- a/backend/app/api/routes/languages.py +++ b/backend/app/api/routes/languages.py @@ -37,8 +37,9 @@ def get_language(session: SessionDep, auth_context: AuthContextDep, language_id: """ Retrieve a language by ID. """ - language = get_language_by_id(session=session, language_id=language_id) - if language is None: + try: + language = get_language_by_id(session=session, language_id=language_id) + except HTTPException: logger.error(f"[get_language] Language not found | language_id={language_id}") raise HTTPException(status_code=404, detail="Language not found") return APIResponse.success_response(language) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index fb18f4843..368d8587c 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -46,13 +46,7 @@ def create_dataset( """Create an STT evaluation dataset.""" # Validate language_id if dataset_create.language_id is not None: - language = get_language_by_id( - session=session, language_id=dataset_create.language_id - ) - if not language: - raise HTTPException( - status_code=400, detail="Invalid language_id: language not found" - ) + get_language_by_id(session=session, language_id=dataset_create.language_id) dataset, samples = upload_stt_dataset( session=session, @@ -231,13 +225,7 @@ def update_sample( logger.info(f"[update_sample] Updating sample | " f"sample_id: {sample_id}") if sample_update.language_id is not None: - language = get_language_by_id( - session=session, language_id=sample_update.language_id - ) - if not language: - raise HTTPException( - status_code=400, detail="Invalid language_id: language not found" - ) + get_language_by_id(session=session, language_id=sample_update.language_id) sample = update_stt_sample( session=session, diff --git a/backend/app/api/routes/tts_evaluations/dataset.py b/backend/app/api/routes/tts_evaluations/dataset.py index 2005a4dc3..115df8fb5 100644 --- a/backend/app/api/routes/tts_evaluations/dataset.py +++ b/backend/app/api/routes/tts_evaluations/dataset.py @@ -39,13 +39,7 @@ def create_dataset( """Create a TTS evaluation dataset.""" # Validate language_id if provided if dataset_create.language_id is not None: - language = get_language_by_id( - session=session, language_id=dataset_create.language_id - ) - if not language: - raise HTTPException( - status_code=400, detail="Invalid language_id: language not found" - ) + get_language_by_id(session=session, language_id=dataset_create.language_id) dataset = upload_tts_dataset( session=session, diff --git a/backend/app/crud/language.py b/backend/app/crud/language.py index eb4f7d273..9585587e2 100644 --- a/backend/app/crud/language.py +++ b/backend/app/crud/language.py @@ -1,6 +1,7 @@ import logging from typing import Optional +from fastapi import HTTPException from sqlmodel import Session, select from app.models import Language @@ -16,10 +17,15 @@ def get_languages(session: Session, skip: int = 0, limit: int = 100) -> list[Lan return list(session.exec(statement).all()) -def get_language_by_id(session: Session, language_id: int) -> Optional[Language]: - """Retrieve a language by its ID.""" +def get_language_by_id(session: Session, language_id: int) -> Language: + """Retrieve a language by its ID. Raises HTTPException if not found.""" statement = select(Language).where(Language.id == language_id) - return session.exec(statement).first() + language = session.exec(statement).first() + if not language: + raise HTTPException( + status_code=400, detail="Invalid language_id: language not found" + ) + return language def get_language_by_locale(session: Session, locale: str) -> Optional[Language]: From 586f728cc05cd9a3620ff90c744efeca2de6f45a Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 18 Mar 2026 12:53:33 +0530 Subject: [PATCH 8/8] code cleanups --- backend/app/api/routes/languages.py | 13 +++++++------ backend/app/api/routes/stt_evaluations/dataset.py | 5 +++-- backend/app/crud/language.py | 12 ++++++++---- backend/app/crud/stt_evaluations/dataset.py | 12 ++++-------- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/backend/app/api/routes/languages.py b/backend/app/api/routes/languages.py index eb0d06042..9a184ea81 100644 --- a/backend/app/api/routes/languages.py +++ b/backend/app/api/routes/languages.py @@ -1,6 +1,6 @@ import logging -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter from app.api.deps import AuthContextDep, SessionDep from app.crud.language import get_language_by_id, get_languages @@ -37,9 +37,10 @@ def get_language(session: SessionDep, auth_context: AuthContextDep, language_id: """ Retrieve a language by ID. """ - try: - language = get_language_by_id(session=session, language_id=language_id) - except HTTPException: - logger.error(f"[get_language] Language not found | language_id={language_id}") - raise HTTPException(status_code=404, detail="Language not found") + language = get_language_by_id( + session=session, + language_id=language_id, + status_code=404, + detail="Language not found", + ) return APIResponse.success_response(language) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index 368d8587c..b42a4b17e 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -12,7 +12,6 @@ from app.crud.stt_evaluations import ( get_samples_by_dataset_id, get_stt_dataset_by_id, - get_stt_sample_by_id, list_stt_datasets, update_stt_sample, ) @@ -162,7 +161,6 @@ def get_dataset( session=session, project_id=auth_context.project_.id ) - samples = [] for s in sample_records: signed_url = None if storage and s.file_id in file_map: @@ -236,6 +234,9 @@ def update_sample( ground_truth=sample_update.ground_truth, ) + if not sample: + raise HTTPException(status_code=404, detail="Sample not found") + return APIResponse.success_response( data=STTSamplePublic( id=sample.id, diff --git a/backend/app/crud/language.py b/backend/app/crud/language.py index 9585587e2..765e50542 100644 --- a/backend/app/crud/language.py +++ b/backend/app/crud/language.py @@ -17,14 +17,18 @@ def get_languages(session: Session, skip: int = 0, limit: int = 100) -> list[Lan return list(session.exec(statement).all()) -def get_language_by_id(session: Session, language_id: int) -> Language: +def get_language_by_id( + session: Session, + language_id: int, + *, + status_code: int = 400, + detail: str = "Invalid language_id: language not found", +) -> Language: """Retrieve a language by its ID. Raises HTTPException if not found.""" statement = select(Language).where(Language.id == language_id) language = session.exec(statement).first() if not language: - raise HTTPException( - status_code=400, detail="Invalid language_id: language not found" - ) + raise HTTPException(status_code=status_code, detail=detail) return language diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index 665c65239..2c559fb10 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -239,7 +239,7 @@ def update_stt_sample( project_id: int, language_id: int | None = None, ground_truth: str | None = None, -) -> STTSample: +) -> STTSample | None: """Update an STT sample's language and/or ground truth. Args: @@ -251,10 +251,7 @@ def update_stt_sample( ground_truth: Optional new ground truth transcription Returns: - STTSample: Updated sample - - Raises: - HTTPException: If sample not found + STTSample | None: Updated sample, or None if not found """ sample = get_stt_sample_by_id( session=session, @@ -264,7 +261,7 @@ def update_stt_sample( ) if not sample: - raise HTTPException(status_code=404, detail="Sample not found") + return None if language_id is not None: sample.language_id = language_id @@ -275,8 +272,7 @@ def update_stt_sample( sample.updated_at = now() session.add(sample) - session.commit() - session.refresh(sample) + session.flush() logger.info( f"[update_stt_sample] Sample updated | "