diff --git a/changelog.d/budget-window-batch.fixed.md b/changelog.d/budget-window-batch.fixed.md new file mode 100644 index 000000000..9bd03006f --- /dev/null +++ b/changelog.d/budget-window-batch.fixed.md @@ -0,0 +1 @@ +Added a budget-window economy endpoint that batches yearly impact calculations with bounded server-side concurrency and returns aggregated progress plus totals. diff --git a/changelog.d/fix-silent-exception-swallowing.fixed.md b/changelog.d/fix-silent-exception-swallowing.fixed.md new file mode 100644 index 000000000..4b10062e5 --- /dev/null +++ b/changelog.d/fix-silent-exception-swallowing.fixed.md @@ -0,0 +1 @@ +Log exceptions instead of silently swallowing them during household calculations. diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 112cce9ac..eb3eba9ee 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -4,6 +4,7 @@ import time import sys +import os start_time = time.time() @@ -157,8 +158,11 @@ def log_timing(message): app.register_blueprint(user_profile_bp) log_timing("User profile routes registered") -app.route("/simulations", methods=["GET"])(get_simulations) -log_timing("Simulations endpoint registered") +if os.environ.get("FLASK_DEBUG") == "1": + app.route("/simulations", methods=["GET"])(get_simulations) + log_timing("Simulations endpoint registered") +else: + log_timing("Simulations endpoint skipped outside debug mode") app.register_blueprint(tracer_analysis_bp) log_timing("Tracer analysis routes registered") diff --git a/policyengine_api/country.py b/policyengine_api/country.py index 4278637d8..430df888c 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -1,5 +1,6 @@ import importlib import inspect +import logging import json from policyengine_core.taxbenefitsystems import TaxBenefitSystem from typing import Union, Optional @@ -429,11 +430,9 @@ def calculate( entity_result ) except Exception as e: - if "axes" in household: - pass - else: + logging.exception(f"Error computing {variable_name} for {entity_id}") + if "axes" not in household: household[entity_plural][entity_id][variable_name][period] = None - print(f"Error computing {variable_name} for {entity_id}: {e}") tracer_output = simulation.tracer.computation_log log_lines = tracer_output.lines(aggregate=False, max_depth=10) diff --git a/policyengine_api/data/__init__.py b/policyengine_api/data/__init__.py index 15673afdb..94703ee36 100644 --- a/policyengine_api/data/__init__.py +++ b/policyengine_api/data/__init__.py @@ -1 +1,6 @@ -from .data import PolicyEngineDatabase, database, local_database +from .data import ( + PolicyEngineDatabase, + database, + get_remote_database, + local_database, +) diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index 6b16e713e..78cdb5459 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -19,6 +19,7 @@ class _ResultProxy: Provides fetchone()/fetchall() with dict-like row access.""" def __init__(self, cursor_result): + self.rowcount = getattr(cursor_result, "rowcount", -1) try: # Use .mappings() so rows behave like dicts self._rows = list(cursor_result.mappings()) @@ -105,16 +106,20 @@ def _create_pool(self): with open(".dbpw") as f: db_pass = f.read().strip() db_name = "policyengine" - conn = self.connector.connect( - instance_connection_string=instance_connection_name, - driver="pymysql", - db=db_name, - user=db_user, - password=db_pass, - ) + + def get_connection(): + return self.connector.connect( + instance_connection_string=instance_connection_name, + driver="pymysql", + db=db_name, + user=db_user, + password=db_pass, + ) + self.pool = sqlalchemy.create_engine( "mysql+pymysql://", - creator=lambda: conn, + creator=get_connection, + pool_pre_ping=True, ) def _close_pool(self): @@ -259,3 +264,11 @@ def initialize(self): database = PolicyEngineDatabase(local=False, initialize=False) local_database = PolicyEngineDatabase(local=True, initialize=False) +remote_database = None + + +def get_remote_database() -> PolicyEngineDatabase: + global remote_database + if remote_database is None: + remote_database = PolicyEngineDatabase(local=False, initialize=False) + return remote_database diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index a0d9bd70d..d300ae80b 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import local_database +from policyengine_api.data import get_remote_database """ @@ -42,10 +42,14 @@ def get_simulations( max_results = _DEFAULT_SIMULATION_RESULTS max_results = max(1, min(max_results, _MAX_SIMULATION_RESULTS)) - result = local_database.query( - "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", - (max_results,), - ).fetchall() + result = ( + get_remote_database() + .query( + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", + (max_results,), + ) + .fetchall() + ) # Format into [{}] diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index 3d7660791..c567de06f 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -7,7 +7,7 @@ import os import sys -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional import httpx @@ -42,6 +42,28 @@ def name(self) -> str: return self.job_id +@dataclass +class ModalBudgetWindowBatchExecution: + """ + Represents a budget-window batch execution in the Modal simulation API. + """ + + batch_job_id: str + status: str + progress: Optional[int] = None + completed_years: list[str] = field(default_factory=list) + running_years: list[str] = field(default_factory=list) + queued_years: list[str] = field(default_factory=list) + failed_years: list[str] = field(default_factory=list) + result: Optional[dict] = None + error: Optional[str] = None + + @property + def name(self) -> str: + """Alias for batch_job_id.""" + return self.batch_job_id + + class SimulationAPIModal: """ HTTP client for the Modal Simulation API. @@ -144,10 +166,51 @@ def run(self, payload: dict) -> ModalSimulationExecution: ) raise + def run_budget_window_batch(self, payload: dict) -> ModalBudgetWindowBatchExecution: + """ + Submit a budget-window batch job to the Modal API. + """ + try: + modal_payload = dict(payload) + if "model_version" in modal_payload: + modal_payload["version"] = modal_payload.pop("model_version") + modal_payload.pop("data_version", None) + + response = self.client.post( + f"{self.base_url}/simulate/economy/budget-window", + json=modal_payload, + ) + response.raise_for_status() + data = response.json() + + logger.log_struct( + { + "message": "Modal budget-window batch submitted", + "batch_job_id": data.get("batch_job_id"), + "status": data.get("status"), + }, + severity="INFO", + ) + + return ModalBudgetWindowBatchExecution( + batch_job_id=data["batch_job_id"], + status=data["status"], + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal batch API HTTP error: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + except httpx.RequestError as e: logger.log_struct( { - "message": f"Modal API request error: {str(e)}", + "message": f"Modal batch API request error: {str(e)}", "run_id": (payload.get("_telemetry") or {}).get("run_id"), }, severity="ERROR", @@ -226,10 +289,44 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution: ) raise + def get_budget_window_batch_by_id( + self, batch_job_id: str + ) -> ModalBudgetWindowBatchExecution: + """ + Poll the Modal API for the current status of a budget-window batch. + """ + try: + response = self.client.get( + f"{self.base_url}/budget-window-jobs/{batch_job_id}" + ) + data = response.json() + + return ModalBudgetWindowBatchExecution( + batch_job_id=batch_job_id, + status=data["status"], + progress=data.get("progress"), + completed_years=data.get("completed_years", []), + running_years=data.get("running_years", []), + queued_years=data.get("queued_years", []), + failed_years=data.get("failed_years", []), + result=data.get("result"), + error=data.get("error"), + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal batch API HTTP error polling job {batch_job_id}: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + except httpx.RequestError as e: logger.log_struct( { - "message": f"Modal API request error polling job {job_id}: {str(e)}", + "message": f"Modal batch API request error polling job {batch_job_id}: {str(e)}", }, severity="ERROR", ) diff --git a/policyengine_api/openapi_spec.yaml b/policyengine_api/openapi_spec.yaml index a49268c8c..77daadc9e 100644 --- a/policyengine_api/openapi_spec.yaml +++ b/policyengine_api/openapi_spec.yaml @@ -660,6 +660,138 @@ paths: type: string message: type: string + /{country_id}/economy/{policy_id}/over/{baseline_policy_id}/budget-window: + get: + summary: Calculate budget-window economic impacts + operationId: get_budget_window_economic_impact + description: Calculate annual and total budget impacts for a policy over a multi-year budget window. + parameters: + - name: country_id + in: path + description: The country ID. + required: true + schema: + type: string + - name: policy_id + in: path + description: The reform policy ID. + required: true + schema: + type: string + - name: baseline_policy_id + in: path + description: The baseline policy ID. + required: true + schema: + type: string + - name: region + in: query + description: The sub-national region. + required: true + schema: + type: string + - name: start_year + in: query + description: First year in the budget window. + required: true + schema: + type: string + - name: window_size + in: query + description: Number of years to include in the budget window. + required: true + schema: + type: integer + - name: dataset + in: query + description: Dataset selection. + required: false + schema: + type: string + default: default + - name: version + in: query + description: API version number. + required: false + schema: + type: string + - name: include_district_breakdowns + in: query + description: Whether to include congressional district breakdowns for US national simulations. + required: false + schema: + type: boolean + default: false + - name: target + in: query + description: Impact target. Budget-window calculations only support general impacts. + required: false + schema: + type: string + default: general + responses: + 200: + description: Budget-window economic impact, progress, or error state. + content: + application/json: + schema: + type: object + properties: + status: + type: string + enum: + - ok + - computing + - error + message: + type: string + nullable: true + result: + type: object + nullable: true + progress: + type: integer + nullable: true + completed_years: + type: array + items: + type: string + computing_years: + type: array + items: + type: string + queued_years: + type: array + items: + type: string + error: + type: string + nullable: true + 400: + description: Invalid budget-window request. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + result: + type: object + nullable: true + 404: + description: Invalid country ID. + content: + text/html: + schema: + type: object + properties: + status: + type: string + message: + type: string /{country_id}/analysis: post: summary: Get or trigger policy analysis diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 1807416f2..d772697c2 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -2,6 +2,7 @@ from policyengine_api.services.economy_service import ( EconomyService, EconomicImpactResult, + BudgetWindowEconomicImpactResult, ) from policyengine_api.utils import get_current_law_policy_id from policyengine_api.utils.payload_validators import validate_country @@ -13,6 +14,25 @@ economy_service = EconomyService() +def _json_response(payload: dict, status: int = 200) -> Response: + return Response( + json.dumps(payload), + status=status, + mimetype="application/json", + ) + + +def _bad_request_response(message: str) -> Response: + return _json_response( + { + "status": "error", + "message": message, + "result": None, + }, + status=400, + ) + + @economy_bp.route( "//economy//over/", methods=["GET"], @@ -56,14 +76,93 @@ def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int result_dict: dict[str, str | dict | None] = economic_impact_result.to_dict() - return Response( - json.dumps( - { - "status": result_dict["status"], - "message": None, - "result": result_dict["data"], - } - ), - status=200, - mimetype="application/json", + return _json_response( + { + "status": result_dict["status"], + "message": None, + "result": result_dict["data"], + } + ) + + +@economy_bp.route( + "//economy//over//budget-window", + methods=["GET"], +) +@validate_country +def get_budget_window_economic_impact( + country_id: str, policy_id: int, baseline_policy_id: int +): + policy_id = int(policy_id or get_current_law_policy_id(country_id)) + baseline_policy_id = int( + baseline_policy_id or get_current_law_policy_id(country_id) + ) + + query_parameters = request.args + options = dict(query_parameters) + options = json.loads(json.dumps(options)) + region = options.pop("region", None) + if not region: + return _bad_request_response("Missing required query parameter: region") + + dataset = options.pop("dataset", "default") + start_year = options.pop("start_year", None) + if not start_year: + return _bad_request_response("Missing required query parameter: start_year") + + window_size_raw = options.pop("window_size", None) + if window_size_raw is None: + return _bad_request_response("Missing required query parameter: window_size") + + try: + window_size = int(window_size_raw) + except (TypeError, ValueError): + return _bad_request_response("window_size must be an integer") + + include_district_breakdowns_raw = options.pop( + "include_district_breakdowns", "false" + ) + include_district_breakdowns = include_district_breakdowns_raw.lower() == "true" + if include_district_breakdowns and country_id == "us" and region == "us": + dataset = "national-with-breakdowns" + + target: Literal["general", "cliff"] = options.pop("target", "general") + if target != "general": + return _bad_request_response( + "Budget-window calculations only support target=general" + ) + + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) + + try: + economic_impact_result: BudgetWindowEconomicImpactResult = ( + economy_service.get_budget_window_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) + ) + except ValueError as error: + return _bad_request_response(str(error)) + + result_dict = economic_impact_result.to_dict() + + return _json_response( + { + "status": result_dict["status"], + "message": result_dict["message"], + "result": result_dict["data"], + "progress": result_dict["progress"], + "completed_years": result_dict["completed_years"], + "computing_years": result_dict["computing_years"], + "queued_years": result_dict["queued_years"], + "error": result_dict["error"], + } ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 871c896cc..1ec5f969b 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -26,9 +26,9 @@ import datetime import hashlib import uuid -from typing import Literal, Any, Optional, Annotated +from typing import Literal, Any, Optional, Annotated, Union from dotenv import load_dotenv -from pydantic import BaseModel +from pydantic import BaseModel, Field import numpy as np from enum import Enum @@ -57,6 +57,7 @@ class ImpactAction(Enum): COMPLETED = "completed" COMPUTING = "computing" CREATE = "create" + ERROR = "error" class ImpactStatus(Enum): @@ -71,6 +72,14 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value +BUDGET_WINDOW_MAX_ACTIVE_YEARS = 20 +BUDGET_WINDOW_MAX_YEARS = 75 +BUDGET_WINDOW_MAX_END_YEAR = 2099 +PENDING_EXECUTION_ID_PREFIX = "pending:" +PROVISIONAL_CLAIM_TTL_SECONDS = 90 +STALE_PROVISIONAL_IMPACT_MESSAGE = ( + "Simulation claim expired before job submission completed" +) class EconomicImpactSetupOptions(BaseModel): @@ -99,6 +108,7 @@ class EconomicImpactResult(BaseModel): status: ImpactStatus data: Optional[dict] = None + message: Optional[str] = None model_config = {"frozen": True} # Make model immutable @@ -131,7 +141,80 @@ def error(cls, message: str) -> "EconomicImpactResult": Create an EconomicImpactResult for an error in the impact calculation. """ logger.log_struct({"message": message}, severity="ERROR") - return cls(status=ImpactStatus.ERROR, data=None) + return cls(status=ImpactStatus.ERROR, data=None, message=message) + + +class BudgetWindowEconomicImpactResult(BaseModel): + """ + Model for a batch budget-window economic impact response. + """ + + status: ImpactStatus + data: Optional[dict] = None + progress: Optional[int] = None + completed_years: list[str] = Field(default_factory=list) + computing_years: list[str] = Field(default_factory=list) + queued_years: list[str] = Field(default_factory=list) + message: Optional[str] = None + error: Optional[str] = None + + model_config = {"frozen": True} + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status.value, + "data": self.data, + "progress": self.progress, + "completed_years": self.completed_years, + "computing_years": self.computing_years, + "queued_years": self.queued_years, + "message": self.message, + "error": self.error, + } + + @classmethod + def completed(cls, data: dict) -> "BudgetWindowEconomicImpactResult": + return cls(status=ImpactStatus.OK, data=data, progress=100) + + @classmethod + def computing( + cls, + *, + progress: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + message: str, + ) -> "BudgetWindowEconomicImpactResult": + return cls( + status=ImpactStatus.COMPUTING, + data=None, + progress=progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=message, + ) + + @classmethod + def failed( + cls, + message: str, + *, + completed_years: Optional[list[str]] = None, + computing_years: Optional[list[str]] = None, + queued_years: Optional[list[str]] = None, + ) -> "BudgetWindowEconomicImpactResult": + logger.log_struct({"message": message}, severity="ERROR") + return cls( + status=ImpactStatus.ERROR, + data=None, + completed_years=completed_years or [], + computing_years=computing_years or [], + queued_years=queued_years or [], + message=message, + error=message, + ) class EconomyService: @@ -168,133 +251,678 @@ def get_economic_impact( if country_id == "us": region = normalize_us_region(region) - # Set up logging - process_id: str = self._create_process_id() - - country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) - cache_version = get_economy_impact_cache_version(country_id, api_version) - resolved_dataset = self._setup_data( + economic_impact_setup_options = self._build_economic_impact_setup_options( country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, region=region, dataset=dataset, + time_period=time_period, + options=options, + api_version=api_version, + target=target, + ) + + return self._get_or_create_economic_impact( + setup_options=economic_impact_setup_options, + ) + + except Exception as e: + print(f"Error getting economic impact: {str(e)}") + raise e + + def get_budget_window_economic_impact( + self, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + start_year: str, + window_size: int, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + max_active_years: int = BUDGET_WINDOW_MAX_ACTIVE_YEARS, + ) -> BudgetWindowEconomicImpactResult: + try: + if country_id == "us": + region = normalize_us_region(region) + + if target != "general": + raise ValueError( + "Budget-window calculations only support target='general'" + ) + + start_year_int = int(start_year) + if not 1 <= window_size <= BUDGET_WINDOW_MAX_YEARS: + raise ValueError( + f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}" + ) + end_year = start_year_int + window_size - 1 + if end_year > BUDGET_WINDOW_MAX_END_YEAR: + raise ValueError( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" + ) + + start_year = str(start_year_int) + years = self._build_budget_window_years( + start_year=start_year, + window_size=window_size, ) - resolved_model_version = country_package_version - resolved_data_version = self._extract_dataset_version(resolved_dataset) - options_hash = self._build_options_hash( + tracking_setup_options = self._build_budget_window_tracking_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, options=options, - model_version=resolved_model_version, - dataset=resolved_dataset, + api_version=api_version, + target=target, ) - economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( - { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": resolved_dataset, - "time_period": time_period, - "options": options, - "api_version": cache_version, - "target": target, - "model_version": resolved_model_version, - "policyengine_version": None, - "data_version": resolved_data_version, - "runtime_app_name": None, - "options_hash": options_hash, - } + most_recent_impact = self._get_budget_window_tracking_impact( + tracking_setup_options + ) + if most_recent_impact is None: + self._start_budget_window_batch( + setup_options=tracking_setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_active_years, + ) + return self._build_budget_window_computing_result( + total_years=len(years), + completed_years=[], + computing_years=[], + queued_years=years, + progress=0, + ) + + return self._get_budget_window_result_from_tracking_impact( + setup_options=tracking_setup_options, + most_recent_impact=most_recent_impact, + total_years=len(years), + queued_years_on_submit=years, ) + except Exception as e: + print(f"Error getting budget-window economic impact: {str(e)}") + raise e + + def _build_budget_window_years( + self, + *, + start_year: str, + window_size: int, + ) -> list[str]: + start_year_int = int(start_year) + return [str(start_year_int + index) for index in range(window_size)] + + def _build_budget_window_tracking_time_period( + self, + *, + start_year: str, + window_size: int, + ) -> str: + return f"budget_window:{start_year}:{window_size}" + + def _build_budget_window_tracking_setup_options( + self, + *, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + start_year: str, + window_size: int, + options: dict, + api_version: str, + target: Literal["general", "cliff"], + ) -> EconomicImpactSetupOptions: + return self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=self._build_budget_window_tracking_time_period( + start_year=start_year, + window_size=window_size, + ), + options=dict(options), + api_version=api_version, + target=target, + ) - # Logging that we've received a request + def _build_budget_window_batch_payload( + self, + *, + setup_options: EconomicImpactSetupOptions, + start_year: str, + window_size: int, + max_parallel: int, + ) -> dict[str, Any]: + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, + setup_options.baseline_policy_id, + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, + setup_options.reform_policy_id, + ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=start_year, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=False, + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) + sim_params = sim_config.model_dump() + sim_params.pop("time_period", None) + sim_params["start_year"] = start_year + sim_params["window_size"] = window_size + sim_params["max_parallel"] = max_parallel + sim_params["target"] = setup_options.target + return sim_params + + def _get_budget_window_tracking_impact( + self, + setup_options: EconomicImpactSetupOptions, + ) -> dict | None: + return self._get_exact_reform_impact(setup_options) + + def _start_budget_window_batch( + self, + *, + setup_options: EconomicImpactSetupOptions, + start_year: str, + window_size: int, + max_parallel: int, + ) -> None: + sim_params = self._build_budget_window_batch_payload( + setup_options=setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_parallel, + ) + + logger.log_struct( + { + "message": "Submitting budget-window batch job", + **setup_options.model_dump(), + "start_year": start_year, + "window_size": window_size, + "max_parallel": max_parallel, + }, + severity="INFO", + ) + + batch_execution = simulation_api.run_budget_window_batch(sim_params) + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=batch_execution.batch_job_id, + ) + + def _get_budget_window_result_from_tracking_impact( + self, + *, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + total_years: int, + queued_years_on_submit: list[str], + ) -> BudgetWindowEconomicImpactResult: + impact_status = most_recent_impact.get("status") + if impact_status == ImpactStatus.OK.value: + return BudgetWindowEconomicImpactResult.completed( + json.loads(most_recent_impact["reform_impact_json"]) + ) + + execution_id = most_recent_impact.get("execution_id") + if not execution_id: + return BudgetWindowEconomicImpactResult.failed( + most_recent_impact.get("message") + or "Budget-window batch tracking row is missing execution_id", + queued_years=queued_years_on_submit, + ) + + try: + batch_execution = simulation_api.get_budget_window_batch_by_id(execution_id) + except Exception: + if impact_status == ImpactStatus.ERROR.value: + return BudgetWindowEconomicImpactResult.failed( + most_recent_impact.get("message") or "Budget-window batch failed", + queued_years=queued_years_on_submit, + ) + raise + + if batch_execution.status in EXECUTION_STATUSES_SUCCESS: + result = batch_execution.result or {} + self._set_reform_impact_complete( + setup_options=setup_options, + reform_impact_json=json.dumps(result), + execution_id=execution_id, + ) + return BudgetWindowEconomicImpactResult.completed(result) + + if batch_execution.status in EXECUTION_STATUSES_FAILURE: + error_message = batch_execution.error or ( + most_recent_impact.get("message") or "Budget-window batch failed" + ) + self._set_reform_impact_error( + setup_options=setup_options, + message=error_message, + execution_id=execution_id, + ) + return BudgetWindowEconomicImpactResult.failed( + error_message, + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years, + ) + + if batch_execution.status in EXECUTION_STATUSES_PENDING: + return self._build_budget_window_computing_result( + total_years=total_years, + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years, + progress=batch_execution.progress, + ) + + raise ValueError( + f"Unexpected budget-window batch execution state: {batch_execution.status}" + ) + + def _build_budget_window_computing_result( + self, + *, + total_years: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + progress: Optional[int] = None, + ) -> BudgetWindowEconomicImpactResult: + resolved_progress = progress + if resolved_progress is None: + resolved_progress = round((len(completed_years) / total_years) * 100) + + return BudgetWindowEconomicImpactResult.computing( + progress=resolved_progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=self._build_budget_window_progress_message( + completed_years=completed_years, + total_years=total_years, + computing_years=computing_years, + queued_years=queued_years, + ), + ) + + def _build_economic_impact_setup_options( + self, + *, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + time_period: str, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + ) -> EconomicImpactSetupOptions: + process_id: str = self._create_process_id() + cache_version = get_economy_impact_cache_version(country_id, api_version) + country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) + resolved_dataset = self._setup_data( + country_id=country_id, + region=region, + dataset=dataset, + ) + resolved_data_version = self._extract_dataset_version(resolved_dataset) + options_hash = self._build_options_hash( + options=options, + model_version=country_package_version, + dataset=resolved_dataset, + ) + + return EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": resolved_dataset, + "time_period": time_period, + "options": options, + "api_version": cache_version, + "target": target, + "model_version": country_package_version, + "policyengine_version": None, + "data_version": resolved_data_version, + "runtime_app_name": None, + "options_hash": options_hash, + } + ) + + def _get_or_create_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> EconomicImpactResult: + logger.log_struct( + { + "message": "Received request for economic impact; checking if already in reform_impacts table", + **setup_options.model_dump(), + }, + severity="INFO", + ) + + most_recent_impact: dict | None = self._get_most_recent_impact( + setup_options=setup_options + ) + + if most_recent_impact and self._should_refresh_cached_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ): + most_recent_impact = self._get_most_recent_impact(setup_options) + if ( + not most_recent_impact + or most_recent_impact.get("options_hash") != setup_options.options_hash + ): + most_recent_impact = None + + impact_action: ImpactAction = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPLETED: logger.log_struct( { - "message": "Received request for economic impact; checking if already in reform_impacts table", - **economic_impact_setup_options.model_dump(), + "message": "Found completed economic impact in db; returning result", + **setup_options.model_dump(), }, severity="INFO", ) - - most_recent_impact: dict | None = self._get_most_recent_impact( - setup_options=economic_impact_setup_options, + return self._handle_completed_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, ) - if most_recent_impact and self._should_refresh_cached_impact( - setup_options=economic_impact_setup_options, + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact record in db; confirming this is still computing", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, most_recent_impact=most_recent_impact, - ): - most_recent_impact = self._get_most_recent_impact( - economic_impact_setup_options - ) - if ( - not most_recent_impact - or most_recent_impact.get("options_hash") - != economic_impact_setup_options.options_hash - ): - most_recent_impact = None + ) - impact_action: ImpactAction = self._determine_impact_action( + if impact_action == ImpactAction.ERROR: + logger.log_struct( + { + "message": "Found failed economic impact in db; returning error", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_error_impact( + setup_options=setup_options, most_recent_impact=most_recent_impact, ) - if impact_action == ImpactAction.COMPLETED: - logger.log_struct( - { - "message": "Found completed economic impact in db; returning result", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_completed_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ) - - if impact_action == ImpactAction.COMPUTING: - logger.log_struct( - { - "message": "Found computing economic impact record in db; confirming this is still computing", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_computing_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ) + if impact_action == ImpactAction.CREATE: + self._resolve_runtime_bundle_for_setup_options(setup_options) + try: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): + most_recent_impact = self._get_exact_reform_impact( + setup_options=setup_options + ) + impact_action = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) - if impact_action == ImpactAction.CREATE: - if economic_impact_setup_options.runtime_app_name is None: - ( - economic_impact_setup_options.runtime_app_name, - economic_impact_setup_options.model_version, - ) = simulation_api.resolve_app_name( - country_id, - economic_impact_setup_options.model_version, + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( + { + "message": "Found completed economic impact in db after locking; returning result", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_completed_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact in db after locking; returning progress", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if impact_action == ImpactAction.ERROR: + logger.log_struct( + { + "message": "Found failed economic impact in db after locking; returning error", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + stale_provisional_execution_id = None + if self._is_stale_provisional_impact(most_recent_impact): + stale_provisional_execution_id = most_recent_impact.get( + "execution_id" + ) + + provisional_execution_id = self._build_provisional_execution_id( + setup_options.process_id ) - economic_impact_setup_options.options_hash = self._build_options_hash( - options=options, - model_version=economic_impact_setup_options.model_version, - dataset=resolved_dataset, - data_version=resolved_data_version, - runtime_app_name=economic_impact_setup_options.runtime_app_name, + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=provisional_execution_id, ) + if stale_provisional_execution_id: + self._expire_stale_provisional_impact( + setup_options=setup_options, + execution_id=stale_provisional_execution_id, + ) + except TimeoutError: logger.log_struct( { - "message": "No previous economic impact record found in db; creating new simulation run", - **economic_impact_setup_options.model_dump(), + "message": "Timed out waiting for economic impact claim lock; re-checking existing claim", + **setup_options.model_dump(), }, - severity="INFO", + severity="WARNING", ) - return self._handle_create_impact( - setup_options=economic_impact_setup_options, + existing_impact = self._get_existing_economic_impact( + setup_options=setup_options ) + if existing_impact is not None: + return existing_impact + return EconomicImpactResult.computing() - raise ValueError(f"Unexpected impact action: {impact_action}") + logger.log_struct( + { + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_create_impact( + setup_options=setup_options, + provisional_execution_id=provisional_execution_id, + ) - except Exception as e: - print(f"Error getting economic impact: {str(e)}") - raise e + raise ValueError(f"Unexpected impact action: {impact_action}") + + def _resolve_runtime_bundle_for_setup_options( + self, + setup_options: EconomicImpactSetupOptions, + ) -> None: + if setup_options.runtime_app_name is None: + ( + setup_options.runtime_app_name, + setup_options.model_version, + ) = simulation_api.resolve_app_name( + setup_options.country_id, + setup_options.model_version, + ) + + setup_options.options_hash = self._build_options_hash( + options=setup_options.options, + model_version=setup_options.model_version, + dataset=setup_options.dataset, + data_version=setup_options.data_version, + policyengine_version=setup_options.policyengine_version, + runtime_app_name=setup_options.runtime_app_name, + ) + + def _get_existing_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> Optional[EconomicImpactResult]: + most_recent_impact = self._get_exact_reform_impact(setup_options=setup_options) + if not most_recent_impact: + return None + + status = most_recent_impact.get("status") + if status == ImpactStatus.ERROR.value: + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if status == ImpactStatus.OK.value: + return self._handle_completed_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if status == ImpactStatus.COMPUTING.value: + if self._is_stale_provisional_impact(most_recent_impact): + return None + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + raise ValueError(f"Unknown impact status: {status}") + + def _get_economic_impact_error_message( + self, result: EconomicImpactResult, year: str + ) -> str: + if result.message: + return result.message + + if isinstance(result.data, dict): + data_message = result.data.get("message") + if isinstance(data_message, str) and data_message: + return data_message + + return f"Budget-window calculation failed for {year}" + + def _extract_budget_window_annual_impact( + self, year: str, impact_data: dict + ) -> dict[str, Union[str, int, float]]: + budget = impact_data.get("budget", {}) + state_tax_revenue_impact = budget.get("state_tax_revenue_impact", 0) + tax_revenue_impact = budget.get("tax_revenue_impact", 0) + + return { + "year": year, + "taxRevenueImpact": tax_revenue_impact, + "federalTaxRevenueImpact": tax_revenue_impact - state_tax_revenue_impact, + "stateTaxRevenueImpact": state_tax_revenue_impact, + "benefitSpendingImpact": budget.get("benefit_spending_impact", 0), + "budgetaryImpact": budget.get("budgetary_impact", 0), + } + + def _sum_budget_window_annual_impacts(self, annual_impacts: list[dict]) -> dict: + totals = { + "year": "Total", + "taxRevenueImpact": 0, + "federalTaxRevenueImpact": 0, + "stateTaxRevenueImpact": 0, + "benefitSpendingImpact": 0, + "budgetaryImpact": 0, + } + + for annual_impact in annual_impacts: + totals["taxRevenueImpact"] += annual_impact["taxRevenueImpact"] + totals["federalTaxRevenueImpact"] += annual_impact[ + "federalTaxRevenueImpact" + ] + totals["stateTaxRevenueImpact"] += annual_impact["stateTaxRevenueImpact"] + totals["benefitSpendingImpact"] += annual_impact["benefitSpendingImpact"] + totals["budgetaryImpact"] += annual_impact["budgetaryImpact"] + + return totals + + def _build_budget_window_output( + self, *, start_year: str, window_size: int, annual_impacts: list[dict] + ) -> dict: + return { + "kind": "budgetWindow", + "startYear": start_year, + "endYear": str(int(start_year) + window_size - 1), + "windowSize": window_size, + "annualImpacts": annual_impacts, + "totals": self._sum_budget_window_annual_impacts(annual_impacts), + } + + def _build_budget_window_progress_message( + self, + *, + completed_years: list[str], + total_years: int, + computing_years: list[str], + queued_years: list[str], + ) -> str: + completed_count = len(completed_years) + if computing_years: + active_years = ", ".join(computing_years[:2]) + if len(computing_years) > 2: + active_years = f"{active_years} + {len(computing_years) - 2} more" + return f"Scoring {active_years} ({completed_count} of {total_years} complete)..." + + if queued_years: + return f"Queued {queued_years[0]} ({completed_count} of {total_years} complete)..." + + return f"Scoring budget window ({completed_count} of {total_years} complete)..." def _get_previous_impacts( self, @@ -324,8 +952,37 @@ def _get_previous_impacts( api_version, ) ) + if previous_impacts: + return previous_impacts + + return reform_impacts_service.get_all_reform_impacts( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + ) - return previous_impacts + def _get_exact_reform_impact( + self, + setup_options: EconomicImpactSetupOptions, + ) -> dict | None: + previous_impacts = reform_impacts_service.get_all_reform_impacts( + setup_options.country_id, + setup_options.reform_policy_id, + setup_options.baseline_policy_id, + setup_options.region, + setup_options.dataset, + setup_options.time_period, + setup_options.options_hash, + setup_options.api_version, + ) + if not previous_impacts: + return None + return previous_impacts[0] def _get_most_recent_impact( self, @@ -355,6 +1012,62 @@ def _get_most_recent_impact( return previous_impacts[0] + def _build_provisional_execution_id(self, process_id: str) -> str: + return f"{PENDING_EXECUTION_ID_PREFIX}{process_id}" + + def _is_provisional_execution_id(self, execution_id: Any) -> bool: + return isinstance(execution_id, str) and execution_id.startswith( + PENDING_EXECUTION_ID_PREFIX + ) + + def _coerce_impact_start_time(self, start_time: Any) -> Optional[datetime.datetime]: + if start_time is None: + return None + + if isinstance(start_time, str): + parsed_start_time = datetime.datetime.fromisoformat(start_time) + elif hasattr(start_time, "tzinfo") and hasattr(start_time, "isoformat"): + parsed_start_time = start_time + else: + return None + + if parsed_start_time.tzinfo is None: + return parsed_start_time.replace(tzinfo=datetime.timezone.utc) + + return parsed_start_time.astimezone(datetime.timezone.utc) + + def _is_stale_provisional_impact(self, impact: dict | None) -> bool: + if not impact: + return False + + if not self._is_provisional_execution_id(impact.get("execution_id")): + return False + + start_time = self._coerce_impact_start_time(impact.get("start_time")) + if start_time is None: + return False + + current_time = datetime.datetime.now(datetime.timezone.utc) + if current_time.tzinfo is None: + current_time = current_time.replace(tzinfo=datetime.timezone.utc) + + claim_age = current_time - start_time + return claim_age.total_seconds() > PROVISIONAL_CLAIM_TTL_SECONDS + + def _expire_stale_provisional_impact( + self, + setup_options: EconomicImpactSetupOptions, + execution_id: str, + ) -> None: + if not self._is_provisional_execution_id(execution_id): + return + + self._set_reform_impact_error( + setup_options=setup_options, + message=STALE_PROVISIONAL_IMPACT_MESSAGE, + execution_id=execution_id, + ) + def _determine_impact_action( self, most_recent_impact: dict | None, @@ -363,9 +1076,13 @@ def _determine_impact_action( return ImpactAction.CREATE status = most_recent_impact.get("status") - if status in [ImpactStatus.OK.value, ImpactStatus.ERROR.value]: + if status == ImpactStatus.OK.value: return ImpactAction.COMPLETED + elif status == ImpactStatus.ERROR.value: + return ImpactAction.ERROR elif status == ImpactStatus.COMPUTING.value: + if self._is_stale_provisional_impact(most_recent_impact): + return ImpactAction.CREATE return ImpactAction.COMPUTING else: raise ValueError(f"Unknown impact status: {status}") @@ -444,14 +1161,30 @@ def _handle_completed_impact( ) ) - def _handle_computing_impact( + def _handle_error_impact( self, setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> EconomicImpactResult: - execution = simulation_api.get_execution_by_id( - most_recent_impact["execution_id"] + error_message = most_recent_impact.get("message") or ( + f"Economic impact failed for {setup_options.time_period}" + ) + return EconomicImpactResult( + status=ImpactStatus.ERROR, + data=None, + message=error_message, ) + + def _handle_computing_impact( + self, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + ) -> EconomicImpactResult: + execution_id = most_recent_impact["execution_id"] + if self._is_provisional_execution_id(execution_id): + return EconomicImpactResult.computing() + + execution = simulation_api.get_execution_by_id(execution_id) execution_state = simulation_api.get_execution_status(execution) return self._handle_execution_state( execution_state=execution_state, @@ -463,65 +1196,76 @@ def _handle_computing_impact( def _handle_create_impact( self, setup_options: EconomicImpactSetupOptions, + provisional_execution_id: str, ) -> EconomicImpactResult: - baseline_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.baseline_policy_id - ) - reform_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.reform_policy_id - ) + try: + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.baseline_policy_id + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.reform_policy_id + ) - sim_config: SimulationOptions = self._setup_sim_options( - country_id=setup_options.country_id, - reform_policy=reform_policy, - baseline_policy=baseline_policy, - region=setup_options.region, - time_period=setup_options.time_period, - dataset=setup_options.dataset, - scope="macro", - include_cliffs=setup_options.target == "cliff", - model_version=setup_options.model_version, - data_version=setup_options.data_version, - ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=setup_options.time_period, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=setup_options.target == "cliff", + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) - sim_params = sim_config.model_dump(mode="json") - telemetry = self._build_simulation_telemetry( - setup_options=setup_options, - sim_config=sim_params, - ) + sim_params = sim_config.model_dump(mode="json") + telemetry = self._build_simulation_telemetry( + setup_options=setup_options, + sim_config=sim_params, + ) - logger.log_struct( - { - "message": "Setting up sim API job", - "run_id": telemetry["run_id"], - **setup_options.model_dump(), + logger.log_struct( + { + "message": "Setting up sim API job", + "run_id": telemetry["run_id"], + **setup_options.model_dump(), + } + ) + + # Preserve both legacy metadata and the new telemetry envelope. + sim_params["_metadata"] = { + "reform_policy_id": setup_options.reform_policy_id, + "baseline_policy_id": setup_options.baseline_policy_id, + "process_id": setup_options.process_id, + "model_version": setup_options.model_version, + "policyengine_version": setup_options.policyengine_version, + "data_version": setup_options.data_version, + "dataset": setup_options.dataset, + "resolved_app_name": setup_options.runtime_app_name, } - ) + sim_params["_telemetry"] = telemetry + + # The simulation gateway (policyengine-api-v2 PR #458) requires + # ``time_period`` as a string, but the upstream ``policyengine`` + # package (``TimePeriodType = int``) coerces the value to int during + # ``model_validate`` and ``model_dump`` re-emits it as int. Cast back + # to str at the request boundary so the gateway's strict schema + # validates instead of returning 422. + if sim_params.get("time_period") is not None: + sim_params["time_period"] = str(sim_params["time_period"]) + + sim_api_execution = simulation_api.run(sim_params) + execution_id = simulation_api.get_execution_id(sim_api_execution) + except Exception as error: + error_message = f"Failed to start simulation API job: {str(error)}" + self._set_reform_impact_error( + setup_options=setup_options, + message=error_message, + execution_id=provisional_execution_id, + ) + return EconomicImpactResult.error(message=error_message) - # Preserve both legacy metadata and the new telemetry envelope. - sim_params["_metadata"] = { - "reform_policy_id": setup_options.reform_policy_id, - "baseline_policy_id": setup_options.baseline_policy_id, - "process_id": setup_options.process_id, - "model_version": setup_options.model_version, - "policyengine_version": setup_options.policyengine_version, - "data_version": setup_options.data_version, - "dataset": setup_options.dataset, - "resolved_app_name": setup_options.runtime_app_name, - } - sim_params["_telemetry"] = telemetry - - # The simulation gateway (policyengine-api-v2 PR #458) requires - # ``time_period`` as a string, but the upstream ``policyengine`` - # package (``TimePeriodType = int``) coerces the value to int during - # ``model_validate`` and ``model_dump`` re-emits it as int. Cast back - # to str at the request boundary so the gateway's strict schema - # validates instead of returning 422. - if sim_params.get("time_period") is not None: - sim_params["time_period"] = str(sim_params["time_period"]) - - sim_api_execution = simulation_api.run(sim_params) - execution_id = simulation_api.get_execution_id(sim_api_execution) run_id = getattr(sim_api_execution, "run_id", None) or telemetry["run_id"] progress_log = { @@ -532,13 +1276,117 @@ def _handle_create_impact( } logger.log_struct(progress_log, severity="INFO") - self._set_reform_impact_computing( - setup_options=setup_options, - execution_id=execution_id, - ) + try: + updated_rows = self._update_reform_impact_execution_id( + setup_options=setup_options, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, + ) + except Exception as error: + logger.log_struct( + { + "message": "Failed to promote provisional reform impact row; inserting replacement tracking row", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "error": str(error), + }, + severity="WARNING", + ) + updated_rows = 0 + + if updated_rows != 1: + self._recover_failed_execution_id_promotion( + setup_options=setup_options, + provisional_execution_id=provisional_execution_id, + execution_id=execution_id, + updated_rows=updated_rows, + ) return EconomicImpactResult.computing() + def _recover_failed_execution_id_promotion( + self, + *, + setup_options: EconomicImpactSetupOptions, + provisional_execution_id: str, + execution_id: str, + updated_rows: int | None, + ) -> None: + logger.log_struct( + { + "message": "Provisional reform impact row was not updated; checking whether tracking has already been superseded", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "updated_rows": updated_rows, + }, + severity="WARNING", + ) + + try: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): + most_recent_impact = self._get_exact_reform_impact( + setup_options=setup_options + ) + if most_recent_impact is not None: + impact_status = most_recent_impact.get("status") + tracked_execution_id = most_recent_impact.get("execution_id") + if tracked_execution_id == execution_id: + return + + if ( + impact_status == ImpactStatus.COMPUTING.value + and tracked_execution_id == provisional_execution_id + ): + retry_updated_rows = self._update_reform_impact_execution_id( + setup_options=setup_options, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, + ) + if retry_updated_rows == 1: + return + elif impact_status in ( + ImpactStatus.OK.value, + ImpactStatus.COMPUTING.value, + ): + logger.log_struct( + { + "message": "Skipping replacement tracking row because another claim is already authoritative", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "tracked_execution_id": tracked_execution_id, + "tracked_status": impact_status, + }, + severity="WARNING", + ) + return + + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=execution_id, + ) + except TimeoutError: + logger.log_struct( + { + "message": "Timed out while recovering failed provisional promotion; leaving the newer claim authoritative", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + }, + severity="WARNING", + ) + def _setup_sim_options( self, country_id: str, @@ -619,7 +1467,7 @@ def _should_refresh_cached_impact( setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> bool: - if most_recent_impact.get("status") == ImpactStatus.COMPUTING.value: + if most_recent_impact.get("status") != ImpactStatus.OK.value: return False cached_result = self._extract_cached_result(most_recent_impact) @@ -855,6 +1703,9 @@ def _set_reform_impact_computing( In the reform_impact table, set the status of the impact to "computing". """ try: + start_time = datetime.datetime.now(datetime.timezone.utc).replace( + tzinfo=None + ) reform_impacts_service.set_reform_impact( country_id=setup_options.country_id, policy_id=setup_options.reform_policy_id, @@ -867,7 +1718,7 @@ def _set_reform_impact_computing( status=ImpactStatus.COMPUTING.value, api_version=setup_options.api_version, reform_impact_json=json.dumps({}), - start_time=datetime.datetime.now(), + start_time=start_time, execution_id=execution_id, ) except Exception as e: @@ -879,6 +1730,33 @@ def _set_reform_impact_computing( ) raise e + def _update_reform_impact_execution_id( + self, + setup_options: EconomicImpactSetupOptions, + current_execution_id: str, + new_execution_id: str, + ) -> int | None: + try: + return reform_impacts_service.update_reform_impact_execution_id( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + current_execution_id=current_execution_id, + new_execution_id=new_execution_id, + ) + except Exception as e: + logger.log_struct( + { + "message": f"Error updating reform impact execution id: {str(e)}", + **setup_options.model_dump(), + } + ) + raise e + def _set_reform_impact_complete( self, setup_options: EconomicImpactSetupOptions, diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 0f41352f3..91928495e 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -1,14 +1,130 @@ -from policyengine_api.data import local_database +from contextlib import contextmanager +import hashlib +from threading import Lock +from policyengine_api.data import database import datetime +LOCAL_REFORM_IMPACT_LOCK = Lock() +REFORM_IMPACT_SCHEMA_LOCK = Lock() +REFORM_IMPACT_LOCK_TIMEOUT_SECONDS = 5 + + class ReformImpactsService: """ Service for storing and retrieving economy-wide reform impacts; - this is connected to the locally-stored reform_impact table - and no existing route + this is connected to the shared reform_impact table. """ + def __init__(self): + self._schema_checked = False + + def _ensure_remote_schema(self) -> None: + if database.local or self._schema_checked: + return + + with REFORM_IMPACT_SCHEMA_LOCK: + if self._schema_checked: + return + + existing_columns = { + row["Field"] + for row in database.query("SHOW COLUMNS FROM reform_impact").fetchall() + } + required_columns = { + "dataset": ( + "ALTER TABLE reform_impact " + "ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'" + ), + "execution_id": ( + "ALTER TABLE reform_impact " + "ADD COLUMN execution_id VARCHAR(255) NULL" + ), + "end_time": ( + "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL" + ), + } + + for column_name, alter_query in required_columns.items(): + if column_name in existing_columns: + continue + try: + database.query(alter_query) + except Exception as error: + if "Duplicate column name" not in str(error): + raise + + self._schema_checked = True + + def _build_lock_name( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + ) -> str: + raw_key = ( + f"{country_id}:{policy_id}:{baseline_policy_id}:{region}:{dataset}:" + f"{time_period}:{options_hash}:{api_version}" + ) + digest = hashlib.sha256(raw_key.encode("utf-8")).hexdigest() + return f"ri:{digest[:61]}" + + @contextmanager + def claim_lock( + self, + *, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + timeout_seconds: int = REFORM_IMPACT_LOCK_TIMEOUT_SECONDS, + ): + if database.local: + with LOCAL_REFORM_IMPACT_LOCK: + yield + return + + lock_name = self._build_lock_name( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options_hash=options_hash, + api_version=api_version, + ) + with database.pool.connect() as conn: + acquired = ( + conn.exec_driver_sql( + "SELECT GET_LOCK(%s, %s) AS acquired", + (lock_name, timeout_seconds), + ) + .mappings() + .first() + ) + if acquired is None or acquired["acquired"] != 1: + raise TimeoutError( + f"Could not acquire reform impact lock for {country_id}/{policy_id}/{time_period}" + ) + + try: + yield + finally: + conn.exec_driver_sql( + "SELECT RELEASE_LOCK(%s) AS released", (lock_name,) + ) + conn.commit() + def get_all_reform_impacts( self, country_id, @@ -21,13 +137,15 @@ def get_all_reform_impacts( api_version, ): try: + self._ensure_remote_schema() query = ( "SELECT reform_impact_json, status, message, start_time, execution_id FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " - "options_hash = ? AND api_version = ? AND dataset = ?" + "options_hash = ? AND api_version = ? AND dataset = ? " + "ORDER BY start_time DESC, reform_impact_id DESC" ) - return local_database.query( + return database.query( query, ( country_id, @@ -57,6 +175,7 @@ def get_all_reform_impacts_by_options_hash_prefix( api_version, ): try: + self._ensure_remote_schema() query = ( "SELECT reform_impact_json, status, message, start_time, execution_id, options_hash FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " @@ -64,7 +183,7 @@ def get_all_reform_impacts_by_options_hash_prefix( "(options_hash = ? OR options_hash LIKE ? ESCAPE '\\') AND api_version = ? AND dataset = ? " "ORDER BY CASE WHEN options_hash = ? THEN 0 ELSE 1 END, start_time DESC" ) - return local_database.query( + return database.query( query, ( country_id, @@ -100,12 +219,13 @@ def set_reform_impact( execution_id: str, ): try: + self._ensure_remote_schema() query = ( "INSERT INTO reform_impact (country_id, reform_policy_id, baseline_policy_id, " "region, dataset, time_period, options_json, options_hash, status, api_version, " "reform_impact_json, start_time, execution_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ) - local_database.query( + database.query( query, ( country_id, @@ -127,6 +247,45 @@ def set_reform_impact( print(f"Error setting reform impact: {str(e)}") raise e + def update_reform_impact_execution_id( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + current_execution_id, + new_execution_id, + ): + try: + self._ensure_remote_schema() + query = ( + "UPDATE reform_impact SET execution_id = ? WHERE country_id = ? AND " + "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " + "time_period = ? AND options_hash = ? AND dataset = ? AND " + "execution_id = ? AND status = 'computing'" + ) + result = database.query( + query, + ( + new_execution_id, + country_id, + policy_id, + baseline_policy_id, + region, + time_period, + options_hash, + dataset, + current_execution_id, + ), + ) + return getattr(result, "rowcount", None) + except Exception as e: + print(f"Error updating reform impact execution id: {str(e)}") + raise e + def delete_reform_impact( self, country_id, @@ -138,6 +297,7 @@ def delete_reform_impact( options_hash, ): try: + self._ensure_remote_schema() query = ( "DELETE FROM reform_impact WHERE country_id = ? AND " "reform_policy_id = ? AND baseline_policy_id = ? AND " @@ -145,7 +305,7 @@ def delete_reform_impact( "dataset = ? AND status = 'computing'" ) - local_database.query( + database.query( query, ( country_id, @@ -174,13 +334,14 @@ def set_error_reform_impact( execution_id: str, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ? WHERE " "country_id = ? AND reform_policy_id = ? AND baseline_policy_id = ? AND " "region = ? AND time_period = ? AND options_hash = ? AND dataset = ? AND " "execution_id = ?" ) - local_database.query( + database.query( query, ( "error", @@ -218,13 +379,14 @@ def set_complete_reform_impact( execution_id, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ?, " "reform_impact_json = ? WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " "options_hash = ? AND dataset = ? AND execution_id = ?" ) - local_database.query( + database.query( query, ( "ok", diff --git a/tests/fixtures/libs/simulation_api_modal.py b/tests/fixtures/libs/simulation_api_modal.py index 6d514a7e5..a9b4ce45e 100644 --- a/tests/fixtures/libs/simulation_api_modal.py +++ b/tests/fixtures/libs/simulation_api_modal.py @@ -19,6 +19,7 @@ # Mock data constants MOCK_MODAL_JOB_ID = "fc-abc123xyz" MOCK_RUN_ID = "run-abc123xyz" +MOCK_BATCH_JOB_ID = "fc-batch123xyz" MOCK_MODAL_BASE_URL = "https://test-modal-api.modal.run" MOCK_SIMULATION_PAYLOAD = { @@ -87,6 +88,54 @@ MOCK_HEALTH_RESPONSE = {"status": "healthy"} +MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS = { + "batch_job_id": MOCK_BATCH_JOB_ID, + "status": MODAL_EXECUTION_STATUS_SUBMITTED, + "poll_url": f"/budget-window-jobs/{MOCK_BATCH_JOB_ID}", + "country": "us", + "version": "1.459.0", +} + +MOCK_BATCH_POLL_RESPONSE_RUNNING = { + "status": MODAL_EXECUTION_STATUS_RUNNING, + "progress": 33, + "completed_years": ["2026"], + "running_years": ["2027"], + "queued_years": ["2028"], + "failed_years": [], + "result": None, + "error": None, +} + +MOCK_BATCH_POLL_RESPONSE_COMPLETE = { + "status": MODAL_EXECUTION_STATUS_COMPLETE, + "progress": 100, + "completed_years": ["2026", "2027", "2028"], + "running_years": [], + "queued_years": [], + "failed_years": [], + "result": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + }, + "error": None, +} + +MOCK_BATCH_POLL_RESPONSE_FAILED = { + "status": MODAL_EXECUTION_STATUS_FAILED, + "progress": 33, + "completed_years": ["2026"], + "running_years": [], + "queued_years": ["2028"], + "failed_years": ["2027"], + "result": None, + "error": "Budget window failed", +} + def create_mock_httpx_response( status_code: int = 200, diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index cf41873ed..f29eb88b9 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch, MagicMock import json import datetime +from contextlib import nullcontext from policyengine_api.constants import ( MODAL_EXECUTION_STATUS_SUBMITTED, @@ -123,8 +124,10 @@ def mock_reform_impacts_service(): mock_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] mock_service.get_all_reform_impacts.return_value = [] mock_service.set_reform_impact.return_value = None + mock_service.update_reform_impact_execution_id.return_value = 1 mock_service.set_complete_reform_impact.return_value = None mock_service.set_error_reform_impact.return_value = None + mock_service.claim_lock.side_effect = lambda **kwargs: nullcontext() with patch( "policyengine_api.services.economy_service.reform_impacts_service", @@ -138,6 +141,7 @@ def mock_simulation_api(): """Mock SimulationAPIModal with all required methods.""" mock_api = MagicMock() mock_execution = create_mock_modal_execution() + mock_batch_execution = create_mock_budget_window_batch_execution() mock_api._setup_sim_options.return_value = MOCK_SIM_CONFIG mock_api.run.return_value = mock_execution @@ -149,6 +153,8 @@ def mock_simulation_api(): mock_api.get_execution_by_id.return_value = mock_execution mock_api.get_execution_status.return_value = MODAL_EXECUTION_STATUS_RUNNING mock_api.get_execution_result.return_value = MOCK_REFORM_IMPACT_DATA + mock_api.run_budget_window_batch.return_value = mock_batch_execution + mock_api.get_budget_window_batch_by_id.return_value = mock_batch_execution with patch( "policyengine_api.services.economy_service.simulation_api", mock_api @@ -187,6 +193,9 @@ def create_mock_reform_impact( reform_impact_json=None, execution_id=MOCK_MODAL_JOB_ID, options_hash=MOCK_OPTIONS_HASH, + start_time=None, + time_period=MOCK_TIME_PERIOD, + message=None, ): """Helper function to create mock reform impact records.""" default_reform_impact_json = json.dumps( @@ -208,13 +217,14 @@ def create_mock_reform_impact( "baseline_policy_id": MOCK_BASELINE_POLICY_ID, "region": MOCK_REGION, "dataset": MOCK_RESOLVED_DATASET, - "time_period": MOCK_TIME_PERIOD, + "time_period": time_period, "options_hash": options_hash, "status": status, "api_version": MOCK_API_VERSION, "reform_impact_json": reform_impact_json or default_reform_impact_json, "execution_id": execution_id, - "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), + "message": message, + "start_time": start_time or datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None ), @@ -259,6 +269,32 @@ def create_mock_modal_execution( return mock_execution +def create_mock_budget_window_batch_execution( + batch_job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + progress=None, + completed_years=None, + running_years=None, + queued_years=None, + failed_years=None, + result=None, + error=None, +): + """Helper function to create mock batch execution objects.""" + mock_execution = MagicMock() + mock_execution.batch_job_id = batch_job_id + mock_execution.name = batch_job_id + mock_execution.status = status + mock_execution.progress = progress + mock_execution.completed_years = completed_years or [] + mock_execution.running_years = running_years or [] + mock_execution.queued_years = queued_years or [] + mock_execution.failed_years = failed_years or [] + mock_execution.result = result + mock_execution.error = error + return mock_execution + + @pytest.fixture def mock_simulation_api_modal(): """Mock SimulationAPIModal with all required methods.""" diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py new file mode 100644 index 000000000..f185f8c28 --- /dev/null +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -0,0 +1,133 @@ +import json +from unittest.mock import Mock, patch + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_rejects_cliff_target( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=10&target=cliff" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "target=general" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&start_year=2026" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_integer_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=abc" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be an integer" == data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +def test_budget_window_route_rejects_oversized_window(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=999" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be between 1 and" in data["message"] + + +def test_budget_window_route_rejects_end_year_after_2099(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2090&window_size=20" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "budget-window end_year must be 2099 or earlier" == data["message"] + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_passes_version_to_service( + mock_get_budget_window_economic_impact, rest_client +): + mock_result = Mock() + mock_result.to_dict.return_value = { + "status": "ok", + "message": None, + "data": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2027", + "windowSize": 2, + "annualImpacts": [], + "totals": {}, + }, + "progress": 100, + "completed_years": ["2026", "2027"], + "computing_years": [], + "queued_years": [], + "error": None, + } + mock_get_budget_window_economic_impact.return_value = mock_result + + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=2&version=1.2.3" + ) + + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + mock_get_budget_window_economic_impact.assert_called_once_with( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="default", + start_year="2026", + window_size=2, + options={}, + api_version="1.2.3", + target="general", + ) diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index 3882bb0f7..2ea63f0f0 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -12,6 +12,7 @@ import pytest import sqlalchemy +from unittest.mock import MagicMock from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase @@ -180,3 +181,34 @@ def test_remote_delete(self): db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)]) result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) assert result.fetchone() is None + + +class TestRemotePoolCreation: + def test_create_pool_uses_fresh_connection_creator(self, monkeypatch): + first_connection = MagicMock(name="first_connection") + second_connection = MagicMock(name="second_connection") + mock_connector = MagicMock() + mock_connector.connect.side_effect = [first_connection, second_connection] + + captured_kwargs = {} + + def fake_create_engine(url, **kwargs): + captured_kwargs.update(kwargs) + return MagicMock() + + monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", "test-password") + monkeypatch.setattr( + "policyengine_api.data.data.Connector", lambda: mock_connector + ) + monkeypatch.setattr( + "policyengine_api.data.data.sqlalchemy.create_engine", + fake_create_engine, + ) + + db = PolicyEngineDatabase.__new__(PolicyEngineDatabase) + db._create_pool() + + creator = captured_kwargs["creator"] + assert creator() is first_connection + assert creator() is second_connection + assert captured_kwargs["pool_pre_ping"] is True diff --git a/tests/unit/endpoints/test_simulation.py b/tests/unit/endpoints/test_simulation.py new file mode 100644 index 000000000..e2936de11 --- /dev/null +++ b/tests/unit/endpoints/test_simulation.py @@ -0,0 +1,20 @@ +from unittest.mock import MagicMock, patch + +from policyengine_api.endpoints.simulation import get_simulations + + +def test_get_simulations_reads_from_remote_database(): + mock_database = MagicMock() + mock_database.query.return_value.fetchall.return_value = [{"id": 1}] + + with patch( + "policyengine_api.endpoints.simulation.get_remote_database", + return_value=mock_database, + ): + result = get_simulations() + + mock_database.query.assert_called_once_with( + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", + (100,), + ) + assert result == {"result": [{"id": 1}]} diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 26b321135..c80ab373e 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -21,6 +21,7 @@ os.environ.setdefault("FLASK_DEBUG", "1") from policyengine_api.libs.simulation_api_modal import ( # noqa: E402 + ModalBudgetWindowBatchExecution, ModalSimulationExecution, SimulationAPIModal, ) @@ -32,6 +33,7 @@ ) from tests.fixtures.libs.simulation_api_modal import ( # noqa: E402 MOCK_MODAL_JOB_ID, + MOCK_BATCH_JOB_ID, MOCK_MODAL_BASE_URL, MOCK_SIMULATION_PAYLOAD, MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY, @@ -44,6 +46,10 @@ MOCK_POLL_RESPONSE_COMPLETE, MOCK_POLL_RESPONSE_FAILED, MOCK_HEALTH_RESPONSE, + MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + MOCK_BATCH_POLL_RESPONSE_RUNNING, + MOCK_BATCH_POLL_RESPONSE_COMPLETE, + MOCK_BATCH_POLL_RESPONSE_FAILED, create_mock_httpx_response, ) @@ -117,6 +123,18 @@ def test__given_failed_execution__then_error_attribute_populated(self): assert execution.result is None +class TestModalBudgetWindowBatchExecution: + """Tests for the ModalBudgetWindowBatchExecution dataclass.""" + + def test__given_batch_job_id__then_name_returns_batch_job_id(self): + execution = ModalBudgetWindowBatchExecution( + batch_job_id=MOCK_BATCH_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + ) + + assert execution.name == MOCK_BATCH_JOB_ID + + class TestSimulationAPIModal: """Tests for the SimulationAPIModal class.""" @@ -322,6 +340,40 @@ def test__given_country_and_version__then_returns_registered_app( assert app_name == MOCK_RESOLVED_APP_NAME assert resolved_version == "1.459.0" + class TestRunBudgetWindowBatch: + def test__given_valid_payload__then_returns_batch_execution( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + ) + api = SimulationAPIModal() + + execution = api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + + assert execution.batch_job_id == MOCK_BATCH_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_SUBMITTED + call_args = mock_httpx_client.post.call_args + assert "/simulate/economy/budget-window" in call_args[0][0] + + def test__given_http_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_response = create_mock_httpx_response( + status_code=400, + json_data={"error": "Invalid request"}, + ) + mock_httpx_client.post.return_value = mock_response + api = SimulationAPIModal() + + with pytest.raises(httpx.HTTPStatusError): + api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + class TestGetExecutionById: def test__given_running_job__then_returns_running_status( self, @@ -416,6 +468,59 @@ def test__given_unexpected_http_error__then_raises_exception( with pytest.raises(httpx.HTTPStatusError): api.get_execution_by_id(MOCK_MODAL_JOB_ID) + class TestGetBudgetWindowBatchById: + def test__given_running_batch__then_returns_running_status( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_POLL_RESPONSE_RUNNING, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.batch_job_id == MOCK_BATCH_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_RUNNING + assert execution.completed_years == ["2026"] + assert execution.running_years == ["2027"] + assert execution.queued_years == ["2028"] + + def test__given_complete_batch__then_returns_result( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data=MOCK_BATCH_POLL_RESPONSE_COMPLETE, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.status == MODAL_EXECUTION_STATUS_COMPLETE + assert execution.result == MOCK_BATCH_POLL_RESPONSE_COMPLETE["result"] + + def test__given_failed_batch__then_returns_error( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=500, + json_data=MOCK_BATCH_POLL_RESPONSE_FAILED, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.status == MODAL_EXECUTION_STATUS_FAILED + assert execution.failed_years == ["2027"] + assert execution.error == "Budget window failed" + class TestGetExecutionId: def test__given_execution__then_returns_job_id(self, mock_httpx_client): # Given diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index d036ab296..fd9c64416 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,14 +1,74 @@ +import datetime import json +import sys import pytest from unittest.mock import patch, MagicMock from typing import Literal +from types import ModuleType + +try: + from policyengine.simulation import SimulationOptions # noqa: F401 +except ModuleNotFoundError: + policyengine_module = sys.modules.setdefault( + "policyengine", ModuleType("policyengine") + ) + simulation_module = ModuleType("policyengine.simulation") + utils_module = ModuleType("policyengine.utils") + data_module = ModuleType("policyengine.utils.data") + datasets_module = ModuleType("policyengine.utils.data.datasets") + + class _StubSimulationOptions: + def __init__(self, payload): + self._payload = payload + + @classmethod + def model_validate(cls, payload): + return cls(payload) + + def model_dump(self): + return dict(self._payload) + + simulation_module.SimulationOptions = _StubSimulationOptions + policyengine_module.simulation = simulation_module + + def _stub_get_default_dataset(country, region): + if country == "us": + if region == "us": + return "gs://policyengine-us-data/enhanced_cps_2024.h5" + if region == "state/ca": + return "gs://policyengine-us-data/states/CA.h5" + if region == "state/ut": + return "gs://policyengine-us-data/states/UT.h5" + if region == "place/NJ-57000": + return "gs://policyengine-us-data/states/NJ.h5" + if region == "congressional_district/CA-37": + return "gs://policyengine-us-data/districts/CA-37.h5" + if country == "uk" and region == "uk": + return "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" + raise ValueError( + f"Error getting default dataset for country={country}, region={region}: unsupported in test stub" + ) + + datasets_module.get_default_dataset = _stub_get_default_dataset + data_module.datasets = datasets_module + utils_module.data = data_module + policyengine_module.utils = utils_module + sys.modules["policyengine.simulation"] = simulation_module + sys.modules["policyengine.utils"] = utils_module + sys.modules["policyengine.utils.data"] = data_module + sys.modules["policyengine.utils.data.datasets"] = datasets_module from policyengine_api.services.economy_service import ( + BUDGET_WINDOW_MAX_END_YEAR, + BUDGET_WINDOW_MAX_YEARS, EconomyService, EconomicImpactResult, EconomicImpactSetupOptions, ImpactAction, ImpactStatus, + PENDING_EXECUTION_ID_PREFIX, + PROVISIONAL_CLAIM_TTL_SECONDS, + STALE_PROVISIONAL_IMPACT_MESSAGE, ) from tests.fixtures.services.economy_service import ( MOCK_COUNTRY_ID, @@ -30,12 +90,30 @@ MOCK_REFORM_IMPACT_DATA, MOCK_RESOLVED_DATASET, MOCK_RESOLVED_APP_NAME, + create_mock_budget_window_batch_execution, create_mock_reform_impact, ) pytest_plugins = ("tests.fixtures.services.economy_service",) +def make_mock_budget_impact_data( + *, + tax_revenue_impact: int, + state_tax_revenue_impact: int, + benefit_spending_impact: int, + budgetary_impact: int, +): + return { + "budget": { + "tax_revenue_impact": tax_revenue_impact, + "state_tax_revenue_impact": state_tax_revenue_impact, + "benefit_spending_impact": benefit_spending_impact, + "budgetary_impact": budgetary_impact, + } + } + + class TestEconomyService: class TestGetEconomicImpact: @pytest.fixture @@ -125,6 +203,37 @@ def test__given_legacy_completed_impact__refreshes_cache( ) mock_simulation_api.run.assert_called_once() + def test__given_error_impact__returns_error_result( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + error_impact = create_mock_reform_impact( + status="error", + reform_impact_json=json.dumps({}), + ) + error_impact["message"] = "Failed to start simulation API job" + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ + error_impact + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.data is None + assert result.message == "Failed to start simulation API job" + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.assert_called_once() + mock_simulation_api.run.assert_not_called() + def test__given_computing_impact_with_succeeded_execution__returns_completed_result( self, economy_service, @@ -238,6 +347,21 @@ def test__given_no_previous_impact__creates_new_simulation( assert result.data is None mock_simulation_api.run.assert_called_once() mock_reform_impacts_service.set_reform_impact.assert_called_once() + assert any( + call.args == (datetime.timezone.utc,) + for call in mock_datetime.now.call_args_list + ) + mock_reform_impacts_service.update_reform_impact_execution_id.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_RESOLVED_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + current_execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + new_execution_id=MOCK_EXECUTION_ID, + ) def test__given_no_previous_impact__includes_metadata_in_simulation_params( self, @@ -309,6 +433,226 @@ def test__given_no_previous_impact__includes_telemetry_in_simulation_params( mock_logger.log_struct.call_args_list[-1].kwargs["severity"] == "INFO" ) + def test__given_simulation_api_submission_failure__marks_provisional_claim_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_simulation_api.run.side_effect = RuntimeError("gateway unavailable") + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Failed to start simulation API job: gateway unavailable" + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_RESOLVED_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message="Failed to start simulation API job: gateway unavailable", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + ) + mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + + def test__given_simulation_setup_failure__marks_provisional_claim_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + with patch.object( + economy_service, + "_setup_sim_options", + side_effect=ValueError("Invalid US state: 'zz'"), + ): + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Failed to start simulation API job: Invalid US state: 'zz'" + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_RESOLVED_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message="Failed to start simulation API job: Invalid US state: 'zz'", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + ) + mock_simulation_api.run.assert_not_called() + mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + + def test__given_claim_lock_timeout_and_existing_provisional_claim__returns_computing( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_numpy_random, + ): + provisional_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_other", + start_time=datetime.datetime.now(datetime.timezone.utc), + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [provisional_impact], + ] + mock_reform_impacts_service.claim_lock.side_effect = TimeoutError( + "lock busy" + ) + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.run.assert_not_called() + + def test__given_stale_provisional_claim__expires_and_recreates_simulation( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + ): + stale_start_time = datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1) + stale_provisional_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=stale_start_time, + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [stale_provisional_impact], + [stale_provisional_impact], + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_RESOLVED_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message=STALE_PROVISIONAL_IMPACT_MESSAGE, + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_simulation_api.run.assert_called_once() + + def test__given_provisional_promotion_updates_zero_rows__inserts_replacement_tracking_row( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 2 + first_insert = mock_reform_impacts_service.set_reform_impact.call_args_list[ + 0 + ] + second_insert = ( + mock_reform_impacts_service.set_reform_impact.call_args_list[1] + ) + assert ( + first_insert.kwargs["execution_id"] + == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" + ) + assert second_insert.kwargs["execution_id"] == MOCK_EXECUTION_ID + + def test__given_provisional_promotion_updates_zero_rows_but_newer_claim_exists__does_not_insert_fallback( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + replacement_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_replacement", + start_time=datetime.datetime.now(datetime.timezone.utc), + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [], + [replacement_impact], + ] + mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 1 + inserted_execution_id = ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "execution_id" + ] + ) + assert ( + inserted_execution_id + == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" + ) + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, @@ -617,6 +961,336 @@ def test__given_uk_request__preserves_model_version_in_bundle( sim_params = mock_simulation_api.run.call_args[0][0] assert sim_params["_metadata"]["model_version"] == "2.7.8" + class TestGetBudgetWindowEconomicImpact: + @pytest.fixture + def economy_service( + self, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + return EconomyService() + + @pytest.fixture + def base_params(self): + return { + "country_id": MOCK_COUNTRY_ID, + "policy_id": MOCK_POLICY_ID, + "baseline_policy_id": MOCK_BASELINE_POLICY_ID, + "region": MOCK_REGION, + "dataset": MOCK_DATASET, + "start_year": "2026", + "window_size": 3, + "options": MOCK_OPTIONS, + "api_version": MOCK_API_VERSION, + "target": "general", + } + + def test__given_no_tracking_row__submits_parent_batch_and_returns_queued_result( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + batch_execution = create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="submitted", + ) + mock_simulation_api.run_budget_window_batch.return_value = batch_execution + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 0 + assert result.completed_years == [] + assert result.computing_years == [] + assert result.queued_years == ["2026", "2027", "2028"] + assert "Queued 2026" in result.message + mock_simulation_api.run_budget_window_batch.assert_called_once() + submitted_payload = ( + mock_simulation_api.run_budget_window_batch.call_args.args[0] + ) + assert submitted_payload["start_year"] == "2026" + assert submitted_payload["window_size"] == 3 + assert submitted_payload["max_parallel"] == 20 + assert submitted_payload["target"] == "general" + assert "time_period" not in submitted_payload + mock_reform_impacts_service.set_reform_impact.assert_called_once() + assert ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "execution_id" + ] + == "fc-budget-123" + ) + + def test__given_completed_tracking_row__returns_completed_batch_result( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [ + { + "year": "2026", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + } + ], + "totals": { + "year": "Total", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + }, + } + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="ok", + execution_id="fc-budget-123", + reform_impact_json=json.dumps(completed_result), + time_period="budget_window:2026:3", + ) + ] + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert result.progress == 100 + assert result.data == completed_result + mock_simulation_api.get_budget_window_batch_by_id.assert_not_called() + + def test__given_running_tracking_row__returns_running_batch_progress( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", + ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="running", + progress=33, + completed_years=["2026"], + running_years=["2027"], + queued_years=["2028"], + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 33 + assert result.completed_years == ["2026"] + assert result.computing_years == ["2027"] + assert result.queued_years == ["2028"] + assert "1 of 3 complete" in result.message + + def test__given_completed_batch_poll__persists_result_and_returns_completed( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + } + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", + ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="complete", + progress=100, + completed_years=["2026", "2027", "2028"], + result=completed_result, + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert result.data == completed_result + mock_reform_impacts_service.set_complete_reform_impact.assert_called_once() + call_kwargs = ( + mock_reform_impacts_service.set_complete_reform_impact.call_args.kwargs + ) + assert call_kwargs["execution_id"] == "fc-budget-123" + assert json.loads(call_kwargs["reform_impact_json"]) == completed_result + + def test__given_failed_batch_poll__persists_error_and_returns_failed( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", + ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="failed", + progress=33, + completed_years=["2026"], + queued_years=["2028"], + failed_years=["2027"], + error="Budget window failed for 2027", + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Budget window failed for 2027" + assert result.completed_years == ["2026"] + assert result.computing_years == [] + assert result.queued_years == ["2028"] + mock_reform_impacts_service.set_error_reform_impact.assert_called_once() + assert ( + mock_reform_impacts_service.set_error_reform_impact.call_args.kwargs[ + "execution_id" + ] + == "fc-budget-123" + ) + + def test__given_cliff_target__raises_value_error( + self, economy_service, base_params + ): + base_params["target"] = "cliff" + + with pytest.raises( + ValueError, + match="Budget-window calculations only support target='general'", + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_oversized_window__raises_value_error( + self, economy_service, base_params + ): + base_params["window_size"] = BUDGET_WINDOW_MAX_YEARS + 1 + + with pytest.raises( + ValueError, + match=(f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}"), + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_end_year_after_2099__raises_value_error( + self, economy_service, base_params + ): + base_params["start_year"] = "2090" + base_params["window_size"] = 20 + + with pytest.raises( + ValueError, + match=( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" + ), + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_failed_tracking_row_and_unavailable_batch__returns_stored_error( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="error", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", + message="Stored batch failure", + ) + ] + mock_simulation_api.get_budget_window_batch_by_id.side_effect = Exception( + "batch lookup failed" + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Stored batch failure" + assert result.queued_years == ["2026", "2027", "2028"] + + def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_window( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + monkeypatch, + ): + cache_version = "e1cache01" + + monkeypatch.setattr( + "policyengine_api.services.economy_service.get_economy_impact_cache_version", + lambda country_id, api_version=None: cache_version, + ) + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() + assert ( + mock_reform_impacts_service.get_all_reform_impacts.call_args.args[5] + == "budget_window:2026:3" + ) + assert ( + mock_reform_impacts_service.get_all_reform_impacts.call_args.args[7] + == cache_version + ) + assert ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "api_version" + ] + == cache_version + ) + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): @@ -714,6 +1388,47 @@ def test__given_no_impacts__returns_none( # Assert assert result is None + class TestGetExistingEconomicImpact: + @pytest.fixture + def economy_service(self): + return EconomyService() + + @pytest.fixture + def setup_options(self): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + def test__given_stale_provisional_impact__returns_none( + self, + economy_service, + setup_options, + mock_reform_impacts_service, + ): + stale_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), + ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + stale_impact + ] + + result = economy_service._get_existing_economic_impact(setup_options) + + assert result is None + class TestDetermineImpactAction: @pytest.fixture def economy_service(self): @@ -731,12 +1446,12 @@ def test__given_ok_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_error_status__returns_completed(self, economy_service): + def test__given_error_status__returns_error(self, economy_service): impact = create_mock_reform_impact(status="error") result = economy_service._determine_impact_action(impact) - assert result == ImpactAction.COMPLETED + assert result == ImpactAction.ERROR def test__given_computing_status__returns_computing(self, economy_service): impact = create_mock_reform_impact(status="computing") @@ -745,6 +1460,20 @@ def test__given_computing_status__returns_computing(self, economy_service): assert result == ImpactAction.COMPUTING + def test__given_stale_provisional_computing_status__returns_create( + self, economy_service + ): + impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), + ) + + result = economy_service._determine_impact_action(impact) + + assert result == ImpactAction.CREATE + def test__given_unknown_status__raises_error(self, economy_service): impact = create_mock_reform_impact(status="unknown") @@ -818,6 +1547,7 @@ def test__given_failed_state__returns_error_result( assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_active_state__returns_computing_result( @@ -832,6 +1562,21 @@ def test__given_active_state__returns_computing_result( assert result.status == ImpactStatus.COMPUTING assert result.data is None + def test__given_provisional_claim__returns_computing_without_polling( + self, economy_service, setup_options, mock_simulation_api, mock_logger + ): + reform_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_pending", + ) + + result = economy_service._handle_computing_impact( + setup_options, reform_impact + ) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.get_execution_by_id.assert_not_called() + def test__given_unknown_state__raises_error( self, economy_service, setup_options ): @@ -894,6 +1639,7 @@ def test__given_modal_failed_state__then_returns_error_result( # Then assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_modal_failed_state_with_error_message__then_includes_error_in_message( @@ -915,6 +1661,10 @@ def test__given_modal_failed_state_with_error_message__then_includes_error_in_me # Then assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Simulation API execution failed: Simulation timed out" + ) # Verify the error message was passed to the service call_args = mock_reform_impacts_service.set_error_reform_impact.call_args assert "Simulation timed out" in call_args[1]["message"] @@ -1012,6 +1762,7 @@ def test__given_error__creates_correct_instance_and_logs(self): assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Test error message" mock_logger.log_struct.assert_called_once() diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py new file mode 100644 index 000000000..106cf8757 --- /dev/null +++ b/tests/unit/services/test_reform_impacts_service.py @@ -0,0 +1,189 @@ +from unittest.mock import MagicMock + +import pytest + +from policyengine_api.services.reform_impacts_service import ReformImpactsService + + +class TestReformImpactsService: + def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( + self, monkeypatch + ): + service = ReformImpactsService() + + show_columns_result = MagicMock() + show_columns_result.fetchall.return_value = [ + {"Field": "reform_impact_id"}, + {"Field": "status"}, + {"Field": "start_time"}, + ] + alter_dataset_result = MagicMock() + alter_execution_result = MagicMock() + alter_end_time_result = MagicMock() + + mock_database = MagicMock() + mock_database.local = False + mock_database.query.side_effect = [ + show_columns_result, + alter_dataset_result, + alter_execution_result, + alter_end_time_result, + ] + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + service._ensure_remote_schema() + + assert mock_database.query.call_args_list[0].args == ( + "SHOW COLUMNS FROM reform_impact", + ) + assert mock_database.query.call_args_list[1].args == ( + "ALTER TABLE reform_impact ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'", + ) + assert mock_database.query.call_args_list[2].args == ( + "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", + ) + assert mock_database.query.call_args_list[3].args == ( + "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL", + ) + + def test__given_remote_database_existing_columns__ensure_remote_schema_skips_alter( + self, monkeypatch + ): + service = ReformImpactsService() + + show_columns_result = MagicMock() + show_columns_result.fetchall.return_value = [ + {"Field": "reform_impact_id"}, + {"Field": "status"}, + {"Field": "start_time"}, + {"Field": "dataset"}, + {"Field": "execution_id"}, + {"Field": "end_time"}, + ] + + mock_database = MagicMock() + mock_database.local = False + mock_database.query.return_value = show_columns_result + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + service._ensure_remote_schema() + + mock_database.query.assert_called_once_with("SHOW COLUMNS FROM reform_impact") + + def test__given_remote_database__claim_lock_uses_advisory_lock(self, monkeypatch): + service = ReformImpactsService() + + acquired_result = MagicMock() + acquired_result.mappings.return_value.first.return_value = {"acquired": 1} + release_result = MagicMock() + + mock_connection = MagicMock() + mock_connection.exec_driver_sql.side_effect = [ + acquired_result, + release_result, + ] + + mock_connection_context = MagicMock() + mock_connection_context.__enter__.return_value = mock_connection + mock_connection_context.__exit__.return_value = False + + mock_pool = MagicMock() + mock_pool.connect.return_value = mock_connection_context + + mock_database = MagicMock() + mock_database.local = False + mock_database.pool = mock_pool + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + with service.claim_lock( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ): + pass + + assert mock_connection.exec_driver_sql.call_count == 2 + + acquire_call = mock_connection.exec_driver_sql.call_args_list[0] + assert acquire_call.args == ( + "SELECT GET_LOCK(%s, %s) AS acquired", + ( + service._build_lock_name( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ), + 5, + ), + ) + assert len(acquire_call.args[1][0]) <= 64 + + release_call = mock_connection.exec_driver_sql.call_args_list[1] + assert release_call.args == ( + "SELECT RELEASE_LOCK(%s) AS released", + (acquire_call.args[1][0],), + ) + mock_connection.commit.assert_called_once() + + def test__given_remote_database_lock_timeout__claim_lock_raises(self, monkeypatch): + service = ReformImpactsService() + + acquired_result = MagicMock() + acquired_result.mappings.return_value.first.return_value = {"acquired": 0} + + mock_connection = MagicMock() + mock_connection.exec_driver_sql.return_value = acquired_result + + mock_connection_context = MagicMock() + mock_connection_context.__enter__.return_value = mock_connection + mock_connection_context.__exit__.return_value = False + + mock_pool = MagicMock() + mock_pool.connect.return_value = mock_connection_context + + mock_database = MagicMock() + mock_database.local = False + mock_database.pool = mock_pool + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + with pytest.raises( + TimeoutError, + match="Could not acquire reform impact lock", + ): + with service.claim_lock( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ): + pass diff --git a/uv.lock b/uv.lock index 8bb11c5e4..778b61515 100644 --- a/uv.lock +++ b/uv.lock @@ -2622,7 +2622,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/a0/f3/eeea7dab690e46cd9 [[package]] name = "policyengine-api" -version = "3.40.7" +version = "3.40.8" source = { editable = "." } dependencies = [ { name = "anthropic" },