Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/app/api/docs/evaluation/get_dataset.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Get details of a specific dataset by ID.

Returns comprehensive dataset information including metadata (ID, name, item counts, duplication factor), Langfuse integration details (dataset ID), and the object store URL for the CSV file.

Use `include_signed_url=true` to include a time-limited signed URL for downloading the dataset CSV directly from object storage. The signed URL expires after 1 hour.
20 changes: 18 additions & 2 deletions backend/app/api/routes/evaluations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from app.api.deps import AuthContextDep, SessionDep
from app.api.permissions import Permission, require_permission
from app.core.cloud import get_cloud_storage
from app.crud.evaluations import (
get_dataset_by_id,
list_datasets as list_evaluation_datasets,
Expand All @@ -34,7 +35,9 @@
router = APIRouter(prefix="/evaluations/datasets", tags=["Evaluation"])


def _dataset_to_response(dataset: EvaluationDataset) -> DatasetUploadResponse:
def _dataset_to_response(
dataset: EvaluationDataset, signed_url: str | None = None
) -> DatasetUploadResponse:
"""Convert a dataset model to a DatasetUploadResponse."""
return DatasetUploadResponse(
dataset_id=dataset.id,
Expand All @@ -44,6 +47,7 @@ def _dataset_to_response(dataset: EvaluationDataset) -> DatasetUploadResponse:
duplication_factor=dataset.dataset_metadata.get("duplication_factor", 1),
langfuse_dataset_id=dataset.langfuse_dataset_id,
object_store_url=dataset.object_store_url,
signed_url=signed_url,
)


Expand Down Expand Up @@ -124,6 +128,9 @@ def get_dataset(
dataset_id: int,
session: SessionDep,
auth_context: AuthContextDep,
include_signed_url: bool = Query(
False, description="Include signed URL for dataset"
),
) -> APIResponse[DatasetUploadResponse]:
"""Get a specific evaluation dataset."""
logger.info(
Expand All @@ -144,7 +151,16 @@ def get_dataset(
status_code=404, detail=f"Dataset {dataset_id} not found or not accessible"
)

return APIResponse.success_response(data=_dataset_to_response(dataset))
signed_url = None
if include_signed_url and dataset.object_store_url:
storage = get_cloud_storage(
session=session, project_id=auth_context.project_.id
)
signed_url = storage.get_signed_url(dataset.object_store_url)
Comment on lines +154 to +159
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing error handling for cloud storage failures.

According to backend/app/core/cloud/storage.py:267-284, get_cloud_storage() can raise ValueError for invalid projects or propagate exceptions when AWS initialization fails. If cloud storage is not configured (e.g., in development environments without AWS credentials), this will cause a 500 error even when the dataset exists.

Consider gracefully handling these failures by returning the dataset without the signed URL instead of failing the entire request.

🛡️ Proposed fix with error handling
     signed_url = None
     if include_signed_url and dataset.object_store_url:
-        storage = get_cloud_storage(
-            session=_session, project_id=auth_context.project_.id
-        )
-        signed_url = storage.get_signed_url(dataset.object_store_url)
+        try:
+            storage = get_cloud_storage(
+                session=_session, project_id=auth_context.project_.id
+            )
+            signed_url = storage.get_signed_url(dataset.object_store_url)
+        except Exception as e:
+            logger.warning(
+                f"[get_dataset] Failed to generate signed URL | dataset_id={dataset_id} | error={str(e)}"
+            )
+            # Continue without signed_url rather than failing the request
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
signed_url = None
if include_signed_url and dataset.object_store_url:
storage = get_cloud_storage(
session=_session, project_id=auth_context.project_.id
)
signed_url = storage.get_signed_url(dataset.object_store_url)
signed_url = None
if include_signed_url and dataset.object_store_url:
try:
storage = get_cloud_storage(
session=_session, project_id=auth_context.project_.id
)
signed_url = storage.get_signed_url(dataset.object_store_url)
except Exception as e:
logger.warning(
f"[get_dataset] Failed to generate signed URL | dataset_id={dataset_id} | error={str(e)}"
)
# Continue without signed_url rather than failing the request
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@backend/app/api/routes/evaluations/dataset.py` around lines 154 - 159, The
current signed URL retrieval (guarded by include_signed_url and
dataset.object_store_url) calls get_cloud_storage(...) and
storage.get_signed_url(...) but can raise ValueError or other exceptions; update
the block to catch ValueError and a broad Exception around get_cloud_storage and
storage.get_signed_url (referencing get_cloud_storage, storage.get_signed_url,
include_signed_url, dataset.object_store_url, auth_context.project_.id,
session=_session), log a warning or debug message with the error, and simply
leave signed_url as None so the endpoint returns the dataset without failing
when cloud storage is not configured or initialization fails.


return APIResponse.success_response(
data=_dataset_to_response(dataset, signed_url=signed_url)
)


@router.delete(
Expand Down
90 changes: 83 additions & 7 deletions backend/app/core/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,106 @@
import re
from collections import defaultdict

from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from app.utils import APIResponse
from fastapi.responses import JSONResponse
from starlette.status import (
HTTP_422_UNPROCESSABLE_ENTITY,
HTTP_500_INTERNAL_SERVER_ERROR,
)

from app.utils import APIResponse

_BRANCH_PATTERN = re.compile(r"^[A-Z]|[\[\]()]")


def _is_branch_identifier(part: str) -> bool:
return bool(part and isinstance(part, str) and _BRANCH_PATTERN.search(part))
Comment on lines +14 to +18
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

Pydantic v2 discriminated union validation error loc format

💡 Result:

In Pydantic v2, ValidationError.errors() returns each error with:

  • loc: a tuple path of str | int segments (field names and container indexes). (docs.pydantic.dev)
  • When the error happens inside a discriminated (tagged) union, Pydantic inserts the selected union tag (i.e., the discriminator result) into the location path, between the union field and the nested field(s).

What the loc looks like for discriminated unions

If your model field is pet and the discriminator selects tag "dog", then a missing field error inside that variant will have a location conceptually like:

For nested discriminated unions, tags can stack:

  • ('pet', 'cat', 'black', 'black_name') → displayed as pet.cat.black.black_name (docs.pydantic.dev)

This is the same behavior people notice in practice, e.g. ("child", "a", "field_a") where "a" is the discriminator/tag. (stackoverflow.com)

When the discriminator itself is wrong/missing

For discriminator failures, Pydantic uses error types like:

(Those errors are raised at the union field; the printed error path you see includes the tag when it’s known, e.g. pet.cat in the nested example.) (docs.pydantic.dev)

Citations:


🏁 Script executed:

# First, check the full context of the exception handler file
cat -n backend/app/core/exception_handlers.py | head -50

Repository: ProjectTech4DevAI/kaapi-backend

Length of output: 2025


🏁 Script executed:

# Check the request.py file to verify the discriminator values mentioned
cat -n backend/app/models/llm/request.py | grep -A 5 -B 5 "discriminator\|text\|audio\|openai"

Repository: ProjectTech4DevAI/kaapi-backend

Length of output: 9327


The union branch identifier pattern is insufficient for lowercase discriminators used in the codebase.

The regex ^[A-Z]|[\[\]()] only matches strings starting with uppercase letters or containing brackets. However, according to Pydantic v2 documentation, validation error loc tuples include discriminator tag values—and your codebase uses lowercase literal discriminators like "text", "audio", "openai", "openai-native" (from backend/app/models/llm/request.py lines 109, 114, 119, 125 and 197, 217).

Since the pattern won't match these lowercase discriminators when they appear in error locations, _is_branch_identifier() will fail to identify them as branch identifiers. This causes errors from discriminated unions with lowercase discriminators to be misclassified as non-union errors instead of being properly grouped and filtered by union branch (see lines 38–47).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@backend/app/core/exception_handlers.py` around lines 14 - 18, The current
_BRANCH_PATTERN only matches uppercase starters and brackets, so
_is_branch_identifier fails for lowercase discriminators like "text" or
"openai-native"; update the regex used by _BRANCH_PATTERN to include lowercase
letters (for example change r"^[A-Z]|[\[\]()]" to r"^[A-Za-z]|[\[\]()]") so
_is_branch_identifier(part: str) correctly recognizes lowercase discriminators
and hyphenated tags used in discriminated unions (leave _is_branch_identifier
implementation as-is and add/verify tests for tags such as "text", "audio",
"openai", "openai-native").



def _filter_union_branch_errors(errors: list[dict]) -> list[dict]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what exactly do you mean by "branch" here

Copy link
Copy Markdown
Collaborator Author

@vprashrex vprashrex Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, “branch” refers to each possible model inside a Union / discriminated union.
For example, CompletionConfig can be either NativeCompletionConfig or KaapiCompletionConfig.

CompletionConfig = Annotated[
    Union[NativeCompletionConfig, KaapiCompletionConfig],
    Field(discriminator="provider"),
]

When a validation error occurs, Pydantic often returns errors from all union branches, even though the issue actually belongs to only one branch. This can result in confusing error messages.

The _filter_union_branch_errors function filters these results and keeps only the errors from the branch where the validation actually failed.

Note: If errors genuinely occur in multiple branches, errors from multiple branches will still be returned.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the docstring as well for that function filter union branch

"""When a field is a Union type, pydantic returns errors for every possible branch.

This function picks the branch where the validation error happend.
"""
try:
branch_errors: dict[str, dict[str, list[dict]]] = defaultdict(
lambda: defaultdict(list)
)
non_union_errors: list[dict] = []

for err in errors:
loc = err.get("loc", ())
branch_name = None
parent_field = None
for i, part in enumerate(loc):
if _is_branch_identifier(part):
branch_name = part
parent_field = loc[:i] if i > 0 else ("root",)
break

if branch_name and parent_field:
branch_errors[str(parent_field)][branch_name].append(err)
else:
non_union_errors.append(err)

filtered = list(non_union_errors)
for _parent, branches in branch_errors.items():
if len(branches) <= 1:
for errs in branches.values():
filtered.extend(errs)
else:
# NOTE: Keep all branches tied for fewest literal errors
best_count = min(
sum(1 for e in errs if e.get("type") == "literal_error")
for errs in branches.values()
)
for errs in branches.values():
if (
sum(1 for e in errs if e.get("type") == "literal_error")
<= best_count
):
filtered.extend(errs)

for err in filtered:
loc = err.get("loc", ())
err["loc"] = tuple(p for p in loc if not _is_branch_identifier(p))

seen_errors: set[tuple] = set()
unique_errors: list[dict] = []
for error in filtered:
error_key = (tuple(error.get("loc", ())), error.get("msg", ""))
if error_key not in seen_errors:
seen_errors.add(error_key)
unique_errors.append(error)

return unique_errors or errors
except Exception:
return errors


def register_exception_handlers(app: FastAPI):
def register_exception_handlers(app: FastAPI) -> None:
@app.exception_handler(RequestValidationError)
async def validation_error_handler(request: Request, exc: RequestValidationError):
async def validation_error_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
errors = _filter_union_branch_errors(exc.errors())
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content=APIResponse.failure_response(exc.errors()).model_dump(),
content=APIResponse.failure_response(errors).model_dump(),
)

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
async def http_exception_handler(
request: Request, exc: HTTPException
) -> JSONResponse:
return JSONResponse(
status_code=exc.status_code,
content=APIResponse.failure_response(exc.detail).model_dump(),
)

@app.exception_handler(Exception)
async def generic_error_handler(request: Request, exc: Exception):
async def generic_error_handler(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse.failure_response(
Expand Down
3 changes: 3 additions & 0 deletions backend/app/models/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class DatasetUploadResponse(BaseModel):
object_store_url: str | None = Field(
None, description="Object store URL if uploaded"
)
signed_url: str | None = Field(
None, description="A signed URL for downloading the dataset"
)


class EvaluationResult(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def test_collection_creation_vector_only_request_validation_error(
assert body["success"] is False
assert body["data"] is None
assert body["metadata"] is None
assert (
assert body["errors"]
assert any(
"To create an Assistant, provide BOTH 'model' and 'instructions'"
in body["error"]
in e["message"]
for e in body["errors"]
)
45 changes: 27 additions & 18 deletions backend/app/tests/api/routes/test_doc_transformation_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def test_get_job_invalid_uuid_422(
)
assert resp.status_code == 422
body = resp.json()
assert "error" in body and "valid UUID" in body["error"]
assert body["errors"] and any(
"valid UUID" in e["message"] for e in body["errors"]
)

def test_get_job_different_project_404(
self,
Expand Down Expand Up @@ -207,9 +209,12 @@ def test_get_jobs_with_empty_string(
body = response.json()
assert body["success"] is False
assert body["data"] is None
assert "error" in body
assert "valid UUID" in body["error"] or "expected length" in body["error"]
assert "job_ids" in body["error"]
assert body["errors"]
assert any(
"valid UUID" in e["message"] or "expected length" in e["message"]
for e in body["errors"]
)
assert any("job_ids" in e["field"] for e in body["errors"])

def test_get_jobs_with_whitespace_only(
self, client: TestClient, user_api_key: TestAuthContext
Expand All @@ -224,8 +229,8 @@ def test_get_jobs_with_whitespace_only(
body = response.json()
assert body["success"] is False
assert body["data"] is None
assert "error" in body
assert "valid UUID" in body["error"]
assert body["errors"]
assert any("valid UUID" in e["message"] for e in body["errors"])

def test_get_jobs_invalid_uuid_format_422(
self, client: TestClient, user_api_key: TestAuthContext
Expand All @@ -242,9 +247,12 @@ def test_get_jobs_invalid_uuid_format_422(
body = response.json()
assert body["success"] is False
assert body["data"] is None
assert "error" in body
assert "valid UUID" in body["error"] or "expected length" in body["error"]
assert "job_ids" in body["error"]
assert body["errors"]
assert any(
"valid UUID" in e["message"] or "expected length" in e["message"]
for e in body["errors"]
)
assert any("job_ids" in e["field"] for e in body["errors"])

def test_get_jobs_mixed_valid_invalid_uuid_422(
self, client: TestClient, db: Session, user_api_key: TestAuthContext
Expand All @@ -266,12 +274,13 @@ def test_get_jobs_mixed_valid_invalid_uuid_422(
body = response.json()
assert body["success"] is False
assert body["data"] is None
assert "error" in body
assert "job_ids" in body["error"]
assert (
"valid UUID" in body["error"]
or "invalid character" in body["error"]
or "invalid length" in body["error"]
assert body["errors"]
assert any("job_ids" in e["field"] for e in body["errors"])
assert any(
"valid UUID" in e["message"]
or "invalid character" in e["message"]
or "invalid length" in e["message"]
for e in body["errors"]
)

def test_get_jobs_missing_parameter_422(
Expand All @@ -287,9 +296,9 @@ def test_get_jobs_missing_parameter_422(
body = response.json()
assert body["success"] is False
assert body["data"] is None
assert "error" in body
assert "Field required" in body["error"]
assert "job_ids" in body["error"]
assert body["errors"]
assert any("Field required" in e["message"] for e in body["errors"])
assert any("job_ids" in e["field"] for e in body["errors"])

def test_get_jobs_different_project_not_found(
self,
Expand Down
80 changes: 74 additions & 6 deletions backend/app/tests/api/routes/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,11 @@ def test_upload_with_duplication_factor_below_minimum(

assert response.status_code == 422
response_data = response.json()
# Check that the error mentions validation and minimum value
assert "error" in response_data
assert "greater than or equal to 1" in response_data["error"]
assert response_data["errors"]
assert any(
"greater than or equal to 1" in e["message"]
for e in response_data["errors"]
)

def test_upload_with_duplication_factor_above_maximum(
self,
Expand All @@ -372,9 +374,10 @@ def test_upload_with_duplication_factor_above_maximum(

assert response.status_code == 422
response_data = response.json()
# Check that the error mentions validation and maximum value
assert "error" in response_data
assert "less than or equal to 5" in response_data["error"]
assert response_data["errors"]
assert any(
"less than or equal to 5" in e["message"] for e in response_data["errors"]
)

def test_upload_with_duplication_factor_boundary_minimum(
self,
Expand Down Expand Up @@ -1376,6 +1379,71 @@ def test_get_dataset_not_found(
)
assert "not found" in error_str.lower() or "not accessible" in error_str.lower()

def test_default_no_signed_url(
self,
db: Session,
client: TestClient,
user_api_key: TestAuthContext,
) -> None:
"""Test that signed_url is not included by default."""
dataset = create_test_evaluation_dataset(
db=db,
organization_id=user_api_key.organization_id,
project_id=user_api_key.project_id,
)
response = client.get(
f"/api/v1/evaluations/datasets/{dataset.id}",
headers={"X-API-KEY": user_api_key.key},
)
assert response.status_code == 200
assert response.json()["data"].get("signed_url") is None

def test_include_signed_url(
self,
db: Session,
client: TestClient,
user_api_key: TestAuthContext,
) -> None:
"""Test that signed_url is returned when requested."""
dataset = create_test_evaluation_dataset(
db=db,
organization_id=user_api_key.organization_id,
project_id=user_api_key.project_id,
)
with patch("app.api.routes.evaluations.dataset.get_cloud_storage") as mock:
mock.return_value.get_signed_url.return_value = "https://signed.url"
response = client.get(
f"/api/v1/evaluations/datasets/{dataset.id}",
headers={"X-API-KEY": user_api_key.key},
params={"include_signed_url": True},
)
assert response.json()["data"]["signed_url"] == "https://signed.url"

def test_no_object_store_url_skips_signing(
self,
db: Session,
client: TestClient,
user_api_key: TestAuthContext,
) -> None:
"""Test that signing is skipped when dataset has no object_store_url."""
dataset = create_test_evaluation_dataset(
db=db,
organization_id=user_api_key.organization_id,
project_id=user_api_key.project_id,
)
dataset.object_store_url = None
db.add(dataset)
db.commit()

with patch("app.api.routes.evaluations.dataset.get_cloud_storage") as mock:
response = client.get(
f"/api/v1/evaluations/datasets/{dataset.id}",
headers={"X-API-KEY": user_api_key.key},
params={"include_signed_url": True},
)
mock.assert_not_called()
assert response.json()["data"].get("signed_url") is None


class TestDeleteDataset:
"""Test DELETE /evaluations/datasets/{dataset_id} endpoint."""
Expand Down
Loading
Loading