diff --git a/backend/app/api/docs/evaluation/get_dataset.md b/backend/app/api/docs/evaluation/get_dataset.md index a1a27276a..17530148e 100644 --- a/backend/app/api/docs/evaluation/get_dataset.md +++ b/backend/app/api/docs/evaluation/get_dataset.md @@ -1,3 +1,5 @@ Get details of a specific dataset by ID. Returns comprehensive dataset information including metadata (ID, name, item counts, duplication factor), Langfuse integration details (dataset ID), and the object store URL for the CSV file. + +Use `include_signed_url=true` to include a time-limited signed URL for downloading the dataset CSV directly from object storage. The signed URL expires after 1 hour. diff --git a/backend/app/api/routes/evaluations/dataset.py b/backend/app/api/routes/evaluations/dataset.py index 558a92f2c..a63eeba42 100644 --- a/backend/app/api/routes/evaluations/dataset.py +++ b/backend/app/api/routes/evaluations/dataset.py @@ -14,6 +14,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission +from app.core.cloud import get_cloud_storage from app.crud.evaluations import ( get_dataset_by_id, list_datasets as list_evaluation_datasets, @@ -34,7 +35,9 @@ router = APIRouter(prefix="/evaluations/datasets", tags=["Evaluation"]) -def _dataset_to_response(dataset: EvaluationDataset) -> DatasetUploadResponse: +def _dataset_to_response( + dataset: EvaluationDataset, signed_url: str | None = None +) -> DatasetUploadResponse: """Convert a dataset model to a DatasetUploadResponse.""" return DatasetUploadResponse( dataset_id=dataset.id, @@ -44,6 +47,7 @@ def _dataset_to_response(dataset: EvaluationDataset) -> DatasetUploadResponse: duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1), langfuse_dataset_id=dataset.langfuse_dataset_id, object_store_url=dataset.object_store_url, + signed_url=signed_url, ) @@ -124,6 +128,9 @@ def get_dataset( dataset_id: int, session: SessionDep, auth_context: AuthContextDep, + include_signed_url: bool = Query( + False, description="Include signed URL for dataset" + ), ) -> APIResponse[DatasetUploadResponse]: """Get a specific evaluation dataset.""" logger.info( @@ -144,7 +151,16 @@ def get_dataset( status_code=404, detail=f"Dataset {dataset_id} not found or not accessible" ) - return APIResponse.success_response(data=_dataset_to_response(dataset)) + signed_url = None + if include_signed_url and dataset.object_store_url: + storage = get_cloud_storage( + session=session, project_id=auth_context.project_.id + ) + signed_url = storage.get_signed_url(dataset.object_store_url) + + return APIResponse.success_response( + data=_dataset_to_response(dataset, signed_url=signed_url) + ) @router.delete( diff --git a/backend/app/core/exception_handlers.py b/backend/app/core/exception_handlers.py index f6e614f5d..1ac5053e8 100644 --- a/backend/app/core/exception_handlers.py +++ b/backend/app/core/exception_handlers.py @@ -1,30 +1,106 @@ +import re +from collections import defaultdict + from fastapi import FastAPI, Request, HTTPException -from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError -from app.utils import APIResponse +from fastapi.responses import JSONResponse from starlette.status import ( HTTP_422_UNPROCESSABLE_ENTITY, HTTP_500_INTERNAL_SERVER_ERROR, ) +from app.utils import APIResponse + +_BRANCH_PATTERN = re.compile(r"^[A-Z]|[\[\]()]") + + +def _is_branch_identifier(part: str) -> bool: + return bool(part and isinstance(part, str) and _BRANCH_PATTERN.search(part)) + + +def _filter_union_branch_errors(errors: list[dict]) -> list[dict]: + """When a field is a Union type, pydantic returns errors for every possible branch. + + This function picks the branch where the validation error happend. + """ + try: + branch_errors: dict[str, dict[str, list[dict]]] = defaultdict( + lambda: defaultdict(list) + ) + non_union_errors: list[dict] = [] + + for err in errors: + loc = err.get("loc", ()) + branch_name = None + parent_field = None + for i, part in enumerate(loc): + if _is_branch_identifier(part): + branch_name = part + parent_field = loc[:i] if i > 0 else ("root",) + break + + if branch_name and parent_field: + branch_errors[str(parent_field)][branch_name].append(err) + else: + non_union_errors.append(err) + + filtered = list(non_union_errors) + for _parent, branches in branch_errors.items(): + if len(branches) <= 1: + for errs in branches.values(): + filtered.extend(errs) + else: + # NOTE: Keep all branches tied for fewest literal errors + best_count = min( + sum(1 for e in errs if e.get("type") == "literal_error") + for errs in branches.values() + ) + for errs in branches.values(): + if ( + sum(1 for e in errs if e.get("type") == "literal_error") + <= best_count + ): + filtered.extend(errs) + + for err in filtered: + loc = err.get("loc", ()) + err["loc"] = tuple(p for p in loc if not _is_branch_identifier(p)) + + seen_errors: set[tuple] = set() + unique_errors: list[dict] = [] + for error in filtered: + error_key = (tuple(error.get("loc", ())), error.get("msg", "")) + if error_key not in seen_errors: + seen_errors.add(error_key) + unique_errors.append(error) + + return unique_errors or errors + except Exception: + return errors + -def register_exception_handlers(app: FastAPI): +def register_exception_handlers(app: FastAPI) -> None: @app.exception_handler(RequestValidationError) - async def validation_error_handler(request: Request, exc: RequestValidationError): + async def validation_error_handler( + request: Request, exc: RequestValidationError + ) -> JSONResponse: + errors = _filter_union_branch_errors(exc.errors()) return JSONResponse( status_code=HTTP_422_UNPROCESSABLE_ENTITY, - content=APIResponse.failure_response(exc.errors()).model_dump(), + content=APIResponse.failure_response(errors).model_dump(), ) @app.exception_handler(HTTPException) - async def http_exception_handler(request: Request, exc: HTTPException): + async def http_exception_handler( + request: Request, exc: HTTPException + ) -> JSONResponse: return JSONResponse( status_code=exc.status_code, content=APIResponse.failure_response(exc.detail).model_dump(), ) @app.exception_handler(Exception) - async def generic_error_handler(request: Request, exc: Exception): + async def generic_error_handler(request: Request, exc: Exception) -> JSONResponse: return JSONResponse( status_code=HTTP_500_INTERNAL_SERVER_ERROR, content=APIResponse.failure_response( diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index 8944779d3..18e7749bf 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -43,6 +43,9 @@ class DatasetUploadResponse(BaseModel): object_store_url: str | None = Field( None, description="Object store URL if uploaded" ) + signed_url: str | None = Field( + None, description="A signed URL for downloading the dataset" + ) class EvaluationResult(BaseModel): diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 12db54ff3..66775bd8b 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -124,7 +124,9 @@ def test_collection_creation_vector_only_request_validation_error( assert body["success"] is False assert body["data"] is None assert body["metadata"] is None - assert ( + assert body["errors"] + assert any( "To create an Assistant, provide BOTH 'model' and 'instructions'" - in body["error"] + in e["message"] + for e in body["errors"] ) diff --git a/backend/app/tests/api/routes/test_doc_transformation_job.py b/backend/app/tests/api/routes/test_doc_transformation_job.py index cea74ff75..298f73ec0 100644 --- a/backend/app/tests/api/routes/test_doc_transformation_job.py +++ b/backend/app/tests/api/routes/test_doc_transformation_job.py @@ -55,7 +55,9 @@ def test_get_job_invalid_uuid_422( ) assert resp.status_code == 422 body = resp.json() - assert "error" in body and "valid UUID" in body["error"] + assert body["errors"] and any( + "valid UUID" in e["message"] for e in body["errors"] + ) def test_get_job_different_project_404( self, @@ -207,9 +209,12 @@ def test_get_jobs_with_empty_string( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "valid UUID" in body["error"] or "expected length" in body["error"] - assert "job_ids" in body["error"] + assert body["errors"] + assert any( + "valid UUID" in e["message"] or "expected length" in e["message"] + for e in body["errors"] + ) + assert any("job_ids" in e["field"] for e in body["errors"]) def test_get_jobs_with_whitespace_only( self, client: TestClient, user_api_key: TestAuthContext @@ -224,8 +229,8 @@ def test_get_jobs_with_whitespace_only( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "valid UUID" in body["error"] + assert body["errors"] + assert any("valid UUID" in e["message"] for e in body["errors"]) def test_get_jobs_invalid_uuid_format_422( self, client: TestClient, user_api_key: TestAuthContext @@ -242,9 +247,12 @@ def test_get_jobs_invalid_uuid_format_422( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "valid UUID" in body["error"] or "expected length" in body["error"] - assert "job_ids" in body["error"] + assert body["errors"] + assert any( + "valid UUID" in e["message"] or "expected length" in e["message"] + for e in body["errors"] + ) + assert any("job_ids" in e["field"] for e in body["errors"]) def test_get_jobs_mixed_valid_invalid_uuid_422( self, client: TestClient, db: Session, user_api_key: TestAuthContext @@ -266,12 +274,13 @@ def test_get_jobs_mixed_valid_invalid_uuid_422( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "job_ids" in body["error"] - assert ( - "valid UUID" in body["error"] - or "invalid character" in body["error"] - or "invalid length" in body["error"] + assert body["errors"] + assert any("job_ids" in e["field"] for e in body["errors"]) + assert any( + "valid UUID" in e["message"] + or "invalid character" in e["message"] + or "invalid length" in e["message"] + for e in body["errors"] ) def test_get_jobs_missing_parameter_422( @@ -287,9 +296,9 @@ def test_get_jobs_missing_parameter_422( body = response.json() assert body["success"] is False assert body["data"] is None - assert "error" in body - assert "Field required" in body["error"] - assert "job_ids" in body["error"] + assert body["errors"] + assert any("Field required" in e["message"] for e in body["errors"]) + assert any("job_ids" in e["field"] for e in body["errors"]) def test_get_jobs_different_project_not_found( self, diff --git a/backend/app/tests/api/routes/test_evaluation.py b/backend/app/tests/api/routes/test_evaluation.py index 0129fa44a..4b751a59a 100644 --- a/backend/app/tests/api/routes/test_evaluation.py +++ b/backend/app/tests/api/routes/test_evaluation.py @@ -347,9 +347,11 @@ def test_upload_with_duplication_factor_below_minimum( assert response.status_code == 422 response_data = response.json() - # Check that the error mentions validation and minimum value - assert "error" in response_data - assert "greater than or equal to 1" in response_data["error"] + assert response_data["errors"] + assert any( + "greater than or equal to 1" in e["message"] + for e in response_data["errors"] + ) def test_upload_with_duplication_factor_above_maximum( self, @@ -372,9 +374,10 @@ def test_upload_with_duplication_factor_above_maximum( assert response.status_code == 422 response_data = response.json() - # Check that the error mentions validation and maximum value - assert "error" in response_data - assert "less than or equal to 5" in response_data["error"] + assert response_data["errors"] + assert any( + "less than or equal to 5" in e["message"] for e in response_data["errors"] + ) def test_upload_with_duplication_factor_boundary_minimum( self, @@ -1376,6 +1379,71 @@ def test_get_dataset_not_found( ) assert "not found" in error_str.lower() or "not accessible" in error_str.lower() + def test_default_no_signed_url( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Test that signed_url is not included by default.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + response = client.get( + f"/api/v1/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + ) + assert response.status_code == 200 + assert response.json()["data"].get("signed_url") is None + + def test_include_signed_url( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Test that signed_url is returned when requested.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + with patch("app.api.routes.evaluations.dataset.get_cloud_storage") as mock: + mock.return_value.get_signed_url.return_value = "https://signed.url" + response = client.get( + f"/api/v1/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + params={"include_signed_url": True}, + ) + assert response.json()["data"]["signed_url"] == "https://signed.url" + + def test_no_object_store_url_skips_signing( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ) -> None: + """Test that signing is skipped when dataset has no object_store_url.""" + dataset = create_test_evaluation_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + dataset.object_store_url = None + db.add(dataset) + db.commit() + + with patch("app.api.routes.evaluations.dataset.get_cloud_storage") as mock: + response = client.get( + f"/api/v1/evaluations/datasets/{dataset.id}", + headers={"X-API-KEY": user_api_key.key}, + params={"include_signed_url": True}, + ) + mock.assert_not_called() + assert response.json()["data"].get("signed_url") is None + class TestDeleteDataset: """Test DELETE /evaluations/datasets/{dataset_id} endpoint.""" diff --git a/backend/app/tests/api/routes/test_onboarding.py b/backend/app/tests/api/routes/test_onboarding.py index 0b3f14971..3a6f13b76 100644 --- a/backend/app/tests/api/routes/test_onboarding.py +++ b/backend/app/tests/api/routes/test_onboarding.py @@ -197,8 +197,10 @@ def test_onboard_project_invalid_provider( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) def test_onboard_project_non_dict_values_in_credential( @@ -227,9 +229,13 @@ def test_onboard_project_non_dict_values_in_credential( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] - assert "must be an object/dict" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) + assert any( + "must be an object/dict" in e["message"] for e in error_response["errors"] + ) def test_onboard_project_missing_required_fields_for_openai( @@ -258,9 +264,11 @@ def test_onboard_project_missing_required_fields_for_openai( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] - assert "openai" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) + assert any("openai" in e["message"] for e in error_response["errors"]) def test_onboard_project_missing_required_fields_for_langfuse( @@ -291,9 +299,11 @@ def test_onboard_project_missing_required_fields_for_langfuse( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] - assert "langfuse" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) + assert any("langfuse" in e["message"] for e in error_response["errors"]) def test_onboard_project_aggregates_multiple_credential_errors( @@ -325,7 +335,9 @@ def test_onboard_project_aggregates_multiple_credential_errors( assert response.status_code == 422 error_response = response.json() - assert "error" in error_response - assert "credential validation failed" in error_response["error"] - assert "[0]" in error_response["error"] - assert "[1]" in error_response["error"] + assert error_response["errors"] + assert any( + "credential validation failed" in e["message"] for e in error_response["errors"] + ) + assert any("[0]" in e["message"] for e in error_response["errors"]) + assert any("[1]" in e["message"] for e in error_response["errors"]) diff --git a/backend/app/tests/api/routes/test_threads.py b/backend/app/tests/api/routes/test_threads.py index e9a11fad8..9a1f297a6 100644 --- a/backend/app/tests/api/routes/test_threads.py +++ b/backend/app/tests/api/routes/test_threads.py @@ -601,4 +601,5 @@ def test_threads_start_missing_question( assert response.status_code == 422 # Unprocessable Entity (FastAPI will raise 422) error_response = response.json() assert error_response["success"] is False - assert "question" in error_response["error"] + assert error_response["errors"] + assert any("question" in e["field"] for e in error_response["errors"]) diff --git a/backend/app/tests/core/test_exception_handlers.py b/backend/app/tests/core/test_exception_handlers.py new file mode 100644 index 000000000..8971f1b47 --- /dev/null +++ b/backend/app/tests/core/test_exception_handlers.py @@ -0,0 +1,140 @@ +from fastapi.testclient import TestClient + +from app.core.config import settings +from app.core.exception_handlers import _filter_union_branch_errors +from app.tests.utils.auth import TestAuthContext + + +class TestFilterUnionBranchErrors: + """Unit tests for _filter_union_branch_errors.""" + + def test_non_union_errors_pass_through(self) -> None: + errors = [ + {"type": "missing", "loc": ("body", "name"), "msg": "Field required"}, + ] + assert _filter_union_branch_errors(errors) == errors + + def test_picks_branch_with_fewer_literal_errors(self) -> None: + errors = [ + { + "type": "literal_error", + "loc": ("body", "c", "NativeConfig", "provider"), + "msg": "bad", + }, + { + "type": "missing", + "loc": ("body", "c", "NativeConfig", "params"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("body", "c", "KaapiConfig", "type"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("body", "c", "KaapiConfig", "params"), + "msg": "Field required", + }, + ] + result = _filter_union_branch_errors(errors) + assert len(result) == 2 + for err in result: + assert "NativeConfig" not in err["loc"] + + def test_tied_branches_keep_both_and_dedup(self) -> None: + """When two branches have the same literal_error count, both are kept but duplicates removed.""" + errors = [ + { + "type": "missing", + "loc": ("body", "c", "BranchA", "x"), + "msg": "Field required", + }, + { + "type": "missing", + "loc": ("body", "c", "BranchB", "x"), + "msg": "Field required", + }, + ] + result = _filter_union_branch_errors(errors) + assert len(result) == 1 + assert result[0]["loc"] == ("body", "c", "x") + + def test_strips_branch_identifiers_from_loc(self) -> None: + errors = [ + { + "type": "missing", + "loc": ( + "body", + "cfg", + "completion", + "function-after[validate_params(), Foo]", + "params", + ), + "msg": "Field required", + } + ] + result = _filter_union_branch_errors(errors) + assert result[0]["loc"] == ("body", "cfg", "completion", "params") + + def test_non_union_preserved_with_union(self) -> None: + errors = [ + {"type": "missing", "loc": ("body", "name"), "msg": "Field required"}, + { + "type": "literal_error", + "loc": ("body", "c", "NativeConfig", "p"), + "msg": "bad", + }, + { + "type": "missing", + "loc": ("body", "c", "KaapiConfig", "t"), + "msg": "Field required", + }, + ] + result = _filter_union_branch_errors(errors) + assert len(result) == 2 + locs = [r["loc"] for r in result] + assert ("body", "name") in locs + + def test_empty_list(self) -> None: + assert _filter_union_branch_errors([]) == [] + + def test_fallback_on_malformed_input(self) -> None: + malformed = [None, 42] # type: ignore[list-item] + result = _filter_union_branch_errors(malformed) + assert result == malformed + + +class TestValidationErrorResponse: + """Integration: structured errors via configs endpoint.""" + + def test_structured_error_format( + self, client: TestClient, user_api_key: TestAuthContext + ) -> None: + response = client.post( + f"{settings.API_V1_STR}/configs/", + headers={"X-API-KEY": user_api_key.key}, + json={}, + ) + assert response.status_code == 422 + body = response.json() + assert body["success"] is False + assert body["error"] == "Validation failed" + assert isinstance(body["errors"], list) + assert all("field" in e and "message" in e for e in body["errors"]) + + def test_union_noise_filtered( + self, client: TestClient, user_api_key: TestAuthContext + ) -> None: + response = client.post( + f"{settings.API_V1_STR}/configs/", + headers={"X-API-KEY": user_api_key.key}, + json={ + "name": "test-config", + "config_blob": {"completion": {"provider": "openai"}}, + }, + ) + assert response.status_code == 422 + for error in response.json()["errors"]: + assert "openai-native" not in error["message"] + assert "NativeCompletionConfig" not in error["field"] diff --git a/backend/app/utils.py b/backend/app/utils.py index 9c1be2a11..9a6659abe 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -44,10 +44,16 @@ T = TypeVar("T") +class ValidationErrorDetail(BaseModel): + field: str + message: str + + class APIResponse(BaseModel, Generic[T]): success: bool data: Optional[T] = None error: Optional[str] = None + errors: Optional[list[ValidationErrorDetail]] = None metadata: Optional[Dict[str, Any]] = None @classmethod @@ -64,11 +70,27 @@ def failure_response( metadata: Optional[Dict[str, Any]] = None, ) -> "APIResponse[None]": if isinstance(error, list): # to handle cases when error is a list of errors - error_message = "\n".join([f"{err['loc']}: {err['msg']}" for err in error]) - else: - error_message = error + structured_errors = [] + for err in error: + loc = err.get("loc", ()) + parts = [str(p) for p in loc if p != "body"] + field = ".".join(parts) if parts else "unknown" + structured_errors.append( + ValidationErrorDetail( + field=str(field), message=str(err.get("msg", "")) + ) + ) + + return cls( + success=False, + data=data, + error="Validation failed", + errors=structured_errors, + metadata=metadata, + ) - return cls(success=False, data=data, error=error_message, metadata=metadata) + else: + return cls(success=False, data=data, error=error, metadata=metadata) @dataclass