From a424d9ae8dd610ca0120f1fa3b23431b69bf4c2d Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Mon, 9 Mar 2026 16:12:39 +0530 Subject: [PATCH 1/7] Enhance dataset API: Add support for signed URLs and improve validation error handling --- .../app/api/docs/evaluation/get_dataset.md | 2 + backend/app/api/routes/evaluations/dataset.py | 20 ++++- backend/app/core/exception_handlers.py | 74 +++++++++++++++++-- backend/app/models/evaluation.py | 3 + backend/app/utils.py | 30 +++++++- 5 files changed, 116 insertions(+), 13 deletions(-) diff --git a/backend/app/api/docs/evaluation/get_dataset.md b/backend/app/api/docs/evaluation/get_dataset.md index a1a27276a..17530148e 100644 --- a/backend/app/api/docs/evaluation/get_dataset.md +++ b/backend/app/api/docs/evaluation/get_dataset.md @@ -1,3 +1,5 @@ Get details of a specific dataset by ID. Returns comprehensive dataset information including metadata (ID, name, item counts, duplication factor), Langfuse integration details (dataset ID), and the object store URL for the CSV file. + +Use `include_signed_url=true` to include a time-limited signed URL for downloading the dataset CSV directly from object storage. The signed URL expires after 1 hour. diff --git a/backend/app/api/routes/evaluations/dataset.py b/backend/app/api/routes/evaluations/dataset.py index 1ce42742a..c5b2015b1 100644 --- a/backend/app/api/routes/evaluations/dataset.py +++ b/backend/app/api/routes/evaluations/dataset.py @@ -14,6 +14,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission +from app.core.cloud import get_cloud_storage from app.crud.evaluations import ( get_dataset_by_id, list_datasets as list_evaluation_datasets, @@ -34,7 +35,9 @@ router = APIRouter(prefix="/evaluations/datasets", tags=["Evaluation"]) -def _dataset_to_response(dataset: EvaluationDataset) -> DatasetUploadResponse: +def _dataset_to_response( + dataset: EvaluationDataset, signed_url: str | None = None +) -> DatasetUploadResponse: """Convert a dataset model to a DatasetUploadResponse.""" return DatasetUploadResponse( dataset_id=dataset.id, @@ -44,6 +47,7 @@ def _dataset_to_response(dataset: EvaluationDataset) -> DatasetUploadResponse: duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1), langfuse_dataset_id=dataset.langfuse_dataset_id, object_store_url=dataset.object_store_url, + signed_url=signed_url, ) @@ -124,6 +128,9 @@ def get_dataset( dataset_id: int, _session: SessionDep, auth_context: AuthContextDep, + include_signed_url: bool = Query( + False, description="Include signed URLs for dataset" + ), ) -> APIResponse[DatasetUploadResponse]: """Get a specific evaluation dataset.""" logger.info( @@ -144,7 +151,16 @@ def get_dataset( status_code=404, detail=f"Dataset {dataset_id} not found or not accessible" ) - return APIResponse.success_response(data=_dataset_to_response(dataset)) + signed_url = None + if include_signed_url and dataset.object_store_url: + storage = get_cloud_storage( + session=_session, project_id=auth_context.project_.id + ) + signed_url = storage.get_signed_url(dataset.object_store_url) + + return APIResponse.success_response( + data=_dataset_to_response(dataset, signed_url=signed_url) + ) @router.delete( diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index f6e614f5d..a5ecd78d9 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -1,30 +1,90 @@ +import re +from collections import defaultdict + from fastapi import FastAPI, Request, HTTPException -from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError -from app.utils import APIResponse +from fastapi.responses import JSONResponse from starlette.status import ( HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SERVER_ERROR, ) +from app.utils import APIResponse + +_BRANCH_PATTERN = re.compile(r"^[A-Z]|[\[\]()]") + + +def _is_branch_identifier(part: str) -> bool: + return bool(part and isinstance(part, str) and _BRANCH_PATTERN.search(part)) + + +def _filter_union_branch_errors(errors: list[dict]) -> list[dict]: + try: + branch_errors: dict[str, dict[str, list[dict]]] = defaultdict( + lambda: defaultdict(list) + ) + non_union_errors: list[dict] = [] + + for err in errors: + loc = err.get("loc", ()) + branch_name = None + parent_field = None + for i, part in enumerate(loc): + if _is_branch_identifier(part): + branch_name = part + parent_field = loc[:i] if i > 0 else ("root",) + break + + if branch_name and parent_field: + branch_errors[str(parent_field)][branch_name].append(err) + else: + non_union_errors.append(err) + + filtered = list(non_union_errors) + for _parent, branches in branch_errors.items(): + if len(branches) <= 1: + for errs in branches.values(): + filtered.extend(errs) + else: + best_branch = min( + branches.items(), + key=lambda item: ( + sum(1 for e in item[1] if e.get("type") == "literal_error"), + ), + ) + filtered.extend(best_branch[1]) + + for err in filtered: + loc = err.get("loc", ()) + err["loc"] = tuple(p for p in loc if not _is_branch_identifier(p)) + + return filtered or errors + except Exception: + return errors + -def register_exception_handlers(app: FastAPI): +def register_exception_handlers(app: FastAPI) -> None: @app.exception_handler(RequestValidationError) - async def validation_error_handler(request: Request, exc: RequestValidationError): + async def validation_error_handler( + request: Request, exc: RequestValidationError + ) -> JSONResponse: + errors = _filter_union_branch_errors(exc.errors()) return JSONResponse( status_code=HTTP_422_UNPROCESSABLE_ENTITY, - content=APIResponse.failure_response(exc.errors()).model_dump(), + content=APIResponse.failure_response(errors).model_dump(), ) @app.exception_handler(HTTPException) - async def http_exception_handler(request: Request, exc: HTTPException): + async def http_exception_handler( + request: Request, exc: HTTPException + ) -> JSONResponse: return JSONResponse( status_code=exc.status_code, content=APIResponse.failure_response(exc.detail).model_dump(), ) @app.exception_handler(Exception) - async def generic_error_handler(request: Request, exc: Exception): + async def generic_error_handler(request: Request, exc: Exception) -> JSONResponse: return JSONResponse( status_code=HTTP_500_INTERNAL_SERVER_ERROR, content=APIResponse.failure_response( diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index 8944779d3..18e7749bf 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -43,6 +43,9 @@ class DatasetUploadResponse(BaseModel): object_store_url: str | None = Field( None, description="Object store URL if uploaded" ) + signed_url: str | None = Field( + None, description="A signed URL for downloading the dataset" + ) class EvaluationResult(BaseModel): diff --git a/backend/app/utils.py b/backend/app/utils.py index 9c1be2a11..9a6659abe 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -44,10 +44,16 @@ T = TypeVar("T") +class ValidationErrorDetail(BaseModel): + field: str + message: str + + class APIResponse(BaseModel, Generic[T]): success: bool data: Optional[T] = None error: Optional[str] = None + errors: Optional[list[ValidationErrorDetail]] = None metadata: Optional[Dict[str, Any]] = None @classmethod @@ -64,11 +70,27 @@ def failure_response( metadata: Optional[Dict[str, Any]] = None, ) -> "APIResponse[None]": if isinstance(error, list): # to handle cases when error is a list of errors - error_message = "\n".join([f"{err['loc']}: {err['msg']}" for err in error]) - else: - error_message = error + structured_errors = [] + for err in error: + loc = err.get("loc", ()) + parts = [str(p) for p in loc if p != "body"] + field = ".".join(parts) if parts else "unknown" + structured_errors.append( + ValidationErrorDetail( + field=str(field), message=str(err.get("msg", "")) + ) + ) + + return cls( + success=False, + data=data, + error="Validation failed", + errors=structured_errors, + metadata=metadata, + ) - return cls(success=False, data=data, error=error_message, metadata=metadata) + else: + return cls(success=False, data=data, error=error, metadata=metadata) @dataclass From 75478641c64e72bcb280928b8474d6789aed0c9c Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Mon, 9 Mar 2026 16:43:05 +0530 Subject: [PATCH 2/7] Refactor validation error assertions in tests to use structured error format --- .../collections/test_create_collections.py | 6 +- .../api/routes/test_doc_transformation_job.py | 45 ++- .../app/tests/api/routes/test_evaluation.py | 15 +- .../app/tests/api/routes/test_onboarding.py | 42 +- backend/app/tests/api/routes/test_threads.py | 3 +- .../app/tests/core/test_exception_handlers.py | 362 ++++++++++++++++++ 6 files changed, 431 insertions(+), 42 deletions(-) create mode 100644 backend/app/tests/core/test_exception_handlers.py diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 12db54ff3..66775bd8b 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -124,7 +124,9 @@ def test_collection_creation_vector_only_request_validation_error( assert body["success"] is False assert body["data"] is None assert body["metadata"] is None - assert ( + assert body["errors"] + assert any( "To create an Assistant, provide BOTH 'model' and 'instructions'" - in body["error"] + in e["message"] + for e in body["errors"] ) diff --git a/backend/app/tests/api/routes/test_doc_transformation_job.py b/backend/app/tests/api/routes/test_doc_transformation_job.py index cea74ff75..298f73ec0 100644 --- a/backend/app/tests/api/routes/test_doc_transformation_job.py +++ b/backend/app/tests/api/routes/test_doc_transformation_job.py @@ -55,7 +55,9 @@ def test_get_job_invalid_uuid_422( ) assert resp.status_code == 422 body = resp.json() - assert "error" in body and "valid UUID" in body["error"] + assert body["errors"] and any( + "valid UUID" in e["message"] for e in body["errors"] + ) def test_get_job_different_project_404( self, @@ -207,9 +209,12 @@ def test_get_jobs_with_empty_string( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "valid UUID" in body["error"] or "expected length" in body["error"] - assert "job_ids" in body["error"] + assert body["errors"] + assert any( + "valid UUID" in e["message"] or "expected length" in e["message"] + for e in body["errors"] + ) + assert any("job_ids" in e["field"] for e in body["errors"]) def test_get_jobs_with_whitespace_only( self, client: TestClient, user_api_key: TestAuthContext @@ -224,8 +229,8 @@ def test_get_jobs_with_whitespace_only( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "valid UUID" in body["error"] + assert body["errors"] + assert any("valid UUID" in e["message"] for e in body["errors"]) def test_get_jobs_invalid_uuid_format_422( self, client: TestClient, user_api_key: TestAuthContext @@ -242,9 +247,12 @@ def test_get_jobs_invalid_uuid_format_422( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "valid UUID" in body["error"] or "expected length" in body["error"] - assert "job_ids" in body["error"] + assert body["errors"] + assert any( + "valid UUID" in e["message"] or "expected length" in e["message"] + for e in body["errors"] + ) + assert any("job_ids" in e["field"] for e in body["errors"]) def test_get_jobs_mixed_valid_invalid_uuid_422( self, client: TestClient, db: Session, user_api_key: TestAuthContext @@ -266,12 +274,13 @@ def test_get_jobs_mixed_valid_invalid_uuid_422( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "job_ids" in body["error"] - assert ( - "valid UUID" in body["error"] - or "invalid character" in body["error"] - or "invalid length" in body["error"] + assert body["errors"] + assert any("job_ids" in e["field"] for e in body["errors"]) + assert any( + "valid UUID" in e["message"] + or "invalid character" in e["message"] + or "invalid length" in e["message"] + for e in body["errors"] ) def test_get_jobs_missing_parameter_422( @@ -287,9 +296,9 @@ def test_get_jobs_missing_parameter_422( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "Field required" in body["error"] - assert "job_ids" in body["error"] + assert body["errors"] + assert any("Field required" in e["message"] for e in body["errors"]) + assert any("job_ids" in e["field"] for e in body["errors"]) def test_get_jobs_different_project_not_found( self, diff --git a/backend/app/tests/api/routes/test_evaluation.py b/backend/app/tests/api/routes/test_evaluation.py index 0129fa44a..7a49fe80b 100644 --- a/backend/app/tests/api/routes/test_evaluation.py +++ b/backend/app/tests/api/routes/test_evaluation.py @@ -347,9 +347,11 @@ def test_upload_with_duplication_factor_below_minimum( assert response.status_code == 422 response_data = response.json() - # Check that the error mentions validation and minimum value - assert "error" in response_data - assert "greater than or equal to 1" in response_data["error"] + assert response_data["errors"] + assert any( + "greater than or equal to 1" in e["message"] + for e in response_data["errors"] + ) def test_upload_with_duplication_factor_above_maximum( self, @@ -372,9 +374,10 @@ def test_upload_with_duplication_factor_above_maximum( assert response.status_code == 422 response_data = response.json() - # Check that the error mentions validation and maximum value - assert "error" in response_data - assert "less than or equal to 5" in response_data["error"] + assert response_data["errors"] + assert any( + "less than or equal to 5" in e["message"] for e in response_data["errors"] + ) def test_upload_with_duplication_factor_boundary_minimum( self, diff --git a/backend/app/tests/api/routes/test_onboarding.py b/backend/app/tests/api/routes/test_onboarding.py index 0b3f14971..3a6f13b76 100644 --- a/backend/app/tests/api/routes/test_onboarding.py +++ b/backend/app/tests/api/routes/test_onboarding.py @@ -197,8 +197,10 @@ def test_onboard_project_invalid_provider( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) def test_onboard_project_non_dict_values_in_credential( @@ -227,9 +229,13 @@ def test_onboard_project_non_dict_values_in_credential( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] - assert "must be an object/dict" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) + assert any( + "must be an object/dict" in e["message"] for e in error_response["errors"] + ) def test_onboard_project_missing_required_fields_for_openai( @@ -258,9 +264,11 @@ def test_onboard_project_missing_required_fields_for_openai( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] - assert "openai" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) + assert any("openai" in e["message"] for e in error_response["errors"]) def test_onboard_project_missing_required_fields_for_langfuse( @@ -291,9 +299,11 @@ def test_onboard_project_missing_required_fields_for_langfuse( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] - assert "langfuse" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) + assert any("langfuse" in e["message"] for e in error_response["errors"]) def test_onboard_project_aggregates_multiple_credential_errors( @@ -325,7 +335,9 @@ def test_onboard_project_aggregates_multiple_credential_errors( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] - assert "[0]" in error_response["error"] - assert "[1]" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) + assert any("[0]" in e["message"] for e in error_response["errors"]) + assert any("[1]" in e["message"] for e in error_response["errors"]) diff --git a/backend/app/tests/api/routes/test_threads.py b/backend/app/tests/api/routes/test_threads.py index e9a11fad8..9a1f297a6 100644 --- a/backend/app/tests/api/routes/test_threads.py +++ b/backend/app/tests/api/routes/test_threads.py @@ -601,4 +601,5 @@ def test_threads_start_missing_question( assert response.status_code == 422 # Unprocessable Entity (FastAPI will raise 422) error_response = response.json() assert error_response["success"] is False - assert "question" in error_response["error"] + assert error_response["errors"] + assert any("question" in e["field"] for e in error_response["errors"]) diff --git a/backend/app/tests/core/test_exception_handlers.py b/backend/app/tests/core/test_exception_handlers.py new file mode 100644 index 000000000..9ed3fad9e --- /dev/null +++ b/backend/app/tests/core/test_exception_handlers.py @@ -0,0 +1,362 @@ +from unittest.mock import patch + +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.core.config import settings +from app.core.exception_handlers import _filter_union_branch_errors +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.test_data import create_test_evaluation_dataset + + +# --------------------------------------------------------------------------- +# Unit tests for _filter_union_branch_errors +# --------------------------------------------------------------------------- + + +class TestFilterUnionBranchErrors: + """Unit tests for the discriminated union branch error filter.""" + + def test_no_union_errors_returned_unchanged(self) -> None: + """Non-union errors pass through unchanged.""" + errors = [ + {"type": "missing", "loc": ("body", "name"), "msg": "Field required"}, + { + "type": "missing", + "loc": ("body", "config_blob"), + "msg": "Field required", + }, + ] + result = _filter_union_branch_errors(errors) + assert result == errors + + def test_single_branch_errors_passed_through(self) -> None: + """When only one branch has errors it is included without filtering.""" + errors = [ + { + "type": "missing", + "loc": ("body", "completion", "KaapiCompletionConfig", "type"), + "msg": "Field required", + } + ] + result = _filter_union_branch_errors(errors) + assert len(result) == 1 + # Branch identifier stripped from loc + assert "KaapiCompletionConfig" not in result[0]["loc"] + + def test_picks_branch_with_fewer_literal_errors(self) -> None: + """When multiple branches exist, the one with fewer literal_errors wins.""" + errors = [ + # NativeCompletionConfig branch — provider literal_error (wrong value) + { + "type": "literal_error", + "loc": ("body", "completion", "NativeCompletionConfig", "provider"), + "msg": "Input should be 'openai-native'", + }, + { + "type": "missing", + "loc": ("body", "completion", "NativeCompletionConfig", "params"), + "msg": "Field required", + }, + # KaapiCompletionConfig branch — no literal_error (provider matched) + { + "type": "missing", + "loc": ("body", "completion", "KaapiCompletionConfig", "type"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("body", "completion", "KaapiCompletionConfig", "params"), + "msg": "Field required", + }, + ] + result = _filter_union_branch_errors(errors) + # Only KaapiCompletionConfig errors should remain + assert len(result) == 2 + for err in result: + assert "NativeCompletionConfig" not in err["loc"] + assert "KaapiCompletionConfig" not in err["loc"] + + def test_branch_identifiers_stripped_from_loc(self) -> None: + """Branch class names and pydantic internals are removed from loc tuples.""" + errors = [ + { + "type": "missing", + "loc": ( + "body", + "config_blob", + "completion", + "function-after[validate_params(), KaapiCompletionConfig]", + "params", + ), + "msg": "Field required", + } + ] + result = _filter_union_branch_errors(errors) + assert len(result) == 1 + loc = result[0]["loc"] + assert "function-after[validate_params(), KaapiCompletionConfig]" not in loc + assert loc == ("body", "config_blob", "completion", "params") + + def test_non_union_errors_preserved_alongside_union_errors(self) -> None: + """Top-level field errors coexist with filtered union branch errors.""" + errors = [ + # Top-level missing field (not a union branch error) + {"type": "missing", "loc": ("body", "name"), "msg": "Field required"}, + # Union branch errors + { + "type": "literal_error", + "loc": ("body", "completion", "NativeCompletionConfig", "provider"), + "msg": "Input should be 'openai-native'", + }, + { + "type": "missing", + "loc": ("body", "completion", "KaapiCompletionConfig", "type"), + "msg": "Field required", + }, + ] + result = _filter_union_branch_errors(errors) + # name error + KaapiCompletionConfig error + assert len(result) == 2 + locs = [r["loc"] for r in result] + assert ("body", "name") in locs + + def test_empty_errors_list(self) -> None: + """Empty list returns empty list without raising.""" + assert _filter_union_branch_errors([]) == [] + + def test_fallback_on_malformed_input(self) -> None: + """Malformed errors are returned as-is via the try/except fallback.""" + # Passing non-dict items — should not raise, returns original list + malformed = [None, 42] # type: ignore[list-item] + result = _filter_union_branch_errors(malformed) + assert result == malformed + + +# --------------------------------------------------------------------------- +# Integration tests — validation error response format via API +# --------------------------------------------------------------------------- + + +class TestValidationErrorResponseFormat: + """Test that the structured errors array is returned correctly by the API.""" + + def test_missing_required_field_returns_structured_errors( + self, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Missing required field returns {field, message} structured error.""" + # config_blob is present but name is missing + response = client.post( + f"{settings.API_V1_STR}/configs/", + headers={"X-API-KEY": user_api_key.key}, + json={ + "config_blob": { + "completion": { + "provider": "openai", + "type": "text", + "params": {"model": "gpt-4o-mini"}, + } + } + }, + ) + assert response.status_code == 422 + body = response.json() + assert body["success"] is False + assert body["error"] == "Validation failed" + assert body["errors"] is not None + assert isinstance(body["errors"], list) + + fields = [e["field"] for e in body["errors"]] + assert "name" in fields + + name_error = next(e for e in body["errors"] if e["field"] == "name") + assert "required" in name_error["message"].lower() + + def test_union_branch_noise_not_in_response( + self, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """NativeCompletionConfig errors must not appear when using openai provider.""" + response = client.post( + f"{settings.API_V1_STR}/configs/", + headers={"X-API-KEY": user_api_key.key}, + json={ + "name": "test-config", + "config_blob": { + "completion": { + "provider": "openai", + # type and params are intentionally missing to trigger errors + } + }, + }, + ) + assert response.status_code == 422 + body = response.json() + assert body["errors"] is not None + + # No NativeCompletionConfig literal errors should be in the response + for error in body["errors"]: + assert "openai-native" not in error["message"] + assert "NativeCompletionConfig" not in error["field"] + + def test_nested_field_path_in_error( + self, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Field path shows full dotted path, not just the last segment.""" + response = client.post( + f"{settings.API_V1_STR}/configs/", + headers={"X-API-KEY": user_api_key.key}, + json={ + "name": "test-config", + "config_blob": { + "completion": { + "provider": "openai", + "type": "text", + # params missing — error should show config_blob.completion.params + } + }, + }, + ) + assert response.status_code == 422 + body = response.json() + fields = [e["field"] for e in body["errors"]] + # Should show full path, not just "params" + assert any("." in f for f in fields) + assert any("params" in f for f in fields) + + def test_error_response_structure( + self, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Validation error response always has success=False, error summary, and errors array.""" + response = client.post( + f"{settings.API_V1_STR}/configs/", + headers={"X-API-KEY": user_api_key.key}, + json={}, + ) + assert response.status_code == 422 + body = response.json() + assert body["success"] is False + assert body["data"] is None + assert body["error"] == "Validation failed" + assert isinstance(body["errors"], list) + assert len(body["errors"]) > 0 + for err in body["errors"]: + assert "field" in err + assert "message" in err + + +# --------------------------------------------------------------------------- +# Integration tests — dataset signed URL +# --------------------------------------------------------------------------- + + +class TestDatasetSignedUrl: + """Test GET /evaluations/datasets/{id} signed URL feature.""" + + def test_get_dataset_without_signed_url( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """By default signed_url is not included in the response.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + + response = client.get( + f"{settings.API_V1_STR}/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["data"]["dataset_id"] == dataset.id + assert body["data"].get("signed_url") is None + + def test_get_dataset_with_signed_url( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """include_signed_url=true returns a presigned URL.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + + mock_signed_url = "https://s3.amazonaws.com/bucket/key?X-Amz-Signature=abc123" + + with patch( + "app.api.routes.evaluations.dataset.get_cloud_storage" + ) as mock_get_storage: + mock_storage = mock_get_storage.return_value + mock_storage.get_signed_url.return_value = mock_signed_url + + response = client.get( + f"{settings.API_V1_STR}/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + params={"include_signed_url": True}, + ) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["data"]["signed_url"] == mock_signed_url + + def test_get_dataset_signed_url_none_when_no_object_store_url( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """signed_url is None when dataset has no object_store_url.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + # Ensure no object_store_url + dataset.object_store_url = None + db.add(dataset) + db.commit() + + with patch( + "app.api.routes.evaluations.dataset.get_cloud_storage" + ) as mock_get_storage: + response = client.get( + f"{settings.API_V1_STR}/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + params={"include_signed_url": True}, + ) + mock_get_storage.assert_not_called() + + assert response.status_code == 200 + body = response.json() + assert body["data"].get("signed_url") is None + + def test_get_dataset_not_found( + self, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Non-existent dataset returns 404.""" + response = client.get( + f"{settings.API_V1_STR}/evaluations/datasets/999999", + headers={"X-API-KEY": user_api_key.key}, + ) + assert response.status_code == 404 + body = response.json() + assert body["success"] is False From a912a3e5061897b9436c4ce20e3953103b84376c Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:53:25 +0530 Subject: [PATCH 3/7] Refactor error filtering in _filter_union_branch_errors to improve uniqueness and handling of literal errors --- backend/app/core/exception_handlers.py | 30 ++++++++++++++++++++------ 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index a5ecd78d9..c507fa977 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -19,6 +19,10 @@ def _is_branch_identifier(part: str) -> bool: def _filter_union_branch_errors(errors: list[dict]) -> list[dict]: + """When a field is a Union type, pydantic returns errors for every possible branch. + + This picks the branch where the validation error happend. + """ try: branch_errors: dict[str, dict[str, list[dict]]] = defaultdict( lambda: defaultdict(list) @@ -46,19 +50,31 @@ def _filter_union_branch_errors(errors: list[dict]) -> list[dict]: for errs in branches.values(): filtered.extend(errs) else: - best_branch = min( - branches.items(), - key=lambda item: ( - sum(1 for e in item[1] if e.get("type") == "literal_error"), - ), + # NOTE: Keep all branches tied for fewest literal errors + best_count = min( + sum(1 for e in errs if e.get("type") == "literal_error") + for errs in branches.values() ) - filtered.extend(best_branch[1]) + for errs in branches.values(): + if ( + sum(1 for e in errs if e.get("type") == "literal_error") + <= best_count + ): + filtered.extend(errs) for err in filtered: loc = err.get("loc", ()) err["loc"] = tuple(p for p in loc if not _is_branch_identifier(p)) - return filtered or errors + seen_errors: set[tuple] = set() + unique_errors: list[dict] = [] + for error in filtered: + error_key = (tuple(error.get("loc", ())), error.get("msg", "")) + if error_key not in seen_errors: + seen_errors.add(error_key) + unique_errors.append(error) + + return unique_errors or errors except Exception: return errors From eea7a4875d9cfd4d1cd54c9267d9ca9e17bcd178 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Mon, 9 Mar 2026 18:08:55 +0530 Subject: [PATCH 4/7] Add tests for dataset signed URL handling and validation error responses --- .../app/tests/api/routes/test_evaluation.py | 65 ++++ .../app/tests/core/test_exception_handlers.py | 314 +++--------------- 2 files changed, 111 insertions(+), 268 deletions(-) diff --git a/backend/app/tests/api/routes/test_evaluation.py b/backend/app/tests/api/routes/test_evaluation.py index 7a49fe80b..4b751a59a 100644 --- a/backend/app/tests/api/routes/test_evaluation.py +++ b/backend/app/tests/api/routes/test_evaluation.py @@ -1379,6 +1379,71 @@ def test_get_dataset_not_found( ) assert "not found" in error_str.lower() or "not accessible" in error_str.lower() + def test_default_no_signed_url( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Test that signed_url is not included by default.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + response = client.get( + f"/api/v1/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + ) + assert response.status_code == 200 + assert response.json()["data"].get("signed_url") is None + + def test_include_signed_url( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Test that signed_url is returned when requested.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + with patch("app.api.routes.evaluations.dataset.get_cloud_storage") as mock: + mock.return_value.get_signed_url.return_value = "https://signed.url" + response = client.get( + f"/api/v1/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + params={"include_signed_url": True}, + ) + assert response.json()["data"]["signed_url"] == "https://signed.url" + + def test_no_object_store_url_skips_signing( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Test that signing is skipped when dataset has no object_store_url.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + dataset.object_store_url = None + db.add(dataset) + db.commit() + + with patch("app.api.routes.evaluations.dataset.get_cloud_storage") as mock: + response = client.get( + f"/api/v1/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + params={"include_signed_url": True}, + ) + mock.assert_not_called() + assert response.json()["data"].get("signed_url") is None + class TestDeleteDataset: """Test DELETE /evaluations/datasets/{dataset_id} endpoint.""" diff --git a/backend/app/tests/core/test_exception_handlers.py b/backend/app/tests/core/test_exception_handlers.py index 9ed3fad9e..8971f1b47 100644 --- a/backend/app/tests/core/test_exception_handlers.py +++ b/backend/app/tests/core/test_exception_handlers.py @@ -1,362 +1,140 @@ -from unittest.mock import patch - from fastapi.testclient import TestClient -from sqlmodel import Session from app.core.config import settings from app.core.exception_handlers import _filter_union_branch_errors from app.tests.utils.auth import TestAuthContext -from app.tests.utils.test_data import create_test_evaluation_dataset - - -# --------------------------------------------------------------------------- -# Unit tests for _filter_union_branch_errors -# --------------------------------------------------------------------------- class TestFilterUnionBranchErrors: - """Unit tests for the discriminated union branch error filter.""" + """Unit tests for _filter_union_branch_errors.""" - def test_no_union_errors_returned_unchanged(self) -> None: - """Non-union errors pass through unchanged.""" + def test_non_union_errors_pass_through(self) -> None: errors = [ {"type": "missing", "loc": ("body", "name"), "msg": "Field required"}, - { - "type": "missing", - "loc": ("body", "config_blob"), - "msg": "Field required", - }, - ] - result = _filter_union_branch_errors(errors) - assert result == errors - - def test_single_branch_errors_passed_through(self) -> None: - """When only one branch has errors it is included without filtering.""" - errors = [ - { - "type": "missing", - "loc": ("body", "completion", "KaapiCompletionConfig", "type"), - "msg": "Field required", - } ] - result = _filter_union_branch_errors(errors) - assert len(result) == 1 - # Branch identifier stripped from loc - assert "KaapiCompletionConfig" not in result[0]["loc"] + assert _filter_union_branch_errors(errors) == errors def test_picks_branch_with_fewer_literal_errors(self) -> None: - """When multiple branches exist, the one with fewer literal_errors wins.""" errors = [ - # NativeCompletionConfig branch — provider literal_error (wrong value) { "type": "literal_error", - "loc": ("body", "completion", "NativeCompletionConfig", "provider"), - "msg": "Input should be 'openai-native'", + "loc": ("body", "c", "NativeConfig", "provider"), + "msg": "bad", }, { "type": "missing", - "loc": ("body", "completion", "NativeCompletionConfig", "params"), + "loc": ("body", "c", "NativeConfig", "params"), "msg": "Field required", }, - # KaapiCompletionConfig branch — no literal_error (provider matched) { "type": "missing", - "loc": ("body", "completion", "KaapiCompletionConfig", "type"), + "loc": ("body", "c", "KaapiConfig", "type"), "msg": "Field required", }, { "type": "missing", - "loc": ("body", "completion", "KaapiCompletionConfig", "params"), + "loc": ("body", "c", "KaapiConfig", "params"), "msg": "Field required", }, ] result = _filter_union_branch_errors(errors) - # Only KaapiCompletionConfig errors should remain assert len(result) == 2 for err in result: - assert "NativeCompletionConfig" not in err["loc"] - assert "KaapiCompletionConfig" not in err["loc"] + assert "NativeConfig" not in err["loc"] + + def test_tied_branches_keep_both_and_dedup(self) -> None: + """When two branches have the same literal_error count, both are kept but duplicates removed.""" + errors = [ + { + "type": "missing", + "loc": ("body", "c", "BranchA", "x"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("body", "c", "BranchB", "x"), + "msg": "Field required", + }, + ] + result = _filter_union_branch_errors(errors) + assert len(result) == 1 + assert result[0]["loc"] == ("body", "c", "x") - def test_branch_identifiers_stripped_from_loc(self) -> None: - """Branch class names and pydantic internals are removed from loc tuples.""" + def test_strips_branch_identifiers_from_loc(self) -> None: errors = [ { "type": "missing", "loc": ( "body", - "config_blob", + "cfg", "completion", - "function-after[validate_params(), KaapiCompletionConfig]", + "function-after[validate_params(), Foo]", "params", ), "msg": "Field required", } ] result = _filter_union_branch_errors(errors) - assert len(result) == 1 - loc = result[0]["loc"] - assert "function-after[validate_params(), KaapiCompletionConfig]" not in loc - assert loc == ("body", "config_blob", "completion", "params") + assert result[0]["loc"] == ("body", "cfg", "completion", "params") - def test_non_union_errors_preserved_alongside_union_errors(self) -> None: - """Top-level field errors coexist with filtered union branch errors.""" + def test_non_union_preserved_with_union(self) -> None: errors = [ - # Top-level missing field (not a union branch error) {"type": "missing", "loc": ("body", "name"), "msg": "Field required"}, - # Union branch errors { "type": "literal_error", - "loc": ("body", "completion", "NativeCompletionConfig", "provider"), - "msg": "Input should be 'openai-native'", + "loc": ("body", "c", "NativeConfig", "p"), + "msg": "bad", }, { "type": "missing", - "loc": ("body", "completion", "KaapiCompletionConfig", "type"), + "loc": ("body", "c", "KaapiConfig", "t"), "msg": "Field required", }, ] result = _filter_union_branch_errors(errors) - # name error + KaapiCompletionConfig error assert len(result) == 2 locs = [r["loc"] for r in result] assert ("body", "name") in locs - def test_empty_errors_list(self) -> None: - """Empty list returns empty list without raising.""" + def test_empty_list(self) -> None: assert _filter_union_branch_errors([]) == [] def test_fallback_on_malformed_input(self) -> None: - """Malformed errors are returned as-is via the try/except fallback.""" - # Passing non-dict items — should not raise, returns original list malformed = [None, 42] # type: ignore[list-item] result = _filter_union_branch_errors(malformed) assert result == malformed -# --------------------------------------------------------------------------- -# Integration tests — validation error response format via API -# --------------------------------------------------------------------------- - - -class TestValidationErrorResponseFormat: - """Test that the structured errors array is returned correctly by the API.""" +class TestValidationErrorResponse: + """Integration: structured errors via configs endpoint.""" - def test_missing_required_field_returns_structured_errors( - self, - client: TestClient, - user_api_key: TestAuthContext, + def test_structured_error_format( + self, client: TestClient, user_api_key: TestAuthContext ) -> None: - """Missing required field returns {field, message} structured error.""" - # config_blob is present but name is missing response = client.post( f"{settings.API_V1_STR}/configs/", headers={"X-API-KEY": user_api_key.key}, - json={ - "config_blob": { - "completion": { - "provider": "openai", - "type": "text", - "params": {"model": "gpt-4o-mini"}, - } - } - }, + json={}, ) assert response.status_code == 422 body = response.json() assert body["success"] is False assert body["error"] == "Validation failed" - assert body["errors"] is not None assert isinstance(body["errors"], list) + assert all("field" in e and "message" in e for e in body["errors"]) - fields = [e["field"] for e in body["errors"]] - assert "name" in fields - - name_error = next(e for e in body["errors"] if e["field"] == "name") - assert "required" in name_error["message"].lower() - - def test_union_branch_noise_not_in_response( - self, - client: TestClient, - user_api_key: TestAuthContext, + def test_union_noise_filtered( + self, client: TestClient, user_api_key: TestAuthContext ) -> None: - """NativeCompletionConfig errors must not appear when using openai provider.""" response = client.post( f"{settings.API_V1_STR}/configs/", headers={"X-API-KEY": user_api_key.key}, json={ "name": "test-config", - "config_blob": { - "completion": { - "provider": "openai", - # type and params are intentionally missing to trigger errors - } - }, + "config_blob": {"completion": {"provider": "openai"}}, }, ) assert response.status_code == 422 - body = response.json() - assert body["errors"] is not None - - # No NativeCompletionConfig literal errors should be in the response - for error in body["errors"]: + for error in response.json()["errors"]: assert "openai-native" not in error["message"] assert "NativeCompletionConfig" not in error["field"] - - def test_nested_field_path_in_error( - self, - client: TestClient, - user_api_key: TestAuthContext, - ) -> None: - """Field path shows full dotted path, not just the last segment.""" - response = client.post( - f"{settings.API_V1_STR}/configs/", - headers={"X-API-KEY": user_api_key.key}, - json={ - "name": "test-config", - "config_blob": { - "completion": { - "provider": "openai", - "type": "text", - # params missing — error should show config_blob.completion.params - } - }, - }, - ) - assert response.status_code == 422 - body = response.json() - fields = [e["field"] for e in body["errors"]] - # Should show full path, not just "params" - assert any("." in f for f in fields) - assert any("params" in f for f in fields) - - def test_error_response_structure( - self, - client: TestClient, - user_api_key: TestAuthContext, - ) -> None: - """Validation error response always has success=False, error summary, and errors array.""" - response = client.post( - f"{settings.API_V1_STR}/configs/", - headers={"X-API-KEY": user_api_key.key}, - json={}, - ) - assert response.status_code == 422 - body = response.json() - assert body["success"] is False - assert body["data"] is None - assert body["error"] == "Validation failed" - assert isinstance(body["errors"], list) - assert len(body["errors"]) > 0 - for err in body["errors"]: - assert "field" in err - assert "message" in err - - -# --------------------------------------------------------------------------- -# Integration tests — dataset signed URL -# --------------------------------------------------------------------------- - - -class TestDatasetSignedUrl: - """Test GET /evaluations/datasets/{id} signed URL feature.""" - - def test_get_dataset_without_signed_url( - self, - db: Session, - client: TestClient, - user_api_key: TestAuthContext, - ) -> None: - """By default signed_url is not included in the response.""" - dataset = create_test_evaluation_dataset( - db=db, - organization_id=user_api_key.organization_id, - project_id=user_api_key.project_id, - ) - - response = client.get( - f"{settings.API_V1_STR}/evaluations/datasets/{dataset.id}", - headers={"X-API-KEY": user_api_key.key}, - ) - - assert response.status_code == 200 - body = response.json() - assert body["success"] is True - assert body["data"]["dataset_id"] == dataset.id - assert body["data"].get("signed_url") is None - - def test_get_dataset_with_signed_url( - self, - db: Session, - client: TestClient, - user_api_key: TestAuthContext, - ) -> None: - """include_signed_url=true returns a presigned URL.""" - dataset = create_test_evaluation_dataset( - db=db, - organization_id=user_api_key.organization_id, - project_id=user_api_key.project_id, - ) - - mock_signed_url = "https://s3.amazonaws.com/bucket/key?X-Amz-Signature=abc123" - - with patch( - "app.api.routes.evaluations.dataset.get_cloud_storage" - ) as mock_get_storage: - mock_storage = mock_get_storage.return_value - mock_storage.get_signed_url.return_value = mock_signed_url - - response = client.get( - f"{settings.API_V1_STR}/evaluations/datasets/{dataset.id}", - headers={"X-API-KEY": user_api_key.key}, - params={"include_signed_url": True}, - ) - - assert response.status_code == 200 - body = response.json() - assert body["success"] is True - assert body["data"]["signed_url"] == mock_signed_url - - def test_get_dataset_signed_url_none_when_no_object_store_url( - self, - db: Session, - client: TestClient, - user_api_key: TestAuthContext, - ) -> None: - """signed_url is None when dataset has no object_store_url.""" - dataset = create_test_evaluation_dataset( - db=db, - organization_id=user_api_key.organization_id, - project_id=user_api_key.project_id, - ) - # Ensure no object_store_url - dataset.object_store_url = None - db.add(dataset) - db.commit() - - with patch( - "app.api.routes.evaluations.dataset.get_cloud_storage" - ) as mock_get_storage: - response = client.get( - f"{settings.API_V1_STR}/evaluations/datasets/{dataset.id}", - headers={"X-API-KEY": user_api_key.key}, - params={"include_signed_url": True}, - ) - mock_get_storage.assert_not_called() - - assert response.status_code == 200 - body = response.json() - assert body["data"].get("signed_url") is None - - def test_get_dataset_not_found( - self, - client: TestClient, - user_api_key: TestAuthContext, - ) -> None: - """Non-existent dataset returns 404.""" - response = client.get( - f"{settings.API_V1_STR}/evaluations/datasets/999999", - headers={"X-API-KEY": user_api_key.key}, - ) - assert response.status_code == 404 - body = response.json() - assert body["success"] is False From 6769017d144fd5614bdcf2bb13bd3d7a3b53af5a Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 10 Mar 2026 00:52:36 +0530 Subject: [PATCH 5/7] Fix: Correct session parameter in get_signed_url call for dataset retrieval --- backend/app/api/routes/evaluations/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/api/routes/evaluations/dataset.py b/backend/app/api/routes/evaluations/dataset.py index 582c2b3df..9bf5e1e44 100644 --- a/backend/app/api/routes/evaluations/dataset.py +++ b/backend/app/api/routes/evaluations/dataset.py @@ -154,7 +154,7 @@ def get_dataset( signed_url = None if include_signed_url and dataset.object_store_url: storage = get_cloud_storage( - session=_session, project_id=auth_context.project_.id + session=session, project_id=auth_context.project_.id ) signed_url = storage.get_signed_url(dataset.object_store_url) From 634658dddd7352004b02ed8a2f3a903f1bcd5b2e Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Wed, 11 Mar 2026 11:42:18 +0530 Subject: [PATCH 6/7] change URLs to URL --- backend/app/api/routes/evaluations/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/api/routes/evaluations/dataset.py b/backend/app/api/routes/evaluations/dataset.py index 9bf5e1e44..a63eeba42 100644 --- a/backend/app/api/routes/evaluations/dataset.py +++ b/backend/app/api/routes/evaluations/dataset.py @@ -129,7 +129,7 @@ def get_dataset( session: SessionDep, auth_context: AuthContextDep, include_signed_url: bool = Query( - False, description="Include signed URLs for dataset" + False, description="Include signed URL for dataset" ), ) -> APIResponse[DatasetUploadResponse]: """Get a specific evaluation dataset.""" From d8432e1a48d61dec34374427a8ce6561fcd6f3a7 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:41:39 +0530 Subject: [PATCH 7/7] Fix: Clarify docstring for _filter_union_branch_errors function --- backend/app/core/exception_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index c507fa977..1ac5053e8 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -21,7 +21,7 @@ def _is_branch_identifier(part: str) -> bool: def _filter_union_branch_errors(errors: list[dict]) -> list[dict]: """When a field is a Union type, pydantic returns errors for every possible branch. - This picks the branch where the validation error happend. + This function picks the branch where the validation error happend. """ try: branch_errors: dict[str, dict[str, list[dict]]] = defaultdict(