diff --git a/backend/app/alembic/versions/055_add_assessment_manager_table.py b/backend/app/alembic/versions/055_add_assessment_manager_table.py new file mode 100644 index 000000000..840c25e6d --- /dev/null +++ b/backend/app/alembic/versions/055_add_assessment_manager_table.py @@ -0,0 +1,227 @@ +"""add assessment and assessment_run tables + +Revision ID: 055 +Revises: 054 +Create Date: 2026-03-26 23:30:00.000000 + +""" + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "055" +down_revision = "054" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "assessment", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the assessment", + ), + sa.Column( + "experiment_name", + sqlmodel.sql.sqltypes.AutoString(), + nullable=False, + comment="Name of the experiment grouping its config runs", + ), + sa.Column( + "dataset_id", + sa.Integer(), + nullable=False, + comment="Reference to the evaluation dataset", + ), + sa.Column( + "status", + sqlmodel.sql.sqltypes.AutoString(), + nullable=False, + server_default="pending", + comment=( + "Aggregate status: pending, processing, completed, " + "completed_with_errors, failed" + ), + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the assessment was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the assessment was last updated", + ), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["evaluation_dataset.id"], + name="fk_assessment_dataset_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organization.id"], + name="fk_assessment_organization_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + name="fk_assessment_project_id", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_assessment_experiment_name"), + "assessment", + ["experiment_name"], + unique=False, + ) + op.create_index( + "idx_assessment_org_project", + "assessment", + ["organization_id", "project_id", "inserted_at"], + unique=False, + ) + op.create_index( + "idx_assessment_status", + "assessment", + ["status"], + unique=False, + ) + + op.create_table( + "assessment_run", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the assessment run", + ), + sa.Column( + "assessment_id", + sa.Integer(), + nullable=False, + comment="Reference to the parent assessment", + ), + sa.Column( + "config_id", + sa.Uuid(), + nullable=False, + comment="Reference to the stored config used", + ), + sa.Column( + "config_version", + sa.Integer(), + nullable=False, + comment="Version of the config used", + ), + sa.Column( + "status", + sqlmodel.sql.sqltypes.AutoString(), + nullable=False, + server_default="pending", + comment="Run status: pending, processing, completed, failed", + ), + sa.Column( + "batch_job_id", + sa.Integer(), + nullable=True, + comment="Reference to the batch job processing this run", + ), + sa.Column( + "total_items", + sa.Integer(), + nullable=False, + server_default="0", + comment="Total number of dataset items in this run", + ), + sa.Column( + "input", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + comment=( + "Assessment input: prompt_template, text_columns, attachments, " + "output_schema" + ), + ), + sa.Column( + "object_store_url", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + comment="S3 URL of processed batch results", + ), + sa.Column( + "error_message", + sa.Text(), + nullable=True, + comment="Error message if the run failed", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the run was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the run was last updated", + ), + sa.ForeignKeyConstraint( + ["assessment_id"], + ["assessment.id"], + name="fk_assessment_run_assessment_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["config_id"], + ["config.id"], + name="fk_assessment_run_config_id", + ), + sa.ForeignKeyConstraint( + ["batch_job_id"], + ["batch_job.id"], + name="fk_assessment_run_batch_job_id", + ondelete="SET NULL", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "idx_assessment_run_assessment_id", + "assessment_run", + ["assessment_id"], + unique=False, + ) + + +def downgrade(): + op.drop_index("idx_assessment_run_assessment_id", table_name="assessment_run") + op.drop_table("assessment_run") + op.drop_index("idx_assessment_status", table_name="assessment") + op.drop_index("idx_assessment_org_project", table_name="assessment") + op.drop_index(op.f("ix_assessment_experiment_name"), table_name="assessment") + op.drop_table("assessment") diff --git a/backend/app/alembic/versions/056_add_config_tag.py b/backend/app/alembic/versions/056_add_config_tag.py new file mode 100644 index 000000000..a374afcbe --- /dev/null +++ b/backend/app/alembic/versions/056_add_config_tag.py @@ -0,0 +1,82 @@ +"""add tag column to config table + +Revision ID: 056 +Revises: 055 +Create Date: 2026-05-03 12:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "056" +down_revision = "055" +branch_labels = None +depends_on = None + + +CONFIG_TAG_VALUES = ("default", "ASSESSMENT") +DEFAULT_TAG_SERVER_DEFAULT = sa.text("'default'::config_tag") + + +def upgrade(): + config_tag = postgresql.ENUM( + *CONFIG_TAG_VALUES, + name="config_tag", + create_type=False, + ) + config_tag.create(op.get_bind(), checkfirst=True) + + with op.get_context().autocommit_block(): + op.execute("ALTER TYPE config_tag ADD VALUE IF NOT EXISTS 'default'") + op.execute("ALTER TYPE config_tag ADD VALUE IF NOT EXISTS 'ASSESSMENT'") + + op.add_column( + "config", + sa.Column( + "tag", + config_tag, + nullable=False, + server_default=DEFAULT_TAG_SERVER_DEFAULT, + comment=( + "Tag classifying the config: " + "'default' for general use, 'ASSESSMENT' for configs used in assessments. " + ), + ), + ) + + op.execute( + """ + UPDATE config + SET tag = 'ASSESSMENT' + FROM ( + SELECT DISTINCT config_id + FROM assessment_run + ) AS assessment_configs + WHERE config.id = assessment_configs.config_id + """ + ) + + with op.get_context().autocommit_block(): + op.create_index( + "idx_config_project_id_tag_active", + "config", + ["project_id", "tag", sa.text("updated_at DESC")], + unique=False, + postgresql_where=sa.text("deleted_at IS NULL"), + postgresql_concurrently=True, + ) + + +def downgrade(): + with op.get_context().autocommit_block(): + op.drop_index( + "idx_config_project_id_tag_active", + table_name="config", + postgresql_concurrently=True, + ) + + op.drop_column("config", "tag") + sa.Enum(name="config_tag").drop(op.get_bind(), checkfirst=True) diff --git a/backend/app/api/docs/assessment/create_run.md b/backend/app/api/docs/assessment/create_run.md new file mode 100644 index 000000000..0986a6e8c --- /dev/null +++ b/backend/app/api/docs/assessment/create_run.md @@ -0,0 +1,7 @@ +Start an assessment across one or more stored config versions. + +Creates an assessment and one child assessment run per config, then submits each +run to batch processing. + +Optional `system_instruction` is forwarded into each generated provider request +as the system/developer instruction for that assessment run. diff --git a/backend/app/api/docs/assessment/delete_dataset.md b/backend/app/api/docs/assessment/delete_dataset.md new file mode 100644 index 000000000..0d36f7a7a --- /dev/null +++ b/backend/app/api/docs/assessment/delete_dataset.md @@ -0,0 +1,4 @@ +Delete an assessment dataset. + +This removes dataset metadata and associated storage references for the +given dataset in the current organization and project. diff --git a/backend/app/api/docs/assessment/export_assessment_results.md b/backend/app/api/docs/assessment/export_assessment_results.md new file mode 100644 index 000000000..f832400d3 --- /dev/null +++ b/backend/app/api/docs/assessment/export_assessment_results.md @@ -0,0 +1,4 @@ +Export results for all child runs under an assessment. + +For `json`, returns a flat list in the API response. For `csv`/`xlsx`, +returns one file for a single run or a ZIP archive when multiple runs exist. diff --git a/backend/app/api/docs/assessment/export_run_results.md b/backend/app/api/docs/assessment/export_run_results.md new file mode 100644 index 000000000..4387bea08 --- /dev/null +++ b/backend/app/api/docs/assessment/export_run_results.md @@ -0,0 +1,3 @@ +Export results for a single assessment run. + +Supports `json`, `csv`, and `xlsx` output formats. diff --git a/backend/app/api/docs/assessment/get_assessment.md b/backend/app/api/docs/assessment/get_assessment.md new file mode 100644 index 000000000..dcfdf0285 --- /dev/null +++ b/backend/app/api/docs/assessment/get_assessment.md @@ -0,0 +1,3 @@ +Get an assessment by ID. + +Returns aggregate run counts and status metadata for the assessment. diff --git a/backend/app/api/docs/assessment/get_dataset.md b/backend/app/api/docs/assessment/get_dataset.md new file mode 100644 index 000000000..5ba766d5b --- /dev/null +++ b/backend/app/api/docs/assessment/get_dataset.md @@ -0,0 +1,3 @@ +Get a single assessment dataset by ID. + +Optionally include a signed URL to download the original uploaded file. diff --git a/backend/app/api/docs/assessment/get_run.md b/backend/app/api/docs/assessment/get_run.md new file mode 100644 index 000000000..b5f182534 --- /dev/null +++ b/backend/app/api/docs/assessment/get_run.md @@ -0,0 +1,3 @@ +Get a single assessment run by ID. + +Returns run metadata, status, config reference, and assessment input payload. diff --git a/backend/app/api/docs/assessment/list_assessments.md b/backend/app/api/docs/assessment/list_assessments.md new file mode 100644 index 000000000..3e0f6f2c1 --- /dev/null +++ b/backend/app/api/docs/assessment/list_assessments.md @@ -0,0 +1,3 @@ +List assessments runs for the current organization/project. + +Each record includes aggregate status counters across its child runs. diff --git a/backend/app/api/docs/assessment/list_datasets.md b/backend/app/api/docs/assessment/list_datasets.md new file mode 100644 index 000000000..b915171b2 --- /dev/null +++ b/backend/app/api/docs/assessment/list_datasets.md @@ -0,0 +1,3 @@ +List assessment datasets for the current organization and project. + +Supports pagination via `limit` and `offset`. diff --git a/backend/app/api/docs/assessment/list_runs.md b/backend/app/api/docs/assessment/list_runs.md new file mode 100644 index 000000000..26249d011 --- /dev/null +++ b/backend/app/api/docs/assessment/list_runs.md @@ -0,0 +1,4 @@ +List assessment runs for the current organization/project. + +Optionally filter by `assessment_id` to list runs for a specific parent +assessment. diff --git a/backend/app/api/docs/assessment/retry_assessment.md b/backend/app/api/docs/assessment/retry_assessment.md new file mode 100644 index 000000000..56a4652cd --- /dev/null +++ b/backend/app/api/docs/assessment/retry_assessment.md @@ -0,0 +1,4 @@ +Retry an existing assessment. + +Reuses the original dataset and config references from the selected +assessment and creates a fresh assessment with new child runs. diff --git a/backend/app/api/docs/assessment/retry_run.md b/backend/app/api/docs/assessment/retry_run.md new file mode 100644 index 000000000..e2448252c --- /dev/null +++ b/backend/app/api/docs/assessment/retry_run.md @@ -0,0 +1,4 @@ +Retry a single assessment run. + +Creates a new assessment using the same dataset and config used by the +selected child run. diff --git a/backend/app/api/docs/assessment/upload_dataset.md b/backend/app/api/docs/assessment/upload_dataset.md new file mode 100644 index 000000000..358803c5f --- /dev/null +++ b/backend/app/api/docs/assessment/upload_dataset.md @@ -0,0 +1,4 @@ +Upload a CSV or Excel dataset for assessment workflows. + +The file is stored in object storage and indexed as an assessment dataset +for the current organization and project. diff --git a/backend/app/api/docs/config/create_version.md b/backend/app/api/docs/config/create_version.md index 3abe6471c..205de1444 100644 --- a/backend/app/api/docs/config/create_version.md +++ b/backend/app/api/docs/config/create_version.md @@ -6,6 +6,10 @@ create a new version under the same configuration with an incremented version nu Version numbers are automatically incremented sequentially (1, 2, 3, etc.) and cannot be manually set or skipped. +When `tag` is omitted, this endpoint only resolves general configurations: +configs tagged `default`. Pass an explicit +tag such as `ASSESSMENT` for tagged config surfaces. + ## Important - This endpoint accepts partial updates using dict[str, Any] for config_blob. - Only the fields that need to be updated should be provided. diff --git a/backend/app/api/docs/config/get_version.md b/backend/app/api/docs/config/get_version.md index 15b83f49c..46b30f331 100644 --- a/backend/app/api/docs/config/get_version.md +++ b/backend/app/api/docs/config/get_version.md @@ -1,4 +1,8 @@ Retrieve a specific version of a configuration. +When `tag` is omitted, this endpoint only resolves versions for general +configurations: configs tagged `default`. Pass +an explicit tag such as `ASSESSMENT` for tagged config surfaces. + Returns the complete version details including the full configuration blob (config_blob) with all LLM parameters. diff --git a/backend/app/api/docs/config/list.md b/backend/app/api/docs/config/list.md index dc057aa8d..c7c28d147 100644 --- a/backend/app/api/docs/config/list.md +++ b/backend/app/api/docs/config/list.md @@ -4,6 +4,10 @@ limit: Maximum number of records to return (default: 100, max: 100) Retrieve all configurations for the current project. -Returns a paginated list of configurations ordered by most recently -updated first. Each configuration includes metadata (name, description, -timestamps) but excludes version details for performance. +When `tag` is omitted, this endpoint returns only general configurations: +configs tagged `default`. Pass an explicit +tag such as `ASSESSMENT` to list configs for that tagged surface. + +Returns a paginated list of configurations ordered by most recently updated +first. Each configuration includes metadata (name, description, timestamps) +but excludes version details for performance. diff --git a/backend/app/api/docs/config/list_versions.md b/backend/app/api/docs/config/list_versions.md index a77d7d4d6..05cc7b29c 100644 --- a/backend/app/api/docs/config/list_versions.md +++ b/backend/app/api/docs/config/list_versions.md @@ -1,4 +1,8 @@ List all versions for a specific configuration. +When `tag` is omitted, versions are available only for general configurations: +configs tagged `default`. Pass an explicit +tag such as `ASSESSMENT` to access versions for that tagged config surface. + Returns versions in descending order (newest first), allowing you to see the evolution of configuration parameters over time. diff --git a/backend/app/api/main.py b/backend/app/api/main.py index b82fd45ad..0b1b338cb 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -3,33 +3,36 @@ from app.api.routes import ( api_keys, assistants, + auth, + collection_job, collections, config, + credentials, + cron, doc_transformation_job, documents, - auth, - login, + evaluations, + fine_tuning, languages, llm, llm_chain, - organization, + login, + model_config, + model_evaluation, + onboarding, openai_conversation, + organization, + private, project, responses, - private, threads, user_project, users, utils, - onboarding, - credentials, - cron, - fine_tuning, - model_evaluation, - collection_job, - model_config, ) -from app.api.routes import evaluations +from app.api.routes import ( + assessment as assessment_routes, +) from app.core.config import settings api_router = APIRouter() @@ -60,6 +63,7 @@ api_router.include_router(fine_tuning.router) api_router.include_router(model_evaluation.router) api_router.include_router(model_config.router) +api_router.include_router(assessment_routes.router) if settings.ENVIRONMENT in ["development", "testing"]: api_router.include_router(private.router) diff --git a/backend/app/api/routes/assessment/__init__.py b/backend/app/api/routes/assessment/__init__.py new file mode 100644 index 000000000..8c88b2e80 --- /dev/null +++ b/backend/app/api/routes/assessment/__init__.py @@ -0,0 +1,13 @@ +"""Main router for assessment API routes.""" + +from fastapi import APIRouter + +from app.api.routes.assessment import assessments, datasets, runs + +router = APIRouter(prefix="/assessment", tags=["Assessment"]) + +router.include_router(datasets.router) +router.include_router(assessments.router) +router.include_router(runs.router) + +__all__ = ["router"] diff --git a/backend/app/api/routes/assessment/assessments.py b/backend/app/api/routes/assessment/assessments.py new file mode 100644 index 000000000..89649f189 --- /dev/null +++ b/backend/app/api/routes/assessment/assessments.py @@ -0,0 +1,175 @@ +"""Parent-assessment endpoints""" + +import logging +from typing import Any, Literal + +from fastapi import APIRouter, Depends, Query +from fastapi.responses import StreamingResponse + +from sqlmodel import Session + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.assessment import ( + build_run_stats, + compute_run_counts, + derive_aggregate_error, + get_assessment_by_id, + get_assessment_runs_for_assessment, + list_assessments as list_assessments_crud, +) +from app.models.assessment import ( + Assessment, + AssessmentPublic, + AssessmentResponse, +) +from app.models.evaluation import EvaluationDataset +from app.services.assessment.service import retry_assessment as retry_assessment_service +from app.services.assessment.utils import build_assessment_results_response +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +def _build_assessment_public( + session: Session, + assessment: Assessment, +) -> AssessmentPublic: + """Build AssessmentPublic with derived counts and run_stats.""" + runs = get_assessment_runs_for_assessment( + session=session, assessment_id=assessment.id + ) + counts = compute_run_counts(runs) + dataset = session.get(EvaluationDataset, assessment.dataset_id) + return AssessmentPublic( + id=assessment.id, + experiment_name=assessment.experiment_name, + dataset_id=assessment.dataset_id, + dataset_name=dataset.name if dataset else None, + status=assessment.status, + counts=counts, + run_stats=build_run_stats(runs), + error_message=derive_aggregate_error(counts), + organization_id=assessment.organization_id, + project_id=assessment.project_id, + inserted_at=assessment.inserted_at, + updated_at=assessment.updated_at, + ) + + +@router.post( + "/assessments/{assessment_id}/retry", + summary="Retry Assessment", + description=load_description("assessment/retry_assessment.md"), + response_model=APIResponse[AssessmentResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def retry_assessment( + assessment_id: int, + session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[AssessmentResponse]: + """Retry a parent assessment using the same dataset/config inputs.""" + assessment = get_assessment_by_id( + session=session, + assessment_id=assessment_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + result = retry_assessment_service( + session=session, + assessment=assessment, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + return APIResponse.success_response(data=result) + + +@router.get( + "/assessments", + summary="List Assessments Parent details", + description=load_description("assessment/list_assessments.md"), + response_model=APIResponse[list[AssessmentPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def list_assessments( + session: SessionDep, + auth_context: AuthContextDep, + limit: int = Query(default=50, ge=1, le=100), + offset: int = Query(default=0, ge=0), +) -> APIResponse[list[AssessmentPublic]]: + """List assessments.""" + assessments = list_assessments_crud( + session=session, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=[ + _build_assessment_public(session, assessment) for assessment in assessments + ] + ) + + +@router.get( + "/assessments/{assessment_id}", + summary="Get Parent Assessment Information", + description=load_description("assessment/get_assessment.md"), + response_model=APIResponse[AssessmentPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def get_assessment( + assessment_id: int, + session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[AssessmentPublic]: + """Get a specific assessment.""" + assessment = get_assessment_by_id( + session=session, + assessment_id=assessment_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response( + data=_build_assessment_public(session, assessment) + ) + + +@router.get( + "/assessments/{assessment_id}/results", + summary="Export Assessment Results", + description=load_description("assessment/export_assessment_results.md"), + response_model=APIResponse[list[dict[str, Any]]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def export_assessment_results( + assessment_id: int, + session: SessionDep, + auth_context: AuthContextDep, + export_format: Literal["json", "csv", "xlsx"] = Query(default="json"), +) -> APIResponse[list[dict[str, Any]]] | StreamingResponse: + """Return child-run results. For CSV/XLSX with multiple runs, returns a ZIP.""" + assessment = get_assessment_by_id( + session=session, + assessment_id=assessment_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + runs = get_assessment_runs_for_assessment( + session=session, assessment_id=assessment_id + ) + + return build_assessment_results_response( + session=session, + assessment=assessment, + runs=runs, + export_format=export_format, + ) diff --git a/backend/app/api/routes/assessment/datasets.py b/backend/app/api/routes/assessment/datasets.py new file mode 100644 index 000000000..d4e71d184 --- /dev/null +++ b/backend/app/api/routes/assessment/datasets.py @@ -0,0 +1,166 @@ +"""Assessment dataset endpoints.""" + +import logging + +from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.core.cloud import get_cloud_storage +from app.crud.assessment.dataset import ( + delete_assessment_dataset, + get_assessment_dataset_by_id, + list_assessment_datasets, +) +from app.models.assessment import AssessmentDatasetResponse +from app.models.evaluation import EvaluationDataset +from app.services.assessment.dataset import upload_dataset as upload_assessment_dataset +from app.services.assessment.validators import validate_dataset_file +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +def _dataset_to_response( + dataset: EvaluationDataset, + signed_url: str | None = None, +) -> AssessmentDatasetResponse: + metadata = dataset.dataset_metadata or {} + return AssessmentDatasetResponse( + dataset_id=dataset.id, + dataset_name=dataset.name, + description=dataset.description, + total_items=metadata.get("total_items_count", 0), + file_extension=metadata.get("file_extension"), + object_store_url=dataset.object_store_url, + signed_url=signed_url, + ) + + +@router.post( + "/datasets", + description=load_description("assessment/upload_dataset.md"), + response_model=APIResponse[AssessmentDatasetResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +async def upload_dataset( + session: SessionDep, + auth_context: AuthContextDep, + file: UploadFile = File( + ..., description="CSV or Excel file to upload as a dataset" + ), + dataset_name: str = Form(..., description="Name for the dataset"), + description: str | None = Form(None, description="Optional dataset description"), +) -> APIResponse[AssessmentDatasetResponse]: + """Upload an assessment dataset (any CSV/Excel file, no column requirements).""" + file_content, file_ext = await validate_dataset_file(file) + + dataset = upload_assessment_dataset( + session=session, + file_content=file_content, + file_ext=file_ext, + dataset_name=dataset_name, + description=description, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response(data=_dataset_to_response(dataset)) + + +@router.get( + "/datasets", + description=load_description("assessment/list_datasets.md"), + response_model=APIResponse[list[AssessmentDatasetResponse]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def list_datasets( + session: SessionDep, + auth_context: AuthContextDep, + limit: int = Query( + default=50, ge=1, le=100, description="Maximum number of datasets to return" + ), + offset: int = Query(default=0, ge=0, description="Number of datasets to skip"), +) -> APIResponse[list[AssessmentDatasetResponse]]: + """List assessment datasets.""" + datasets = list_assessment_datasets( + session=session, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=[_dataset_to_response(dataset) for dataset in datasets] + ) + + +@router.get( + "/datasets/{dataset_id}", + description=load_description("assessment/get_dataset.md"), + response_model=APIResponse[AssessmentDatasetResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def get_dataset( + dataset_id: int, + session: SessionDep, + auth_context: AuthContextDep, + include_signed_url: bool = Query( + False, description="Include a signed URL for downloading the raw file from S3" + ), +) -> APIResponse[AssessmentDatasetResponse]: + """Get a specific assessment dataset.""" + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=dataset_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + signed_url = None + if include_signed_url and dataset.object_store_url: + storage = get_cloud_storage( + session=session, project_id=auth_context.project_.id + ) + signed_url = storage.get_signed_url(dataset.object_store_url) + + return APIResponse.success_response( + data=_dataset_to_response(dataset, signed_url=signed_url) + ) + + +@router.delete( + "/datasets/{dataset_id}", + description=load_description("assessment/delete_dataset.md"), + response_model=APIResponse[dict], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def delete_dataset( + dataset_id: int, + session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[dict]: + """Delete an assessment dataset.""" + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=dataset_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + dataset_name = dataset.name + error = delete_assessment_dataset(session=session, dataset=dataset) + if error: + raise HTTPException(status_code=400, detail=error) + + return APIResponse.success_response( + data={ + "message": ( + f"Successfully deleted dataset '{dataset_name}' (id={dataset_id})" + ), + "dataset_id": dataset_id, + } + ) diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py new file mode 100644 index 000000000..18a9be60e --- /dev/null +++ b/backend/app/api/routes/assessment/runs.py @@ -0,0 +1,223 @@ +"""Assessment run endpoints — one row per config-run inside a parent assessment.""" + +import logging +from typing import Any, Literal + +from fastapi import APIRouter, Depends, Query +from fastapi.responses import StreamingResponse + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.assessment import ( + get_assessment_by_id, + get_assessment_run_by_id as get_run_by_id, + list_assessment_runs as list_runs, +) +from app.models.assessment import ( + Assessment, + AssessmentCreate, + AssessmentResponse, + AssessmentRun, + AssessmentRunPublic, +) +from app.models.evaluation import EvaluationDataset +from app.services.assessment.service import ( + retry_assessment_run as retry_run, +) +from app.services.assessment.service import ( + start_assessment, +) +from app.services.assessment.utils import ( + build_export_response, + build_json_export_rows, + load_export_rows_for_run, + sort_export_rows, +) +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +def _build_run_public( + session: SessionDep, + run: AssessmentRun, +) -> AssessmentRunPublic: + """Build AssessmentRunPublic with parent-derived experiment/dataset info.""" + parent = session.get(Assessment, run.assessment_id) + if parent is None: + logger.warning( + "[_build_run_public] Parent assessment %s not found for run %s", + run.assessment_id, + run.id, + ) + dataset = session.get(EvaluationDataset, parent.dataset_id) if parent else None + return AssessmentRunPublic( + id=run.id, + assessment_id=run.assessment_id, + experiment_name=parent.experiment_name if parent else None, + dataset_id=parent.dataset_id if parent else None, + dataset_name=dataset.name if dataset else None, + config_id=run.config_id, + config_version=run.config_version, + status=run.status, + total_items=run.total_items, + error_message=run.error_message, + input=run.input, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + + +@router.post( + "/runs", + description=load_description("assessment/create_run.md"), + response_model=APIResponse[AssessmentResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def create_assessment_runs( + request: AssessmentCreate, + session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[AssessmentResponse]: + """Submit an assessment and create one child run per config.""" + logger.info( + "[create_assessment_runs] Assessment run submission | experiment=%s | dataset_id=%s | configs=%s", + request.experiment_name, + request.dataset_id, + len(request.configs), + ) + + result = start_assessment( + session=session, + request=request, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response(data=result) + + +@router.post( + "/runs/{run_id}/retry", + description=load_description("assessment/retry_run.md"), + response_model=APIResponse[AssessmentResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def retry_assessment_run( + run_id: int, + session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[AssessmentResponse]: + """Retry a single child assessment run using the same inputs.""" + run = get_run_by_id( + session=session, + run_id=run_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + result = retry_run( + session=session, + run=run, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + return APIResponse.success_response(data=result) + + +@router.get( + "/runs", + description=load_description("assessment/list_runs.md"), + response_model=APIResponse[list[AssessmentRunPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def list_assessment_runs( + session: SessionDep, + auth_context: AuthContextDep, + assessment_id: int | None = Query(default=None, ge=1), + limit: int = Query(default=50, ge=1, le=100), + offset: int = Query(default=0, ge=0), +) -> APIResponse[list[AssessmentRunPublic]]: + """List assessment runs.""" + runs = list_runs( + session=session, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + assessment_id=assessment_id, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=[_build_run_public(session, run) for run in runs] + ) + + +@router.get( + "/runs/{run_id}", + description=load_description("assessment/get_run.md"), + response_model=APIResponse[AssessmentRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def get_assessment_run( + run_id: int, + session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[AssessmentRunPublic]: + """Get a specific assessment run.""" + run = get_run_by_id( + session=session, + run_id=run_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response(data=_build_run_public(session, run)) + + +@router.get( + "/runs/{run_id}/results", + description=load_description("assessment/export_run_results.md"), + response_model=APIResponse[list[dict[str, Any]]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def export_assessment_run_results( + run_id: int, + session: SessionDep, + auth_context: AuthContextDep, + export_format: Literal["json", "csv", "xlsx"] = Query(default="json"), +) -> APIResponse[list[dict[str, Any]]] | StreamingResponse: + """Return flattened results for a single child assessment run.""" + run = get_run_by_id( + session=session, + run_id=run_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + assessment = get_assessment_by_id( + session=session, + assessment_id=run.assessment_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + export_rows = sort_export_rows( + load_export_rows_for_run( + session=session, + run=run, + assessment=assessment, + ) + ) + + base_label = assessment.experiment_name if assessment else f"run_{run.id}" + if export_format != "json": + return build_export_response( + export_rows=export_rows, + export_format=export_format, + base_name=f"{base_label}_run_{run.id}_results", + ) + + return APIResponse.success_response(data=build_json_export_rows(export_rows)) diff --git a/backend/app/api/routes/config/config.py b/backend/app/api/routes/config/config.py index 5f819d042..020a91d16 100644 --- a/backend/app/api/routes/config/config.py +++ b/backend/app/api/routes/config/config.py @@ -1,19 +1,19 @@ from uuid import UUID -from fastapi import APIRouter, Depends, Query, HTTPException -from app.api.deps import SessionDep, AuthContextDep +from fastapi import APIRouter, Depends, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission from app.crud.config import ConfigCrud from app.models import ( - Config, ConfigCreate, - ConfigUpdate, ConfigPublic, + ConfigUpdate, ConfigWithVersion, - ConfigVersion, Message, ) +from app.models.config.config import ConfigTag from app.utils import APIResponse, load_description -from app.api.permissions import Permission, require_permission router = APIRouter() @@ -56,14 +56,27 @@ def list_configs( query: str | None = Query(None, description="search query"), skip: int = Query(0, ge=0, description="Number of records to skip"), limit: int = Query(100, ge=1, le=100, description="Maximum records to return"), + tag: ConfigTag = Query( + ConfigTag.DEFAULT, + description=( + "Config scope. Use 'default' for general configs or 'ASSESSMENT' " + "for assessment configs. " + "Supported values: 'default', 'ASSESSMENT'." + ), + ), ) -> APIResponse[list[ConfigPublic]]: """ List all configurations for the current project. Ordered by updated_at in descending order. """ config_crud = ConfigCrud(session=session, project_id=current_user.project_.id) - configs, has_more = config_crud.read_all(query=query, skip=skip, limit=limit) - return APIResponse.success_response(data=configs, metadata=dict(has_more=has_more)) + configs, has_more = config_crud.read_all( + query=query, + skip=skip, + limit=limit, + tag=tag, + ) + return APIResponse.success_response(data=configs, metadata={"has_more": has_more}) @router.get( diff --git a/backend/app/api/routes/config/version.py b/backend/app/api/routes/config/version.py index fd5e057f7..944df59c3 100644 --- a/backend/app/api/routes/config/version.py +++ b/backend/app/api/routes/config/version.py @@ -1,16 +1,18 @@ from uuid import UUID -from fastapi import APIRouter, Depends, Query, HTTPException, Path -from app.api.deps import SessionDep, AuthContextDep -from app.crud.config import ConfigCrud, ConfigVersionCrud +from fastapi import APIRouter, Depends, Path, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.config import ConfigVersionCrud from app.models import ( - ConfigVersionUpdate, + ConfigVersionItems, ConfigVersionPublic, + ConfigVersionUpdate, Message, - ConfigVersionItems, ) +from app.models.config.config import ConfigTag from app.utils import APIResponse, load_description -from app.api.permissions import Permission, require_permission router = APIRouter() @@ -27,6 +29,13 @@ def create_version( version_create: ConfigVersionUpdate, current_user: AuthContextDep, session: SessionDep, + tag: ConfigTag = Query( + ConfigTag.DEFAULT, + description=( + "Config scope. Use 'default' for general configs or 'ASSESSMENT' " + "for assessment configs." + ), + ), ): """ Create a new version for an existing configuration. @@ -36,7 +45,10 @@ def create_version( Type is inherited from existing config and cannot be changed. """ version_crud = ConfigVersionCrud( - session=session, project_id=current_user.project_.id, config_id=config_id + session=session, + project_id=current_user.project_.id, + config_id=config_id, + tag=tag, ) version = version_crud.create_or_raise(version_create=version_create) @@ -58,13 +70,23 @@ def list_versions( session: SessionDep, skip: int = Query(0, ge=0, description="Number of records to skip"), limit: int = Query(100, ge=1, le=100, description="Maximum records to return"), + tag: ConfigTag = Query( + ConfigTag.DEFAULT, + description=( + "Config scope. Use 'default' for general configs or 'ASSESSMENT' " + "for assessment configs." + ), + ), ): """ List all versions for a specific configuration. Ordered by version number in descending order. """ version_crud = ConfigVersionCrud( - session=session, project_id=current_user.project_.id, config_id=config_id + session=session, + project_id=current_user.project_.id, + config_id=config_id, + tag=tag, ) versions = version_crud.read_all( skip=skip, @@ -89,12 +111,22 @@ def get_version( version_number: int = Path( ..., ge=1, description="The version number of the config" ), + tag: ConfigTag = Query( + ConfigTag.DEFAULT, + description=( + "Config scope. Use 'default' for general configs or 'ASSESSMENT' " + "for assessment configs." + ), + ), ): """ Get a specific version of a config. """ version_crud = ConfigVersionCrud( - session=session, project_id=current_user.project_.id, config_id=config_id + session=session, + project_id=current_user.project_.id, + config_id=config_id, + tag=tag, ) version = version_crud.exists_or_raise(version_number=version_number) return APIResponse.success_response( @@ -116,12 +148,22 @@ def delete_version( version_number: int = Path( ..., ge=1, description="The version number of the config" ), + tag: ConfigTag = Query( + ConfigTag.DEFAULT, + description=( + "Config scope. Use 'default' for general configs or 'ASSESSMENT' " + "for assessment configs." + ), + ), ): """ Delete a specific version of a config. """ version_crud = ConfigVersionCrud( - session=session, project_id=current_user.project_.id, config_id=config_id + session=session, + project_id=current_user.project_.id, + config_id=config_id, + tag=tag, ) version_crud.delete_or_raise(version_number=version_number) diff --git a/backend/app/api/routes/cron.py b/backend/app/api/routes/cron.py index 819185fce..6c7c6b031 100644 --- a/backend/app/api/routes/cron.py +++ b/backend/app/api/routes/cron.py @@ -1,14 +1,13 @@ import logging import sentry_sdk -from sentry_sdk.types import MonitorConfig - -from app.api.permissions import Permission, require_permission from fastapi import APIRouter, Depends +from sentry_sdk.types import MonitorConfig from app.api.deps import SessionDep +from app.api.permissions import Permission, require_permission from app.core.config import settings -from app.crud.evaluations import process_all_pending_evaluations_sync +from app.crud.evaluations import process_all_pending_evaluations logger = logging.getLogger(__name__) @@ -43,7 +42,7 @@ monitor_slug="evaluation-cron-job", monitor_config=EVALUATION_CRON_MONITOR_CONFIG, ) -def evaluation_cron_job( +async def evaluation_cron_job( session: SessionDep, ) -> dict: """ @@ -53,7 +52,8 @@ def evaluation_cron_job( 1. Fetches all evaluation runs with status='processing' 2. Groups them by project_id 3. Processes each project with its OpenAI/Langfuse clients - 4. Returns aggregated results + 4. Also polls pending assessment evaluations + 5. Returns aggregated results Hidden from Swagger documentation. Requires authentication via FIRST_SUPERUSER credentials. @@ -61,7 +61,35 @@ def evaluation_cron_job( logger.info("[evaluation_cron_job] Cron job invoked") try: - result = process_all_pending_evaluations_sync(session=session) + # Process all pending evaluations across all organizations + result = await process_all_pending_evaluations(session=session) + + try: + from app.crud.assessment.cron import ( + poll_all_pending_assessment_evaluations, + ) + + assessment_result = await poll_all_pending_assessment_evaluations( + session=session + ) + + # Merge assessment results into the main result + result["assessment"] = assessment_result + result["total_processed"] = result.get( + "total_processed", 0 + ) + assessment_result.get("processed", 0) + result["total_failed"] = result.get( + "total_failed", 0 + ) + assessment_result.get("failed", 0) + result["total_still_processing"] = result.get( + "total_still_processing", 0 + ) + assessment_result.get("still_processing", 0) + except Exception as ae: + logger.error( + f"[evaluation_cron_job] Assessment polling failed: {ae}", + exc_info=True, + ) + result["assessment_error"] = str(ae) logger.info( f"[evaluation_cron_job] Completed: " diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index f8a61fc6b..f121bb9ee 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -114,7 +114,9 @@ def create_batch( try: # Create JSONL content - jsonl_content = "\n".join(json.dumps(item) for item in jsonl_data) + jsonl_content = "\n".join( + json.dumps(item, ensure_ascii=False) for item in jsonl_data + ) # Upload JSONL file to Gemini File API uploaded_file = self.upload_file(jsonl_content, purpose="batch") diff --git a/backend/app/core/batch/openai.py b/backend/app/core/batch/openai.py index 77a9e8235..9a406ee3d 100644 --- a/backend/app/core/batch/openai.py +++ b/backend/app/core/batch/openai.py @@ -57,7 +57,9 @@ def create_batch( try: # Step 1: Upload file file_id = self.upload_file( - content="\n".join([json.dumps(line) for line in jsonl_data]), + content="\n".join( + json.dumps(line, ensure_ascii=False) for line in jsonl_data + ), purpose="batch", ) diff --git a/backend/app/crud/assessment/__init__.py b/backend/app/crud/assessment/__init__.py new file mode 100644 index 000000000..cd71bff91 --- /dev/null +++ b/backend/app/crud/assessment/__init__.py @@ -0,0 +1,46 @@ +"""Assessment-related CRUD operations.""" + +from app.crud.assessment.core import ( + build_run_stats, + compute_run_counts, + create_assessment, + create_assessment_run, + derive_aggregate_error, + derive_assessment_status, + get_assessment_by_id, + get_assessment_run_by_id, + get_assessment_runs_for_assessment, + list_assessment_runs, + list_assessments, + recompute_assessment_status, + update_assessment_run_status, +) +from app.crud.assessment.dataset import ( + create_assessment_dataset, + delete_assessment_dataset, + get_assessment_dataset_by_id, + list_assessment_datasets, +) +from app.models.assessment import AssessmentRunCounts, AssessmentRunStat + +__all__ = [ + "AssessmentRunCounts", + "AssessmentRunStat", + "build_run_stats", + "compute_run_counts", + "create_assessment_dataset", + "create_assessment", + "create_assessment_run", + "delete_assessment_dataset", + "derive_aggregate_error", + "derive_assessment_status", + "get_assessment_by_id", + "get_assessment_dataset_by_id", + "get_assessment_run_by_id", + "get_assessment_runs_for_assessment", + "list_assessment_runs", + "list_assessment_datasets", + "list_assessments", + "recompute_assessment_status", + "update_assessment_run_status", +] diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py new file mode 100644 index 000000000..7b5966d5e --- /dev/null +++ b/backend/app/crud/assessment/batch.py @@ -0,0 +1,500 @@ +"""Assessment batch JSONL construction and submission. + +Builds provider-specific JSONL files from dataset rows + config, +then submits them via the core batch infrastructure. +""" + +import csv +import io +import logging +from typing import Any + +import openpyxl +from openpyxl.utils.exceptions import InvalidFileException +from sqlmodel import Session + +from app.core.batch import BATCH_KEY, start_batch_job +from app.core.batch.openai import OpenAIBatchProvider +from app.core.cloud import get_cloud_storage +from app.models.assessment import ( + Assessment, + AssessmentAttachment, + AssessmentRun, +) +from app.models.batch_job import BatchJob, BatchJobType +from app.models.evaluation import EvaluationDataset +from app.models.llm.request import ConfigBlob +from app.services.assessment.mappers import ( + map_kaapi_to_google_params, + map_kaapi_to_openai_params, + normalize_llm_text, +) +from app.services.assessment.utils.attachments import ( + resolve_attachment_values, + resolve_image_mime_and_payload, + split_attachment_urls, + split_data_url, + to_direct_attachment_url, +) +from app.services.llm.providers.registry import LLMProvider + +logger = logging.getLogger(__name__) + + +def _load_dataset_rows( + session: Session, + dataset: EvaluationDataset, +) -> list[dict[str, str]]: + """Load dataset rows from object store. + + Returns a list of dicts (one per row) with column-name keys. + """ + if not dataset.object_store_url: + raise ValueError(f"Dataset {dataset.id} has no object_store_url") + + storage = get_cloud_storage(session=session, project_id=dataset.project_id) + + # Download the file content via stream() + body = storage.stream(dataset.object_store_url) + file_content = body.read() + if not file_content: + raise ValueError(f"Failed to download dataset from {dataset.object_store_url}") + + metadata = dataset.dataset_metadata or {} + file_ext = metadata.get("file_extension", ".csv") + + if file_ext == ".xls": + raise ValueError( + "Legacy Excel format (.xls) is not supported. Please upload .xlsx or .csv." + ) + if file_ext == ".xlsx": + return _parse_excel_rows(file_content) + return _parse_csv_rows(file_content) + + +def _parse_csv_rows(content: bytes) -> list[dict[str, str]]: + """Parse CSV content into list of row dicts.""" + for encoding in ("utf-8-sig", "utf-8", "latin-1"): + try: + text = content.decode(encoding) + break + except (UnicodeDecodeError, ValueError): + continue + else: + text = content.decode("utf-8", errors="replace") + + reader = csv.DictReader(io.StringIO(text)) + return [row for row in reader if any(v and v.strip() for v in row.values())] + + +def _parse_excel_rows(content: bytes) -> list[dict[str, str]]: + """Parse Excel content into list of row dicts.""" + wb = None + try: + wb = openpyxl.load_workbook(io.BytesIO(content), read_only=True, data_only=True) + ws = wb.active + if ws is None: + return [] + + rows_iter = ws.iter_rows(values_only=True) + header = next(rows_iter, None) + if header is None: + return [] + + columns = [ + str(col_header) if col_header is not None else f"col_{idx}" + for idx, col_header in enumerate(header) + ] + result = [] + for row in rows_iter: + if row and any(cell is not None for cell in row): + row_dict = { + columns[idx]: str(cell) if cell is not None else "" + for idx, cell in enumerate(row) + if idx < len(columns) + } + result.append(row_dict) + + return result + except InvalidFileException as e: + logger.warning("[_parse_excel_rows] Invalid XLSX file content: %s", e) + raise + except Exception as e: + logger.warning( + "[_parse_excel_rows] Failed to parse XLSX rows | %s", e, exc_info=True + ) + raise ValueError("Failed to parse XLSX dataset rows") from e + finally: + if wb is not None: + wb.close() + + +def _build_text_prompt( + row: dict[str, str], + text_columns: list[str], + prompt_template: str | None, +) -> str: + """Build the text prompt for a single row. + + If prompt_template is provided, placeholders like {column_name} are replaced. + Otherwise, all text column values are concatenated with newlines. + """ + if prompt_template: + prompt = normalize_llm_text(prompt_template) + for col in text_columns: + placeholder = "{" + col + "}" + prompt = prompt.replace(placeholder, normalize_llm_text(row.get(col, ""))) + return prompt + + # No template: concatenate text columns + parts = [ + normalize_llm_text(row.get(col, "")) + for col in text_columns + if row.get(col, "").strip() + ] + return "\n".join(parts) + + +def build_openai_jsonl( + rows: list[dict[str, str]], + text_columns: list[str], + attachments: list[AssessmentAttachment], + prompt_template: str | None, + openai_params: dict, +) -> list[dict[str, Any]]: + """Build OpenAI batch JSONL data from dataset rows. + + Each line follows the OpenAI batch format: + { + "custom_id": "row_0", + "method": "POST", + "url": "/v1/responses", + "body": { model, instructions, temperature, input: [{role, content: [...]}] } + } + """ + jsonl_data = [] + + for idx, row in enumerate(rows): + # Build input array + input_parts: list[dict[str, Any]] = [] + + # Text prompt + text_prompt = _build_text_prompt(row, text_columns, prompt_template) + if text_prompt.strip(): + input_parts.append({"type": "input_text", "text": text_prompt}) + + # Attachments + for att in attachments: + cell_value = row.get(att.column, "") + input_parts.extend(resolve_attachment_values(cell_value, att)) + + if not input_parts: + logger.warning("[build_openai_jsonl] Skipping empty row | idx=%s", idx) + continue + + # Build body from mapped params + body = dict(openai_params) + body["input"] = [ + { + "role": "user", + "content": input_parts, + } + ] + + jsonl_data.append( + { + BATCH_KEY: f"row_{idx}", + "method": "POST", + "url": "/v1/responses", + "body": body, + } + ) + + return jsonl_data + + +def build_google_jsonl( + rows: list[dict[str, str]], + text_columns: list[str], + attachments: list[AssessmentAttachment], + prompt_template: str | None, + google_params: dict, +) -> list[dict[str, Any]]: + """Build Google (Gemini) batch JSONL data from dataset rows. + + Each line follows the Gemini batch format: + { + "key": "row_0", + "request": { "contents": [{ "parts": [...], "role": "user" }] } + } + """ + jsonl_data = [] + + for idx, row in enumerate(rows): + parts: list[dict[str, Any]] = [] + + # Text prompt + text_prompt = _build_text_prompt(row, text_columns, prompt_template) + if text_prompt.strip(): + parts.append({"text": text_prompt}) + + # Attachments (Gemini uses file_data for inline content) + for att in attachments: + cell_value = row.get(att.column, "").strip() + if not cell_value: + continue + + cell_values = ( + split_attachment_urls(cell_value) + if att.format == "url" + else [cell_value] + ) + + for item_value in cell_values: + normalized_value = ( + to_direct_attachment_url(item_value, att.type) + if att.format == "url" + else item_value + ) + if att.type == "image": + mime_type, payload = resolve_image_mime_and_payload( + normalized_value, + att.format, + ) + if att.format == "url": + parts.append( + { + "fileData": { + "mimeType": mime_type, + "fileUri": normalized_value, + } + } + ) + else: + parts.append( + { + "inlineData": { + "mimeType": mime_type, + "data": payload, + } + } + ) + elif att.type == "pdf": + if att.format == "url": + parts.append( + { + "fileData": { + "mimeType": "application/pdf", + "fileUri": normalized_value, + } + } + ) + else: + parts.append( + { + "inlineData": { + "mimeType": "application/pdf", + "data": split_data_url(normalized_value)[1], + } + } + ) + + if not parts: + logger.warning("[build_google_jsonl] Skipping empty row | idx=%s", idx) + continue + + system_instruction = google_params.get("instructions") + request: dict[str, Any] = { + "contents": [{"parts": parts, "role": "user"}], + } + if system_instruction: + request["systemInstruction"] = {"parts": [{"text": system_instruction}]} + + generation_config: dict[str, Any] = {} + temperature = google_params.get("temperature") + if temperature is not None: + generation_config["temperature"] = temperature + top_p = google_params.get("top_p") + if top_p is not None: + generation_config["topP"] = top_p + max_output_tokens = google_params.get("max_output_tokens") + if max_output_tokens is not None: + generation_config["maxOutputTokens"] = max_output_tokens + thinking_config = google_params.get("thinking_config") + if thinking_config: + generation_config["thinkingConfig"] = thinking_config + output_schema = google_params.get("output_schema") + if output_schema: + generation_config["responseMimeType"] = "application/json" + generation_config["responseSchema"] = output_schema + if generation_config: + request["generationConfig"] = generation_config + + jsonl_data.append( + { + "metadata": {"key": f"row_{idx}"}, + "request": request, + } + ) + + return jsonl_data + + +def submit_assessment_batch( + session: Session, + run: AssessmentRun, + assessment: Assessment, + dataset: EvaluationDataset, + config_blob: ConfigBlob, + assessment_input: dict[str, Any], + organization_id: int, + project_id: int, +) -> BatchJob: + """Build JSONL and submit a batch for one assessment run. + + Args: + session: Database session + run: The AssessmentRun to process + dataset: The dataset to read rows from + config_blob: Resolved configuration blob + assessment_input: Assessment input config (prompt_template, text_columns, etc.) + organization_id: Organization ID + project_id: Project ID + + Returns: + Created BatchJob record + """ + text_columns = assessment_input.get("text_columns", []) + prompt_template = assessment_input.get("prompt_template") + system_instruction = assessment_input.get("system_instruction") + attachments_raw = assessment_input.get("attachments", []) + output_schema = assessment_input.get("output_schema") + attachments = [AssessmentAttachment(**a) for a in attachments_raw] + + # Load dataset rows + rows = _load_dataset_rows(session, dataset) + if not rows: + raise ValueError(f"Dataset {dataset.id} has no rows") + + logger.info( + "[submit_assessment_batch] Building JSONL | run_id=%s | rows=%s | provider=%s", + run.id, + len(rows), + config_blob.completion.provider, + ) + + # Determine provider and build params + completion = config_blob.completion + provider_name = completion.provider or "openai" + + params = dict(completion.params) + params.pop("instructions", None) + params.pop("system_instruction", None) + if isinstance(system_instruction, str) and system_instruction.strip(): + params["instructions"] = system_instruction + if output_schema: + params["output_schema"] = output_schema + + # Determine the base provider (openai or google) + base_provider = provider_name.replace("-native", "") + + if base_provider == LLMProvider.OPENAI: + mapped_params, warnings = map_kaapi_to_openai_params( + session=session, + kaapi_params=params, + ) + if warnings: + logger.info("[submit_assessment_batch] Mapper warnings: %s", warnings) + + jsonl_data = build_openai_jsonl( + rows=rows, + text_columns=text_columns, + attachments=attachments, + prompt_template=prompt_template, + openai_params=mapped_params, + ) + + # Get OpenAI client and submit + from app.utils import get_openai_client + + openai_client = get_openai_client( + session=session, + org_id=organization_id, + project_id=project_id, + ) + provider = OpenAIBatchProvider(client=openai_client) + + batch_config = { + "endpoint": "/v1/responses", + "description": f"Assessment: {assessment.experiment_name}", + "completion_window": "24h", + } + + batch_job = start_batch_job( + session=session, + provider=provider, + provider_name="openai", + job_type=BatchJobType.ASSESSMENT, + organization_id=organization_id, + project_id=project_id, + jsonl_data=jsonl_data, + config=batch_config, + ) + + elif base_provider == LLMProvider.GOOGLE: + mapped_params, warnings = map_kaapi_to_google_params(params) + if warnings: + logger.info("[submit_assessment_batch] Mapper warnings: %s", warnings) + + jsonl_data = build_google_jsonl( + rows=rows, + text_columns=text_columns, + attachments=attachments, + prompt_template=prompt_template, + google_params=mapped_params, + ) + + # Get Gemini client and submit + from app.core.batch import GeminiBatchProvider + from app.core.batch.client import GeminiClient + + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=organization_id, + project_id=project_id, + ) + provider = GeminiBatchProvider( + client=gemini_client.client, + model=f"models/{mapped_params.get('model', 'gemini-2.5-pro')}", + ) + + batch_config = { + "display_name": f"assessment-{assessment.experiment_name}", + "model": f"models/{mapped_params.get('model', 'gemini-2.5-pro')}", + } + + batch_job = start_batch_job( + session=session, + provider=provider, + provider_name="google", + job_type=BatchJobType.ASSESSMENT, + organization_id=organization_id, + project_id=project_id, + jsonl_data=jsonl_data, + config=batch_config, + ) + + else: + raise ValueError( + f"Unsupported provider for assessment batches: {provider_name}" + ) + + logger.info( + "[submit_assessment_batch] Submitted batch | run_id=%s | batch_job_id=%s | provider=%s | items=%s", + run.id, + batch_job.id, + base_provider, + len(jsonl_data), + ) + + return batch_job diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py new file mode 100644 index 000000000..c91626660 --- /dev/null +++ b/backend/app/crud/assessment/core.py @@ -0,0 +1,319 @@ +"""Assessment CRUD — operations for Assessment and AssessmentRun tables.""" + +import logging +from typing import Any +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session, select + +from app.core.util import now +from app.models.assessment import ( + Assessment, + AssessmentRun, + AssessmentRunCounts, + AssessmentRunStat, +) + +logger = logging.getLogger(__name__) + + +def create_assessment( + session: Session, + experiment_name: str, + dataset_id: int, + organization_id: int, + project_id: int, +) -> Assessment: + """Create a parent assessment row.""" + assessment = Assessment( + experiment_name=experiment_name, + dataset_id=dataset_id, + status="pending", + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + session.add(assessment) + try: + session.commit() + session.refresh(assessment) + except Exception as e: + session.rollback() + logger.error(f"[create_assessment] Failed: {e}", exc_info=True) + raise + + logger.info( + f"[create_assessment] Created assessment id={assessment.id} | " + f"experiment={experiment_name}" + ) + return assessment + + +def get_assessment_by_id( + session: Session, + assessment_id: int, + organization_id: int, + project_id: int, +) -> Assessment: + """Get a specific parent assessment row.""" + statement = ( + select(Assessment) + .where(Assessment.id == assessment_id) + .where(Assessment.organization_id == organization_id) + .where(Assessment.project_id == project_id) + ) + assessment = session.exec(statement).first() + if not assessment: + raise HTTPException( + status_code=404, + detail=f"Assessment {assessment_id} not found or not accessible", + ) + return assessment + + +def list_assessments( + session: Session, + organization_id: int, + project_id: int, + limit: int = 50, + offset: int = 0, +) -> list[Assessment]: + """List parent assessment rows.""" + statement = ( + select(Assessment) + .where(Assessment.organization_id == organization_id) + .where(Assessment.project_id == project_id) + .order_by(Assessment.inserted_at.desc()) + .limit(limit) + .offset(offset) + ) + return list(session.exec(statement).all()) + + +def create_assessment_run( + session: Session, + assessment_id: int, + config_id: UUID, + config_version: int, + assessment_input: dict[str, Any], +) -> AssessmentRun: + """Create an assessment run record under a parent assessment.""" + run = AssessmentRun( + assessment_id=assessment_id, + config_id=config_id, + config_version=config_version, + status="pending", + total_items=0, + input=assessment_input, + inserted_at=now(), + updated_at=now(), + ) + + session.add(run) + try: + session.commit() + session.refresh(run) + except Exception as e: + session.rollback() + logger.error(f"[create_assessment_run] Failed: {e}", exc_info=True) + raise + + logger.info( + f"[create_assessment_run] Created run id={run.id} | " + f"assessment_id={assessment_id} | " + f"config_id={config_id} v{config_version}" + ) + return run + + +def get_assessment_run_by_id( + session: Session, + run_id: int, + organization_id: int, + project_id: int, +) -> AssessmentRun: + """Get a specific assessment run by ID, scoped via parent organization/project.""" + statement = ( + select(AssessmentRun) + .join(Assessment, Assessment.id == AssessmentRun.assessment_id) + .where(AssessmentRun.id == run_id) + .where(Assessment.organization_id == organization_id) + .where(Assessment.project_id == project_id) + ) + run = session.exec(statement).first() + if not run: + raise HTTPException( + status_code=404, + detail=f"Assessment run {run_id} not found or not accessible", + ) + return run + + +def get_assessment_runs_for_assessment( + session: Session, + assessment_id: int, +) -> list[AssessmentRun]: + """List child runs for a parent assessment, ordered by id.""" + statement = ( + select(AssessmentRun) + .where(AssessmentRun.assessment_id == assessment_id) + .order_by(AssessmentRun.id.asc()) + ) + return list(session.exec(statement).all()) + + +def list_assessment_runs( + session: Session, + organization_id: int, + project_id: int, + assessment_id: int | None = None, + limit: int = 50, + offset: int = 0, +) -> list[AssessmentRun]: + """List assessment runs, optionally filtered by assessment_id.""" + statement = ( + select(AssessmentRun) + .join(Assessment, Assessment.id == AssessmentRun.assessment_id) + .where(Assessment.organization_id == organization_id) + .where(Assessment.project_id == project_id) + ) + if assessment_id is not None: + statement = statement.where(AssessmentRun.assessment_id == assessment_id) + + statement = ( + statement.order_by(AssessmentRun.inserted_at.desc()).limit(limit).offset(offset) + ) + return list(session.exec(statement).all()) + + +def update_assessment_run_status( + session: Session, + run: AssessmentRun, + status: str, + error_message: str | None = None, + batch_job_id: int | None = None, + total_items: int | None = None, + object_store_url: str | None = None, +) -> AssessmentRun: + """Update an assessment run's status and optional fields.""" + run.status = status + run.updated_at = now() + + if error_message is not None: + run.error_message = error_message + if batch_job_id is not None: + run.batch_job_id = batch_job_id + if total_items is not None: + run.total_items = total_items + if object_store_url is not None: + run.object_store_url = object_store_url + + session.add(run) + try: + session.commit() + session.refresh(run) + except Exception as e: + session.rollback() + logger.error(f"[update_assessment_run_status] Failed: {e}", exc_info=True) + raise + + return run + + +def compute_run_counts(runs: list[AssessmentRun]) -> AssessmentRunCounts: + """Aggregate child run statuses into counters.""" + return AssessmentRunCounts( + total=len(runs), + pending=sum(1 for run in runs if run.status == "pending"), + processing=sum( + 1 for run in runs if run.status in {"processing", "in_progress"} + ), + completed=sum(1 for run in runs if run.status == "completed"), + failed=sum(1 for run in runs if run.status == "failed"), + ) + + +def derive_assessment_status(counts: AssessmentRunCounts) -> str: + """Compute parent assessment status from child run counters.""" + if counts.total == 0: + return "pending" + if counts.completed == counts.total: + return "completed" + if counts.failed == counts.total: + return "failed" + if ( + counts.completed > 0 + and counts.failed > 0 + and counts.pending == 0 + and counts.processing == 0 + ): + return "completed_with_errors" + if counts.pending > 0 and counts.pending == counts.total: + return "pending" + return "processing" + + +def build_run_stats(runs: list[AssessmentRun]) -> list[AssessmentRunStat]: + """Build per-run summary entries for embedding in parent responses.""" + return [ + AssessmentRunStat( + run_id=run.id, + config_id=str(run.config_id) if run.config_id else None, + config_version=run.config_version, + status=run.status, + total_items=run.total_items, + error_message=run.error_message, + updated_at=run.updated_at, + ) + for run in runs + ] + + +def derive_aggregate_error(counts: AssessmentRunCounts) -> str | None: + """Build an aggregate error summary string for parent assessments.""" + if counts.failed > 0: + return f"{counts.failed} of {counts.total} run(s) failed" + return None + + +def recompute_assessment_status( + session: Session, + assessment_id: int, + organization_id: int | None = None, + project_id: int | None = None, +) -> Assessment: + """Recompute the parent's `status` from its child runs. + + Counters and run_stats are derived on-read; only `status` is persisted so + cron's `WHERE status IN (...)` filter remains index-friendly. + """ + if organization_id is None and project_id is None: + assessment = session.get(Assessment, assessment_id) + else: + statement = select(Assessment).where(Assessment.id == assessment_id) + if organization_id is not None: + statement = statement.where(Assessment.organization_id == organization_id) + if project_id is not None: + statement = statement.where(Assessment.project_id == project_id) + assessment = session.exec(statement).first() + if not assessment: + raise ValueError(f"Assessment {assessment_id} not found") + + runs = get_assessment_runs_for_assessment(session, assessment_id) + counts = compute_run_counts(runs) + assessment.status = derive_assessment_status(counts) + assessment.updated_at = now() + + session.add(assessment) + try: + session.commit() + session.refresh(assessment) + except Exception as e: + session.rollback() + logger.error(f"[recompute_assessment_status] Failed: {e}", exc_info=True) + raise + + return assessment diff --git a/backend/app/crud/assessment/cron.py b/backend/app/crud/assessment/cron.py new file mode 100644 index 000000000..c69b3157e --- /dev/null +++ b/backend/app/crud/assessment/cron.py @@ -0,0 +1,177 @@ +"""Cron processing functions for assessment evaluations.""" + +import logging +from typing import Any + +from sqlmodel import Session, select + +from app.crud.assessment import ( + compute_run_counts, + get_assessment_runs_for_assessment, + recompute_assessment_status, + update_assessment_run_status, +) +from app.crud.assessment.processing import ( + check_and_process_assessment, + format_assessment_failure_message, +) +from app.models.assessment import Assessment, AssessmentRun + +logger = logging.getLogger(__name__) + + +def _log_config_progress( + result: dict[str, Any], run: AssessmentRun, assessment: Assessment +) -> None: + """Emit explicit config-level logs for grouped assessment experiments.""" + action = result.get("action") + if action not in {"processed", "failed"}: + return + + logger.info( + "[poll_all_pending_assessment_evaluations] Experiment config update | " + "experiment=%s | assessment_id=%s | run_id=%s | config_id=%s | " + "config_version=%s | action=%s | status=%s | provider_status=%s", + assessment.experiment_name, + run.assessment_id, + run.id, + run.config_id, + run.config_version, + action, + result.get("current_status"), + result.get("provider_status"), + ) + + +async def poll_all_pending_assessment_evaluations( + session: Session, +) -> dict[str, Any]: + """Poll all non-terminal parent assessments and their active child runs.""" + statement = select(Assessment).where( + Assessment.status.in_(("pending", "processing")), + ) + pending_assessments = list(session.exec(statement).all()) + + if not pending_assessments: + logger.info( + "[poll_all_pending_assessment_evaluations] " "No active assessments found" + ) + return { + "total": 0, + "processed": 0, + "failed": 0, + "still_processing": 0, + "details": [], + } + + logger.info( + "[poll_all_pending_assessment_evaluations] Found %s active assessments", + len(pending_assessments), + ) + + all_results: list[dict[str, Any]] = [] + processed = 0 + failed = 0 + still_processing = 0 + + for assessment in pending_assessments: + runs = get_assessment_runs_for_assessment( + session=session, assessment_id=assessment.id + ) + active_runs = [run for run in runs if run.status == "processing"] + + if not active_runs: + refreshed = recompute_assessment_status( + session=session, assessment_id=assessment.id + ) + counts = compute_run_counts(runs) + logger.info( + "[poll_all_pending_assessment_evaluations] No active runs for assessment %s | " + "recomputed status=%s | total_runs=%s | completed=%s | failed=%s", + assessment.id, + refreshed.status, + counts.total, + counts.completed, + counts.failed, + ) + if refreshed.status in {"pending", "processing"}: + still_processing += 1 + continue + + for run in active_runs: + try: + result = await check_and_process_assessment( + run=run, + session=session, + ) + all_results.append(result) + _log_config_progress(result, run, assessment) + + if result["action"] == "processed": + processed += 1 + elif result["action"] == "failed": + failed += 1 + else: + still_processing += 1 + + except Exception as e: + error_msg = format_assessment_failure_message(e) + logger.error( + "[poll_all_pending_assessment_evaluations] Failed run %s | " + "experiment=%s | assessment_id=%s | config_id=%s | config_version=%s | error=%s", + run.id, + assessment.experiment_name, + run.assessment_id, + run.config_id, + run.config_version, + error_msg, + exc_info=True, + ) + try: + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message=error_msg, + ) + recompute_assessment_status( + session=session, assessment_id=assessment.id + ) + failure_result = { + "assessment_id": run.assessment_id, + "run_id": run.id, + "experiment_name": assessment.experiment_name, + "config_id": str(run.config_id) if run.config_id else None, + "config_version": run.config_version, + "action": "failed", + "error": error_msg, + "current_status": "failed", + } + all_results.append(failure_result) + failed += 1 + except Exception as cleanup_exc: + logger.error( + "[poll_all_pending_assessment_evaluations] Cleanup failed for run %s | " + "assessment_id=%s | experiment=%s | error=%s", + run.id, + run.assessment_id, + assessment.experiment_name, + cleanup_exc, + exc_info=True, + ) + failed += 1 + + logger.info( + "[poll_all_pending_assessment_evaluations] Summary | processed=%s | failed=%s | still_processing=%s", + processed, + failed, + still_processing, + ) + + return { + "total": len(pending_assessments), + "processed": processed, + "failed": failed, + "still_processing": still_processing, + "details": all_results, + } diff --git a/backend/app/crud/assessment/dataset.py b/backend/app/crud/assessment/dataset.py new file mode 100644 index 000000000..7562f3814 --- /dev/null +++ b/backend/app/crud/assessment/dataset.py @@ -0,0 +1,162 @@ +"""CRUD operations for assessment datasets.""" + +import logging +from typing import Any + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.core.util import now +from app.models.assessment import Assessment +from app.models.evaluation import EvaluationDataset +from app.models.stt_evaluation import EvaluationType + +logger = logging.getLogger(__name__) + + +def create_assessment_dataset( + *, + session: Session, + name: str, + dataset_metadata: dict[str, Any], + organization_id: int, + project_id: int, + description: str | None = None, + object_store_url: str | None = None, +) -> EvaluationDataset: + """Create an assessment dataset backed by the shared evaluation_dataset table.""" + dataset = EvaluationDataset( + name=name, + description=description, + type=EvaluationType.ASSESSMENT.value, + dataset_metadata=dataset_metadata, + object_store_url=object_store_url, + langfuse_dataset_id=None, + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + try: + session.add(dataset) + session.commit() + session.refresh(dataset) + except IntegrityError as e: + session.rollback() + logger.error( + "[create_assessment_dataset] Dataset name already exists | " + "name=%s | org_id=%s | project_id=%s", + name, + organization_id, + project_id, + exc_info=True, + ) + raise HTTPException( + status_code=409, + detail=( + f"Dataset with name '{name}' already exists in this " + "organization and project. Please choose a different name." + ), + ) from e + except Exception as e: + session.rollback() + logger.error( + "[create_assessment_dataset] Failed to create dataset | name=%s", + name, + exc_info=True, + ) + raise HTTPException( + status_code=500, + detail=f"Failed to save assessment dataset metadata", + ) from e + + logger.info( + "[create_assessment_dataset] Created assessment dataset | " + "id=%s | name=%s | org_id=%s | project_id=%s", + dataset.id, + name, + organization_id, + project_id, + ) + return dataset + + +def get_assessment_dataset_by_id( + *, + session: Session, + dataset_id: int, + organization_id: int, + project_id: int, +) -> EvaluationDataset: + """Fetch an assessment dataset by ID, scoped to organization and project.""" + statement = ( + select(EvaluationDataset) + .where(EvaluationDataset.id == dataset_id) + .where(EvaluationDataset.organization_id == organization_id) + .where(EvaluationDataset.project_id == project_id) + .where(EvaluationDataset.type == EvaluationType.ASSESSMENT.value) + ) + dataset = session.exec(statement).first() + if not dataset: + raise HTTPException( + status_code=404, + detail=f"Dataset {dataset_id} not found or not accessible", + ) + return dataset + + +def list_assessment_datasets( + *, + session: Session, + organization_id: int, + project_id: int, + limit: int = 50, + offset: int = 0, +) -> list[EvaluationDataset]: + """List assessment datasets for an organization and project.""" + statement = ( + select(EvaluationDataset) + .where(EvaluationDataset.organization_id == organization_id) + .where(EvaluationDataset.project_id == project_id) + .where(EvaluationDataset.type == EvaluationType.ASSESSMENT.value) + .order_by(EvaluationDataset.inserted_at.desc()) + .limit(limit) + .offset(offset) + ) + return list(session.exec(statement).all()) + + +def delete_assessment_dataset( + *, session: Session, dataset: EvaluationDataset +) -> str | None: + """Delete an unused assessment dataset.""" + statement = select(Assessment).where(Assessment.dataset_id == dataset.id) + assessments = session.exec(statement).all() + if assessments: + return ( + f"Cannot delete dataset {dataset.id}: it is being used by " + f"{len(assessments)} assessment(s). Please delete the assessments first." + ) + + try: + dataset_id = dataset.id + dataset_name = dataset.name + session.delete(dataset) + session.commit() + except Exception as e: + session.rollback() + logger.error( + "[delete_assessment_dataset] Failed to delete dataset | dataset_id=%s", + dataset.id, + exc_info=True, + ) + return f"Failed to delete dataset: {e}" + + logger.info( + "[delete_assessment_dataset] Deleted assessment dataset | id=%s | name=%s", + dataset_id, + dataset_name, + ) + return None diff --git a/backend/app/crud/assessment/processing.py b/backend/app/crud/assessment/processing.py new file mode 100644 index 000000000..46d442354 --- /dev/null +++ b/backend/app/crud/assessment/processing.py @@ -0,0 +1,495 @@ +"""Assessment batch result processing and polling. + +processing but adapted for multi-provider (OpenAI + Google) support. +""" + +import json +import logging +from typing import Any + +from fastapi import HTTPException +from sqlmodel import Session + +from app.core.batch import ( + BATCH_KEY, + GeminiBatchProvider, + OpenAIBatchProvider, + download_batch_results, + poll_batch_status, + upload_batch_results_to_object_store, +) +from app.core.batch.base import BatchProvider +from app.core.batch.client import GeminiClient +from app.core.batch.gemini import BatchJobState, extract_text_from_response_dict +from app.crud.assessment import ( + recompute_assessment_status, + update_assessment_run_status, +) +from app.crud.job import get_batch_job +from app.models.assessment import Assessment, AssessmentRun +from app.services.llm.providers.registry import LLMProvider +from app.utils import get_openai_client + +logger = logging.getLogger(__name__) + + +def format_assessment_failure_message(exc: Exception) -> str: + """Extract a DB-safe error message from assessment polling exceptions.""" + if isinstance(exc, HTTPException): + detail = exc.detail + if isinstance(detail, str): + message = detail.strip() + if message: + return message + elif detail: + try: + return json.dumps(detail, ensure_ascii=False) + except (TypeError, ValueError): + pass + + message = str(exc).strip() + return message or exc.__class__.__name__ + + +def _sanitize_json_output(raw: str) -> str: + """Escape control characters inside JSON string values that the model emitted literally. + + Strict structured-output mode should prevent this, but long Indic-language + responses sometimes contain literal newlines / tabs inside string values, + making the JSON unparseable. This function walks the raw text once and + replaces any bare control characters found while inside a JSON string with + their JSON escape equivalents, producing valid JSON without touching the + surrounding structure. + """ + result: list[str] = [] + in_string = False + escape_next = False + + for ch in raw: + if escape_next: + result.append(ch) + escape_next = False + elif ch == "\\": + result.append(ch) + escape_next = True + elif ch == '"': + in_string = not in_string + result.append(ch) + elif in_string and ch == "\n": + result.append("\\n") + elif in_string and ch == "\r": + result.append("\\r") + elif in_string and ch == "\t": + result.append("\\t") + else: + result.append(ch) + + return "".join(result) + + +def _get_batch_provider( + session: Session, + provider_name: str, + organization_id: int, + project_id: int, +) -> BatchProvider: + """Get the appropriate batch provider instance.""" + if provider_name in (LLMProvider.OPENAI, LLMProvider.OPENAI_NATIVE): + openai_client = get_openai_client( + session=session, + org_id=organization_id, + project_id=project_id, + ) + return OpenAIBatchProvider(client=openai_client) + + if provider_name in (LLMProvider.GOOGLE, LLMProvider.GOOGLE_NATIVE): + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=organization_id, + project_id=project_id, + ) + return GeminiBatchProvider(client=gemini_client.client) + + raise ValueError(f"Unsupported provider for assessment polling: {provider_name}") + + +def parse_assessment_output( + raw_results: list[dict[str, Any]], + provider_name: str, +) -> list[dict[str, Any]]: + """Parse batch results into assessment output format. + + Args: + raw_results: Raw results from batch provider + provider_name: Provider name ('openai' or 'google') + + Returns: + List of parsed results with row_id, output text, usage, etc. + """ + results = [] + + for result in raw_results: + row_id = result.get(BATCH_KEY) or result.get("key", "unknown") + + if provider_name in (LLMProvider.OPENAI, LLMProvider.OPENAI_NATIVE): + response = result.get("response", {}) + response_status = response.get("status_code") + response_body = result.get("response", {}).get("body", {}) + error = result.get("error") + + if error: + results.append( + { + "row_id": row_id, + "output": None, + "error": error.get("message", str(error)), + "usage": None, + } + ) + continue + + if response_status and response_status >= 400: + response_error = response_body.get("error", {}) + results.append( + { + "row_id": row_id, + "output": None, + "error": response_error.get( + "message", f"Request failed with status {response_status}" + ), + "usage": None, + "response_id": response_body.get("id"), + } + ) + continue + + # Prefer the convenience field when present; otherwise concatenate all + # output_text fragments so structured JSON isn't truncated mid-object. + generated_text = response_body.get("output_text") or "" + + if not isinstance(generated_text, str) or not generated_text: + output = response_body.get("output", "") + text_chunks: list[str] = [] + + if isinstance(output, list): + for item in output: + if isinstance(item, dict) and item.get("type") == "message": + for content in item.get("content", []): + if ( + isinstance(content, dict) + and content.get("type") == "output_text" + ): + text = content.get("text") + if isinstance(text, str) and text: + text_chunks.append(text) + generated_text = "".join(text_chunks) + elif isinstance(output, str): + generated_text = output + + if generated_text: + try: + generated_text = json.dumps( + json.loads(generated_text), ensure_ascii=False + ) + except (json.JSONDecodeError, TypeError): + # Model emitted literal control characters inside string values. + # Sanitize and retry once. + try: + sanitized = _sanitize_json_output(generated_text) + generated_text = json.dumps( + json.loads(sanitized), ensure_ascii=False + ) + except (json.JSONDecodeError, TypeError): + pass + + results.append( + { + "row_id": row_id, + "output": generated_text, + "error": None if generated_text else "Empty response output", + "usage": response_body.get("usage"), + "response_id": response_body.get("id"), + } + ) + + elif provider_name in (LLMProvider.GOOGLE, LLMProvider.GOOGLE_NATIVE): + response = result.get("response") + error = result.get("error") + + if error: + results.append( + { + "row_id": row_id, + "output": None, + "error": str(error), + "usage": None, + } + ) + continue + + if response: + text = extract_text_from_response_dict(response) + results.append( + { + "row_id": row_id, + "output": text if text else None, + "error": None if text else "Empty response output", + "usage": None, + } + ) + else: + results.append( + { + "row_id": row_id, + "output": None, + "error": "Empty response", + "usage": None, + } + ) + + else: + logger.error( + "[parse_assessment_output] Unknown provider '%s' for row_id=%s — skipping", + provider_name, + row_id, + ) + + logger.info( + "[parse_assessment_output] Parsed %s results | provider=%s", + len(results), + provider_name, + ) + return results + + +async def check_and_process_assessment( + run: AssessmentRun, + session: Session, +) -> dict[str, Any]: + """Check assessment batch status and process if completed. + + Args: + run: AssessmentRun to check + session: Database session + + Returns: + Dict with status information + """ + log_prefix = f"[check_and_process_assessment][assessment_run={run.id}]" + previous_status = run.status + parent_pre = session.get(Assessment, run.assessment_id) + experiment_name_pre = parent_pre.experiment_name if parent_pre else None + + try: + if not run.batch_job_id: + raise ValueError(f"Assessment run {run.id} has no batch_job_id") + + batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) + if not batch_job: + raise ValueError(f"BatchJob {run.batch_job_id} not found") + + parent = parent_pre + if not parent: + raise ValueError(f"Parent assessment {run.assessment_id} not found") + + # Get provider and poll status + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=parent.organization_id, + project_id=parent.project_id, + ) + status_result = poll_batch_status( + session=session, + provider=provider, + batch_job=batch_job, + ) + session.refresh(batch_job) + + provider_status = batch_job.provider_status + + if ( + provider_status == "completed" + or provider_status == BatchJobState.SUCCEEDED.value + ): + if not batch_job.provider_output_file_id: + request_counts = status_result.get("request_counts") or {} + error_file_id = status_result.get("error_file_id") + failed_count = request_counts.get("failed", 0) + completed_count = request_counts.get("completed", 0) + total_count = request_counts.get("total", 0) + + if error_file_id and failed_count > 0 and completed_count == 0: + error_msg = ( + f"Batch completed with {failed_count} failed request(s)" + f" and no successful outputs" + ) + if total_count: + error_msg += f" out of {total_count}" + error_msg += f" (error_file_id: {error_file_id})" + + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message=error_msg, + ) + recompute_assessment_status( + session=session, assessment_id=run.assessment_id + ) + + return { + "run_id": run.id, + "assessment_id": run.assessment_id, + "experiment_name": experiment_name_pre, + "previous_status": previous_status, + "current_status": "failed", + "provider_status": provider_status, + "action": "failed", + "error": error_msg, + } + + logger.info( + f"{log_prefix} Batch completed but output file is not ready yet | " + f"batch_job_id={batch_job.id} | provider_status={provider_status}" + ) + return { + "run_id": run.id, + "assessment_id": run.assessment_id, + "experiment_name": experiment_name_pre, + "previous_status": previous_status, + "current_status": run.status, + "provider_status": provider_status, + "action": "no_change", + } + + # Download and process results + raw_results = download_batch_results(provider=provider, batch_job=batch_job) + + # Upload raw results to object store + object_store_url = None + try: + object_store_url = upload_batch_results_to_object_store( + session=session, batch_job=batch_job, results=raw_results + ) + except Exception as e: + logger.error( + "%s Object store upload failed — results may be unrecoverable " + "if the provider deletes the output file before next poll: %s", + log_prefix, + e, + exc_info=True, + ) + + # Parse results + parsed = parse_assessment_output(raw_results, batch_job.provider) + error_count = sum(1 for result in parsed if result.get("error")) + success_count = sum(1 for result in parsed if not result.get("error")) + + # Update run status + error_msg = f"{error_count} item(s) failed" if error_count > 0 else None + run_status = ( + "failed" + if parsed and success_count == 0 and error_count > 0 + else "completed" + ) + + if not parsed: + run_status = "failed" + error_msg = "Batch completed but no valid results were produced" + + update_assessment_run_status( + session=session, + run=run, + status=run_status, + error_message=error_msg, + object_store_url=object_store_url, + ) + recompute_assessment_status( + session=session, assessment_id=run.assessment_id + ) + + return { + "run_id": run.id, + "assessment_id": run.assessment_id, + "experiment_name": experiment_name_pre, + "previous_status": previous_status, + "current_status": run_status, + "provider_status": provider_status, + "action": "processed" if run_status == "completed" else "failed", + "total_results": len(parsed), + "errors": error_count, + } + + elif provider_status in ( + "failed", + "expired", + "cancelled", + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, + ): + error_msg = batch_job.error_message or f"Batch {provider_status}" + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message=error_msg, + ) + recompute_assessment_status( + session=session, assessment_id=run.assessment_id + ) + + return { + "run_id": run.id, + "assessment_id": run.assessment_id, + "experiment_name": experiment_name_pre, + "previous_status": previous_status, + "current_status": "failed", + "provider_status": provider_status, + "action": "failed", + "error": error_msg, + } + + else: + # Still processing + return { + "run_id": run.id, + "assessment_id": run.assessment_id, + "experiment_name": experiment_name_pre, + "previous_status": previous_status, + "current_status": run.status, + "provider_status": provider_status, + "action": "no_change", + } + + except Exception as e: + error_msg = format_assessment_failure_message(e) + logger.error( + f"{log_prefix} Error checking assessment: {error_msg}", + exc_info=True, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message=error_msg, + ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + return { + "run_id": run.id, + "assessment_id": run.assessment_id, + "experiment_name": experiment_name_pre, + "previous_status": previous_status, + "current_status": "failed", + "provider_status": "unknown", + "action": "failed", + "error": error_msg, + } + + +async def poll_all_pending_assessments(session: Session) -> dict[str, Any]: + """Backward-compatible wrapper for parent-first assessment polling.""" + from app.crud.assessment.cron import poll_all_pending_assessment_evaluations + + return await poll_all_pending_assessment_evaluations(session=session) diff --git a/backend/app/crud/config/config.py b/backend/app/crud/config/config.py index ea1e849d6..12a1a60fe 100644 --- a/backend/app/crud/config/config.py +++ b/backend/app/crud/config/config.py @@ -1,17 +1,17 @@ import logging from uuid import UUID -from typing import Tuple -from sqlmodel import Session, select, and_ from fastapi import HTTPException +from sqlmodel import Session, and_, select +from app.core.util import now from app.models import ( Config, ConfigCreate, ConfigUpdate, ConfigVersion, ) -from app.core.util import now +from app.models.config.config import ConfigTag logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ def __init__(self, session: Session, project_id: int): def create_or_raise( self, config_create: ConfigCreate - ) -> Tuple[Config, ConfigVersion]: + ) -> tuple[Config, ConfigVersion]: """ Create a new configuration with an initial version. """ @@ -38,6 +38,7 @@ def create_or_raise( name=config_create.name, description=config_create.description, project_id=self.project_id, + tag=config_create.tag, ) self.session.add(config) @@ -72,7 +73,7 @@ def create_or_raise( ) raise HTTPException( status_code=500, - detail=f"Unexpected error occurred: failed to create config", + detail="Unexpected error occurred: failed to create config", ) def read_one(self, config_id: UUID) -> Config | None: @@ -86,7 +87,11 @@ def read_one(self, config_id: UUID) -> Config | None: return self.session.exec(statement).one_or_none() def read_all( - self, query: str | None, skip: int = 0, limit: int = 100 + self, + query: str | None, + skip: int = 0, + limit: int = 100, + tag: ConfigTag = ConfigTag.DEFAULT, ) -> tuple[list[Config], bool]: filters = [ Config.project_id == self.project_id, @@ -96,6 +101,8 @@ def read_all( if query: filters.append(Config.name.ilike(f"{query}%")) + filters.append(self._tag_scope_filter(tag)) + statement = ( select(Config) .where(and_(*filters)) @@ -152,6 +159,29 @@ def exists_or_raise(self, config_id: UUID) -> Config: return config + def exists_in_tag_scope_or_raise( + self, config_id: UUID, tag: ConfigTag = ConfigTag.DEFAULT + ) -> Config: + statement = select(Config).where( + and_( + Config.id == config_id, + Config.project_id == self.project_id, + Config.deleted_at.is_(None), + self._tag_scope_filter(tag), + ) + ) + config = self.session.exec(statement).one_or_none() + if config is None: + raise HTTPException( + status_code=404, + detail=f"config with id '{config_id}' not found", + ) + + return config + + def _tag_scope_filter(self, tag: ConfigTag): + return Config.tag == tag + def _check_unique_name_or_raise(self, name: str) -> None: if self._read_by_name(name): raise HTTPException( diff --git a/backend/app/crud/config/version.py b/backend/app/crud/config/version.py index e1335d171..378ce7291 100644 --- a/backend/app/crud/config/version.py +++ b/backend/app/crud/config/version.py @@ -1,23 +1,25 @@ import logging -from uuid import UUID from typing import Any +from uuid import UUID -from sqlmodel import Session, select, and_, func from fastapi import HTTPException -from sqlalchemy.orm import defer from pydantic import ValidationError +from sqlalchemy.orm import defer +from sqlmodel import Session, and_, select -from .config import ConfigCrud from app.core.util import now from app.models import ( Config, ConfigVersion, ConfigVersionCreate, - ConfigVersionUpdate, ConfigVersionItems, + ConfigVersionUpdate, ) +from app.models.config.config import ConfigTag from app.models.llm.request import ConfigBlob +from .config import ConfigCrud + logger = logging.getLogger(__name__) @@ -26,10 +28,17 @@ class ConfigVersionCrud: CRUD operations for configuration versions scoped to a project. """ - def __init__(self, session: Session, config_id: UUID, project_id: int): + def __init__( + self, + session: Session, + config_id: UUID, + project_id: int, + tag: ConfigTag = ConfigTag.DEFAULT, + ): self.session = session self.project_id = project_id self.config_id = config_id + self.tag = tag def create_or_raise(self, version_create: ConfigVersionUpdate) -> ConfigVersion: """ @@ -243,7 +252,10 @@ def _get_next_version(self, config_id: UUID) -> int | None: def _config_exists_or_raise(self, config_id: UUID) -> Config: """Check if a config exists in the project.""" config_crud = ConfigCrud(session=self.session, project_id=self.project_id) - config_crud.exists_or_raise(config_id) + return config_crud.exists_in_tag_scope_or_raise( + config_id=config_id, + tag=self.tag, + ) def _validate_config_type_unchanged( self, version_create: ConfigVersionCreate diff --git a/backend/app/crud/evaluations/core.py b/backend/app/crud/evaluations/core.py index 6374dca76..7ae103f51 100644 --- a/backend/app/crud/evaluations/core.py +++ b/backend/app/crud/evaluations/core.py @@ -13,6 +13,7 @@ from app.crud.evaluations.langfuse import fetch_trace_scores_from_langfuse from app.crud.evaluations.score import EvaluationScore from app.models import EvaluationRun, EvaluationRunUpdate +from app.models.config.config import ConfigTag from app.models.llm.request import ConfigBlob, LLMCallConfig from app.models.stt_evaluation import EvaluationType from app.services.llm.jobs import resolve_config_blob @@ -25,6 +26,7 @@ def resolve_evaluation_config( config_id: UUID, config_version: int, project_id: int, + tag: ConfigTag = ConfigTag.DEFAULT, ) -> tuple[ConfigBlob | None, str | None]: """ Resolve config blob from stored config management. @@ -42,6 +44,7 @@ def resolve_evaluation_config( session=session, config_id=config_id, project_id=project_id, + tag=tag, ) return resolve_config_blob( diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 05f39032e..cb1089a49 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,5 +1,7 @@ from sqlmodel import SQLModel +from app.models.assessment import Assessment, AssessmentRun # noqa: F401 + from .api_key import ( APIKey, APIKeyBase, diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py new file mode 100644 index 000000000..78035a738 --- /dev/null +++ b/backend/app/models/assessment.py @@ -0,0 +1,345 @@ +"""Assessment models — DB tables, Pydantic schemas, and LLM param wrappers.""" + +from datetime import datetime +from typing import TYPE_CHECKING, Any, Literal, Optional +from uuid import UUID + +from pydantic import BaseModel, Field +from sqlalchemy import Column, Index, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import Relationship, SQLModel + +from app.core.util import now +from app.models.llm.request import TextLLMParams + +if TYPE_CHECKING: + from app.models.batch_job import BatchJob + + +class Assessment(SQLModel, table=True): + """Parent assessment — one experiment over a dataset, grouping N config runs.""" + + __tablename__ = "assessment" + __table_args__ = ( + Index( + "idx_assessment_org_project", + "organization_id", + "project_id", + "inserted_at", + ), + Index("idx_assessment_status", "status"), + ) + + id: int | None = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the assessment"}, + ) + experiment_name: str = SQLField( + index=True, + sa_column_kwargs={"comment": "Name of the experiment grouping its config runs"}, + ) + dataset_id: int = SQLField( + foreign_key="evaluation_dataset.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the evaluation dataset"}, + ) + status: str = SQLField( + default="pending", + sa_column_kwargs={ + "comment": ( + "Aggregate status: pending, processing, completed, " + "completed_with_errors, failed" + ) + }, + ) + organization_id: int = SQLField( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = SQLField( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the assessment was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the assessment was last updated"}, + ) + + +class AssessmentRun(SQLModel, table=True): + """Child run — a single config evaluation against the parent's dataset.""" + + __tablename__ = "assessment_run" + __table_args__ = (Index("idx_assessment_run_assessment_id", "assessment_id"),) + + id: int | None = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the assessment run"}, + ) + assessment_id: int = SQLField( + foreign_key="assessment.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the parent assessment"}, + ) + config_id: UUID = SQLField( + foreign_key="config.id", + nullable=False, + sa_column_kwargs={"comment": "Reference to the stored config used"}, + ) + config_version: int = SQLField( + nullable=False, + sa_column_kwargs={"comment": "Version of the config used"}, + ) + status: str = SQLField( + default="pending", + sa_column_kwargs={ + "comment": "Run status: pending, processing, completed, failed" + }, + ) + batch_job_id: int | None = SQLField( + default=None, + foreign_key="batch_job.id", + nullable=True, + ondelete="SET NULL", + sa_column_kwargs={"comment": "Reference to the batch job processing this run"}, + ) + total_items: int = SQLField( + default=0, + nullable=False, + sa_column_kwargs={"comment": "Total number of dataset items in this run"}, + ) + input: dict[str, Any] = SQLField( + sa_column=Column( + JSONB, + nullable=False, + comment=( + "Assessment input: prompt_template, system_instruction, " + "text_columns, attachments, output_schema" + ), + ), + ) + object_store_url: str | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "S3 URL of processed batch results"}, + ) + error_message: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Error message if the run failed", + ), + ) + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the run was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the run was last updated"}, + ) + + batch_job: Optional["BatchJob"] = Relationship( + sa_relationship_kwargs={"foreign_keys": "[AssessmentRun.batch_job_id]"} + ) + assessment: Optional["Assessment"] = Relationship( + sa_relationship_kwargs={"foreign_keys": "[AssessmentRun.assessment_id]"} + ) + + +class AssessmentRunCounts(BaseModel): + """Derived counters for a parent assessment, computed from its child runs.""" + + total: int = 0 + pending: int = 0 + processing: int = 0 + completed: int = 0 + failed: int = 0 + + +class AssessmentRunStat(BaseModel): + """Summary entry for one child run, embedded in parent responses.""" + + run_id: int + config_id: str | None + config_version: int | None + status: str + total_items: int + error_message: str | None = None + updated_at: datetime | None = None + + +class AssessmentPublic(BaseModel): + """Public model for a parent assessment row, with derived run aggregates.""" + + id: int + experiment_name: str + dataset_id: int + dataset_name: str | None = None + status: str + counts: AssessmentRunCounts = AssessmentRunCounts() + run_stats: list[AssessmentRunStat] = [] + error_message: str | None = None + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class AssessmentRunPublic(BaseModel): + """Public view of an assessment run.""" + + id: int + assessment_id: int + experiment_name: str | None = None + dataset_id: int | None = None + dataset_name: str | None = None + config_id: UUID + config_version: int + status: str + total_items: int + error_message: str | None = None + input: dict[str, Any] | None = Field( + None, + description=( + "Assessment input config: prompt_template, system_instruction, " + "text_columns, attachments, output_schema" + ), + ) + inserted_at: datetime + updated_at: datetime + + +class AssessmentTextLLMParams(TextLLMParams): + """TextLLMParams extended with response_format and output_schema for assessments.""" + + response_format: Literal["text", "json_object"] = Field( + default="text", + description="Response format: 'text' or 'json_object'", + ) + output_schema: dict[str, Any] | None = Field( + default=None, + description="JSON Schema for structured output", + ) + + +class AssessmentAttachment(BaseModel): + """Attachment column configuration.""" + + column: str = Field(..., description="Column name containing the attachment data") + type: Literal["image", "pdf"] = Field(..., description="Attachment type") + format: Literal["url", "base64"] = Field(..., description="Data format") + + +class AssessmentConfigRef(BaseModel): + """Reference to a stored config version.""" + + config_id: UUID = Field(..., description="Stored config UUID") + config_version: int = Field(..., ge=1, description="Config version number") + + +class AssessmentCreate(BaseModel): + """Request body for creating an assessment and child runs.""" + + experiment_name: str = Field( + ..., min_length=1, description="Name for this assessment experiment" + ) + dataset_id: int = Field(..., description="ID of the uploaded dataset") + prompt_template: str | None = Field( + None, + description=( + "Prompt template with {column} placeholders. " + "If null, all text columns are concatenated." + ), + ) + system_instruction: str | None = Field( + None, + description="System instruction used when generating assessment outputs", + ) + text_columns: list[str] = Field( + default_factory=list, description="Column names mapped as text input" + ) + attachments: list[AssessmentAttachment] = Field( + default_factory=list, description="Attachment column configurations" + ) + output_schema: dict[str, Any] | None = Field( + None, description="JSON Schema for structured output" + ) + configs: list[AssessmentConfigRef] = Field( + ..., min_length=1, max_length=4, description="Config versions to run" + ) + + +class AssessmentRunSummary(BaseModel): + """Summary of a single assessment run created for one config.""" + + run_id: int + assessment_id: int + config_id: str + config_version: int + status: str + + +class AssessmentResponse(BaseModel): + """Response after submitting an assessment run request.""" + + assessment_id: int + experiment_name: str + dataset_id: int + dataset_name: str | None + num_configs: int + runs: list[AssessmentRunSummary] + + +class AssessmentExportRow(BaseModel): + """Flattened assessment result row for CSV/XLSX export.""" + + assessment_id: int + experiment_name: str + dataset_id: int | None + dataset_name: str | None + run_id: int + run_name: str + run_status: str + config_id: UUID | None + config_version: int | None + row_id: str + result_status: str + input_data: dict[str, str] | None = None + output: str | None = None + error: str | None = None + response_id: str | None = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None + updated_at: datetime + + +class AssessmentDatasetResponse(BaseModel): + """Response model for assessment dataset.""" + + dataset_id: int + dataset_name: str + description: str | None = None + total_items: int = 0 + file_extension: str | None = None + object_store_url: str | None = None + signed_url: str | None = None diff --git a/backend/app/models/batch_job.py b/backend/app/models/batch_job.py index a01667831..426d44f59 100644 --- a/backend/app/models/batch_job.py +++ b/backend/app/models/batch_job.py @@ -16,6 +16,7 @@ class BatchJobType(str, Enum): STT_EVALUATION = "stt_evaluation" TTS_EVALUATION = "tts_evaluation" EMBEDDING = "embedding" + ASSESSMENT = "assessment" if TYPE_CHECKING: diff --git a/backend/app/models/config/config.py b/backend/app/models/config/config.py index df3577e45..8ee56cdb4 100644 --- a/backend/app/models/config/config.py +++ b/backend/app/models/config/config.py @@ -1,15 +1,33 @@ -from uuid import UUID, uuid4 from datetime import datetime -from typing import TYPE_CHECKING, Any +from enum import StrEnum +from uuid import UUID, uuid4 -from sqlmodel import Field, SQLModel, Index, text +import sqlalchemy as sa from pydantic import field_validator +from sqlalchemy.dialects import postgresql +from sqlmodel import Field, Index, SQLModel, text from app.core.util import now from app.models.llm.request import ConfigBlob + from .version import ConfigVersionPublic +class ConfigTag(StrEnum): + """Config classification tag.""" + + DEFAULT = "default" + ASSESSMENT = "ASSESSMENT" + + +_CONFIG_TAG_PG_ENUM = postgresql.ENUM( + ConfigTag, + name="config_tag", + values_callable=lambda enum_cls: [member.value for member in enum_cls], + create_type=False, +) + + class ConfigBase(SQLModel): """Base model for LLM configuration metadata""" @@ -45,6 +63,13 @@ class Config(ConfigBase, table=True): "updated_at", postgresql_where=text("deleted_at IS NULL"), ), + Index( + "idx_config_project_id_tag_active", + "project_id", + "tag", + text("updated_at DESC"), + postgresql_where=text("deleted_at IS NULL"), + ), ) id: UUID = Field( @@ -60,6 +85,19 @@ class Config(ConfigBase, table=True): sa_column_kwargs={"comment": "Reference to the project"}, ) + tag: ConfigTag = Field( + default=ConfigTag.DEFAULT, + sa_column=sa.Column( + _CONFIG_TAG_PG_ENUM, + nullable=False, + server_default=sa.text("'default'::config_tag"), + comment=( + "Tag classifying the config: 'default' for general use, " + "'ASSESSMENT' for assessment use." + ), + ), + ) + inserted_at: datetime = Field( default_factory=now, nullable=False, @@ -90,6 +128,13 @@ class ConfigCreate(ConfigBase): max_length=512, description="Optional message describing the changes in this version", ) + tag: ConfigTag = Field( + default=ConfigTag.DEFAULT, + description=( + "Optional tag for classifying this config. Omit to store 'default'; " + "set 'ASSESSMENT' for assessment use." + ), + ) @field_validator("config_blob") def validate_blob_not_empty(cls, value): @@ -103,6 +148,10 @@ class ConfigUpdate(SQLModel): description: str | None = Field( default=None, max_length=512, description="Optional description" ) + tag: ConfigTag | None = Field( + default=None, + description=("Optional tag for classifying this config. "), + ) class ConfigPublic(ConfigBase): diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index c9130d3c3..202e58af8 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -108,8 +108,8 @@ class EvaluationDataset(SQLModel, table=True): type: str = SQLField( default="text", max_length=20, - description="Evaluation type: text, stt, or tts", - sa_column_kwargs={"comment": "Evaluation type: text, stt, or tts"}, + description="Evaluation type: text, assessment, stt, or tts", + sa_column_kwargs={"comment": "Evaluation type: text, assessment, stt, or tts"}, ) language_id: int | None = SQLField( default=None, @@ -213,8 +213,8 @@ class EvaluationRun(SQLModel, table=True): type: str = SQLField( default="text", max_length=20, - description="Evaluation type: text, stt, or tts", - sa_column_kwargs={"comment": "Evaluation type: text, stt, or tts"}, + description="Evaluation type: text, assessment, stt, or tts", + sa_column_kwargs={"comment": "Evaluation type: text, assessment, stt, or tts"}, ) language_id: int | None = SQLField( default=None, diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index a5c337a44..80aaa3008 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -31,11 +31,32 @@ class TextLLMParams(SQLModel): default=None, description="Reasoning configuration or instructions", ) + effort: Literal["none", "minimal", "low", "medium", "high", "xhigh"] | None = Field( + default=None, + description="Model-specific reasoning effort setting for reasoning-capable models", + ) + summary: Literal["auto", "detailed", "concise"] | None = Field( + default=None, + description=( + "Model-specific reasoning summary preference. " "Use null/None to disable." + ), + ) temperature: float | None = Field( default=0.1, ge=0.0, le=2.0, ) + top_p: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description="Nucleus sampling parameter", + ) + max_output_tokens: int | None = Field( + default=None, + ge=1, + description="Maximum tokens to generate in the response", + ) max_num_results: int | None = Field( default=None, ge=1, diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index 5e953e36d..0e3dc06a5 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -1,7 +1,7 @@ """STT Evaluation models for Speech-to-Text evaluation feature.""" from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any from pydantic import BaseModel, Field, field_validator @@ -18,10 +18,11 @@ SUPPORTED_STT_MODELS = ["gemini-2.5-pro"] -class EvaluationType(str, Enum): +class EvaluationType(StrEnum): """Type of evaluation dataset/run.""" TEXT = "text" + ASSESSMENT = "assessment" STT = "stt" TTS = "tts" diff --git a/backend/app/services/assessment/__init__.py b/backend/app/services/assessment/__init__.py new file mode 100644 index 000000000..2565a6c8f --- /dev/null +++ b/backend/app/services/assessment/__init__.py @@ -0,0 +1 @@ +"""Assessment services package.""" diff --git a/backend/app/services/assessment/dataset.py b/backend/app/services/assessment/dataset.py new file mode 100644 index 000000000..943fb34d4 --- /dev/null +++ b/backend/app/services/assessment/dataset.py @@ -0,0 +1,209 @@ +"""Dataset management service for assessments (CSV + XLSX). + +Upload stores files directly to object store as-is (no column validation, +no format conversion). Row count is computed for metadata. +""" + +import csv +import io +import logging + +from fastapi import HTTPException +from sqlmodel import Session + +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import generate_timestamped_filename, upload_to_object_store +from app.crud.assessment.dataset import create_assessment_dataset +from app.models.evaluation import EvaluationDataset +from app.services.evaluations.validators import sanitize_dataset_name + +logger = logging.getLogger(__name__) + +try: + from openpyxl.utils.exceptions import InvalidFileException +except Exception: # pragma: no cover - openpyxl is expected in runtime deps + + class InvalidFileException(Exception): + pass + + +_MIME_TYPES = { + ".csv": "text/csv", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", +} + + +def _upload_file_to_object_store( + session: Session, + project_id: int, + file_content: bytes, + file_ext: str, + dataset_name: str, +) -> str | None: + """Upload the raw file to object store, preserving original format.""" + extension = file_ext.lstrip(".") + filename = generate_timestamped_filename(dataset_name, extension=extension) + content_type = _MIME_TYPES.get(file_ext, "application/octet-stream") + + try: + storage = get_cloud_storage(session=session, project_id=project_id) + return upload_to_object_store( + storage=storage, + content=file_content, + filename=filename, + subdirectory="datasets", + content_type=content_type, + ) + except Exception as e: + logger.warning( + f"[_upload_file_to_object_store] Failed to upload | {e}", + exc_info=True, + ) + return None + + +def _count_csv_rows(content: bytes) -> int: + """Count data rows in a CSV file (excluding header).""" + try: + for encoding in ("utf-8-sig", "utf-8", "latin-1"): + try: + text = content.decode(encoding) + break + except (UnicodeDecodeError, ValueError): + continue + else: + text = content.decode("utf-8", errors="replace") + + reader = csv.reader(io.StringIO(text)) + next(reader, None) + return sum(1 for row in reader if any(cell.strip() for cell in row)) + except Exception as e: + logger.warning(f"[_count_csv_rows] Failed to count rows | {e}") + return 0 + + +def _count_excel_rows(content: bytes) -> int: + """Count data rows in an Excel file (excluding header).""" + wb = None + try: + import openpyxl + + wb = openpyxl.load_workbook(io.BytesIO(content), read_only=True, data_only=True) + ws = wb.active + if ws is None: + return 0 + + rows_iter = ws.iter_rows(values_only=True) + header = next(rows_iter, None) + if header is None: + return 0 + + return sum( + 1 for row in rows_iter if row and any(cell is not None for cell in row) + ) + except InvalidFileException as e: + logger.warning("[_count_excel_rows] Invalid XLSX file content: %s", e) + raise + except Exception as e: + logger.warning( + "[_count_excel_rows] Failed to count rows | %s", e, exc_info=True + ) + raise ValueError("Failed to parse XLSX file") from e + finally: + if wb is not None: + wb.close() + + +def _count_rows(content: bytes, file_ext: str) -> int: + """Count data rows in a file (CSV or XLSX), excluding the header.""" + if file_ext == ".xls": + raise ValueError( + "Legacy Excel format (.xls) is not supported. Please upload .xlsx or .csv." + ) + if file_ext == ".xlsx": + return _count_excel_rows(content) + return _count_csv_rows(content) + + +def upload_dataset( + session: Session, + file_content: bytes, + file_ext: str, + dataset_name: str, + description: str | None, + organization_id: int, + project_id: int, +) -> EvaluationDataset: + """Upload a dataset file directly to object store and record metadata.""" + original_name = dataset_name + try: + dataset_name = sanitize_dataset_name(dataset_name) + except ValueError as e: + raise HTTPException(status_code=422, detail=f"Invalid dataset name: {str(e)}") + + if original_name != dataset_name: + logger.info( + f"[upload_dataset] Dataset name sanitized | '{original_name}' -> '{dataset_name}'" + ) + + try: + row_count = _count_rows(file_content, file_ext) + except InvalidFileException as e: + raise HTTPException( + status_code=422, + detail="Invalid XLSX file content. Please upload a valid .xlsx file.", + ) from e + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) from e + except Exception as e: + raise HTTPException( + status_code=422, + detail="Unable to parse dataset file. Please upload a valid CSV or XLSX file.", + ) from e + + logger.info( + f"[upload_dataset] Uploading dataset | dataset={dataset_name} | " + f"file_type={file_ext} | rows={row_count} | " + f"org_id={organization_id} | project_id={project_id}" + ) + + object_store_url = _upload_file_to_object_store( + session=session, + project_id=project_id, + file_content=file_content, + file_ext=file_ext, + dataset_name=dataset_name, + ) + if not object_store_url: + logger.error( + f"[upload_dataset] Object store upload failed | dataset={dataset_name} | " + f"org_id={organization_id} | project_id={project_id}" + ) + raise HTTPException( + status_code=500, + detail="Failed to upload dataset file. Please try again.", + ) + + metadata = { + "file_extension": file_ext, + "file_size_bytes": len(file_content), + "total_items_count": row_count, + } + + dataset = create_assessment_dataset( + session=session, + name=dataset_name, + description=description, + dataset_metadata=metadata, + object_store_url=object_store_url, + langfuse_dataset_id=None, + organization_id=organization_id, + project_id=project_id, + ) + + logger.info( + f"[upload_dataset] Created dataset record | " + f"id={dataset.id} | name={dataset_name} | rows={row_count}" + ) + + return dataset diff --git a/backend/app/services/assessment/mappers.py b/backend/app/services/assessment/mappers.py new file mode 100644 index 000000000..8756ee08c --- /dev/null +++ b/backend/app/services/assessment/mappers.py @@ -0,0 +1,237 @@ +import logging +import unicodedata + +from google.genai import _transformers as genai_transformers +from sqlmodel import Session + +from app.crud.model_config import is_reasoning_model + +logger = logging.getLogger(__name__) + + +def normalize_llm_text(text: str) -> str: + if not isinstance(text, str) or not text: + return text + + text = text.replace("\\n", "\n") + text = text.replace("\\t", "\t") + text = text.replace("\\r", "\r") + text = text.replace('\\"', '"') + text = text.replace("\\\\", "\\") + + text = unicodedata.normalize("NFC", text) + + return text + + +def _ensure_openai_strict_schema(schema: dict) -> dict: + """Recursively add additionalProperties: false for OpenAI strict JSON schema validation.""" + normalized = dict(schema) + + if normalized.get("type") == "object": + normalized["additionalProperties"] = False + + if "properties" in normalized: + normalized["properties"] = { + key: _ensure_openai_strict_schema(value) + if isinstance(value, dict) + else value + for key, value in normalized["properties"].items() + } + + items = normalized.get("items") + if isinstance(items, dict): + normalized["items"] = _ensure_openai_strict_schema(items) + + return normalized + + +def _strip_additional_properties(schema: dict) -> dict: + """Recursively strip additionalProperties — unsupported by Google GenAI.""" + normalized_schema = dict(schema) + normalized_schema.pop("additionalProperties", None) + + if "properties" in normalized_schema: + normalized_schema["properties"] = { + property_name: _strip_additional_properties(property_schema) + if isinstance(property_schema, dict) + else property_schema + for property_name, property_schema in normalized_schema[ + "properties" + ].items() + } + + if "items" in normalized_schema and isinstance(normalized_schema["items"], dict): + normalized_schema["items"] = _strip_additional_properties( + normalized_schema["items"] + ) + + return normalized_schema + + +def _convert_json_schema_to_google(schema: dict) -> dict: + """Convert a JSON Schema dict to Google GenAI's OpenAPI-style schema. + + Strips unsupported fields, then normalizes the schema through the Gemini SDK + so enum/type values match Gemini's expected OpenAPI-flavored shape. + """ + normalized_schema = _strip_additional_properties(schema) + converted = genai_transformers.t_schema(None, normalized_schema) + google_schema = ( + converted.model_dump(mode="json", exclude_none=True) + if converted is not None + else normalized_schema + ) + + if "properties" in google_schema and "propertyOrdering" not in google_schema: + google_schema["propertyOrdering"] = list( + normalized_schema.get("required", []) + ) or list(google_schema["properties"].keys()) + + return google_schema + + +def map_kaapi_to_openai_params( + session: Session, kaapi_params: dict +) -> tuple[dict, list[str]]: + """Map Kaapi-abstracted parameters to OpenAI batch assessment API parameters. + + Extends the base LLM mapper with structured output schema support via + ``output_schema`` → ``text.format.json_schema`` (strict mode). + + Returns: + Tuple of (OpenAI API params dict, list of warning strings) + """ + openai_params: dict = {} + warnings: list[str] = [] + + model = kaapi_params.get("model") + reasoning = kaapi_params.get("reasoning") + effort = kaapi_params.get("effort") or reasoning + summary = kaapi_params.get("summary") + temperature = kaapi_params.get("temperature") + top_p = kaapi_params.get("top_p") + + instructions = normalize_llm_text(kaapi_params.get("instructions")) + knowledge_base_ids = kaapi_params.get("knowledge_base_ids") + max_num_results = kaapi_params.get("max_num_results") + response_format = kaapi_params.get("response_format") + output_schema = kaapi_params.get("output_schema") + + support_reasoning = bool(model) and is_reasoning_model( + session=session, + provider="openai", + model_name=model, + ) + + # max_output_tokens is intentionally omitted for batch assessment — + # Indic feedback responses can be long and a stored token limit would truncate them. + + if support_reasoning: + reasoning_payload: dict[str, object] = {} + if effort is not None: + reasoning_payload["effort"] = effort + if summary is not None: + reasoning_payload["summary"] = None if summary == "null" else summary + if reasoning_payload: + openai_params["reasoning"] = reasoning_payload + if temperature is not None: + warnings.append( + "Parameter 'temperature' was suppressed because the selected model " + "supports reasoning, and temperature is ignored when reasoning is enabled." + ) + if top_p is not None: + warnings.append( + "Parameter 'top_p' was suppressed because the selected model " + "supports reasoning, and top_p is ignored when reasoning is enabled." + ) + else: + if effort is not None or summary is not None: + warnings.append( + "Parameters 'effort'/'summary' were suppressed because the selected model " + "does not support reasoning." + ) + if temperature is not None: + openai_params["temperature"] = temperature + if top_p is not None: + openai_params["top_p"] = top_p + + if model: + openai_params["model"] = model + + if instructions: + openai_params["instructions"] = instructions + + if output_schema is not None: + openai_params["text"] = { + "format": { + "type": "json_schema", + "name": "output", + "strict": True, + "schema": _ensure_openai_strict_schema(output_schema), + } + } + elif response_format and response_format != "text": + openai_params["text"] = {"format": {"type": response_format}} + + if knowledge_base_ids: + openai_params["tools"] = [ + { + "type": "file_search", + "vector_store_ids": knowledge_base_ids, + "max_num_results": max_num_results or 20, + } + ] + + return openai_params, warnings + + +def map_kaapi_to_google_params(kaapi_params: dict) -> tuple[dict, list[str]]: + """Map Kaapi-abstracted parameters to Google AI (Gemini) API parameters. + + Returns: + Tuple of (Google AI params dict, list of warning strings) + """ + google_params: dict = {} + warnings: list[str] = [] + + model = kaapi_params.get("model") + if not model: + return {}, ["Missing required 'model' parameter"] + + google_params["model"] = model + + instructions = normalize_llm_text(kaapi_params.get("instructions")) + if instructions: + google_params["instructions"] = instructions + + temperature = kaapi_params.get("temperature") + if temperature is not None: + google_params["temperature"] = temperature + + top_p = kaapi_params.get("top_p") + if top_p is not None: + google_params["top_p"] = top_p + + max_output_tokens = kaapi_params.get("max_output_tokens") + if max_output_tokens is not None: + google_params["max_output_tokens"] = max_output_tokens + + thinking_level = kaapi_params.get("thinking_level") + if thinking_level: + google_params["thinking_config"] = {"thinking_level": thinking_level} + + reasoning = kaapi_params.get("reasoning") + if reasoning: + google_params["reasoning"] = reasoning + + output_schema = kaapi_params.get("output_schema") + if output_schema is not None: + google_params["output_schema"] = _convert_json_schema_to_google(output_schema) + + if kaapi_params.get("knowledge_base_ids"): + warnings.append( + "Parameter 'knowledge_base_ids' is not supported by Google AI and was ignored." + ) + + return google_params, warnings diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py new file mode 100644 index 000000000..45a283ea5 --- /dev/null +++ b/backend/app/services/assessment/service.py @@ -0,0 +1,304 @@ +"""Assessment run orchestration service.""" + +import logging +from typing import Any +from uuid import UUID + +from fastapi import HTTPException +from sqlmodel import Session + +from app.crud.assessment import ( + create_assessment, + create_assessment_run, + get_assessment_dataset_by_id, + get_assessment_runs_for_assessment, + recompute_assessment_status, + update_assessment_run_status, +) +from app.crud.assessment.batch import submit_assessment_batch +from app.crud.config import ConfigCrud +from app.crud.evaluations.core import resolve_evaluation_config +from app.models.assessment import ( + Assessment, + AssessmentAttachment, + AssessmentConfigRef, + AssessmentCreate, + AssessmentResponse, + AssessmentRun, + AssessmentRunSummary, +) +from app.models.config.config import ConfigTag +from app.services.llm.providers.registry import LLMProvider + +logger = logging.getLogger(__name__) + +_SUPPORTED_BATCH_PROVIDERS = { + LLMProvider.OPENAI, + LLMProvider.OPENAI_NATIVE, + LLMProvider.GOOGLE, + LLMProvider.GOOGLE_NATIVE, +} + + +def _build_retry_request( + *, + experiment_name: str, + dataset_id: int, + runs: list[AssessmentRun], +) -> AssessmentCreate: + if not runs: + raise HTTPException(status_code=400, detail="No assessment runs found to retry") + + first_run = runs[0] + assessment_input = first_run.input + if not isinstance(assessment_input, dict): + raise HTTPException( + status_code=400, + detail="Assessment input configuration is missing for retry", + ) + + attachments = assessment_input.get("attachments") or [] + configs: list[AssessmentConfigRef] = [] + for run in runs: + if not run.config_id or run.config_version is None: + raise HTTPException( + status_code=400, + detail=f"Config reference is missing for run {run.id}", + ) + configs.append( + AssessmentConfigRef( + config_id=UUID(str(run.config_id)), + config_version=run.config_version, + ) + ) + + return AssessmentCreate( + experiment_name=experiment_name, + dataset_id=dataset_id, + prompt_template=assessment_input.get("prompt_template"), + system_instruction=assessment_input.get("system_instruction"), + text_columns=list(assessment_input.get("text_columns") or []), + attachments=[AssessmentAttachment.model_validate(item) for item in attachments], + output_schema=assessment_input.get("output_schema"), + configs=configs, + ) + + +def start_assessment( + session: Session, + request: AssessmentCreate, + organization_id: int, + project_id: int, +) -> AssessmentResponse: + """Start an assessment run request. + + Validates the dataset, resolves each config, creates one AssessmentRun per config, + and kicks off batch processing for each. + """ + logger.info( + "[start_assessment] Starting | experiment=%s | dataset_id=%s | configs=%s | org_id=%s", + request.experiment_name, + request.dataset_id, + len(request.configs), + organization_id, + ) + + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=request.dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + assessment_input: dict[str, Any] = { + "prompt_template": request.prompt_template, + "system_instruction": request.system_instruction, + "text_columns": request.text_columns, + "attachments": [att.model_dump() for att in request.attachments], + } + if request.output_schema: + assessment_input["output_schema"] = request.output_schema + + config_crud = ConfigCrud(session=session, project_id=project_id) + + resolved_configs = [] + for cfg in request.configs: + # Assessment runs must use configs explicitly tagged for assessment use. + parent_config = config_crud.read_one(cfg.config_id) + if parent_config is not None and parent_config.tag != ConfigTag.ASSESSMENT: + tag_value = ( + parent_config.tag.value + if parent_config.tag is not None + else ConfigTag.DEFAULT.value + ) + raise HTTPException( + status_code=422, + detail=( + f"Config {cfg.config_id} has tag '{tag_value}' " + f"and cannot be used for assessment. " + f"Only configs tagged 'ASSESSMENT' are allowed." + ), + ) + + config_blob, error = resolve_evaluation_config( + session=session, + config_id=cfg.config_id, + config_version=cfg.config_version, + project_id=project_id, + tag=ConfigTag.ASSESSMENT, + ) + if error or config_blob is None: + raise HTTPException( + status_code=400, + detail=( + f"Failed to resolve config {cfg.config_id} " + f"v{cfg.config_version}: {error}" + ), + ) + provider = config_blob.completion.provider or LLMProvider.OPENAI + if provider not in _SUPPORTED_BATCH_PROVIDERS: + raise HTTPException( + status_code=422, + detail=( + f"Config {cfg.config_id} v{cfg.config_version} uses provider " + f"'{provider}', which is not supported for batch assessment. " + f"Supported providers: {sorted(_SUPPORTED_BATCH_PROVIDERS)}" + ), + ) + resolved_configs.append((cfg, config_blob)) + + assessment = create_assessment( + session=session, + experiment_name=request.experiment_name, + dataset_id=request.dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + runs: list[AssessmentRun] = [] + try: + for cfg, config_blob in resolved_configs: + run = create_assessment_run( + session=session, + assessment_id=assessment.id, + config_id=cfg.config_id, + config_version=cfg.config_version, + assessment_input=assessment_input, + ) + + try: + batch_job = submit_assessment_batch( + session=session, + run=run, + assessment=assessment, + dataset=dataset, + config_blob=config_blob, + assessment_input=assessment_input, + organization_id=organization_id, + project_id=project_id, + ) + + run = update_assessment_run_status( + session=session, + run=run, + status="processing", + batch_job_id=batch_job.id, + total_items=batch_job.total_items, + ) + + except Exception as e: + logger.error( + "[start_assessment] Failed to submit batch for run %s: %s", + run.id, + e, + exc_info=True, + ) + run = update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Batch submission failed. Please try again or contact support.", + ) + + runs.append(run) + except Exception: + recompute_assessment_status(session=session, assessment_id=assessment.id) + raise + + recompute_assessment_status(session=session, assessment_id=assessment.id) + + logger.info( + "[start_assessment] Created assessment %s with %s runs | run_ids=%s", + assessment.id, + len(runs), + [run.id for run in runs], + ) + + return AssessmentResponse( + assessment_id=assessment.id, + experiment_name=request.experiment_name, + dataset_id=request.dataset_id, + dataset_name=dataset.name, + num_configs=len(runs), + runs=[ + AssessmentRunSummary( + run_id=completed_run.id, + assessment_id=completed_run.assessment_id, + config_id=str(completed_run.config_id), + config_version=completed_run.config_version, + status=completed_run.status, + ) + for completed_run in runs + ], + ) + + +def retry_assessment( + session: Session, + assessment: Assessment, + organization_id: int, + project_id: int, +) -> AssessmentResponse: + """Create a new assessment using the same parent assessment inputs.""" + runs = get_assessment_runs_for_assessment( + session=session, assessment_id=assessment.id + ) + request = _build_retry_request( + experiment_name=assessment.experiment_name, + dataset_id=assessment.dataset_id, + runs=runs, + ) + return start_assessment( + session=session, + request=request, + organization_id=organization_id, + project_id=project_id, + ) + + +def retry_assessment_run( + session: Session, + run: AssessmentRun, + organization_id: int, + project_id: int, +) -> AssessmentResponse: + """Create a new assessment using the same inputs as a single child run.""" + parent = getattr(run, "assessment", None) or session.get( + Assessment, run.assessment_id + ) + if not parent: + raise HTTPException( + status_code=404, + detail=f"Parent assessment {run.assessment_id} not found", + ) + request = _build_retry_request( + experiment_name=parent.experiment_name, + dataset_id=parent.dataset_id, + runs=[run], + ) + return start_assessment( + session=session, + request=request, + organization_id=organization_id, + project_id=project_id, + ) diff --git a/backend/app/services/assessment/utils/__init__.py b/backend/app/services/assessment/utils/__init__.py new file mode 100644 index 000000000..1db8e89a1 --- /dev/null +++ b/backend/app/services/assessment/utils/__init__.py @@ -0,0 +1,25 @@ +"""Assessment utility functions.""" + +from app.services.assessment.utils.export import ( + build_assessment_results_response, + build_export_response, + build_json_export_rows, + load_export_rows_for_run, + serialize_export_rows, + sort_export_rows, +) +from app.services.assessment.utils.parsing import ( + parse_stored_results, + usage_totals, +) + +__all__ = [ + "build_assessment_results_response", + "build_export_response", + "build_json_export_rows", + "load_export_rows_for_run", + "parse_stored_results", + "serialize_export_rows", + "sort_export_rows", + "usage_totals", +] diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py new file mode 100644 index 000000000..5a141a757 --- /dev/null +++ b/backend/app/services/assessment/utils/attachments.py @@ -0,0 +1,183 @@ +"""Attachment resolution utilities for assessment batch builds. + +Handles MIME type detection, base64 decoding, Google Drive URL normalization, +data-URL parsing, and conversion of dataset cell values into provider input objects. +""" + +import base64 +import binascii +import re +from typing import Any +from urllib.parse import urlparse + +from app.models.assessment import AssessmentAttachment + +_IMAGE_MIME_BY_EXT = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".webp": "image/webp", + ".gif": "image/gif", + ".bmp": "image/bmp", + ".tif": "image/tiff", + ".tiff": "image/tiff", + ".heic": "image/heic", + ".heif": "image/heif", +} + + +def split_attachment_urls(value: str) -> list[str]: + """Split comma/newline separated attachment URLs from a single dataset cell.""" + return [part.strip() for part in re.split(r"[\n,]+", value) if part.strip()] + + +def to_direct_attachment_url(url: str, attachment_type: str) -> str: + """Normalize share-page attachment URLs into provider-fetchable direct URLs. + + This currently handles common Google Drive share URL shapes. The file must + still be publicly accessible to the model provider. + """ + url = url.strip() + file_id = None + + match = re.match(r"https://drive\.google\.com/file/d/([^/]+)", url) + if match: + file_id = match.group(1) + + if not file_id: + match = re.search(r"[?&]id=([a-zA-Z0-9_-]+)", url) + if match and ( + "drive.google.com" in url or "drive.usercontent.google.com" in url + ): + file_id = match.group(1) + + if not file_id: + return url + + if attachment_type == "image": + return f"https://lh3.googleusercontent.com/d/{file_id}" + + return f"https://drive.google.com/uc?export=download&id={file_id}" + + +def split_data_url(value: str) -> tuple[str | None, str]: + """Return (mime_type, base64_payload) for a data URL; otherwise (None, value).""" + match = re.match( + r"^data:([^;]+);base64,(.+)$", + value.strip(), + flags=re.IGNORECASE | re.DOTALL, + ) + if not match: + return None, value.strip() + return match.group(1).strip().lower(), match.group(2).strip() + + +def _guess_image_mime_from_url(url: str) -> str | None: + path = urlparse(url).path or "" + for ext, mime in _IMAGE_MIME_BY_EXT.items(): + if path.lower().endswith(ext): + return mime + return None + + +def _decode_base64_prefix(payload: str, max_chars: int = 256) -> bytes | None: + compact = re.sub(r"\s+", "", payload) + if not compact: + return None + sample = compact[:max_chars] + padding = "=" * (-len(sample) % 4) + try: + return base64.b64decode(sample + padding, validate=False) + except (binascii.Error, ValueError): + return None + + +def _guess_image_mime_from_base64(payload: str) -> str | None: + blob = _decode_base64_prefix(payload) + if not blob: + return None + if blob.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if blob.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if blob.startswith((b"GIF87a", b"GIF89a")): + return "image/gif" + if blob.startswith(b"BM"): + return "image/bmp" + if len(blob) >= 12 and blob[:4] == b"RIFF" and blob[8:12] == b"WEBP": + return "image/webp" + if blob.startswith((b"II*\x00", b"MM\x00*")): + return "image/tiff" + return None + + +def resolve_image_mime_and_payload( + value: str, + format_type: str, +) -> tuple[str, str]: + """Resolve image mime type and raw base64 payload (for base64 format).""" + if format_type == "url": + return _guess_image_mime_from_url(value) or "image/png", value + + data_url_mime, payload = split_data_url(value) + if data_url_mime and data_url_mime.startswith("image/"): + return data_url_mime, payload + + return _guess_image_mime_from_base64(payload) or "image/png", payload + + +def resolve_attachment_values( + value: str, + att: AssessmentAttachment, +) -> list[dict[str, Any]]: + """Convert one dataset cell into one or more OpenAI-style input objects.""" + value = value.strip() + if not value: + return [] + + if att.format == "url": + values = split_attachment_urls(value) + else: + values = [value] + + resolved: list[dict[str, Any]] = [] + for item_value in values: + normalized_value = ( + to_direct_attachment_url(item_value, att.type) + if att.format == "url" + else item_value + ) + + if att.type == "image": + if att.format == "url": + resolved.append({"type": "input_image", "image_url": normalized_value}) + else: + mime_type, payload = resolve_image_mime_and_payload( + normalized_value, + "base64", + ) + resolved.append( + { + "type": "input_image", + "image_url": f"data:{mime_type};base64,{payload}", + } + ) + elif att.type == "pdf": + if att.format == "url": + resolved.append( + { + "type": "input_file", + "file_url": normalized_value, + } + ) + else: + _, payload = split_data_url(normalized_value) + resolved.append( + { + "type": "input_file", + "file_data": f"data:application/pdf;base64,{payload}", + "filename": "document.pdf", + } + ) + + return resolved diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py new file mode 100644 index 000000000..ca273afc6 --- /dev/null +++ b/backend/app/services/assessment/utils/export.py @@ -0,0 +1,546 @@ +"""Export utilities for assessment results (CSV, XLSX, JSON).""" + +import csv +import io +import json +import logging +import re +import zipfile +from typing import Any, Literal + +from fastapi import HTTPException +from fastapi.responses import StreamingResponse +from sqlmodel import Session + +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import generate_timestamped_filename +from app.crud.assessment.processing import parse_assessment_output +from app.crud.job import get_batch_job +from app.models.assessment import Assessment, AssessmentExportRow, AssessmentRun +from app.models.batch_job import BatchJob +from app.models.evaluation import EvaluationDataset +from app.services.assessment.utils.parsing import parse_stored_results, usage_totals +from app.utils import APIResponse + +logger = logging.getLogger(__name__) + + +def _load_dataset_rows( + session: Session, + dataset: EvaluationDataset, +) -> list[dict[str, str]]: + from app.crud.assessment.batch import _load_dataset_rows as load_dataset_rows + + return load_dataset_rows(session, dataset) + + +def _safe_filename_part(value: str) -> str: + """Build a filesystem-safe filename component.""" + sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._") + return sanitized or "assessment_results" + + +def _expand_input_columns( + row_payload: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[str]]: + """Expand ``input_data`` dict into separate input columns. + + Uses the original column names from the dataset (no prefix). + + Returns: + (expanded_rows with input_data replaced by individual columns, + ordered list of input column names) + """ + input_keys: list[str] = [] + seen_keys: dict[str, None] = {} + + for row in row_payload: + input_data = row.get("input_data") + if isinstance(input_data, dict): + for input_key in input_data: + if input_key not in seen_keys: + seen_keys[input_key] = None + input_keys.append(input_key) + + if not input_keys: + for row in row_payload: + row.pop("input_data", None) + return row_payload, [] + + reserved_fields = set(AssessmentExportRow.model_fields.keys()) - {"input_data"} + key_map: dict[str, str] = {} + for input_key in input_keys: + col = f"input_{input_key}" if input_key in reserved_fields else input_key + key_map[input_key] = col + + collisions = {key: value for key, value in key_map.items() if key != value} + if collisions: + logger.warning( + "[_expand_input_columns] Input dataset columns conflict with reserved " + "export fields and were namespaced: %s", + collisions, + ) + + expanded: list[dict[str, Any]] = [] + for row in row_payload: + input_data = row.pop("input_data", None) or {} + new_row = {} + for input_key in input_keys: + new_row[key_map[input_key]] = input_data.get(input_key) + new_row.update(row) + expanded.append(new_row) + + return expanded, [key_map[input_key] for input_key in input_keys] + + +def _drop_empty_columns( + rows: list[dict[str, Any]], + fieldnames: list[str], +) -> tuple[list[dict[str, Any]], list[str]]: + """Remove columns where every row has a null or empty-string value.""" + non_empty_fields: list[str] = [] + for field in fieldnames: + if any( + row.get(field) is not None and str(row.get(field, "")).strip() != "" + for row in rows + ): + non_empty_fields.append(field) + + if len(non_empty_fields) == len(fieldnames): + return rows, fieldnames + + pruned = [{field: row.get(field) for field in non_empty_fields} for row in rows] + return pruned, non_empty_fields + + +def _expand_output_columns( + row_payload: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[str]]: + """Expand the ``output`` field into separate columns when it contains valid JSON. + + Returns: + (expanded_rows, ordered_fieldnames) + """ + # First expand input columns + row_payload, input_col_names = _expand_input_columns(row_payload) + + base_fields = [ + field + for field in AssessmentExportRow.model_fields.keys() + if field not in ("output", "input_data") + ] + + parsed_outputs: list[dict[str, Any] | None] = [] + output_keys: list[str] = [] + seen_keys: dict[str, None] = {} # ordered set + has_unparsed_output = False + + for row in row_payload: + raw = row.get("output") + if raw is None: + parsed_outputs.append(None) + continue + + if isinstance(raw, str): + try: + parsed = json.loads(raw) + except (json.JSONDecodeError, TypeError): + parsed = None + elif isinstance(raw, dict): + parsed = raw + else: + parsed = None + + if not isinstance(parsed, dict): + has_unparsed_output = True + parsed_outputs.append(None) + continue + + parsed_outputs.append(parsed) + for output_key in parsed: + if output_key not in seen_keys: + seen_keys[output_key] = None + output_keys.append(output_key) + + if not output_keys: + # Keep original layout with output as a single column + fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) + fieldnames = [field for field in fieldnames if field != "input_data"] + return row_payload, fieldnames + + # Build expanded rows + expanded: list[dict[str, Any]] = [] + for row, parsed in zip(row_payload, parsed_outputs, strict=True): + new_row = {col: val for col, val in row.items() if col != "output"} + if parsed: + for output_key in output_keys: + new_row[output_key] = parsed.get(output_key) + else: + for output_key in output_keys: + new_row[output_key] = None + if row.get("output") is not None: + new_row["output_raw"] = row.get("output") + expanded.append(new_row) + + # Build fieldnames: input columns + base fields + output columns + output_idx = base_fields.index("result_status") + 1 # after result_status + fieldnames = ( + input_col_names + + base_fields[:output_idx] + + output_keys + + base_fields[output_idx:] + ) + if has_unparsed_output: + fieldnames.insert( + len(input_col_names) + output_idx + len(output_keys), "output_raw" + ) + + return expanded, fieldnames + + +def serialize_export_rows( + export_rows: list[AssessmentExportRow], + export_format: Literal["json", "csv", "xlsx"], +) -> tuple[bytes, str]: + """Serialize export rows into the requested file format.""" + row_payload = [row.model_dump(mode="json") for row in export_rows] + + if export_format == "json": + expanded, _ = _expand_output_columns(row_payload) + return ( + json.dumps(expanded, ensure_ascii=False, indent=2).encode("utf-8"), + "application/json", + ) + + # For CSV/XLSX, expand output keys into separate columns + expanded, fieldnames = _expand_output_columns(row_payload) + + if export_format == "csv": + output = io.StringIO() + writer = csv.DictWriter(output, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(expanded) + return output.getvalue().encode("utf-8"), "text/csv" + + try: + import pandas as pd + except ImportError as exc: + raise HTTPException( + status_code=500, + detail="XLSX export requires pandas/openpyxl support in the backend runtime", + ) from exc + + # XLSX shows input columns + output columns only (no metadata fields). + metadata_fields = { + field + for field in AssessmentExportRow.model_fields.keys() + if field not in ("output", "input_data") + } + excel_fields = [field for field in fieldnames if field not in metadata_fields] + if not excel_fields: + excel_fields = ["output"] + + # Drop columns where every row is null/empty + expanded, excel_fields = _drop_empty_columns(expanded, excel_fields) + + buf = io.BytesIO() + data_frame = pd.DataFrame(expanded, columns=excel_fields) + with pd.ExcelWriter(buf) as writer: + data_frame.to_excel(writer, index=False, sheet_name="results") + return ( + buf.getvalue(), + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ) + + +def build_json_export_rows( + export_rows: list[AssessmentExportRow], +) -> list[dict[str, Any]]: + """Return JSON rows with structured output expanded into top-level keys.""" + row_payload = [row.model_dump(mode="json") for row in export_rows] + expanded, _ = _expand_output_columns(row_payload) + return expanded + + +def build_export_response( + export_rows: list[AssessmentExportRow], + export_format: Literal["json", "csv", "xlsx"], + base_name: str, +) -> StreamingResponse: + """Return a file download response for assessment exports.""" + payload, media_type = serialize_export_rows(export_rows, export_format) + filename = generate_timestamped_filename( + _safe_filename_part(base_name), + extension=export_format, + ) + return StreamingResponse( + io.BytesIO(payload), + media_type=media_type, + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) + + +def _load_parsed_results_for_run( + session: Session, + run: AssessmentRun, + batch_job: BatchJob, +) -> list[dict[str, Any]] | None: + """Fetch and parse the stored batch results for a run. + + Tries object store first; falls back to downloading directly from the + batch provider (e.g. OpenAI file API) when the S3 copy is unavailable. + """ + parent = session.get(Assessment, run.assessment_id) + if not parent: + logger.warning( + "[_load_parsed_results_for_run] Parent assessment not found for run id=%s", + run.id, + ) + return None + + # 1. Try object store (S3) + if run.object_store_url: + try: + storage = get_cloud_storage(session, project_id=parent.project_id) + body = storage.stream(run.object_store_url) + raw_results = parse_stored_results(body.read().decode("utf-8")) + if raw_results: + return parse_assessment_output(raw_results, batch_job.provider) + logger.warning( + "[_load_parsed_results_for_run] S3 file was empty for run id=%s", + run.id, + ) + except Exception as exc: + logger.warning( + "[_load_parsed_results_for_run] S3 download failed for run id=%s: %s", + run.id, + exc, + ) + + # 2. Fallback: download directly from batch provider + if batch_job.provider_output_file_id: + try: + from app.core.batch import download_batch_results + from app.crud.assessment.processing import _get_batch_provider + + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=parent.organization_id, + project_id=parent.project_id, + ) + raw_results = download_batch_results(provider=provider, batch_job=batch_job) + return parse_assessment_output(raw_results, batch_job.provider) + except Exception as exc: + logger.error( + "[_load_parsed_results_for_run] Provider download also failed for run id=%s: %s", + run.id, + exc, + exc_info=True, + ) + + logger.warning( + "[_load_parsed_results_for_run] No results available for run id=%s " + "(object_store_url=%s, provider_output_file_id=%s)", + run.id, + run.object_store_url, + batch_job.provider_output_file_id, + ) + return None + + +def _load_dataset_rows_for_run( + session: Session, + run: AssessmentRun, + assessment: Assessment, +) -> list[dict[str, str]]: + """Load original dataset rows for input-output correlation. + + Returns an empty list if the dataset is not available. + """ + try: + dataset = session.get(EvaluationDataset, assessment.dataset_id) + if not dataset or not dataset.object_store_url: + logger.warning( + "[_load_dataset_rows_for_run] Dataset not available for run id=%s", + run.id, + ) + return [] + return _load_dataset_rows(session, dataset) + except Exception as exc: + logger.warning( + "[_load_dataset_rows_for_run] Failed to load dataset for run id=%s: %s", + run.id, + exc, + ) + return [] + + +def load_export_rows_for_run( + session: Session, + run: AssessmentRun, + assessment: Assessment | None = None, +) -> list[AssessmentExportRow]: + """Load flattened export rows for a single child assessment run.""" + if not run.batch_job_id: + logger.warning( + "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id + ) + return [] + + batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) + if not batch_job: + logger.warning( + "[load_export_rows_for_run] Missing batch job for run id=%s", + run.id, + ) + return [] + + if assessment is None: + assessment = session.get(Assessment, run.assessment_id) + if assessment is None: + logger.warning( + "[load_export_rows_for_run] Parent assessment missing for run id=%s", + run.id, + ) + return [] + + parsed_results = _load_parsed_results_for_run( + session=session, + run=run, + batch_job=batch_job, + ) + if parsed_results is None: + return [] + + if not parsed_results: + logger.warning( + "[load_export_rows_for_run] Parsed results empty for run id=%s", run.id + ) + return [] + + dataset_rows = _load_dataset_rows_for_run(session, run, assessment) + dataset = session.get(EvaluationDataset, assessment.dataset_id) + dataset_name = dataset.name if dataset else None + + export_rows: list[AssessmentExportRow] = [] + for item in parsed_results: + input_tokens, output_tokens, total_tokens = usage_totals(item.get("usage")) + + # Correlate with original input row via row_id (format: "row_{idx}") + input_data: dict[str, str] | None = None + row_id_str = str(item.get("row_id", "")) + if dataset_rows and row_id_str.startswith("row_"): + try: + row_idx = int(row_id_str.split("_", 1)[1]) + if 0 <= row_idx < len(dataset_rows): + input_data = dataset_rows[row_idx] + except (ValueError, IndexError): + pass + + export_rows.append( + AssessmentExportRow( + assessment_id=run.assessment_id, + experiment_name=assessment.experiment_name, + dataset_id=assessment.dataset_id, + dataset_name=dataset_name, + run_id=run.id, + run_name=assessment.experiment_name, + run_status=run.status, + config_id=run.config_id, + config_version=run.config_version, + row_id=row_id_str, + result_status="failed" if item.get("error") else "passed", + input_data=input_data, + output=item.get("output"), + error=item.get("error"), + response_id=item.get("response_id"), + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + updated_at=run.updated_at, + ) + ) + + return export_rows + + +def sort_export_rows( + export_rows: list[AssessmentExportRow], +) -> list[AssessmentExportRow]: + """Sort exported rows for stable downloads across runs/configs.""" + + def _row_index(row_id: str) -> int: + if not row_id.startswith("row_"): + return 0 + try: + return int(row_id.split("_", 1)[1]) + except (ValueError, IndexError): + return 0 + + export_rows.sort( + key=lambda row: ( + row.config_version or 0, + _row_index(row.row_id), + row.run_id, + ) + ) + return export_rows + + +def build_assessment_results_response( + session: Session, + assessment: Assessment, + runs: list[AssessmentRun], + export_format: Literal["json", "csv", "xlsx"], +) -> APIResponse[list[dict[str, Any]]] | StreamingResponse: + """Bundle child-run results for a parent assessment into a download response. + + JSON returns a flat list. CSV/XLSX with one run returns a single file; + multiple runs are zipped one-file-per-run. + """ + runs_with_rows: list[tuple[AssessmentRun, list[AssessmentExportRow]]] = [] + all_rows: list[AssessmentExportRow] = [] + for run in runs: + rows = load_export_rows_for_run(session=session, run=run, assessment=assessment) + if rows: + runs_with_rows.append((run, sort_export_rows(rows))) + all_rows.extend(rows) + + all_rows = sort_export_rows(all_rows) + + if export_format == "json": + return APIResponse.success_response(data=build_json_export_rows(all_rows)) + + if len(runs_with_rows) <= 1: + return build_export_response( + export_rows=all_rows, + export_format=export_format, + base_name=( + f"{assessment.experiment_name}_assessment_{assessment.id}_results" + ), + ) + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: + for run, rows in runs_with_rows: + config_label = ( + f"config_v{run.config_version}" + if run.config_version + else f"run_{run.id}" + ) + config_id_short = str(run.config_id)[:8] if run.config_id else "" + file_base = _safe_filename_part(f"{config_label}_{config_id_short}") + file_bytes, _ = serialize_export_rows(rows, export_format) + zf.writestr(f"{file_base}.{export_format}", file_bytes) + + zip_buffer.seek(0) + zip_filename = generate_timestamped_filename( + _safe_filename_part(f"{assessment.experiment_name}_assessment_{assessment.id}"), + extension="zip", + ) + return StreamingResponse( + zip_buffer, + media_type="application/zip", + headers={"Content-Disposition": f'attachment; filename="{zip_filename}"'}, + ) diff --git a/backend/app/services/assessment/utils/parsing.py b/backend/app/services/assessment/utils/parsing.py new file mode 100644 index 000000000..fbb064484 --- /dev/null +++ b/backend/app/services/assessment/utils/parsing.py @@ -0,0 +1,32 @@ +"""Parsing utilities for assessment batch results.""" + +import json +from typing import Any + + +def parse_stored_results(raw_content: str) -> list[dict[str, Any]]: + """Parse stored batch results from JSONL or JSON array.""" + content = raw_content.strip() + if not content: + return [] + + if content.startswith("["): + parsed = json.loads(content) + return parsed if isinstance(parsed, list) else [] + + return [json.loads(line) for line in content.splitlines() if line.strip()] + + +def usage_totals(usage: Any) -> tuple[int | None, int | None, int | None]: + """Extract common token totals from provider usage payloads.""" + if not isinstance(usage, dict): + return None, None, None + + input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") + output_tokens = usage.get("output_tokens") or usage.get("completion_tokens") + total_tokens = usage.get("total_tokens") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return input_tokens, output_tokens, total_tokens diff --git a/backend/app/services/assessment/validators.py b/backend/app/services/assessment/validators.py new file mode 100644 index 000000000..3d3f992f2 --- /dev/null +++ b/backend/app/services/assessment/validators.py @@ -0,0 +1,71 @@ +"""Validation utilities for assessment dataset file uploads (CSV + XLSX). + +Only validates file type and size — no column requirements. +""" + +import logging +from pathlib import Path + +from fastapi import HTTPException, UploadFile + +logger = logging.getLogger(__name__) + +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB + +ALLOWED_EXTENSIONS = {".csv", ".xlsx"} +ALLOWED_MIME_TYPES = { + "text/csv", + "application/csv", + "text/plain", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", +} + + +async def validate_dataset_file(file: UploadFile) -> tuple[bytes, str]: + """Validate an uploaded dataset file (CSV or XLSX). + + Only checks file type and size — does NOT inspect columns. + + Returns: + Tuple of (file content as bytes, file extension) + + Raises: + HTTPException: If validation fails + """ + if not file.filename: + raise HTTPException(status_code=422, detail="File must have a filename") + + file_ext = Path(file.filename).suffix.lower() + if file_ext == ".xls": + raise HTTPException( + status_code=422, + detail="Legacy Excel format (.xls) is not supported. Please upload .xlsx or .csv.", + ) + if file_ext not in ALLOWED_EXTENSIONS: + raise HTTPException( + status_code=422, + detail=f"Invalid file type. Allowed: CSV, XLSX. Got: {file_ext}", + ) + + content_type = file.content_type + if content_type not in ALLOWED_MIME_TYPES: + logger.warning( + f"[validate_dataset_file] Unexpected content type '{content_type}' " + f"for extension '{file_ext}', proceeding based on extension" + ) + + file.file.seek(0, 2) + file_size = file.file.tell() + file.file.seek(0) + + if file_size > MAX_FILE_SIZE: + raise HTTPException( + status_code=413, + detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024 * 1024):.0f}MB", + ) + + if file_size == 0: + raise HTTPException(status_code=422, detail="Empty file uploaded") + + content = await file.read() + return content, file_ext diff --git a/backend/app/tests/api/routes/configs/test_config.py b/backend/app/tests/api/routes/configs/test_config.py index e466bfc3b..8a0d7e7fd 100644 --- a/backend/app/tests/api/routes/configs/test_config.py +++ b/backend/app/tests/api/routes/configs/test_config.py @@ -4,6 +4,7 @@ from sqlmodel import Session from app.core.config import settings +from app.models.config.config import ConfigTag from app.tests.utils.auth import TestAuthContext from app.tests.utils.test_data import create_test_config, create_test_project @@ -148,6 +149,65 @@ def test_list_configs( assert config.name in config_names +def test_list_configs_without_tag_returns_default_configs( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test default config list returns default configs.""" + implicit_default_config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="api-implicit-default-config", + ) + default_config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="api-default-config", + tag=ConfigTag.DEFAULT, + ) + + response = client.get( + f"{settings.API_V1_STR}/configs/", + headers={"X-API-KEY": user_api_key.key}, + ) + + assert response.status_code == 200 + config_names = [c["name"] for c in response.json()["data"]] + assert implicit_default_config.name in config_names + assert default_config.name in config_names + + +def test_list_configs_with_explicit_tag_returns_matching_tag( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test explicit config tag query returns matching tagged configs.""" + default_config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="api-default-config", + tag=ConfigTag.DEFAULT, + ) + implicit_default_config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="api-implicit-default-config", + ) + + response = client.get( + f"{settings.API_V1_STR}/configs/", + headers={"X-API-KEY": user_api_key.key}, + params={"tag": ConfigTag.DEFAULT.value}, + ) + + assert response.status_code == 200 + config_names = [c["name"] for c in response.json()["data"]] + assert default_config.name in config_names + assert implicit_default_config.name in config_names + + def test_list_configs_with_pagination( db: Session, client: TestClient, diff --git a/backend/app/tests/api/routes/configs/test_version.py b/backend/app/tests/api/routes/configs/test_version.py index dd952858b..2bdbc0025 100644 --- a/backend/app/tests/api/routes/configs/test_version.py +++ b/backend/app/tests/api/routes/configs/test_version.py @@ -4,14 +4,15 @@ from sqlmodel import Session from app.core.config import settings +from app.models import ConfigBlob +from app.models.config.config import ConfigTag +from app.models.llm.request import NativeCompletionConfig from app.tests.utils.auth import TestAuthContext from app.tests.utils.test_data import ( create_test_config, create_test_project, create_test_version, ) -from app.models import ConfigBlob -from app.models.llm.request import NativeCompletionConfig def test_create_version_success( @@ -266,6 +267,57 @@ def test_list_versions_different_project_fails( assert response.status_code == 404 +def test_list_versions_with_explicit_default_tag_allows_implicit_default_config( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test explicit default tag scope allows configs created without an explicit tag.""" + config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="implicit-default-version-config", + ) + + response = client.get( + f"{settings.API_V1_STR}/configs/{config.id}/versions", + headers={"X-API-KEY": user_api_key.key}, + params={"tag": ConfigTag.DEFAULT.value}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert len(data["data"]) == 1 + assert data["data"][0]["version"] == 1 + + +def test_list_versions_with_explicit_default_tag_allows_default_config( + db: Session, + client: TestClient, + user_api_key: TestAuthContext, +) -> None: + """Test explicit tag query allows version listing for matching tagged config.""" + config = create_test_config( + db=db, + project_id=user_api_key.project_id, + name="default-version-config", + tag=ConfigTag.DEFAULT, + ) + + response = client.get( + f"{settings.API_V1_STR}/configs/{config.id}/versions", + headers={"X-API-KEY": user_api_key.key}, + params={"tag": ConfigTag.DEFAULT.value}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert len(data["data"]) == 1 + assert data["data"][0]["version"] == 1 + + def test_get_version_by_number( db: Session, client: TestClient, diff --git a/backend/app/tests/api/routes/test_cron.py b/backend/app/tests/api/routes/test_cron.py index 858834dc9..0ed9b6404 100644 --- a/backend/app/tests/api/routes/test_cron.py +++ b/backend/app/tests/api/routes/test_cron.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from fastapi.testclient import TestClient @@ -26,8 +26,13 @@ def test_evaluation_cron_job_success( } with patch( - "app.api.routes.cron.process_all_pending_evaluations_sync", - return_value=mock_result, + "app.api.routes.cron.process_all_pending_evaluations", + new=AsyncMock(return_value=mock_result), + ), patch( + "app.crud.assessment.cron.poll_all_pending_assessment_evaluations", + new=AsyncMock( + return_value={"processed": 0, "failed": 0, "still_processing": 0} + ), ): response = client.get( f"{settings.API_V1_STR}/cron/evaluations", @@ -56,8 +61,13 @@ def test_evaluation_cron_job_no_pending( } with patch( - "app.api.routes.cron.process_all_pending_evaluations_sync", - return_value=mock_result, + "app.api.routes.cron.process_all_pending_evaluations", + new=AsyncMock(return_value=mock_result), + ), patch( + "app.crud.assessment.cron.poll_all_pending_assessment_evaluations", + new=AsyncMock( + return_value={"processed": 0, "failed": 0, "still_processing": 0} + ), ): response = client.get( f"{settings.API_V1_STR}/cron/evaluations", @@ -91,8 +101,13 @@ def test_evaluation_cron_job_with_failures( } with patch( - "app.api.routes.cron.process_all_pending_evaluations_sync", - return_value=mock_result, + "app.api.routes.cron.process_all_pending_evaluations", + new=AsyncMock(return_value=mock_result), + ), patch( + "app.crud.assessment.cron.poll_all_pending_assessment_evaluations", + new=AsyncMock( + return_value={"processed": 0, "failed": 0, "still_processing": 0} + ), ): response = client.get( f"{settings.API_V1_STR}/cron/evaluations", @@ -106,6 +121,75 @@ def test_evaluation_cron_job_with_failures( assert data["total_processed"] == 3 +def test_evaluation_cron_job_merges_assessment_totals( + client: TestClient, + superuser_api_key: TestAuthContext, +) -> None: + """Test cron job aggregates standard + assessment poll totals.""" + mock_result = { + "status": "success", + "total_processed": 2, + "total_failed": 1, + "total_still_processing": 3, + "results": [], + } + assessment_result = {"processed": 4, "failed": 2, "still_processing": 1} + + with patch( + "app.api.routes.cron.process_all_pending_evaluations", + new=AsyncMock(return_value=mock_result), + ), patch( + "app.crud.assessment.cron.poll_all_pending_assessment_evaluations", + new=AsyncMock(return_value=assessment_result), + ): + response = client.get( + f"{settings.API_V1_STR}/cron/evaluations", + headers={"X-API-KEY": superuser_api_key.key}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert data["total_processed"] == 6 + assert data["total_failed"] == 3 + assert data["total_still_processing"] == 4 + assert data["assessment"] == assessment_result + + +def test_evaluation_cron_job_assessment_polling_failure( + client: TestClient, + superuser_api_key: TestAuthContext, +) -> None: + """Test cron keeps response successful when assessment polling fails.""" + mock_result = { + "status": "success", + "total_processed": 3, + "total_failed": 0, + "total_still_processing": 2, + "results": [], + } + + with patch( + "app.api.routes.cron.process_all_pending_evaluations", + new=AsyncMock(return_value=mock_result), + ), patch( + "app.crud.assessment.cron.poll_all_pending_assessment_evaluations", + new=AsyncMock(side_effect=RuntimeError("assessment poll failure")), + ): + response = client.get( + f"{settings.API_V1_STR}/cron/evaluations", + headers={"X-API-KEY": superuser_api_key.key}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert data["total_processed"] == 3 + assert data["total_failed"] == 0 + assert data["total_still_processing"] == 2 + assert "assessment poll failure" in data["assessment_error"] + + def test_evaluation_cron_job_requires_superuser( client: TestClient, user_api_key: TestAuthContext, diff --git a/backend/app/tests/assessment/__init__.py b/backend/app/tests/assessment/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py new file mode 100644 index 000000000..b91e59b2c --- /dev/null +++ b/backend/app/tests/assessment/test_batch.py @@ -0,0 +1,425 @@ +"""Tests for assessment/batch.py provider routing in submit_assessment_batch.""" + +import io +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from openpyxl import Workbook +from openpyxl.utils.exceptions import InvalidFileException + +from app.crud.assessment.batch import ( + _build_text_prompt, + _load_dataset_rows, + _parse_excel_rows, + build_google_jsonl, + build_openai_jsonl, + submit_assessment_batch, +) +from app.models.assessment import AssessmentAttachment +from app.services.assessment.utils.attachments import ( + _decode_base64_prefix, + _guess_image_mime_from_base64, + _guess_image_mime_from_url, + resolve_attachment_values, + resolve_image_mime_and_payload, + split_attachment_urls, + split_data_url, + to_direct_attachment_url, +) + + +def _make_run() -> MagicMock: + run = MagicMock() + run.id = 99 + return run + + +def _make_assessment() -> MagicMock: + assessment = MagicMock() + assessment.id = 21 + assessment.experiment_name = "exp-v1" + return assessment + + +def _make_dataset() -> MagicMock: + dataset = MagicMock() + dataset.id = 8 + return dataset + + +class TestSubmitAssessmentBatchProviderRouting: + def test_openai_native_routes_to_openai_batch(self) -> None: + session = MagicMock() + run = _make_run() + dataset = _make_dataset() + config_blob = SimpleNamespace( + completion=SimpleNamespace( + provider="openai-native", + params={"instructions": "config system"}, + ) + ) + batch_job = MagicMock() + batch_job.id = 1 + batch_job.total_items = 1 + + with ( + patch( + "app.crud.assessment.batch._load_dataset_rows", + return_value=[{"question": "q1"}], + ), + patch( + "app.crud.assessment.batch.map_kaapi_to_openai_params", + return_value=({}, []), + ) as map_params, + patch( + "app.crud.assessment.batch.build_openai_jsonl", + return_value=[{"custom_id": "row_0"}], + ), + patch( + "app.utils.get_openai_client", + return_value=MagicMock(), + ), + patch( + "app.crud.assessment.batch.OpenAIBatchProvider", + return_value=MagicMock(), + ), + patch( + "app.crud.assessment.batch.start_batch_job", + return_value=batch_job, + ) as start_batch, + ): + result = submit_assessment_batch( + session=session, + run=run, + assessment=_make_assessment(), + dataset=dataset, + config_blob=config_blob, + assessment_input={ + "text_columns": ["question"], + "attachments": [], + "system_instruction": "request system", + }, + organization_id=1, + project_id=1, + ) + + assert result.id == 1 + assert map_params.call_args.kwargs["session"] is session + assert map_params.call_args.kwargs["kaapi_params"]["instructions"] == ( + "request system" + ) + assert start_batch.call_args.kwargs["provider_name"] == "openai" + + def test_config_instruction_is_not_used_without_request_instruction(self) -> None: + session = MagicMock() + run = _make_run() + dataset = _make_dataset() + config_blob = SimpleNamespace( + completion=SimpleNamespace( + provider="openai", + params={"instructions": "config system", "model": "gpt-4.1-mini"}, + ) + ) + batch_job = MagicMock() + batch_job.id = 3 + batch_job.total_items = 1 + + with ( + patch( + "app.crud.assessment.batch._load_dataset_rows", + return_value=[{"question": "q1"}], + ), + patch( + "app.crud.assessment.batch.map_kaapi_to_openai_params", + return_value=({"model": "gpt-4.1-mini"}, []), + ) as map_params, + patch( + "app.crud.assessment.batch.build_openai_jsonl", + return_value=[{"custom_id": "row_0"}], + ), + patch( + "app.utils.get_openai_client", + return_value=MagicMock(), + ), + patch( + "app.crud.assessment.batch.OpenAIBatchProvider", + return_value=MagicMock(), + ), + patch( + "app.crud.assessment.batch.start_batch_job", + return_value=batch_job, + ), + ): + submit_assessment_batch( + session=session, + run=run, + assessment=_make_assessment(), + dataset=dataset, + config_blob=config_blob, + assessment_input={"text_columns": ["question"], "attachments": []}, + organization_id=1, + project_id=1, + ) + + assert map_params.call_args.kwargs["session"] is session + assert "instructions" not in map_params.call_args.kwargs["kaapi_params"] + + def test_google_native_routes_to_google_batch(self) -> None: + session = MagicMock() + run = _make_run() + dataset = _make_dataset() + config_blob = SimpleNamespace( + completion=SimpleNamespace( + provider="google-native", + params={"instructions": "config system"}, + ) + ) + batch_job = MagicMock() + batch_job.id = 2 + batch_job.total_items = 1 + gemini_client = MagicMock() + gemini_client.client = MagicMock() + + with ( + patch( + "app.crud.assessment.batch._load_dataset_rows", + return_value=[{"question": "q1"}], + ), + patch( + "app.crud.assessment.batch.map_kaapi_to_google_params", + return_value=({"model": "gemini-2.5-pro"}, []), + ) as map_params, + patch( + "app.crud.assessment.batch.build_google_jsonl", + return_value=[{"key": "row_0"}], + ), + patch("app.core.batch.client.GeminiClient") as gemini_cls, + patch( + "app.core.batch.GeminiBatchProvider", + return_value=MagicMock(), + ), + patch( + "app.crud.assessment.batch.start_batch_job", + return_value=batch_job, + ) as start_batch, + ): + gemini_cls.from_credentials.return_value = gemini_client + result = submit_assessment_batch( + session=session, + run=run, + assessment=_make_assessment(), + dataset=dataset, + config_blob=config_blob, + assessment_input={ + "text_columns": ["question"], + "attachments": [], + "system_instruction": "request system", + }, + organization_id=1, + project_id=1, + ) + + assert result.id == 2 + assert map_params.call_args.args[0]["instructions"] == "request system" + assert start_batch.call_args.kwargs["provider_name"] == "google" + + +class TestBatchDatasetParsing: + def test_load_dataset_rows_routes_xlsx_to_excel_parser(self) -> None: + session = MagicMock() + dataset = MagicMock() + dataset.id = 8 + dataset.project_id = 1 + dataset.object_store_url = "s3://bucket/key" + dataset.dataset_metadata = {"file_extension": ".xlsx"} + + storage = MagicMock() + stream_body = MagicMock() + stream_body.read.return_value = b"xlsx-content" + storage.stream.return_value = stream_body + + expected = [{"question": "q1"}] + with ( + patch("app.crud.assessment.batch.get_cloud_storage", return_value=storage), + patch( + "app.crud.assessment.batch._parse_excel_rows", + return_value=expected, + ) as parse_excel, + ): + result = _load_dataset_rows(session=session, dataset=dataset) + + assert result == expected + parse_excel.assert_called_once_with(b"xlsx-content") + + def test_load_dataset_rows_rejects_legacy_xls(self) -> None: + session = MagicMock() + dataset = MagicMock() + dataset.id = 8 + dataset.project_id = 1 + dataset.object_store_url = "s3://bucket/key" + dataset.dataset_metadata = {"file_extension": ".xls"} + + storage = MagicMock() + stream_body = MagicMock() + stream_body.read.return_value = b"legacy-xls-content" + storage.stream.return_value = stream_body + + with patch("app.crud.assessment.batch.get_cloud_storage", return_value=storage): + with pytest.raises(ValueError, match="Legacy Excel format"): + _load_dataset_rows(session=session, dataset=dataset) + + def test_parse_excel_rows_invalid_payload_raises(self) -> None: + with pytest.raises((ValueError, InvalidFileException)): + _parse_excel_rows(b"not-a-valid-xlsx") + + def test_parse_excel_rows_success(self) -> None: + wb = Workbook() + ws = wb.active + assert ws is not None + ws.append(["question", "answer"]) + ws.append(["What is 2+2?", "4"]) + ws.append(["", None]) # empty row should be skipped + buf = io.BytesIO() + wb.save(buf) + wb.close() + + rows = _parse_excel_rows(buf.getvalue()) + assert rows == [{"question": "What is 2+2?", "answer": "4"}] + + def test_parse_excel_rows_returns_empty_when_sheet_missing(self) -> None: + fake_wb = MagicMock() + fake_wb.active = None + with patch( + "app.crud.assessment.batch.openpyxl.load_workbook", return_value=fake_wb + ): + assert _parse_excel_rows(b"irrelevant") == [] + fake_wb.close.assert_called_once() + + def test_parse_excel_rows_returns_empty_when_header_missing(self) -> None: + fake_ws = MagicMock() + fake_ws.iter_rows.return_value = iter([]) + fake_wb = MagicMock() + fake_wb.active = fake_ws + with patch( + "app.crud.assessment.batch.openpyxl.load_workbook", return_value=fake_wb + ): + assert _parse_excel_rows(b"irrelevant") == [] + fake_wb.close.assert_called_once() + + def test_parse_excel_rows_invalid_file_exception_re_raises(self) -> None: + with patch( + "app.crud.assessment.batch.openpyxl.load_workbook", + side_effect=InvalidFileException("bad xlsx"), + ): + with pytest.raises(InvalidFileException): + _parse_excel_rows(b"bad") + + def test_parse_excel_rows_unexpected_exception_raises_value_error(self) -> None: + with patch( + "app.crud.assessment.batch.openpyxl.load_workbook", + side_effect=RuntimeError("boom"), + ): + with pytest.raises(ValueError, match="Failed to parse XLSX dataset rows"): + _parse_excel_rows(b"bad") + + +class TestBatchHelpers: + def test_build_text_prompt_template_and_concat(self) -> None: + row = {"q": " What? ", "ctx": "Context"} + templated = _build_text_prompt(row, ["q", "ctx"], "Q:{q}\nC:{ctx}") + assert "Q:" in templated + assert "What?" in templated + concatenated = _build_text_prompt(row, ["q", "ctx"], None) + assert "What?" in concatenated + assert concatenated.endswith("\nContext") + + def test_split_and_direct_urls(self) -> None: + urls = split_attachment_urls(" https://a.com\nhttps://b.com , https://c.com ") + assert urls == ["https://a.com", "https://b.com", "https://c.com"] + image_url = to_direct_attachment_url( + "https://drive.google.com/file/d/abc123/view?usp=sharing", "image" + ) + assert "googleusercontent.com" in image_url + pdf_url = to_direct_attachment_url( + "https://drive.google.com/open?id=abc123", "pdf" + ) + assert "drive.google.com/uc" in pdf_url + + def test_data_url_and_mime_guessers(self) -> None: + mime, payload = split_data_url("data:image/png;base64,AAAA") + assert mime == "image/png" + assert payload == "AAAA" + none_mime, raw = split_data_url("rawbase64") + assert none_mime is None + assert raw == "rawbase64" + assert _guess_image_mime_from_url("https://x/y/file.jpeg") == "image/jpeg" + assert _guess_image_mime_from_url("https://x/y/file.unknown") is None + + def test_base64_guess_and_decode(self) -> None: + png_head = "iVBORw0KGgoAAAANSUhEUg==" + assert _guess_image_mime_from_base64(png_head) == "image/png" + assert _decode_base64_prefix("###") == b"" + + def testresolve_image_mime_and_payload(self) -> None: + mime, payload = resolve_image_mime_and_payload("https://x/y/file.webp", "url") + assert mime == "image/webp" + assert payload.endswith("file.webp") + mime2, payload2 = resolve_image_mime_and_payload( + "data:image/jpeg;base64,AAAA", "base64" + ) + assert mime2 == "image/jpeg" + assert payload2 == "AAAA" + + def testresolve_attachment_values(self) -> None: + image_url_att = AssessmentAttachment(column="img", type="image", format="url") + image_b64_att = AssessmentAttachment( + column="img", type="image", format="base64" + ) + pdf_url_att = AssessmentAttachment(column="pdf", type="pdf", format="url") + pdf_b64_att = AssessmentAttachment(column="pdf", type="pdf", format="base64") + + values = resolve_attachment_values( + "https://example.com/a.png,https://example.com/b.png", image_url_att + ) + assert len(values) == 2 + assert values[0]["type"] == "input_image" + + values = resolve_attachment_values("data:image/png;base64,AAAA", image_b64_att) + assert values[0]["image_url"].startswith("data:image/png;base64,") + + values = resolve_attachment_values("https://example.com/a.pdf", pdf_url_att) + assert values[0]["type"] == "input_file" + assert "file_url" in values[0] + + values = resolve_attachment_values( + "data:application/pdf;base64,AAAA", pdf_b64_att + ) + assert values[0]["file_data"].startswith("data:application/pdf;base64,") + + def test_build_openai_and_google_jsonl(self) -> None: + rows = [{"q": "What is 2+2?", "img": "https://example.com/a.png"}] + attachments = [AssessmentAttachment(column="img", type="image", format="url")] + + openai_jsonl = build_openai_jsonl( + rows=rows, + text_columns=["q"], + attachments=attachments, + prompt_template=None, + openai_params={"model": "gpt-4.1-mini"}, + ) + assert len(openai_jsonl) == 1 + assert openai_jsonl[0]["custom_id"] == "row_0" + + google_jsonl = build_google_jsonl( + rows=rows, + text_columns=["q"], + attachments=attachments, + prompt_template=None, + google_params={"temperature": 0.2, "instructions": "system"}, + ) + assert len(google_jsonl) == 1 + assert google_jsonl[0]["metadata"]["key"] == "row_0" + assert google_jsonl[0]["request"]["systemInstruction"] == { + "parts": [{"text": "system"}] + } diff --git a/backend/app/tests/assessment/test_cron.py b/backend/app/tests/assessment/test_cron.py new file mode 100644 index 000000000..c9407bd5c --- /dev/null +++ b/backend/app/tests/assessment/test_cron.py @@ -0,0 +1,190 @@ +"""Tests for assessment/cron.py helper functions.""" + +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.crud.assessment.cron import ( + _log_config_progress, + poll_all_pending_assessment_evaluations, +) + + +def _make_assessment(*, id: int = 1, status: str = "processing") -> MagicMock: + a = MagicMock() + a.id = id + a.status = status + a.experiment_name = "exp" + return a + + +def _make_run( + *, + id: int = 10, + assessment_id: int = 1, + config_id=None, + config_version: int | None = 1, + updated_at=None, +) -> MagicMock: + r = MagicMock() + r.id = id + r.assessment_id = assessment_id + r.config_id = config_id + r.config_version = config_version + r.updated_at = updated_at or datetime(2024, 6, 1, 12, 0, 0) + return r + + +class TestLogConfigProgress: + def test_no_log_for_no_change_action(self) -> None: + run = _make_run() + assessment = _make_assessment() + assert _log_config_progress({"action": "no_change"}, run, assessment) is None + + def test_no_log_for_still_processing(self) -> None: + run = _make_run() + assessment = _make_assessment() + assert ( + _log_config_progress({"action": "still_processing"}, run, assessment) + is None + ) + + def test_processed_action_does_not_raise(self) -> None: + run = _make_run() + assessment = _make_assessment() + _log_config_progress( + { + "action": "processed", + "current_status": "completed", + "provider_status": "completed", + }, + run, + assessment, + ) + + def test_failed_action_does_not_raise(self) -> None: + run = _make_run() + assessment = _make_assessment() + _log_config_progress( + { + "action": "failed", + "current_status": "failed", + "provider_status": "failed", + }, + run, + assessment, + ) + + +class TestPollAllPendingAssessmentEvaluations: + @pytest.mark.asyncio + async def test_no_pending_assessments(self) -> None: + session = MagicMock() + session.exec.return_value.all.return_value = [] + result = await poll_all_pending_assessment_evaluations(session=session) + assert result["total"] == 0 + assert result["processed"] == 0 + + @pytest.mark.asyncio + async def test_no_active_runs_recompute(self) -> None: + session = MagicMock() + assessment = _make_assessment(id=1, status="processing") + session.exec.return_value.all.return_value = [assessment] + refreshed = _make_assessment(id=1, status="processing") + + run = _make_run(id=11, config_version=1) + run.status = "completed" + + with patch( + "app.crud.assessment.cron.get_assessment_runs_for_assessment", + return_value=[run], + ), patch( + "app.crud.assessment.cron.recompute_assessment_status", + return_value=refreshed, + ), patch( + "app.crud.assessment.cron.check_and_process_assessment", new=AsyncMock() + ): + result = await poll_all_pending_assessment_evaluations(session=session) + + assert result["total"] == 1 + assert result["still_processing"] == 1 + + @pytest.mark.asyncio + async def test_active_run_processed(self) -> None: + session = MagicMock() + assessment = _make_assessment(id=1, status="processing") + run = _make_run(id=11) + run.status = "processing" + session.exec.return_value.all.return_value = [assessment] + + with patch( + "app.crud.assessment.cron.get_assessment_runs_for_assessment", + return_value=[run], + ), patch( + "app.crud.assessment.cron.check_and_process_assessment", + new=AsyncMock( + return_value={ + "action": "processed", + "current_status": "completed", + "provider_status": "completed", + } + ), + ): + result = await poll_all_pending_assessment_evaluations(session=session) + + assert result["processed"] == 1 + + @pytest.mark.asyncio + async def test_active_run_failure_and_cleanup_failure(self) -> None: + session = MagicMock() + assessment = _make_assessment(id=1, status="processing") + run = _make_run(id=11) + run.status = "processing" + session.exec.return_value.all.return_value = [assessment] + + with patch( + "app.crud.assessment.cron.get_assessment_runs_for_assessment", + return_value=[run], + ), patch( + "app.crud.assessment.cron.check_and_process_assessment", + new=AsyncMock(side_effect=RuntimeError("boom")), + ), patch( + "app.crud.assessment.cron.update_assessment_run_status", + side_effect=RuntimeError("cleanup-failed"), + ), patch( + "app.crud.assessment.cron.recompute_assessment_status", + ): + result = await poll_all_pending_assessment_evaluations(session=session) + + assert result["failed"] == 1 + + @pytest.mark.asyncio + async def test_active_run_failure_updates_db_with_same_error_message(self) -> None: + session = MagicMock() + assessment = _make_assessment(id=1, status="processing") + run = _make_run(id=11) + run.status = "processing" + session.exec.return_value.all.return_value = [assessment] + + with patch( + "app.crud.assessment.cron.get_assessment_runs_for_assessment", + return_value=[run], + ), patch( + "app.crud.assessment.cron.check_and_process_assessment", + new=AsyncMock(side_effect=RuntimeError("gemini quota exceeded")), + ), patch( + "app.crud.assessment.cron.update_assessment_run_status", + ) as update_run, patch( + "app.crud.assessment.cron.recompute_assessment_status", + ): + result = await poll_all_pending_assessment_evaluations(session=session) + + assert result["failed"] == 1 + assert result["details"][0]["error"] == "gemini quota exceeded" + update_run.assert_called_once_with( + session=session, + run=run, + status="failed", + error_message="gemini quota exceeded", + ) diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py new file mode 100644 index 000000000..f77bcd8be --- /dev/null +++ b/backend/app/tests/assessment/test_crud.py @@ -0,0 +1,297 @@ +"""Tests for assessment/crud.py.""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock +from uuid import UUID + +import pytest +from fastapi import HTTPException + +from app.crud.assessment import ( + AssessmentRunCounts, + build_run_stats, + compute_run_counts, + create_assessment, + create_assessment_dataset, + create_assessment_run, + derive_aggregate_error, + derive_assessment_status, + get_assessment_by_id, + get_assessment_dataset_by_id, + get_assessment_run_by_id, + get_assessment_runs_for_assessment, + list_assessment_runs, + list_assessments, + recompute_assessment_status, + update_assessment_run_status, +) +from app.models.stt_evaluation import EvaluationType + + +def _counts(total=0, pending=0, processing=0, completed=0, failed=0): + return AssessmentRunCounts( + total=total, + pending=pending, + processing=processing, + completed=completed, + failed=failed, + ) + + +class TestDeriveAssessmentStatus: + def test_status_variants(self) -> None: + assert derive_assessment_status(_counts()) == "pending" + assert derive_assessment_status(_counts(total=2, completed=2)) == "completed" + assert derive_assessment_status(_counts(total=2, failed=2)) == "failed" + assert ( + derive_assessment_status(_counts(total=2, completed=1, failed=1)) + == "completed_with_errors" + ) + assert derive_assessment_status(_counts(total=2, pending=2)) == "pending" + assert ( + derive_assessment_status(_counts(total=2, pending=1, processing=1)) + == "processing" + ) + + +class TestCrudBasicQueries: + def test_get_and_list_helpers(self) -> None: + session = MagicMock() + session.exec.return_value.first.return_value = "assessment" + session.exec.return_value.all.return_value = ["a1", "a2"] + + assert get_assessment_by_id(session, 1, 1, 1) == "assessment" + assert list_assessments(session, 1, 1, 10, 0) == ["a1", "a2"] + assert get_assessment_run_by_id(session, 1, 1, 1) == "assessment" + assert list_assessment_runs(session, 1, 1, None, 10, 0) == ["a1", "a2"] + + def test_get_assessment_by_id_not_found(self) -> None: + session = MagicMock() + session.exec.return_value.first.return_value = None + with pytest.raises(HTTPException) as exc_info: + get_assessment_by_id(session, 99, 1, 1) + assert exc_info.value.status_code == 404 + assert "99" in exc_info.value.detail + + def test_get_assessment_run_by_id_not_found(self) -> None: + session = MagicMock() + session.exec.return_value.first.return_value = None + with pytest.raises(HTTPException) as exc_info: + get_assessment_run_by_id(session, 99, 1, 1) + assert exc_info.value.status_code == 404 + assert "99" in exc_info.value.detail + + def test_get_assessment_dataset_by_id_not_found(self) -> None: + session = MagicMock() + session.exec.return_value.first.return_value = None + with pytest.raises(HTTPException) as exc_info: + get_assessment_dataset_by_id( + session=session, + dataset_id=99, + organization_id=1, + project_id=1, + ) + assert exc_info.value.status_code == 404 + assert "99" in exc_info.value.detail + + def test_get_assessment_runs_for_assessment(self) -> None: + session = MagicMock() + session.exec.return_value.all.return_value = ["r1", "r2"] + assert get_assessment_runs_for_assessment(session, 10) == ["r1", "r2"] + + +class TestCrudWrites: + def test_create_assessment_dataset_uses_assessment_type(self) -> None: + session = MagicMock() + result = create_assessment_dataset( + session=session, + name="dataset", + description="desc", + dataset_metadata={"total_items_count": 2}, + object_store_url="s3://datasets/file.csv", + organization_id=1, + project_id=1, + ) + + assert result.type == EvaluationType.ASSESSMENT.value + session.add.assert_called_once() + session.commit.assert_called_once() + session.refresh.assert_called_once() + + def test_create_assessment_success(self) -> None: + session = MagicMock() + result = create_assessment( + session=session, + experiment_name="exp", + dataset_id=1, + organization_id=1, + project_id=1, + ) + assert result.experiment_name == "exp" + session.add.assert_called_once() + session.commit.assert_called_once() + session.refresh.assert_called_once() + + def test_create_assessment_commit_failure_rolls_back(self) -> None: + session = MagicMock() + session.commit.side_effect = RuntimeError("db error") + with pytest.raises(RuntimeError): + create_assessment(session, "exp", 1, 1, 1) + session.rollback.assert_called_once() + + def test_create_assessment_run_success_and_failure(self) -> None: + session = MagicMock() + run = create_assessment_run( + session=session, + assessment_id=10, + config_id=UUID("00000000-0000-0000-0000-000000000001"), + config_version=1, + assessment_input={"k": "v"}, + ) + assert run.assessment_id == 10 + assert run.input == {"k": "v"} + + session2 = MagicMock() + session2.commit.side_effect = RuntimeError("db error") + with pytest.raises(RuntimeError): + create_assessment_run( + session=session2, + assessment_id=10, + config_id=UUID("00000000-0000-0000-0000-000000000001"), + config_version=1, + assessment_input={}, + ) + session2.rollback.assert_called_once() + + def test_update_assessment_run_status(self) -> None: + session = MagicMock() + run = SimpleNamespace( + status="pending", + updated_at=None, + error_message=None, + batch_job_id=None, + total_items=0, + object_store_url=None, + ) + updated = update_assessment_run_status( + session=session, + run=run, + status="processing", + error_message="e", + batch_job_id=11, + total_items=9, + object_store_url="s3://x", + ) + assert updated.status == "processing" + assert updated.error_message == "e" + assert updated.batch_job_id == 11 + assert updated.total_items == 9 + assert updated.object_store_url == "s3://x" + + def test_update_assessment_run_status_failure_rolls_back(self) -> None: + session = MagicMock() + session.commit.side_effect = RuntimeError("db error") + run = SimpleNamespace( + status="pending", + updated_at=None, + error_message=None, + batch_job_id=None, + total_items=0, + object_store_url=None, + ) + with pytest.raises(RuntimeError): + update_assessment_run_status(session, run, "failed") + session.rollback.assert_called_once() + + +class TestDerivedAggregates: + def test_compute_run_counts(self) -> None: + runs = [ + SimpleNamespace(status="completed"), + SimpleNamespace(status="failed"), + SimpleNamespace(status="processing"), + SimpleNamespace(status="pending"), + ] + counts = compute_run_counts(runs) + assert counts.total == 4 + assert counts.completed == 1 + assert counts.failed == 1 + assert counts.processing == 1 + assert counts.pending == 1 + + def test_build_run_stats(self) -> None: + runs = [ + SimpleNamespace( + id=1, + config_id=UUID("00000000-0000-0000-0000-000000000001"), + config_version=1, + status="completed", + total_items=2, + error_message=None, + updated_at=datetime(2024, 1, 1), + ), + ] + stats = build_run_stats(runs) + assert len(stats) == 1 + assert stats[0].run_id == 1 + assert stats[0].status == "completed" + + def test_derive_aggregate_error(self) -> None: + assert derive_aggregate_error(_counts(total=2, completed=2)) is None + assert ( + derive_aggregate_error(_counts(total=3, completed=1, failed=2)) + == "2 of 3 run(s) failed" + ) + + +class TestRecomputeAssessmentStatus: + def test_recompute_not_found(self) -> None: + session = MagicMock() + session.get.return_value = None + with pytest.raises(ValueError, match="not found"): + recompute_assessment_status(session=session, assessment_id=1) + + def test_recompute_success_persists_status_only(self) -> None: + session = MagicMock() + assessment = SimpleNamespace( + id=1, status="pending", updated_at=datetime(2024, 1, 1) + ) + runs = [ + SimpleNamespace( + id=1, + config_id=UUID("00000000-0000-0000-0000-000000000001"), + config_version=1, + status="completed", + total_items=2, + error_message=None, + updated_at=datetime(2024, 1, 1), + ), + SimpleNamespace( + id=2, + config_id=None, + config_version=2, + status="failed", + total_items=2, + error_message="bad", + updated_at=datetime(2024, 1, 2), + ), + ] + session.get.return_value = assessment + session.exec.return_value.all.return_value = runs + + result = recompute_assessment_status(session=session, assessment_id=1) + assert result.status == "completed_with_errors" + session.commit.assert_called_once() + + def test_recompute_commit_failure_rolls_back(self) -> None: + session = MagicMock() + assessment = SimpleNamespace( + id=1, status="pending", updated_at=datetime(2024, 1, 1) + ) + session.get.return_value = assessment + session.exec.return_value.all.return_value = [] + session.commit.side_effect = RuntimeError("db error") + with pytest.raises(RuntimeError): + recompute_assessment_status(session=session, assessment_id=1) + session.rollback.assert_called_once() diff --git a/backend/app/tests/assessment/test_dataset.py b/backend/app/tests/assessment/test_dataset.py new file mode 100644 index 000000000..ceca9c854 --- /dev/null +++ b/backend/app/tests/assessment/test_dataset.py @@ -0,0 +1,152 @@ +"""Tests for assessment/dataset.py upload and row counting behavior.""" + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException +from openpyxl.utils.exceptions import InvalidFileException + +from app.services.assessment.dataset import ( + _count_csv_rows, + _count_excel_rows, + _count_rows, + upload_dataset, +) + + +class TestCountRows: + def test_legacy_xls_rejected(self) -> None: + with pytest.raises(ValueError, match="Legacy Excel format"): + _count_rows(b"legacy-xls-content", ".xls") + + def test_count_excel_rows_invalid_file_re_raises(self) -> None: + with patch( + "openpyxl.load_workbook", + side_effect=InvalidFileException("bad xlsx"), + ): + with pytest.raises(InvalidFileException): + _count_excel_rows(b"bad") + + def test_count_excel_rows_unexpected_error_raises_value_error(self) -> None: + with patch("openpyxl.load_workbook", side_effect=RuntimeError("boom")): + with pytest.raises(ValueError, match="Failed to parse XLSX file"): + _count_excel_rows(b"bad") + + def test_count_csv_rows(self) -> None: + assert _count_csv_rows(b"a,b\n1,2\n\n3,4\n") == 2 + + def test_count_rows_csv_and_xlsx(self) -> None: + with patch("app.services.assessment.dataset._count_excel_rows", return_value=5): + assert _count_rows(b"x", ".xlsx") == 5 + assert _count_rows(b"a,b\n1,2\n", ".csv") == 1 + + +class TestUploadDataset: + def test_invalid_xlsx_returns_422(self) -> None: + session = MagicMock() + with patch( + "app.services.assessment.dataset.sanitize_dataset_name", return_value="ds-1" + ), patch( + "app.services.assessment.dataset._count_rows", + side_effect=InvalidFileException("bad xlsx"), + ): + with pytest.raises(HTTPException) as exc_info: + upload_dataset( + session=session, + file_content=b"invalid-xlsx", + file_ext=".xlsx", + dataset_name="ds-1", + description=None, + organization_id=1, + project_id=1, + ) + assert exc_info.value.status_code == 422 + assert "Invalid XLSX file content" in exc_info.value.detail + + def test_count_rows_value_error_returns_422(self) -> None: + session = MagicMock() + with patch( + "app.services.assessment.dataset.sanitize_dataset_name", return_value="ds-1" + ), patch( + "app.services.assessment.dataset._count_rows", + side_effect=ValueError("Legacy Excel format (.xls) is not supported."), + ): + with pytest.raises(HTTPException) as exc_info: + upload_dataset( + session=session, + file_content=b"bad", + file_ext=".xls", + dataset_name="ds-1", + description=None, + organization_id=1, + project_id=1, + ) + assert exc_info.value.status_code == 422 + assert "Legacy Excel format" in exc_info.value.detail + + def test_count_rows_unexpected_error_returns_generic_422(self) -> None: + session = MagicMock() + with patch( + "app.services.assessment.dataset.sanitize_dataset_name", return_value="ds-1" + ), patch( + "app.services.assessment.dataset._count_rows", + side_effect=RuntimeError("unexpected"), + ): + with pytest.raises(HTTPException) as exc_info: + upload_dataset( + session=session, + file_content=b"bad", + file_ext=".xlsx", + dataset_name="ds-1", + description=None, + organization_id=1, + project_id=1, + ) + assert exc_info.value.status_code == 422 + assert "Unable to parse dataset file" in exc_info.value.detail + + def test_upload_dataset_success(self) -> None: + session = MagicMock() + created = MagicMock() + created.id = 9 + with patch( + "app.services.assessment.dataset.sanitize_dataset_name", return_value="ds-1" + ), patch("app.services.assessment.dataset._count_rows", return_value=2), patch( + "app.services.assessment.dataset._upload_file_to_object_store", + return_value="s3://datasets/file.csv", + ), patch( + "app.services.assessment.dataset.create_assessment_dataset", + return_value=created, + ) as create_ds: + result = upload_dataset( + session=session, + file_content=b"a,b\n1,2\n", + file_ext=".csv", + dataset_name="ds-1", + description="desc", + organization_id=1, + project_id=1, + ) + assert result.id == 9 + create_ds.assert_called_once() + assert create_ds.call_args.kwargs["dataset_metadata"]["total_items_count"] == 2 + + def test_upload_dataset_object_store_failure_returns_500(self) -> None: + session = MagicMock() + with patch( + "app.services.assessment.dataset.sanitize_dataset_name", return_value="ds-1" + ), patch("app.services.assessment.dataset._count_rows", return_value=1), patch( + "app.services.assessment.dataset._upload_file_to_object_store", + return_value=None, + ): + with pytest.raises(HTTPException) as exc_info: + upload_dataset( + session=session, + file_content=b"a,b\n1,2\n", + file_ext=".csv", + dataset_name="ds-1", + description=None, + organization_id=1, + project_id=1, + ) + assert exc_info.value.status_code == 500 diff --git a/backend/app/tests/assessment/test_export.py b/backend/app/tests/assessment/test_export.py new file mode 100644 index 000000000..3ace89dbd --- /dev/null +++ b/backend/app/tests/assessment/test_export.py @@ -0,0 +1,650 @@ +"""Tests for assessment/utils/export.py helper functions.""" + +import json +from datetime import datetime +from unittest.mock import MagicMock, patch + +from app.models.assessment import AssessmentExportRow +from app.services.assessment.utils.export import ( + _drop_empty_columns, + _expand_input_columns, + _expand_output_columns, + _load_dataset_rows_for_run, + _load_parsed_results_for_run, + _safe_filename_part, + build_json_export_rows, + load_export_rows_for_run, + serialize_export_rows, + sort_export_rows, +) + + +def _make_row( + *, + run_id: int = 1, + row_id: str = "row_0", + output: str | None = None, + input_data: dict | None = None, + result_status: str = "passed", + config_version: int | None = None, +) -> AssessmentExportRow: + return AssessmentExportRow( + assessment_id=1, + experiment_name="exp", + dataset_id=1, + dataset_name="ds", + run_id=run_id, + run_name="run", + run_status="completed", + config_id=None, + config_version=config_version, + row_id=row_id, + result_status=result_status, + input_data=input_data, + output=output, + error=None, + response_id=None, + input_tokens=None, + output_tokens=None, + total_tokens=None, + updated_at=datetime(2024, 1, 1), + ) + + +class TestSafeFilenamePart: + def test_alphanumeric_unchanged(self) -> None: + assert _safe_filename_part("my_export") == "my_export" + + def test_spaces_replaced(self) -> None: + result = _safe_filename_part("my export file") + assert " " not in result + + def test_special_chars_replaced(self) -> None: + result = _safe_filename_part("hello/world:test") + assert "/" not in result + assert ":" not in result + + def test_empty_string_returns_default(self) -> None: + assert _safe_filename_part("") == "assessment_results" + + def test_only_special_chars_returns_default(self) -> None: + assert _safe_filename_part("!!!") == "assessment_results" + + def test_preserves_dots_and_hyphens(self) -> None: + result = _safe_filename_part("my-file.v2") + assert "." in result + assert "-" in result + + +class TestExpandInputColumns: + def test_no_input_data_removes_key(self) -> None: + rows = [{"output": "x", "input_data": None}] + expanded, keys = _expand_input_columns(rows) + assert keys == [] + assert "input_data" not in expanded[0] + + def test_input_data_dict_expanded(self) -> None: + rows = [{"input_data": {"question": "q1", "context": "c1"}, "output": "x"}] + expanded, keys = _expand_input_columns(rows) + assert "question" in keys + assert "context" in keys + assert expanded[0]["question"] == "q1" + assert expanded[0]["context"] == "c1" + assert "input_data" not in expanded[0] + + def test_multiple_rows_union_of_keys(self) -> None: + rows = [ + {"input_data": {"a": "1"}, "output": "x"}, + {"input_data": {"b": "2"}, "output": "y"}, + ] + _, keys = _expand_input_columns(rows) + assert "a" in keys + assert "b" in keys + + def test_missing_key_in_row_gets_none(self) -> None: + rows = [ + {"input_data": {"a": "1", "b": "2"}, "output": "x"}, + {"input_data": {"a": "3"}, "output": "y"}, + ] + expanded, _ = _expand_input_columns(rows) + assert expanded[1].get("b") is None + + def test_reserved_field_collision_namespaced(self) -> None: + rows = [ + { + "input_data": {"output": "expected answer", "question": "q1"}, + "output": "model answer", + } + ] + expanded, keys = _expand_input_columns(rows) + assert "input_output" in keys + assert "question" in keys + assert expanded[0]["input_output"] == "expected answer" + assert expanded[0]["output"] == "model answer" + + +class TestDropEmptyColumns: + def test_keeps_non_empty_columns(self) -> None: + rows = [{"a": "val", "b": None}, {"a": "val2", "b": ""}] + result_rows, result_fields = _drop_empty_columns(rows, ["a", "b"]) + assert "a" in result_fields + assert "b" not in result_fields + + def test_no_change_when_all_have_values(self) -> None: + rows = [{"a": "1", "b": "2"}] + result_rows, result_fields = _drop_empty_columns(rows, ["a", "b"]) + assert result_fields == ["a", "b"] + + def test_all_empty_drops_all(self) -> None: + rows = [{"a": None, "b": None}] + _, result_fields = _drop_empty_columns(rows, ["a", "b"]) + assert result_fields == [] + + +class TestExpandOutputColumns: + def test_plain_string_output_not_expanded(self) -> None: + rows = [{"output": "plain text", "input_data": None}] + expanded, fieldnames = _expand_output_columns(rows) + assert "output" in fieldnames + + def test_json_dict_output_expanded(self) -> None: + rows = [ + {"output": json.dumps({"score": 5, "reason": "good"}), "input_data": None} + ] + expanded, fieldnames = _expand_output_columns(rows) + assert "score" in fieldnames + assert "reason" in fieldnames + assert expanded[0]["score"] == 5 + + def test_mixed_parsed_and_unparsed_adds_output_raw(self) -> None: + rows = [ + {"output": json.dumps({"score": 3}), "input_data": None}, + {"output": "not json", "input_data": None}, + ] + expanded, fieldnames = _expand_output_columns(rows) + assert "output_raw" in fieldnames + # Second row that didn't parse should get output_raw + assert expanded[1].get("output_raw") == "not json" + + def test_none_output_handled(self) -> None: + rows = [{"output": None, "input_data": None}] + expanded, fieldnames = _expand_output_columns(rows) + assert expanded[0].get("output") is None + + +class TestSerializeExportRows: + def _make_rows(self) -> list[AssessmentExportRow]: + return [ + _make_row(row_id="row_0", output=json.dumps({"score": 4})), + _make_row(row_id="row_1", output=json.dumps({"score": 2})), + ] + + def test_json_format_returns_json_bytes(self) -> None: + rows = self._make_rows() + payload, media_type = serialize_export_rows(rows, "json") + assert media_type == "application/json" + parsed = json.loads(payload) + assert isinstance(parsed, list) + assert len(parsed) == 2 + + def test_csv_format_returns_csv_bytes(self) -> None: + rows = self._make_rows() + payload, media_type = serialize_export_rows(rows, "csv") + assert media_type == "text/csv" + content = payload.decode("utf-8") + assert "score" in content + + def test_csv_contains_all_rows(self) -> None: + rows = self._make_rows() + payload, _ = serialize_export_rows(rows, "csv") + lines = [line for line in payload.decode("utf-8").splitlines() if line.strip()] + assert len(lines) == 3 # header + 2 data rows + + def test_json_with_no_output(self) -> None: + rows = [_make_row(output=None)] + payload, media_type = serialize_export_rows(rows, "json") + assert media_type == "application/json" + parsed = json.loads(payload) + assert len(parsed) == 1 + + def test_csv_with_input_data(self) -> None: + rows = [_make_row(input_data={"question": "What?", "context": "Some context"})] + payload, _ = serialize_export_rows(rows, "csv") + content = payload.decode("utf-8") + assert "question" in content + assert "context" in content + + +class TestSortExportRows: + def test_sorts_by_config_version_then_numeric_row_index(self) -> None: + rows = [ + _make_row(run_id=1, row_id="row_1", config_version=2), + _make_row(run_id=2, row_id="row_0", config_version=1), + _make_row(run_id=3, row_id="row_10", config_version=2), + _make_row(run_id=4, row_id="row_2", config_version=2), + ] + sorted_rows = sort_export_rows(rows) + assert sorted_rows[0].config_version == 1 + assert sorted_rows[1].config_version == 2 + assert [r.row_id for r in sorted_rows[1:]] == ["row_1", "row_2", "row_10"] + + def test_none_config_version_treated_as_zero(self) -> None: + rows = [ + _make_row(run_id=1, row_id="row_0", config_version=1), + _make_row(run_id=2, row_id="row_0", config_version=None), + ] + sorted_rows = sort_export_rows(rows) + assert sorted_rows[0].config_version is None + + def test_invalid_row_id_suffix_falls_back_to_zero(self) -> None: + rows = [ + _make_row(run_id=3, row_id="row_2", config_version=1), + _make_row(run_id=2, row_id="row_xyz", config_version=1), + _make_row(run_id=1, row_id="bad", config_version=1), + ] + sorted_rows = sort_export_rows(rows) + assert [r.run_id for r in sorted_rows] == [1, 2, 3] + + def test_empty_list(self) -> None: + assert sort_export_rows([]) == [] + + +class TestExpandOutputColumnsDictOutput: + def test_dict_output_expanded_directly(self) -> None: + # raw output is already a dict (not a JSON string) + rows = [{"output": {"score": 9, "label": "good"}, "input_data": None}] + expanded, fieldnames = _expand_output_columns(rows) + assert "score" in fieldnames + assert expanded[0]["score"] == 9 + + def test_non_dict_non_string_output_treated_as_unparsed(self) -> None: + rows = [{"output": 42, "input_data": None}] + expanded, fieldnames = _expand_output_columns(rows) + # 42 is not a dict/string, treated as unparsed → output stays as-is + assert "output" in fieldnames + + +class TestSerializeExportRowsXlsx: + def test_xlsx_format_returns_xlsx_bytes(self) -> None: + rows = [_make_row(output=json.dumps({"score": 3}))] + payload, media_type = serialize_export_rows(rows, "xlsx") + assert ( + media_type + == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ) + assert len(payload) > 0 + + def test_xlsx_no_excel_fields_falls_back_to_output(self) -> None: + # Row with no output — excel_fields may be empty after filtering metadata + rows = [_make_row(output=None)] + _, media_type = serialize_export_rows(rows, "xlsx") + assert ( + media_type + == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ) + + +class TestBuildJsonExportRows: + def test_returns_expanded_list(self) -> None: + rows = [_make_row(output=json.dumps({"score": 7}))] + result = build_json_export_rows(rows) + assert isinstance(result, list) + assert result[0]["score"] == 7 + + def test_empty_input_returns_empty_list(self) -> None: + assert build_json_export_rows([]) == [] + + +class TestBuildExportResponse: + def test_returns_streaming_response_with_disposition(self) -> None: + from app.services.assessment.utils.export import build_export_response + + rows = [_make_row(output=json.dumps({"score": 3}))] + with patch( + "app.services.assessment.utils.export.generate_timestamped_filename", + return_value="export_2024.csv", + ): + response = build_export_response(rows, "csv", "my experiment") + + assert response.media_type == "text/csv" + assert "export_2024.csv" in response.headers["content-disposition"] + + def test_json_format_returns_json_response(self) -> None: + from app.services.assessment.utils.export import build_export_response + + rows = [_make_row(output='{"score": 5}')] + with patch( + "app.services.assessment.utils.export.generate_timestamped_filename", + return_value="export_2024.json", + ): + response = build_export_response(rows, "json", "exp") + + assert response.media_type == "application/json" + + +class TestLoadParsedResultsForRun: + def _make_run(self, *, object_store_url: str | None = None) -> MagicMock: + run = MagicMock() + run.id = 1 + run.project_id = 1 + run.organization_id = 1 + run.object_store_url = object_store_url + return run + + def _make_batch_job( + self, *, provider: str = "openai", provider_output_file_id: str | None = None + ) -> MagicMock: + job = MagicMock() + job.provider = provider + job.provider_output_file_id = provider_output_file_id + return job + + def test_no_url_no_file_id_returns_none(self) -> None: + session = MagicMock() + run = self._make_run() + batch_job = self._make_batch_job() + result = _load_parsed_results_for_run( + session=session, run=run, batch_job=batch_job + ) + assert result is None + + def test_s3_success_returns_parsed(self) -> None: + session = MagicMock() + run = self._make_run(object_store_url="s3://bucket/file.jsonl") + batch_job = self._make_batch_job() + + raw_line = json.dumps( + { + "custom_id": "row_0", + "response": { + "status_code": 200, + "body": {"output_text": "hello", "usage": {}}, + }, + "error": None, + } + ) + mock_body = MagicMock() + mock_body.read.return_value = raw_line.encode() + mock_storage = MagicMock() + mock_storage.stream.return_value = mock_body + + with patch( + "app.services.assessment.utils.export.get_cloud_storage", + return_value=mock_storage, + ): + result = _load_parsed_results_for_run( + session=session, run=run, batch_job=batch_job + ) + + assert result is not None + assert result[0]["row_id"] == "row_0" + + def test_s3_failure_falls_back_to_none_when_no_file_id(self) -> None: + session = MagicMock() + run = self._make_run(object_store_url="s3://bucket/file.jsonl") + batch_job = self._make_batch_job(provider_output_file_id=None) + + with patch( + "app.services.assessment.utils.export.get_cloud_storage", + side_effect=Exception("S3 down"), + ): + result = _load_parsed_results_for_run( + session=session, run=run, batch_job=batch_job + ) + + assert result is None + + def test_s3_failure_falls_back_to_provider_download(self) -> None: + session = MagicMock() + run = self._make_run(object_store_url="s3://bucket/file.jsonl") + batch_job = self._make_batch_job( + provider="openai", provider_output_file_id="file_abc" + ) + + raw = [ + { + "custom_id": "row_0", + "response": { + "status_code": 200, + "body": {"output_text": "hi", "usage": {}}, + }, + "error": None, + } + ] + with patch( + "app.services.assessment.utils.export.get_cloud_storage", + side_effect=Exception("S3 down"), + ), patch( + "app.crud.assessment.processing._get_batch_provider", + return_value=MagicMock(), + ), patch( + "app.core.batch.download_batch_results", return_value=raw + ): + result = _load_parsed_results_for_run( + session=session, run=run, batch_job=batch_job + ) + + assert result is not None + assert result[0]["row_id"] == "row_0" + + def test_s3_empty_falls_back_logs_warning(self) -> None: + session = MagicMock() + run = self._make_run(object_store_url="s3://bucket/file.jsonl") + batch_job = self._make_batch_job(provider_output_file_id=None) + + mock_body = MagicMock() + mock_body.read.return_value = b"" + mock_storage = MagicMock() + mock_storage.stream.return_value = mock_body + + with patch( + "app.services.assessment.utils.export.get_cloud_storage", + return_value=mock_storage, + ): + result = _load_parsed_results_for_run( + session=session, run=run, batch_job=batch_job + ) + + assert result is None + + +class TestLoadDatasetRowsForRun: + def _make_run(self) -> MagicMock: + run = MagicMock() + run.id = 1 + return run + + def _make_assessment(self, dataset_id: int = 1) -> MagicMock: + assessment = MagicMock() + assessment.id = 10 + assessment.dataset_id = dataset_id + return assessment + + def test_dataset_not_found_returns_empty(self) -> None: + session = MagicMock() + session.get.return_value = None + result = _load_dataset_rows_for_run( + session=session, run=self._make_run(), assessment=self._make_assessment() + ) + assert result == [] + + def test_dataset_no_url_returns_empty(self) -> None: + session = MagicMock() + dataset = MagicMock() + dataset.object_store_url = None + session.get.return_value = dataset + result = _load_dataset_rows_for_run( + session=session, run=self._make_run(), assessment=self._make_assessment() + ) + assert result == [] + + def test_exception_returns_empty(self) -> None: + session = MagicMock() + session.get.side_effect = Exception("DB error") + result = _load_dataset_rows_for_run( + session=session, run=self._make_run(), assessment=self._make_assessment() + ) + assert result == [] + + def test_valid_dataset_returns_rows(self) -> None: + session = MagicMock() + dataset = MagicMock() + dataset.object_store_url = "s3://bucket/ds.csv" + session.get.return_value = dataset + with patch( + "app.services.assessment.utils.export._load_dataset_rows", + return_value=[{"q": "hi"}], + ): + result = _load_dataset_rows_for_run( + session=session, + run=self._make_run(), + assessment=self._make_assessment(), + ) + assert result == [{"q": "hi"}] + + +class TestLoadExportRowsForRun: + def _make_run(self) -> MagicMock: + run = MagicMock() + run.id = 1 + run.assessment_id = 10 + run.batch_job_id = 5 + run.status = "completed" + run.config_id = None + run.config_version = 1 + run.object_store_url = None + run.updated_at = datetime(2024, 1, 1) + return run + + def _make_assessment(self) -> MagicMock: + assessment = MagicMock() + assessment.id = 10 + assessment.experiment_name = "exp_v1" + assessment.dataset_id = 2 + return assessment + + def test_no_batch_job_id_returns_empty(self) -> None: + session = MagicMock() + run = self._make_run() + run.batch_job_id = None + result = load_export_rows_for_run(session=session, run=run) + assert result == [] + + def test_batch_job_not_found_returns_empty(self) -> None: + session = MagicMock() + run = self._make_run() + with patch( + "app.services.assessment.utils.export.get_batch_job", return_value=None + ): + result = load_export_rows_for_run( + session=session, run=run, assessment=self._make_assessment() + ) + assert result == [] + + def test_no_parsed_results_returns_empty(self) -> None: + session = MagicMock() + run = self._make_run() + with patch( + "app.services.assessment.utils.export.get_batch_job", + return_value=MagicMock(), + ), patch( + "app.services.assessment.utils.export._load_parsed_results_for_run", + return_value=None, + ): + result = load_export_rows_for_run( + session=session, run=run, assessment=self._make_assessment() + ) + assert result == [] + + def test_parsed_results_build_export_rows(self) -> None: + session = MagicMock() + dataset = MagicMock() + dataset.name = "ds" + session.get.return_value = dataset + run = self._make_run() + parsed = [ + { + "row_id": "row_0", + "output": '{"score": 5}', + "error": None, + "usage": None, + "response_id": "r1", + } + ] + with patch( + "app.services.assessment.utils.export.get_batch_job", + return_value=MagicMock(), + ), patch( + "app.services.assessment.utils.export._load_parsed_results_for_run", + return_value=parsed, + ), patch( + "app.services.assessment.utils.export._load_dataset_rows_for_run", + return_value=[], + ): + result = load_export_rows_for_run( + session=session, run=run, assessment=self._make_assessment() + ) + assert len(result) == 1 + assert result[0].result_status == "passed" + assert result[0].row_id == "row_0" + + def test_error_result_sets_failed_status(self) -> None: + session = MagicMock() + dataset = MagicMock() + dataset.name = "ds" + session.get.return_value = dataset + run = self._make_run() + parsed = [ + { + "row_id": "row_0", + "output": None, + "error": "timeout", + "usage": None, + "response_id": None, + } + ] + with patch( + "app.services.assessment.utils.export.get_batch_job", + return_value=MagicMock(), + ), patch( + "app.services.assessment.utils.export._load_parsed_results_for_run", + return_value=parsed, + ), patch( + "app.services.assessment.utils.export._load_dataset_rows_for_run", + return_value=[], + ): + result = load_export_rows_for_run( + session=session, run=run, assessment=self._make_assessment() + ) + assert result[0].result_status == "failed" + + def test_input_data_correlated_via_row_id(self) -> None: + session = MagicMock() + dataset = MagicMock() + dataset.name = "ds" + session.get.return_value = dataset + run = self._make_run() + parsed = [ + { + "row_id": "row_1", + "output": "x", + "error": None, + "usage": None, + "response_id": None, + } + ] + dataset_rows = [{"q": "first"}, {"q": "second"}] + with patch( + "app.services.assessment.utils.export.get_batch_job", + return_value=MagicMock(), + ), patch( + "app.services.assessment.utils.export._load_parsed_results_for_run", + return_value=parsed, + ), patch( + "app.services.assessment.utils.export._load_dataset_rows_for_run", + return_value=dataset_rows, + ): + result = load_export_rows_for_run( + session=session, run=run, assessment=self._make_assessment() + ) + assert result[0].input_data == {"q": "second"} diff --git a/backend/app/tests/assessment/test_mappers.py b/backend/app/tests/assessment/test_mappers.py new file mode 100644 index 000000000..dd4579868 --- /dev/null +++ b/backend/app/tests/assessment/test_mappers.py @@ -0,0 +1,272 @@ +"""Tests for assessment/mappers.py.""" + +from unittest.mock import MagicMock, patch + +from app.services.assessment.mappers import ( + _ensure_openai_strict_schema, + _strip_additional_properties, + map_kaapi_to_google_params, + map_kaapi_to_openai_params, + normalize_llm_text, +) + + +class TestNormalizeLlmText: + def test_non_string_returns_as_is(self) -> None: + assert normalize_llm_text(None) is None # type: ignore[arg-type] + assert normalize_llm_text(42) == 42 # type: ignore[arg-type] + + def test_empty_string_returns_as_is(self) -> None: + assert normalize_llm_text("") == "" + + def test_escaped_newline_replaced(self) -> None: + assert normalize_llm_text("line1\\nline2") == "line1\nline2" + + def test_escaped_tab_replaced(self) -> None: + assert normalize_llm_text("col1\\tcol2") == "col1\tcol2" + + def test_escaped_quote_replaced(self) -> None: + assert normalize_llm_text('\\"quoted\\"') == '"quoted"' + + def test_double_backslash_collapsed(self) -> None: + assert normalize_llm_text("a\\\\b") == "a\\b" + + def test_nfc_normalization_applied(self) -> None: + # Combining character sequence → precomposed form + import unicodedata + + text = "é" # e + combining acute accent + result = normalize_llm_text(text) + assert result == unicodedata.normalize("NFC", text) + + +class TestEnsureOpenAIStrictSchema: + def test_object_type_gets_additional_properties_false(self) -> None: + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + result = _ensure_openai_strict_schema(schema) + assert result["additionalProperties"] is False + + def test_nested_object_also_gets_flag(self) -> None: + schema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"city": {"type": "string"}}, + } + }, + } + result = _ensure_openai_strict_schema(schema) + assert result["properties"]["address"]["additionalProperties"] is False + + def test_array_items_processed_recursively(self) -> None: + schema = { + "type": "array", + "items": {"type": "object", "properties": {"x": {"type": "number"}}}, + } + result = _ensure_openai_strict_schema(schema) + assert result["items"]["additionalProperties"] is False + + def test_non_object_type_not_modified(self) -> None: + schema = {"type": "string"} + result = _ensure_openai_strict_schema(schema) + assert "additionalProperties" not in result + + +class TestStripAdditionalProperties: + def test_removes_additional_properties(self) -> None: + schema = {"type": "object", "additionalProperties": False, "properties": {}} + result = _strip_additional_properties(schema) + assert "additionalProperties" not in result + + def test_nested_removal(self) -> None: + schema = { + "type": "object", + "additionalProperties": False, + "properties": {"child": {"type": "object", "additionalProperties": False}}, + } + result = _strip_additional_properties(schema) + assert "additionalProperties" not in result["properties"]["child"] + + def test_array_items_processed(self) -> None: + schema = { + "type": "array", + "items": {"type": "object", "additionalProperties": False}, + } + result = _strip_additional_properties(schema) + assert "additionalProperties" not in result["items"] + + +class TestMapKaapiToOpenAIParams: + def _call(self, params: dict, supports_reasoning: bool = False): + with patch( + "app.services.assessment.mappers.is_reasoning_model", + return_value=supports_reasoning, + ): + return map_kaapi_to_openai_params(session=MagicMock(), kaapi_params=params) + + def test_basic_model_passed_through(self) -> None: + result, warnings = self._call({"model": "gpt-4o"}) + assert result["model"] == "gpt-4o" + assert warnings == [] + + def test_instructions_normalized_and_set(self) -> None: + result, _ = self._call({"model": "gpt-4o", "instructions": "Be helpful\\n"}) + assert result["instructions"] == "Be helpful\n" + + def test_temperature_set_for_non_reasoning_model(self) -> None: + result, _ = self._call({"model": "gpt-4o", "temperature": 0.7}) + assert result["temperature"] == 0.7 + + def test_temperature_suppressed_for_reasoning_model(self) -> None: + result, warnings = self._call( + {"model": "o1", "temperature": 0.5}, supports_reasoning=True + ) + assert "temperature" not in result + assert any("temperature" in w for w in warnings) + + def test_top_p_suppressed_for_reasoning_model(self) -> None: + result, warnings = self._call( + {"model": "o1", "top_p": 0.9}, supports_reasoning=True + ) + assert "top_p" not in result + assert any("top_p" in w for w in warnings) + + def test_effort_set_for_reasoning_model(self) -> None: + result, _ = self._call( + {"model": "o1", "effort": "high"}, supports_reasoning=True + ) + assert result["reasoning"]["effort"] == "high" + + def test_effort_suppressed_for_non_reasoning_model(self) -> None: + result, warnings = self._call({"model": "gpt-4o", "effort": "high"}) + assert "reasoning" not in result + assert any("effort" in w for w in warnings) + + def test_output_schema_sets_text_format(self) -> None: + schema = {"type": "object", "properties": {"score": {"type": "integer"}}} + result, _ = self._call({"model": "gpt-4o", "output_schema": schema}) + assert result["text"]["format"]["type"] == "json_schema" + assert result["text"]["format"]["strict"] is True + + def test_response_format_text_not_set(self) -> None: + result, _ = self._call({"model": "gpt-4o", "response_format": "text"}) + assert "text" not in result + + def test_knowledge_base_ids_sets_tools(self) -> None: + result, _ = self._call( + {"model": "gpt-4o", "knowledge_base_ids": ["vs_123"], "max_num_results": 10} + ) + assert result["tools"][0]["type"] == "file_search" + assert result["tools"][0]["max_num_results"] == 10 + + def test_summary_null_string_sets_none(self) -> None: + result, _ = self._call( + {"model": "o1", "summary": "null"}, supports_reasoning=True + ) + assert result["reasoning"]["summary"] is None + + def test_top_p_set_for_non_reasoning_model(self) -> None: + result, _ = self._call({"model": "gpt-4o", "top_p": 0.85}) + assert result["top_p"] == 0.85 + + +class TestMapKaapiToGoogleParams: + def _call(self, params: dict): + mock_schema = MagicMock() + mock_schema.model_dump.return_value = {} + with patch( + "app.services.assessment.mappers.genai_transformers.t_schema", + return_value=mock_schema, + ): + return map_kaapi_to_google_params(params) + + def test_missing_model_returns_warning(self) -> None: + result, warnings = map_kaapi_to_google_params({}) + assert result == {} + assert any("model" in w for w in warnings) + + def test_basic_model_set(self) -> None: + result, _ = self._call({"model": "gemini-1.5-pro"}) + assert result["model"] == "gemini-1.5-pro" + + def test_temperature_set(self) -> None: + result, _ = self._call({"model": "gemini-1.5-pro", "temperature": 0.3}) + assert result["temperature"] == 0.3 + + def test_top_p_set(self) -> None: + result, _ = self._call({"model": "gemini-1.5-pro", "top_p": 0.8}) + assert result["top_p"] == 0.8 + + def test_thinking_level_set(self) -> None: + result, _ = self._call( + {"model": "gemini-2.0-flash-thinking", "thinking_level": "high"} + ) + assert result["thinking_config"] == {"thinking_level": "high"} + + def test_knowledge_base_ids_warns(self) -> None: + result, warnings = self._call( + {"model": "gemini-1.5-pro", "knowledge_base_ids": ["kb_1"]} + ) + assert any("knowledge_base_ids" in w for w in warnings) + + def test_output_schema_set(self) -> None: + schema = {"type": "object", "properties": {"score": {"type": "integer"}}} + result, _ = self._call({"model": "gemini-1.5-pro", "output_schema": schema}) + assert "output_schema" in result + + def test_instructions_normalized(self) -> None: + result, _ = self._call( + {"model": "gemini-1.5-pro", "instructions": "Be kind\\n"} + ) + assert result["instructions"] == "Be kind\n" + + def test_max_output_tokens_set(self) -> None: + result, _ = self._call({"model": "gemini-1.5-pro", "max_output_tokens": 512}) + assert result["max_output_tokens"] == 512 + + def test_reasoning_set(self) -> None: + result, _ = self._call( + {"model": "gemini-2.0-flash-thinking", "reasoning": "high"} + ) + assert result["reasoning"] == "high" + + +class TestConvertJsonSchemaToGoogle: + def _call(self, schema: dict) -> dict: + mock_result = MagicMock() + mock_result.model_dump.return_value = {"properties": {"score": {}}} + with patch( + "app.services.assessment.mappers.genai_transformers.t_schema", + return_value=mock_result, + ): + from app.services.assessment.mappers import _convert_json_schema_to_google + + return _convert_json_schema_to_google(schema) + + def test_property_ordering_added_from_required(self) -> None: + schema = { + "type": "object", + "required": ["score", "reason"], + "properties": {"score": {}, "reason": {}}, + } + result = self._call(schema) + assert result["propertyOrdering"] == ["score", "reason"] + + def test_property_ordering_falls_back_to_keys(self) -> None: + schema = {"type": "object", "properties": {"a": {}, "b": {}}} + result = self._call(schema) + assert "propertyOrdering" in result + + +class TestOpenAIResponseFormat: + def _call(self, params: dict): + with patch( + "app.services.assessment.mappers.is_reasoning_model", + return_value=False, + ): + return map_kaapi_to_openai_params(session=MagicMock(), kaapi_params=params) + + def test_non_text_response_format_sets_text_field(self) -> None: + result, _ = self._call({"model": "gpt-4o", "response_format": "json_object"}) + assert result["text"]["format"]["type"] == "json_object" diff --git a/backend/app/tests/assessment/test_parsing.py b/backend/app/tests/assessment/test_parsing.py new file mode 100644 index 000000000..28970dc5b --- /dev/null +++ b/backend/app/tests/assessment/test_parsing.py @@ -0,0 +1,78 @@ +"""Tests for assessment/utils/parsing.py.""" + +import json + +from app.services.assessment.utils.parsing import parse_stored_results, usage_totals + + +class TestParseStoredResults: + def test_empty_string_returns_empty_list(self) -> None: + assert parse_stored_results("") == [] + + def test_whitespace_only_returns_empty_list(self) -> None: + assert parse_stored_results(" \n ") == [] + + def test_json_array_format(self) -> None: + data = [ + {"row_id": "row_0", "output": "hello"}, + {"row_id": "row_1", "output": "world"}, + ] + result = parse_stored_results(json.dumps(data)) + assert result == data + + def test_jsonl_single_object_parsed_as_one_entry(self) -> None: + # Does NOT start with '[', so treated as a single JSONL line + result = parse_stored_results(json.dumps({"key": "value"})) + assert result == [{"key": "value"}] + + def test_jsonl_format(self) -> None: + lines = [{"row_id": "row_0", "output": "a"}, {"row_id": "row_1", "output": "b"}] + raw = "\n".join(json.dumps(line) for line in lines) + result = parse_stored_results(raw) + assert result == lines + + def test_jsonl_skips_blank_lines(self) -> None: + line = {"row_id": "row_0", "output": "x"} + raw = f"\n{json.dumps(line)}\n\n" + result = parse_stored_results(raw) + assert result == [line] + + def test_jsonl_single_line(self) -> None: + line = {"k": "v"} + result = parse_stored_results(json.dumps(line)) + assert result == [line] + + +class TestUsageTotals: + def test_non_dict_returns_nones(self) -> None: + assert usage_totals(None) == (None, None, None) + assert usage_totals("string") == (None, None, None) + assert usage_totals(42) == (None, None, None) + + def test_openai_style_keys(self) -> None: + usage = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + assert usage_totals(usage) == (10, 20, 30) + + def test_anthropic_style_keys(self) -> None: + usage = {"input_tokens": 5, "output_tokens": 15} + assert usage_totals(usage) == (5, 15, 20) + + def test_total_tokens_computed_when_missing(self) -> None: + usage = {"input_tokens": 3, "output_tokens": 7} + inp, out, total = usage_totals(usage) + assert total == 10 + + def test_explicit_total_tokens_not_overridden(self) -> None: + usage = {"input_tokens": 3, "output_tokens": 7, "total_tokens": 99} + _, _, total = usage_totals(usage) + assert total == 99 + + def test_missing_tokens_return_none(self) -> None: + assert usage_totals({}) == (None, None, None) + + def test_partial_tokens_no_total_computed(self) -> None: + usage = {"input_tokens": 5} + inp, out, total = usage_totals(usage) + assert inp == 5 + assert out is None + assert total is None diff --git a/backend/app/tests/assessment/test_processing.py b/backend/app/tests/assessment/test_processing.py new file mode 100644 index 000000000..958ab3019 --- /dev/null +++ b/backend/app/tests/assessment/test_processing.py @@ -0,0 +1,436 @@ +"""Tests for assessment/processing.py pure functions.""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.crud.assessment.processing import ( + _get_batch_provider, + _sanitize_json_output, + check_and_process_assessment, + parse_assessment_output, + poll_all_pending_assessments, +) + + +class TestSanitizeJsonOutput: + def test_valid_json_unchanged(self) -> None: + raw = '{"key": "value"}' + assert _sanitize_json_output(raw) == raw + + def test_bare_newline_inside_string_escaped(self) -> None: + raw = '{"text": "line1\nline2"}' + result = _sanitize_json_output(raw) + parsed = json.loads(result) + assert parsed["text"] == "line1\nline2" + + def test_bare_tab_inside_string_escaped(self) -> None: + raw = '{"text": "col1\tcol2"}' + result = _sanitize_json_output(raw) + parsed = json.loads(result) + assert parsed["text"] == "col1\tcol2" + + def test_bare_carriage_return_escaped(self) -> None: + raw = '{"text": "a\rb"}' + result = _sanitize_json_output(raw) + parsed = json.loads(result) + assert parsed["text"] == "a\rb" + + def test_escaped_chars_outside_string_not_changed(self) -> None: + raw = '{"a": 1, "b": 2}' + assert _sanitize_json_output(raw) == raw + + def test_already_escaped_newline_not_double_escaped(self) -> None: + raw = '{"text": "line1\\nline2"}' + result = _sanitize_json_output(raw) + assert result == raw + + def test_empty_string(self) -> None: + assert _sanitize_json_output("") == "" + + +class TestParseAssessmentOutputOpenAI: + def _make_result(self, custom_id: str, output_text: str) -> dict: + return { + "custom_id": custom_id, + "response": { + "status_code": 200, + "body": { + "id": "resp_abc", + "output_text": output_text, + "usage": {"input_tokens": 10, "output_tokens": 5}, + }, + }, + "error": None, + } + + def test_successful_result_parsed(self) -> None: + raw = [self._make_result("row_0", "some output")] + results = parse_assessment_output(raw, "openai") + assert len(results) == 1 + assert results[0]["row_id"] == "row_0" + assert results[0]["output"] == "some output" + assert results[0]["error"] is None + + def test_error_in_result(self) -> None: + raw = [ + { + "custom_id": "row_1", + "response": {"status_code": 200, "body": {}}, + "error": {"message": "rate limit exceeded"}, + } + ] + results = parse_assessment_output(raw, "openai") + assert results[0]["error"] == "rate limit exceeded" + assert results[0]["output"] is None + + def test_4xx_status_code_is_error(self) -> None: + raw = [ + { + "custom_id": "row_2", + "response": { + "status_code": 400, + "body": {"error": {"message": "invalid request"}}, + }, + "error": None, + } + ] + results = parse_assessment_output(raw, "openai") + assert results[0]["error"] == "invalid request" + + def test_json_output_text_re_serialized(self) -> None: + payload = {"score": 4, "reason": "good"} + raw = [self._make_result("row_0", json.dumps(payload))] + results = parse_assessment_output(raw, "openai") + re_parsed = json.loads(results[0]["output"]) + assert re_parsed["score"] == 4 + + def test_output_text_from_output_list(self) -> None: + raw = [ + { + "custom_id": "row_0", + "response": { + "status_code": 200, + "body": { + "id": "resp_abc", + "output": [ + { + "type": "message", + "content": [ + {"type": "output_text", "text": "hello world"} + ], + } + ], + }, + }, + "error": None, + } + ] + results = parse_assessment_output(raw, "openai") + assert results[0]["output"] == "hello world" + + def test_empty_output_text_sets_error(self) -> None: + raw = [self._make_result("row_0", "")] + results = parse_assessment_output(raw, "openai") + assert results[0]["error"] == "Empty response output" + + def test_sanitize_fallback_on_bad_json(self) -> None: + # JSON with literal newline inside a string value + bad_json = '{"text": "line1\nline2"}' + raw = [self._make_result("row_0", bad_json)] + results = parse_assessment_output(raw, "openai") + # Should not raise; output should be a valid JSON string + assert results[0]["output"] is not None + + def test_multiple_results(self) -> None: + raw = [ + self._make_result("row_0", "out0"), + self._make_result("row_1", "out1"), + ] + results = parse_assessment_output(raw, "openai") + assert len(results) == 2 + assert results[1]["row_id"] == "row_1" + + def test_openai_native_provider_accepted(self) -> None: + raw = [self._make_result("row_0", "out")] + results = parse_assessment_output(raw, "openai-native") + assert results[0]["output"] == "out" + + +class TestParseAssessmentOutputGoogle: + def test_successful_google_result(self) -> None: + from unittest.mock import patch + + with patch( + "app.crud.assessment.processing.extract_text_from_response_dict", + return_value="gemini output", + ): + raw = [ + {"key": "row_0", "response": {"text": "gemini output"}, "error": None} + ] + results = parse_assessment_output(raw, "google") + + assert results[0]["row_id"] == "row_0" + assert results[0]["output"] == "gemini output" + assert results[0]["error"] is None + + def test_google_error_result(self) -> None: + raw = [{"key": "row_0", "response": None, "error": "quota exceeded"}] + results = parse_assessment_output(raw, "google") + assert results[0]["error"] == "quota exceeded" + assert results[0]["output"] is None + + def test_google_empty_response(self) -> None: + raw = [{"key": "row_0", "response": None, "error": None}] + results = parse_assessment_output(raw, "google") + assert results[0]["error"] == "Empty response" + + def test_google_empty_text_from_response(self) -> None: + from unittest.mock import patch + + with patch( + "app.crud.assessment.processing.extract_text_from_response_dict", + return_value="", + ): + raw = [{"key": "row_0", "response": {"candidates": []}, "error": None}] + results = parse_assessment_output(raw, "google") + assert results[0]["output"] is None + assert results[0]["error"] == "Empty response output" + + def test_google_native_provider_accepted(self) -> None: + from unittest.mock import patch + + with patch( + "app.crud.assessment.processing.extract_text_from_response_dict", + return_value="out", + ): + raw = [{"key": "row_0", "response": {"x": 1}, "error": None}] + results = parse_assessment_output(raw, "google-native") + assert results[0]["output"] == "out" + + +class TestGetBatchProvider: + def test_unsupported_provider_raises(self) -> None: + session = MagicMock() + with pytest.raises(ValueError, match="Unsupported provider"): + _get_batch_provider( + session=session, + provider_name="anthropic", + organization_id=1, + project_id=1, + ) + + def test_openai_provider_returned(self) -> None: + session = MagicMock() + mock_client = MagicMock() + with patch( + "app.crud.assessment.processing.get_openai_client", return_value=mock_client + ), patch("app.crud.assessment.processing.OpenAIBatchProvider") as mock_cls: + _get_batch_provider( + session=session, + provider_name="openai", + organization_id=1, + project_id=1, + ) + mock_cls.assert_called_once_with(client=mock_client) + + def test_google_provider_returned(self) -> None: + session = MagicMock() + mock_gemini = MagicMock() + with patch("app.crud.assessment.processing.GeminiClient") as mock_cls, patch( + "app.crud.assessment.processing.GeminiBatchProvider" + ) as mock_batch_cls: + mock_cls.from_credentials.return_value = mock_gemini + _get_batch_provider( + session=session, + provider_name="google", + organization_id=1, + project_id=1, + ) + mock_batch_cls.assert_called_once_with(client=mock_gemini.client) + + +class TestPollAllPendingAssessments: + @pytest.mark.asyncio + async def test_delegates_to_cron(self) -> None: + session = MagicMock() + expected = {"processed": 2, "failed": 0} + with patch( + "app.crud.assessment.cron.poll_all_pending_assessment_evaluations", + new=AsyncMock(return_value=expected), + ): + result = await poll_all_pending_assessments(session=session) + assert result == expected + + +class TestCheckAndProcessAssessment: + def _make_run(self) -> MagicMock: + run = MagicMock() + run.id = 1 + run.batch_job_id = 99 + run.status = "processing" + run.assessment_id = 10 + run.organization_id = 1 + run.project_id = 1 + run.run_name = "exp" + return run + + @pytest.mark.asyncio + async def test_completed_with_no_output_file_and_failed_counts(self) -> None: + session = MagicMock() + run = self._make_run() + batch_job = MagicMock() + batch_job.provider = "openai" + batch_job.provider_status = "completed" + batch_job.provider_output_file_id = None + batch_job.id = 99 + + with patch( + "app.crud.assessment.processing.get_batch_job", return_value=batch_job + ), patch( + "app.crud.assessment.processing._get_batch_provider", + return_value=MagicMock(), + ), patch( + "app.crud.assessment.processing.poll_batch_status", + return_value={ + "request_counts": {"failed": 3, "completed": 0, "total": 3}, + "error_file_id": "err-1", + }, + ), patch( + "app.crud.assessment.processing.update_assessment_run_status" + ), patch( + "app.crud.assessment.processing.recompute_assessment_status" + ): + result = await check_and_process_assessment(run=run, session=session) + + assert result["action"] == "failed" + assert result["current_status"] == "failed" + + @pytest.mark.asyncio + async def test_completed_with_no_output_file_not_ready(self) -> None: + session = MagicMock() + run = self._make_run() + batch_job = MagicMock() + batch_job.provider = "openai" + batch_job.provider_status = "completed" + batch_job.provider_output_file_id = None + batch_job.id = 99 + + with patch( + "app.crud.assessment.processing.get_batch_job", return_value=batch_job + ), patch( + "app.crud.assessment.processing._get_batch_provider", + return_value=MagicMock(), + ), patch( + "app.crud.assessment.processing.poll_batch_status", + return_value={"request_counts": {"failed": 0, "completed": 1, "total": 1}}, + ): + result = await check_and_process_assessment(run=run, session=session) + + assert result["action"] == "no_change" + + @pytest.mark.asyncio + async def test_completed_with_output_file_processes_results(self) -> None: + session = MagicMock() + run = self._make_run() + batch_job = MagicMock() + batch_job.provider = "openai" + batch_job.provider_status = "completed" + batch_job.provider_output_file_id = "file-1" + batch_job.id = 99 + + with patch( + "app.crud.assessment.processing.get_batch_job", return_value=batch_job + ), patch( + "app.crud.assessment.processing._get_batch_provider", + return_value=MagicMock(), + ), patch( + "app.crud.assessment.processing.poll_batch_status", + return_value={}, + ), patch( + "app.crud.assessment.processing.download_batch_results", + return_value=[{"custom_id": "row_0"}], + ), patch( + "app.crud.assessment.processing.upload_batch_results_to_object_store", + return_value="s3://results", + ), patch( + "app.crud.assessment.processing.parse_assessment_output", + return_value=[{"row_id": "row_0", "error": None}], + ), patch( + "app.crud.assessment.processing.update_assessment_run_status" + ), patch( + "app.crud.assessment.processing.recompute_assessment_status" + ): + result = await check_and_process_assessment(run=run, session=session) + + assert result["action"] == "processed" + + @pytest.mark.asyncio + async def test_terminal_provider_status_marks_failed(self) -> None: + session = MagicMock() + run = self._make_run() + batch_job = MagicMock() + batch_job.provider = "openai" + batch_job.provider_status = "failed" + batch_job.error_message = "provider failed" + + with patch( + "app.crud.assessment.processing.get_batch_job", return_value=batch_job + ), patch( + "app.crud.assessment.processing._get_batch_provider", + return_value=MagicMock(), + ), patch( + "app.crud.assessment.processing.poll_batch_status", return_value={} + ), patch( + "app.crud.assessment.processing.update_assessment_run_status" + ), patch( + "app.crud.assessment.processing.recompute_assessment_status" + ): + result = await check_and_process_assessment(run=run, session=session) + + assert result["action"] == "failed" + assert result["provider_status"] == "failed" + + @pytest.mark.asyncio + async def test_still_processing_returns_no_change(self) -> None: + session = MagicMock() + run = self._make_run() + batch_job = MagicMock() + batch_job.provider = "openai" + batch_job.provider_status = "in_progress" + + with patch( + "app.crud.assessment.processing.get_batch_job", return_value=batch_job + ), patch( + "app.crud.assessment.processing._get_batch_provider", + return_value=MagicMock(), + ), patch( + "app.crud.assessment.processing.poll_batch_status", return_value={} + ): + result = await check_and_process_assessment(run=run, session=session) + + assert result["action"] == "no_change" + + @pytest.mark.asyncio + async def test_exception_path_marks_failed(self) -> None: + session = MagicMock() + run = self._make_run() + run.batch_job_id = None + + with patch( + "app.crud.assessment.processing.update_assessment_run_status" + ) as update_run, patch( + "app.crud.assessment.processing.recompute_assessment_status" + ): + result = await check_and_process_assessment(run=run, session=session) + + assert result["action"] == "failed" + assert result["provider_status"] == "unknown" + assert result["error"] == "Assessment run 1 has no batch_job_id" + update_run.assert_called_once_with( + session=session, + run=run, + status="failed", + error_message="Assessment run 1 has no batch_job_id", + ) diff --git a/backend/app/tests/assessment/test_routes.py b/backend/app/tests/assessment/test_routes.py new file mode 100644 index 000000000..0271f8a2f --- /dev/null +++ b/backend/app/tests/assessment/test_routes.py @@ -0,0 +1,488 @@ +"""Tests for assessment route endpoints (split into datasets/assessments/runs).""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest +from fastapi import HTTPException +from fastapi.responses import StreamingResponse + +from app.api.routes.assessment.assessments import ( + export_assessment_results, + get_assessment, + list_assessments, + retry_assessment, +) +from app.api.routes.assessment.datasets import ( + _dataset_to_response, + delete_dataset, + get_dataset, + list_datasets, +) +from app.api.routes.assessment.runs import ( + create_assessment_runs, + export_assessment_run_results, + get_assessment_run, + list_assessment_runs, + retry_assessment_run, +) +from app.models.assessment import AssessmentCreate, AssessmentExportRow + +# ─── Fixtures ──────────────────────────────────────────────────────────────── + + +def _auth_context() -> SimpleNamespace: + return SimpleNamespace( + organization_=SimpleNamespace(id=1), + project_=SimpleNamespace(id=1), + ) + + +def _dataset() -> SimpleNamespace: + return SimpleNamespace( + id=7, + name="ds", + description="d", + dataset_metadata={"total_items_count": 2, "file_extension": ".csv"}, + object_store_url="s3://x", + ) + + +def _assessment() -> SimpleNamespace: + return SimpleNamespace( + id=10, + experiment_name="exp", + dataset_id=7, + status="processing", + organization_id=1, + project_id=1, + inserted_at=datetime(2024, 1, 1), + updated_at=datetime(2024, 1, 1), + ) + + +def _run() -> SimpleNamespace: + return SimpleNamespace( + id=22, + assessment_id=10, + config_id=UUID("00000000-0000-0000-0000-000000000001"), + config_version=1, + status="completed", + total_items=1, + error_message=None, + input=None, + batch_job_id=None, + inserted_at=datetime(2024, 1, 1), + updated_at=datetime(2024, 1, 1), + ) + + +def _row(run_id: int = 22) -> AssessmentExportRow: + return AssessmentExportRow( + assessment_id=10, + experiment_name="exp", + dataset_id=7, + dataset_name="ds", + run_id=run_id, + run_name="exp", + run_status="completed", + config_id=None, + config_version=1, + row_id="row_0", + result_status="passed", + input_data={"q": "x"}, + output='{"score":1}', + error=None, + response_id="r", + input_tokens=1, + output_tokens=1, + total_tokens=2, + updated_at=datetime(2024, 1, 1), + ) + + +# ─── Helpers ───────────────────────────────────────────────────────────────── + + +class TestRouteHelpers: + def test_dataset_to_response(self) -> None: + resp = _dataset_to_response(_dataset(), signed_url="signed") + assert resp.dataset_id == 7 + assert resp.signed_url == "signed" + + +# ─── Datasets ──────────────────────────────────────────────────────────────── + + +class TestDatasetRoutes: + def test_list_datasets(self) -> None: + with patch( + "app.api.routes.assessment.datasets.list_assessment_datasets", + return_value=[_dataset()], + ): + resp = list_datasets(session=MagicMock(), auth_context=_auth_context()) + assert resp.success is True + assert len(resp.data or []) == 1 + + def test_get_dataset_not_found(self) -> None: + with patch( + "app.api.routes.assessment.datasets.get_assessment_dataset_by_id", + side_effect=HTTPException( + status_code=404, + detail="Dataset 1 not found or not accessible", + ), + ): + with pytest.raises(HTTPException, match="not found"): + get_dataset(1, session=MagicMock(), auth_context=_auth_context()) + + def test_get_dataset_with_signed_url(self) -> None: + storage = MagicMock() + storage.get_signed_url.return_value = "signed-url" + with patch( + "app.api.routes.assessment.datasets.get_assessment_dataset_by_id", + return_value=_dataset(), + ), patch( + "app.api.routes.assessment.datasets.get_cloud_storage", return_value=storage + ): + resp = get_dataset( + 7, + session=MagicMock(), + auth_context=_auth_context(), + include_signed_url=True, + ) + assert resp.success is True + assert resp.data is not None + assert resp.data.signed_url == "signed-url" + + def test_delete_dataset_success_and_error(self) -> None: + with patch( + "app.api.routes.assessment.datasets.get_assessment_dataset_by_id", + return_value=_dataset(), + ), patch( + "app.api.routes.assessment.datasets.delete_assessment_dataset", + return_value=None, + ): + resp = delete_dataset(7, session=MagicMock(), auth_context=_auth_context()) + assert resp.success is True + + with patch( + "app.api.routes.assessment.datasets.get_assessment_dataset_by_id", + return_value=_dataset(), + ), patch( + "app.api.routes.assessment.datasets.delete_assessment_dataset", + return_value="cannot delete", + ): + with pytest.raises(HTTPException, match="cannot delete"): + delete_dataset(7, session=MagicMock(), auth_context=_auth_context()) + + +# ─── Runs — POST + retry ───────────────────────────────────────────────────── + + +class TestRunRoutes: + def test_create_assessment_runs(self) -> None: + request = AssessmentCreate( + experiment_name="exp", + dataset_id=7, + configs=[ + { + "config_id": "00000000-0000-0000-0000-000000000001", + "config_version": 1, + } + ], + ) + result = SimpleNamespace( + assessment_id=10, + experiment_name="exp", + dataset_id=7, + dataset_name="ds", + num_configs=1, + runs=[], + ) + with patch( + "app.api.routes.assessment.runs.start_assessment", return_value=result + ): + resp = create_assessment_runs( + request, session=MagicMock(), auth_context=_auth_context() + ) + assert resp.success is True + + def test_retry_endpoints(self) -> None: + result = SimpleNamespace( + assessment_id=10, + experiment_name="exp", + dataset_id=7, + dataset_name="ds", + num_configs=1, + runs=[], + ) + with patch( + "app.api.routes.assessment.assessments.get_assessment_by_id", + return_value=_assessment(), + ), patch( + "app.api.routes.assessment.assessments.retry_assessment_service", + return_value=result, + ): + resp = retry_assessment( + 10, session=MagicMock(), auth_context=_auth_context() + ) + assert resp.success is True + + with patch( + "app.api.routes.assessment.runs.get_run_by_id", + return_value=_run(), + ), patch( + "app.api.routes.assessment.runs.retry_run", + return_value=result, + ): + resp = retry_assessment_run( + 22, session=MagicMock(), auth_context=_auth_context() + ) + assert resp.success is True + + +# ─── Assessments (parents) — list/get + Runs list/get ─────────────────────── + + +class TestAssessmentAndRunRoutes: + def test_list_and_get_assessments(self) -> None: + public_stub = MagicMock() + with patch( + "app.api.routes.assessment.assessments.list_assessments_crud", + return_value=[_assessment()], + ), patch( + "app.api.routes.assessment.assessments._build_assessment_public", + return_value=public_stub, + ): + resp = list_assessments( + session=MagicMock(), + auth_context=_auth_context(), + ) + assert resp.success is True + assert len(resp.data or []) == 1 + + with patch( + "app.api.routes.assessment.assessments.get_assessment_by_id", + return_value=_assessment(), + ), patch( + "app.api.routes.assessment.assessments._build_assessment_public", + return_value=public_stub, + ): + resp = get_assessment( + 10, + session=MagicMock(), + auth_context=_auth_context(), + ) + assert resp.success is True + + with patch( + "app.api.routes.assessment.assessments.get_assessment_by_id", + side_effect=HTTPException( + status_code=404, detail="Assessment 10 not found or not accessible" + ), + ): + with pytest.raises(HTTPException, match="not found"): + get_assessment(10, session=MagicMock(), auth_context=_auth_context()) + + def test_list_and_get_runs(self) -> None: + public_stub = MagicMock() + with patch( + "app.api.routes.assessment.runs.list_runs", + return_value=[_run()], + ), patch( + "app.api.routes.assessment.runs._build_run_public", + return_value=public_stub, + ): + resp = list_assessment_runs( + session=MagicMock(), auth_context=_auth_context() + ) + assert resp.success is True + + with patch( + "app.api.routes.assessment.runs.get_run_by_id", + return_value=_run(), + ), patch( + "app.api.routes.assessment.runs._build_run_public", + return_value=public_stub, + ): + resp = get_assessment_run( + 22, session=MagicMock(), auth_context=_auth_context() + ) + assert resp.success is True + + with patch( + "app.api.routes.assessment.runs.get_run_by_id", + side_effect=HTTPException( + status_code=404, detail="Assessment run 22 not found or not accessible" + ), + ): + with pytest.raises(HTTPException, match="not found"): + get_assessment_run( + 22, session=MagicMock(), auth_context=_auth_context() + ) + + +# ─── Export endpoints ──────────────────────────────────────────────────────── + + +class TestExportRoutes: + def test_export_assessment_results_delegates_to_util(self) -> None: + """Parent export route delegates JSON/single-file/ZIP packaging to utils.""" + with patch( + "app.api.routes.assessment.assessments.get_assessment_by_id", + return_value=_assessment(), + ), patch( + "app.api.routes.assessment.assessments.get_assessment_runs_for_assessment", + return_value=[_run()], + ), patch( + "app.api.routes.assessment.assessments.build_assessment_results_response", + return_value="stub-response", + ) as build: + result = export_assessment_results( + 10, + session=MagicMock(), + auth_context=_auth_context(), + export_format="json", + ) + assert result == "stub-response" + assert build.call_args.kwargs["export_format"] == "json" + + def test_export_assessment_run_results_json_and_file(self) -> None: + run = _run() + with patch( + "app.api.routes.assessment.runs.get_run_by_id", + return_value=run, + ), patch( + "app.api.routes.assessment.runs.get_assessment_by_id", + return_value=_assessment(), + ), patch( + "app.api.routes.assessment.runs.load_export_rows_for_run", + return_value=[_row()], + ), patch( + "app.api.routes.assessment.runs.sort_export_rows", + side_effect=lambda rows: rows, + ), patch( + "app.api.routes.assessment.runs.build_json_export_rows", + return_value=[{"x": 1}], + ): + json_resp = export_assessment_run_results( + 22, + session=MagicMock(), + auth_context=_auth_context(), + export_format="json", + ) + assert json_resp.success is True + + with patch( + "app.api.routes.assessment.runs.get_run_by_id", + return_value=run, + ), patch( + "app.api.routes.assessment.runs.get_assessment_by_id", + return_value=_assessment(), + ), patch( + "app.api.routes.assessment.runs.load_export_rows_for_run", + return_value=[_row()], + ), patch( + "app.api.routes.assessment.runs.sort_export_rows", + side_effect=lambda rows: rows, + ), patch( + "app.api.routes.assessment.runs.build_export_response", + return_value=StreamingResponse(iter([b"x"])), + ): + file_resp = export_assessment_run_results( + 22, + session=MagicMock(), + auth_context=_auth_context(), + export_format="csv", + ) + assert isinstance(file_resp, StreamingResponse) + + def test_export_not_found(self) -> None: + with patch( + "app.api.routes.assessment.assessments.get_assessment_by_id", + side_effect=HTTPException( + status_code=404, detail="Assessment 10 not found or not accessible" + ), + ): + with pytest.raises(HTTPException, match="not found"): + export_assessment_results( + 10, + session=MagicMock(), + auth_context=_auth_context(), + ) + with patch( + "app.api.routes.assessment.runs.get_run_by_id", + side_effect=HTTPException( + status_code=404, detail="Assessment run 22 not found or not accessible" + ), + ): + with pytest.raises(HTTPException, match="not found"): + export_assessment_run_results( + 22, + session=MagicMock(), + auth_context=_auth_context(), + ) + + +# ─── New: util-level test for the extracted ZIP/single-file logic ────────── + + +class TestBuildAssessmentResultsResponse: + """Verify the extracted util builds the right shape for json / single-file / zip.""" + + def test_json_returns_apiresponse(self) -> None: + from app.services.assessment.utils.export import ( + build_assessment_results_response, + ) + + with patch( + "app.services.assessment.utils.export.load_export_rows_for_run", + return_value=[_row()], + ), patch( + "app.services.assessment.utils.export.sort_export_rows", + side_effect=lambda rows: rows, + ), patch( + "app.services.assessment.utils.export.build_json_export_rows", + return_value=[{"x": 1}], + ): + resp = build_assessment_results_response( + session=MagicMock(), + assessment=_assessment(), + runs=[_run()], + export_format="json", + ) + assert resp.success is True + + def test_csv_multi_run_returns_zip(self) -> None: + from app.services.assessment.utils.export import ( + build_assessment_results_response, + ) + + run1 = _run() + run2 = _run() + run2.id = 23 + run2.config_version = 2 + + with patch( + "app.services.assessment.utils.export.load_export_rows_for_run", + side_effect=[[_row(run_id=22)], [_row(run_id=23)]], + ), patch( + "app.services.assessment.utils.export.sort_export_rows", + side_effect=lambda rows: rows, + ), patch( + "app.services.assessment.utils.export.serialize_export_rows", + return_value=(b"csv", "text/csv"), + ), patch( + "app.services.assessment.utils.export.generate_timestamped_filename", + return_value="out.zip", + ): + resp = build_assessment_results_response( + session=MagicMock(), + assessment=_assessment(), + runs=[run1, run2], + export_format="csv", + ) + assert isinstance(resp, StreamingResponse) + assert resp.media_type == "application/zip" diff --git a/backend/app/tests/assessment/test_service.py b/backend/app/tests/assessment/test_service.py new file mode 100644 index 000000000..b3654fa9b --- /dev/null +++ b/backend/app/tests/assessment/test_service.py @@ -0,0 +1,405 @@ +"""Tests for assessment/service.py orchestration behavior.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest +from fastapi import HTTPException + +from app.models.assessment import AssessmentConfigRef, AssessmentCreate +from app.models.config.config import ConfigTag +from app.services.assessment.service import ( + _build_retry_request, + retry_assessment, + retry_assessment_run, + start_assessment, +) + + +def _make_request(provider_config_id: UUID) -> AssessmentCreate: + return AssessmentCreate( + experiment_name="exp-1", + dataset_id=7, + prompt_template="Answer: {question}", + system_instruction="Assess strictly", + text_columns=["question"], + attachments=[], + configs=[ + AssessmentConfigRef(config_id=provider_config_id, config_version=1), + ], + ) + + +def _make_dataset() -> MagicMock: + dataset = MagicMock() + dataset.id = 7 + dataset.name = "dataset-1" + return dataset + + +def _make_run() -> MagicMock: + run = MagicMock() + run.id = 11 + run.assessment_id = 21 + run.config_id = UUID("00000000-0000-0000-0000-000000000001") + run.config_version = 1 + run.status = "processing" + return run + + +def _assessment_config_crud_patch(): + """Patch ConfigCrud so the bare tag-check in start_assessment short-circuits. + + Returns a config tagged for assessment use, which the check accepts. + """ + crud = MagicMock() + crud.read_one.return_value = SimpleNamespace( + id=UUID("00000000-0000-0000-0000-000000000001"), + tag=ConfigTag.ASSESSMENT, + ) + return patch("app.services.assessment.service.ConfigCrud", return_value=crud) + + +class TestStartAssessment: + def test_dataset_not_found(self) -> None: + session = MagicMock() + request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) + with patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + side_effect=HTTPException( + status_code=404, + detail="Dataset 7 not found or not accessible", + ), + ): + with pytest.raises(HTTPException, match="not found"): + start_assessment( + session=session, + request=request, + organization_id=1, + project_id=1, + ) + + def test_config_resolution_failure(self) -> None: + session = MagicMock() + request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) + with ( + patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + return_value=_make_dataset(), + ), + patch( + "app.services.assessment.service.resolve_evaluation_config", + return_value=(None, "missing"), + ), + _assessment_config_crud_patch(), + ): + with pytest.raises(HTTPException, match="Failed to resolve config"): + start_assessment( + session=session, + request=request, + organization_id=1, + project_id=1, + ) + + def test_rejects_unsupported_provider(self) -> None: + session = MagicMock() + request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) + config_blob = SimpleNamespace(completion=SimpleNamespace(provider="anthropic")) + + with ( + patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + return_value=_make_dataset(), + ), + patch( + "app.services.assessment.service.resolve_evaluation_config", + return_value=(config_blob, None), + ), + patch( + "app.services.assessment.service.create_assessment" + ) as create_assessment, + _assessment_config_crud_patch(), + ): + with pytest.raises( + HTTPException, match="not supported for batch assessment" + ): + start_assessment( + session=session, + request=request, + organization_id=1, + project_id=1, + ) + create_assessment.assert_not_called() + + def test_google_provider_is_supported(self) -> None: + session = MagicMock() + request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) + dataset = _make_dataset() + assessment = MagicMock() + assessment.id = 21 + run = _make_run() + config_blob = SimpleNamespace( + completion=SimpleNamespace(provider="google", params={"model": "gemini"}) + ) + batch_job = MagicMock() + batch_job.id = 101 + batch_job.total_items = 3 + + with ( + patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + return_value=dataset, + ), + patch( + "app.services.assessment.service.resolve_evaluation_config", + return_value=(config_blob, None), + ), + patch( + "app.services.assessment.service.create_assessment", + return_value=assessment, + ), + patch( + "app.services.assessment.service.create_assessment_run", + return_value=run, + ), + patch( + "app.services.assessment.service.submit_assessment_batch", + return_value=batch_job, + ) as submit_batch, + patch( + "app.services.assessment.service.update_assessment_run_status", + return_value=run, + ), + patch("app.services.assessment.service.recompute_assessment_status"), + _assessment_config_crud_patch(), + ): + response = start_assessment( + session=session, + request=request, + organization_id=1, + project_id=1, + ) + + assert response.num_configs == 1 + assert submit_batch.call_args.kwargs["config_blob"] is config_blob + + def test_defaults_missing_provider_to_openai(self) -> None: + session = MagicMock() + request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) + dataset = _make_dataset() + assessment = MagicMock() + assessment.id = 21 + run = _make_run() + config_blob = SimpleNamespace( + completion=SimpleNamespace(provider=None, params={"model": "gpt-4.1-mini"}) + ) + batch_job = MagicMock() + batch_job.id = 101 + batch_job.total_items = 3 + + with ( + patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + return_value=dataset, + ), + patch( + "app.services.assessment.service.resolve_evaluation_config", + return_value=(config_blob, None), + ), + patch( + "app.services.assessment.service.create_assessment", + return_value=assessment, + ), + patch( + "app.services.assessment.service.create_assessment_run", + return_value=run, + ) as create_run, + patch( + "app.services.assessment.service.submit_assessment_batch", + return_value=batch_job, + ) as submit_batch, + patch( + "app.services.assessment.service.update_assessment_run_status", + return_value=run, + ), + patch("app.services.assessment.service.recompute_assessment_status"), + _assessment_config_crud_patch(), + ): + response = start_assessment( + session=session, + request=request, + organization_id=1, + project_id=1, + ) + + assert response.assessment_id == 21 + assert response.num_configs == 1 + assert response.runs[0].run_id == 11 + assessment_input = create_run.call_args.kwargs["assessment_input"] + assert assessment_input["system_instruction"] == "Assess strictly" + assert ( + submit_batch.call_args.kwargs["assessment_input"]["system_instruction"] + == "Assess strictly" + ) + submit_batch.assert_called_once() + + def test_rejects_default_tagged_config(self) -> None: + """Configs explicitly tagged 'default' must be rejected for assessment.""" + session = MagicMock() + request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) + + crud = MagicMock() + crud.read_one.return_value = SimpleNamespace( + id=UUID("00000000-0000-0000-0000-000000000001"), + tag=ConfigTag.DEFAULT, + ) + + with ( + patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + return_value=_make_dataset(), + ), + patch("app.services.assessment.service.ConfigCrud", return_value=crud), + patch( + "app.services.assessment.service.resolve_evaluation_config" + ) as resolve, + ): + with pytest.raises( + HTTPException, + match="cannot be used for assessment", + ): + start_assessment( + session=session, + request=request, + organization_id=1, + project_id=1, + ) + # Tag check must fire BEFORE config resolution. + resolve.assert_not_called() + + def test_batch_submission_failure_marks_run_failed(self) -> None: + session = MagicMock() + request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) + dataset = _make_dataset() + assessment = MagicMock() + assessment.id = 21 + run = _make_run() + run.status = "failed" + config_blob = SimpleNamespace( + completion=SimpleNamespace( + provider="openai", params={"model": "gpt-4.1-mini"} + ) + ) + + with ( + patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + return_value=dataset, + ), + patch( + "app.services.assessment.service.resolve_evaluation_config", + return_value=(config_blob, None), + ), + patch( + "app.services.assessment.service.create_assessment", + return_value=assessment, + ), + patch( + "app.services.assessment.service.create_assessment_run", + return_value=run, + ), + patch( + "app.services.assessment.service.submit_assessment_batch", + side_effect=RuntimeError("submit failed"), + ), + patch( + "app.services.assessment.service.update_assessment_run_status", + return_value=run, + ) as update_run, + patch("app.services.assessment.service.recompute_assessment_status"), + _assessment_config_crud_patch(), + ): + response = start_assessment( + session=session, + request=request, + organization_id=1, + project_id=1, + ) + assert response.num_configs == 1 + assert update_run.called + + +class TestRetryHelpers: + def test_build_retry_request_errors_and_success(self) -> None: + with pytest.raises(HTTPException, match="No assessment runs"): + _build_retry_request(experiment_name="exp", dataset_id=1, runs=[]) + + run = MagicMock() + run.input = None + with pytest.raises(HTTPException, match="missing for retry"): + _build_retry_request(experiment_name="exp", dataset_id=1, runs=[run]) + + run2 = MagicMock() + run2.id = 1 + run2.input = {"prompt_template": "p", "text_columns": ["q"], "attachments": []} + run2.config_id = None + run2.config_version = None + with pytest.raises(HTTPException, match="Config reference is missing"): + _build_retry_request(experiment_name="exp", dataset_id=1, runs=[run2]) + + run3 = MagicMock() + run3.id = 2 + run3.input = { + "prompt_template": "p", + "system_instruction": "sys", + "text_columns": ["q"], + "attachments": [], + "output_schema": {"type": "object"}, + } + run3.config_id = UUID("00000000-0000-0000-0000-000000000001") + run3.config_version = 1 + req = _build_retry_request(experiment_name="exp", dataset_id=1, runs=[run3]) + assert req.experiment_name == "exp" + assert req.system_instruction == "sys" + assert len(req.configs) == 1 + + def test_retry_assessment_wrappers(self) -> None: + session = MagicMock() + assessment = MagicMock() + assessment.id = 21 + assessment.experiment_name = "exp" + assessment.dataset_id = 7 + run = MagicMock() + run.assessment_id = 21 + run.assessment = assessment + run.input = {"prompt_template": "p", "text_columns": [], "attachments": []} + run.config_id = UUID("00000000-0000-0000-0000-000000000001") + run.config_version = 1 + + result = SimpleNamespace( + assessment_id=1, + experiment_name="exp", + dataset_id=7, + dataset_name="ds", + num_configs=1, + runs=[], + ) + + with ( + patch( + "app.services.assessment.service.get_assessment_runs_for_assessment", + return_value=[run], + ), + patch( + "app.services.assessment.service.start_assessment", return_value=result + ), + ): + resp = retry_assessment(session, assessment, 1, 1) + assert resp.assessment_id == 1 + + with patch( + "app.services.assessment.service.start_assessment", return_value=result + ): + resp2 = retry_assessment_run(session, run, 1, 1) + assert resp2.assessment_id == 1 diff --git a/backend/app/tests/assessment/test_validators.py b/backend/app/tests/assessment/test_validators.py new file mode 100644 index 000000000..87b2b2a2e --- /dev/null +++ b/backend/app/tests/assessment/test_validators.py @@ -0,0 +1,103 @@ +"""Tests for assessment/validators.py.""" + +import io + +import pytest +from fastapi import UploadFile + +from app.services.assessment.validators import MAX_FILE_SIZE, validate_dataset_file + + +def _make_upload( + filename: str, + content: bytes, + content_type: str = "text/csv", +) -> UploadFile: + return UploadFile( + filename=filename, + file=io.BytesIO(content), + headers={"content-type": content_type}, + ) + + +class TestValidateDatasetFile: + @pytest.mark.asyncio + async def test_valid_csv_accepted(self) -> None: + file = _make_upload("data.csv", b"col1,col2\nval1,val2") + content, ext = await validate_dataset_file(file) + assert ext == ".csv" + assert content == b"col1,col2\nval1,val2" + + @pytest.mark.asyncio + async def test_valid_xlsx_accepted(self) -> None: + file = _make_upload( + "data.xlsx", + b"PK\x03\x04fake_xlsx_content", + content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ) + content, ext = await validate_dataset_file(file) + assert ext == ".xlsx" + + @pytest.mark.asyncio + async def test_xls_rejected_with_clear_error(self) -> None: + file = _make_upload( + "data.xls", + b"fake_xls", + content_type="application/vnd.ms-excel", + ) + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + await validate_dataset_file(file) + assert exc_info.value.status_code == 422 + assert "Legacy Excel format (.xls) is not supported" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_missing_filename_raises_422(self) -> None: + from fastapi import HTTPException + + file = _make_upload("", b"data") + file.filename = None # type: ignore[assignment] + with pytest.raises(HTTPException) as exc_info: + await validate_dataset_file(file) + assert exc_info.value.status_code == 422 + + @pytest.mark.asyncio + async def test_invalid_extension_raises_422(self) -> None: + from fastapi import HTTPException + + file = _make_upload("data.txt", b"some data", content_type="text/plain") + with pytest.raises(HTTPException) as exc_info: + await validate_dataset_file(file) + assert exc_info.value.status_code == 422 + assert "Invalid file type" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_empty_file_raises_422(self) -> None: + from fastapi import HTTPException + + file = _make_upload("data.csv", b"") + with pytest.raises(HTTPException) as exc_info: + await validate_dataset_file(file) + assert exc_info.value.status_code == 422 + assert "Empty" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_file_too_large_raises_413(self) -> None: + from fastapi import HTTPException + + oversized = b"x" * (MAX_FILE_SIZE + 1) + file = _make_upload("data.csv", oversized) + with pytest.raises(HTTPException) as exc_info: + await validate_dataset_file(file) + assert exc_info.value.status_code == 413 + assert "too large" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_unexpected_content_type_still_accepted_by_extension(self) -> None: + # Unknown MIME type but valid extension — should proceed with a warning log + file = _make_upload( + "data.csv", b"a,b\n1,2", content_type="application/octet-stream" + ) + content, ext = await validate_dataset_file(file) + assert ext == ".csv" diff --git a/backend/app/tests/core/batch/test_gemini.py b/backend/app/tests/core/batch/test_gemini.py index 59e46bdee..6a00fd135 100644 --- a/backend/app/tests/core/batch/test_gemini.py +++ b/backend/app/tests/core/batch/test_gemini.py @@ -99,6 +99,37 @@ def test_create_batch_with_default_config(self, provider, mock_genai_client): assert result["total_items"] == 1 mock_genai_client.batches.create.assert_called_once() + def test_create_batch_preserves_unicode_text_in_jsonl( + self, provider, mock_genai_client + ): + """Test JSONL upload keeps non-ASCII prompt text UTF-8 encoded.""" + jsonl_data = [ + { + "key": "req-1", + "request": { + "contents": [ + { + "parts": [{"text": "# Problem Statement\n\nఎండ్"}], + "role": "user", + } + ] + }, + } + ] + config = {"display_name": "test"} + + provider.upload_file = MagicMock(return_value="files/uploaded-123") + mock_batch_job = MagicMock() + mock_batch_job.name = "batches/batch-123" + mock_batch_job.state.name = "JOB_STATE_PENDING" + mock_genai_client.batches.create.return_value = mock_batch_job + + provider.create_batch(jsonl_data, config) + + uploaded_content = provider.upload_file.call_args.args[0] + assert "ఎండ్" in uploaded_content + assert "\\u0c" not in uploaded_content + def test_create_batch_file_upload_error(self, provider, mock_genai_client): """Test handling of file upload error during batch creation.""" jsonl_data = [{"key": "req-1", "request": {}}] diff --git a/backend/app/tests/core/batch/test_openai.py b/backend/app/tests/core/batch/test_openai.py index 57dac1fdd..308ea7599 100644 --- a/backend/app/tests/core/batch/test_openai.py +++ b/backend/app/tests/core/batch/test_openai.py @@ -89,6 +89,48 @@ def test_create_batch_with_default_config(self, provider, mock_openai_client): assert result["total_items"] == 1 + def test_create_batch_preserves_unicode_text_in_jsonl( + self, provider, mock_openai_client + ): + """Test JSONL upload keeps non-ASCII prompt text UTF-8 encoded.""" + jsonl_data = [ + { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/responses", + "body": { + "input": [ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "# Problem Statement\n\nఎండ్", + } + ], + } + ] + }, + } + ] + config = {"endpoint": "/v1/responses"} + + mock_file_response = MagicMock() + mock_file_response.id = "file-123" + mock_openai_client.files.create.return_value = mock_file_response + + mock_batch = MagicMock() + mock_batch.id = "batch-456" + mock_batch.status = "validating" + mock_openai_client.batches.create.return_value = mock_batch + + provider.create_batch(jsonl_data, config) + + uploaded_bytes = mock_openai_client.files.create.call_args.kwargs["file"][1] + uploaded_content = uploaded_bytes.decode("utf-8") + assert "ఎండ్" in uploaded_content + assert "\\u0c" not in uploaded_content + def test_create_batch_file_upload_error(self, provider, mock_openai_client): """Test handling of file upload error during batch creation.""" jsonl_data = [{"custom_id": "req-1"}] diff --git a/backend/app/tests/crud/config/test_config.py b/backend/app/tests/crud/config/test_config.py index 3f83f8a42..c588dca97 100644 --- a/backend/app/tests/crud/config/test_config.py +++ b/backend/app/tests/crud/config/test_config.py @@ -1,18 +1,19 @@ from uuid import uuid4 import pytest -from sqlmodel import Session from fastapi import HTTPException +from sqlmodel import Session +from app.crud.config import ConfigCrud from app.models import ( - ConfigBlob, CompletionConfig, + ConfigBlob, ConfigCreate, ConfigUpdate, ) +from app.models.config.config import ConfigTag from app.models.llm.request import NativeCompletionConfig -from app.crud.config import ConfigCrud -from app.tests.utils.test_data import create_test_project, create_test_config +from app.tests.utils.test_data import create_test_config, create_test_project from app.tests.utils.utils import random_lower_string @@ -185,6 +186,48 @@ def test_read_all_configs(db: Session) -> None: assert config3.id in config_ids +def test_read_all_configs_without_tag_returns_default_configs( + db: Session, +) -> None: + """Test unscoped config lists include default configs.""" + project = create_test_project(db) + + implicit_default_config = create_test_config( + db, project_id=project.id, name="implicit-default-config" + ) + default_config = create_test_config( + db, project_id=project.id, name="default-config", tag=ConfigTag.DEFAULT + ) + + config_crud = ConfigCrud(session=db, project_id=project.id) + configs, _ = config_crud.read_all(query=None) + + config_ids = [c.id for c in configs] + assert implicit_default_config.id in config_ids + assert default_config.id in config_ids + + +def test_read_all_configs_with_explicit_default_tag_returns_default_configs( + db: Session, +) -> None: + """Test explicit default tag lists default configs.""" + project = create_test_project(db) + + implicit_default_config = create_test_config( + db, project_id=project.id, name="implicit-default-config" + ) + default_config = create_test_config( + db, project_id=project.id, name="default-config", tag=ConfigTag.DEFAULT + ) + + config_crud = ConfigCrud(session=db, project_id=project.id) + configs, _ = config_crud.read_all(query=None, tag=ConfigTag.DEFAULT) + + config_ids = [c.id for c in configs] + assert implicit_default_config.id in config_ids + assert default_config.id in config_ids + + def test_read_all_configs_pagination(db: Session) -> None: """Test reading configurations with pagination.""" project = create_test_project(db) diff --git a/backend/app/tests/crud/config/test_version.py b/backend/app/tests/crud/config/test_version.py index dccb54f46..efc844fb2 100644 --- a/backend/app/tests/crud/config/test_version.py +++ b/backend/app/tests/crud/config/test_version.py @@ -1,15 +1,16 @@ from uuid import uuid4 import pytest -from sqlmodel import Session from fastapi import HTTPException +from sqlmodel import Session -from app.models import ConfigVersionUpdate, ConfigBlob -from app.models.llm.request import NativeCompletionConfig from app.crud.config import ConfigVersionCrud +from app.models import ConfigBlob, ConfigVersionUpdate +from app.models.config.config import ConfigTag +from app.models.llm.request import NativeCompletionConfig from app.tests.utils.test_data import ( - create_test_project, create_test_config, + create_test_project, create_test_version, ) @@ -504,3 +505,56 @@ def test_read_all_versions_config_not_found(db: Session) -> None: HTTPException, match=f"config with id '{non_existent_config_id}' not found" ): version_crud.read_all() + + +def test_read_all_versions_without_tag_uses_default_scope( + db: Session, +) -> None: + """Test omitted tag scope allows default config versions.""" + config = create_test_config(db) + version_crud = ConfigVersionCrud( + session=db, + project_id=config.project_id, + config_id=config.id, + ) + + versions = version_crud.read_all() + + assert len(versions) == 1 + assert versions[0].version == 1 + + +def test_read_all_versions_with_explicit_default_tag_allows_default_config( + db: Session, +) -> None: + """Test explicit default tag scope allows default config versions.""" + config = create_test_config(db) + version_crud = ConfigVersionCrud( + session=db, + project_id=config.project_id, + config_id=config.id, + tag=ConfigTag.DEFAULT, + ) + + versions = version_crud.read_all() + + assert len(versions) == 1 + assert versions[0].version == 1 + + +def test_read_all_versions_with_explicit_tag_allows_matching_config( + db: Session, +) -> None: + """Test explicit tag scope allows matching tagged config versions.""" + config = create_test_config(db, tag=ConfigTag.DEFAULT) + version_crud = ConfigVersionCrud( + session=db, + project_id=config.project_id, + config_id=config.id, + tag=ConfigTag.DEFAULT, + ) + + versions = version_crud.read_all() + + assert len(versions) == 1 + assert versions[0].version == 1 diff --git a/backend/app/tests/utils/test_data.py b/backend/app/tests/utils/test_data.py index 9b144c2f4..03e24226f 100644 --- a/backend/app/tests/utils/test_data.py +++ b/backend/app/tests/utils/test_data.py @@ -20,6 +20,7 @@ ConfigVersionUpdate, EvaluationDataset, ) +from app.models.config.config import ConfigTag from app.models.llm import KaapiLLMParams, KaapiCompletionConfig, NativeCompletionConfig from app.crud import ( create_organization, @@ -242,6 +243,7 @@ def create_test_config( description: str | None = None, config_blob: ConfigBlob | None = None, use_kaapi_schema: bool = False, + tag: ConfigTag = ConfigTag.DEFAULT, ) -> Config: """ Creates and returns a test configuration with an initial version. @@ -255,6 +257,7 @@ def create_test_config( description: Config description config_blob: Config blob (creates default if None) use_kaapi_schema: If True, creates Kaapi-format config; if False, creates native format + tag: Config classification tag. Defaults to `default`. """ if project_id is None: project = create_test_project(db) @@ -295,6 +298,7 @@ def create_test_config( description=description or "Test configuration description", config_blob=config_blob, commit_message="Initial version", + tag=tag, ) config_crud = ConfigCrud(session=db, project_id=project_id) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index b96cd0050..03af443a6 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "elevenlabs>=2.38.1", "google-auth>=2.49.1", "gevent>=25.9.1", + "openpyxl>=3.1.5", ] [tool.uv] diff --git a/backend/uv.lock b/backend/uv.lock index ef8cf856c..26bf23616 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -240,6 +240,7 @@ dependencies = [ { name = "numpy" }, { name = "openai" }, { name = "openai-responses" }, + { name = "openpyxl" }, { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-instrumentation-celery" }, @@ -302,6 +303,7 @@ requires-dist = [ { name = "numpy", specifier = ">=1.24.0" }, { name = "openai", specifier = ">=1.100.1" }, { name = "openai-responses" }, + { name = "openpyxl", specifier = ">=3.1.5" }, { name = "opentelemetry-api", specifier = ">=1.30.0" }, { name = "opentelemetry-instrumentation", specifier = ">=0.51b0" }, { name = "opentelemetry-instrumentation-celery", specifier = ">=0.51b0" }, @@ -929,6 +931,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/55/7e/b648d640d88d31de49e566832aca9cce025c52d6349b0a0fc65e9df1f4c5/emails-0.6-py2.py3-none-any.whl", hash = "sha256:72c1e3198075709cc35f67e1b49e2da1a2bc087e9b444073db61a379adfb7f3c", size = 56250, upload-time = "2020-06-19T11:20:40.466Z" }, ] +[[package]] +name = "et-xmlfile" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/38/af70d7ab1ae9d4da450eeec1fa3918940a5fafb9055e934af8d6eb0c2313/et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54", size = 17234, upload-time = "2024-10-25T17:25:40.039Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa", size = 18059, upload-time = "2024-10-25T17:25:39.051Z" }, +] + [[package]] name = "fastapi" version = "0.135.1" @@ -2267,6 +2278,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/98/50c755503a55550f170d0211297bc8791b8bf10bf04cb16b4b95ca71d1e3/openai_responses-0.13.1-py3-none-any.whl", hash = "sha256:b5fc7fb15f546b551757864c1dfaeb01b8a4fc0c353961bd6d0d45ff26389721", size = 51887, upload-time = "2025-12-02T21:37:40.15Z" }, ] +[[package]] +name = "openpyxl" +version = "3.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "et-xmlfile" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/f9/88d94a75de065ea32619465d2f77b29a0469500e99012523b91cc4141cd1/openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050", size = 186464, upload-time = "2024-06-28T14:03:44.161Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2", size = 250910, upload-time = "2024-06-28T14:03:41.161Z" }, +] + [[package]] name = "opentelemetry-api" version = "1.41.0"