From b72ff8fdf91c479c632dfaf6787137bafb30349b Mon Sep 17 00:00:00 2001 From: "Zeumer, Moritz" Date: Mon, 8 Dec 2025 16:27:23 +0100 Subject: [PATCH] add simulated filter (optional) to scenario endpoint --- api/src/api/app/apis/scenarios_api.py | 6 ++++-- api/src/api/app/controller/scenario_controller.py | 3 ++- api/src/api/app/db/models.py | 2 ++ api/src/api/app/db/tasks.py | 10 +++++++++- api/src/api/app/models/reduced_scenario.py | 4 +++- api/src/api/app/models/scenario.py | 6 ++++-- 6 files changed, 24 insertions(+), 7 deletions(-) diff --git a/api/src/api/app/apis/scenarios_api.py b/api/src/api/app/apis/scenarios_api.py index 42ccfab..63de599 100644 --- a/api/src/api/app/apis/scenarios_api.py +++ b/api/src/api/app/apis/scenarios_api.py @@ -136,9 +136,11 @@ async def import_scenario_data( tags=["Scenarios"], response_model_by_alias=True, ) -async def list_scenarios() -> List[ReducedScenario]: +async def list_scenarios( + simulatedFilter: Annotated[Optional[bool], Field(description="Filter for scenario query; None - no filter, return all scenarios, True - only return simulated, False - only return unsimulated")] = Query(None, description="Return only simulated scenarios; if False, only return unsimulated scenarios", alias="simulated") +) -> List[ReducedScenario]: """List all available scenarios.""" - return await controller.list_scenarios() + return await controller.list_scenarios(simulatedFilter) # a toy endpoint to test authorization @router.post( diff --git a/api/src/api/app/controller/scenario_controller.py b/api/src/api/app/controller/scenario_controller.py index 92eb620..0b93da7 100644 --- a/api/src/api/app/controller/scenario_controller.py +++ b/api/src/api/app/controller/scenario_controller.py @@ -85,9 +85,10 @@ async def get_scenario( async def list_scenarios( self, + simulatedFilter: bool | None, ) -> List[ReducedScenario]: """List all available scenarios.""" - return scenario_get_all() + return scenario_get_all(simulatedFilter) async def get_infection_data( self, diff --git a/api/src/api/app/db/models.py b/api/src/api/app/db/models.py index 940fb81..3dc9e32 100644 --- a/api/src/api/app/db/models.py +++ b/api/src/api/app/db/models.py @@ -31,6 +31,8 @@ class Scenario(SQLModel, table=True): creatorUserId: Optional[uuid.UUID] = Field(default=None, nullable=True) creatorOrgId: Optional[str] = Field(default=None, nullable=True) + whitelist: Optional[str] = Field(default=None, nullable=True) + class ParameterDefinition(SQLModel, table=True): id: Optional[uuid.UUID] = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False) name: Optional[str] = Field(nullable=False) diff --git a/api/src/api/app/db/tasks.py b/api/src/api/app/db/tasks.py index e7b4deb..60bb237 100644 --- a/api/src/api/app/db/tasks.py +++ b/api/src/api/app/db/tasks.py @@ -542,8 +542,16 @@ def scenario_get_by_id(id: StrictStr) -> Scenario: creator_org_id=scenario.creatorOrgId ) -def scenario_get_all() -> List[ReducedScenario]: +def scenario_get_all(simulatedFilter: bool | None) -> List[ReducedScenario]: query = select(db.Scenario) + # Check simulated filter and add where if filter is present + if simulatedFilter is None: + pass + elif simulatedFilter: + query = query.where(db.Scenario.timestampSimulated is not None) + else: + query = query.where(db.Scenario.timestampSimulated is None) + with next(get_session()) as session: scenarios: List[db.Scenario] = session.exec(query).all() return [ReducedScenario( diff --git a/api/src/api/app/models/reduced_scenario.py b/api/src/api/app/models/reduced_scenario.py index 71a0265..c66ddf1 100644 --- a/api/src/api/app/models/reduced_scenario.py +++ b/api/src/api/app/models/reduced_scenario.py @@ -40,7 +40,8 @@ class ReducedScenario(BaseModel): percentiles: Optional[List[StrictInt]] = Field(default=None, alias="percentiles", description="List of available percentiles for this scenario") timestamp_submitted: Optional[datetime] = Field(default=None, alias="timestampSubmitted", description="Timestamp when the scenario was added/created") timestamp_simulated: Optional[datetime] = Field(default=None, alias="timestampSimulated", description="Timestamp when the scenario was finished simulating and data is available") - __properties: ClassVar[List[str]] = ["id", "name", "description", "startDate", "endDate", "timestamp_submitted", "timestamp_simulated"] + whitelist: Optional[List[StrictStr]] = Field(default=None, alias="whitelist", description="Whitelist of Organizations with access to this scenario") + __properties: ClassVar[List[str]] = ["id", "name", "description", "startDate", "endDate", "percentiles", "timestamp_submitted", "timestamp_simulated", "whitelist"] model_config = { "populate_by_name": True, @@ -101,6 +102,7 @@ def from_dict(cls, obj: Dict) -> Self: "percentiles": obj.get("percentiles"), "timestamp_submitted": obj.get("timestamp_submitted"), "timestamp_simulated": obj.get("timestamp_simulated"), + "whitelist": obj.get("whitelist"), }) return _obj diff --git a/api/src/api/app/models/scenario.py b/api/src/api/app/models/scenario.py index edf5e0c..c399a00 100644 --- a/api/src/api/app/models/scenario.py +++ b/api/src/api/app/models/scenario.py @@ -49,7 +49,8 @@ class Scenario(BaseModel): timestamp_simulated: Optional[datetime] = Field(default=None, alias="timestampSimulated", description="Timestamp when the scenario was finished simulating and data is available") creator_user_id: Optional[str] = Field(default=None, alias="creatorUserId", description="ID of the user who submitted the scenario") creator_org_id: Optional[str] = Field(default=None, alias="creatorOrgId", description="ID of the organization the submitting user belongs to") - __properties: ClassVar[List[str]] = ["id", "name", "description", "startDate", "endDate", "modelId", "modelParameters", "nodeListId", "linkedInterventions", "percentiles", "timestampSubmitted", "timestampSimulated", "creatorUserId", "creatorOrgId"] + whitelist: Optional[List[StrictStr]] = Field(default=None, alias="whitelist", description="Whitelist of Organizations with access to this scenario") + __properties: ClassVar[List[str]] = ["id", "name", "description", "startDate", "endDate", "modelId", "modelParameters", "nodeListId", "linkedInterventions", "percentiles", "timestampSubmitted", "timestampSimulated", "creatorUserId", "creatorOrgId", "whitelist"] model_config = { "populate_by_name": True, "validate_assignment": True, @@ -130,6 +131,7 @@ def from_dict(cls, obj: Dict) -> Self: "linkedInterventions": [InterventionImplementation.from_dict(_item) for _item in obj.get("linkedInterventions")] if obj.get("linkedInterventions") is not None else None, "percentiles": obj.get("percentiles"), "timestampSubmitted": obj.get("timestampSubmitted"), - "timestampSimulated": obj.get("timestampSimulated") + "timestampSimulated": obj.get("timestampSimulated"), + "whitelist": obj.get("whitelist"), }) return _obj