From ccfb1e4b2b2972d0111873517b2aaf8eb925c918 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 30 Apr 2025 07:27:00 -0400 Subject: [PATCH 1/2] Bug prevents state tax calculation in some cases Fixes #113 --- changelog_entry.yaml | 6 ++++ policyengine/constants.py | 35 ++++++++++++++++--- .../macro/single/calculate_single_economy.py | 2 +- policyengine/simulation.py | 19 +++++----- 4 files changed, 47 insertions(+), 15 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..53ad29f2 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: patch + changes: + fixed: + - Bug in state tax revenue calculation. + added: + - Default dataset handling (extra backups added). diff --git a/policyengine/constants.py b/policyengine/constants.py index 651de219..8bd67713 100644 --- a/policyengine/constants.py +++ b/policyengine/constants.py @@ -1,5 +1,8 @@ """Mainly simulation options and parameters.""" +from policyengine_core.data import Dataset +from policyengine.utils.data_download import download + # Datasets ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5" @@ -8,7 +11,31 @@ CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5" POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5" -DEFAULT_DATASETS_BY_COUNTRY = { - "uk": ENHANCED_FRS, - "us": CPS, -} +def get_default_dataset(country: str, region: str): + if country == "uk": + data_file = download( + filepath="enhanced_frs_2022_23.h5", + huggingface_repo="policyengine-uk-data", + gcs_bucket="policyengine-uk-data-private", + ) + time_period = None + elif country == "us": + if region is not None and region != "us": + data_file = download( + filepath="pooled_3_year_cps_2023.h5", + huggingface_repo="policyengine-us-data", + gcs_bucket="policyengine-us-data", + ) + time_period = 2023 + else: + data_file = download( + filepath="cps_2023.h5", + huggingface_repo="policyengine-us-data", + gcs_bucket="policyengine-us-data", + ) + time_period = 2023 + + return Dataset.from_file( + file_path=data_file, + time_period=time_period, + ) diff --git a/policyengine/outputs/macro/single/calculate_single_economy.py b/policyengine/outputs/macro/single/calculate_single_economy.py index 7a21133a..3ed9f517 100644 --- a/policyengine/outputs/macro/single/calculate_single_economy.py +++ b/policyengine/outputs/macro/single/calculate_single_economy.py @@ -376,7 +376,7 @@ def calculate_single_economy( if country_id == "us": try: - total_state_tax = simulation.calculate( + total_state_tax = task_manager.simulation.calculate( "household_state_income_tax" ).sum() except: diff --git a/policyengine/simulation.py b/policyengine/simulation.py index c53d30f0..61f00acb 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field from typing import Literal -from .constants import DEFAULT_DATASETS_BY_COUNTRY +from .constants import get_default_dataset from policyengine_core.simulations import Simulation as CountrySimulation from policyengine_core.simulations import ( Microsimulation as CountryMicrosimulation, @@ -73,11 +73,6 @@ class Simulation: def __init__(self, **options: SimulationOptions): self.options = SimulationOptions(**options) - if self.options.data is None: - self.options.data = DEFAULT_DATASETS_BY_COUNTRY[ - self.options.country - ] - self._set_data() self._initialise_simulations() self._add_output_functions() @@ -115,11 +110,12 @@ def _add_output_functions(self): def _set_data(self): if self.options.data is None: - self.options.data = DEFAULT_DATASETS_BY_COUNTRY[ - self.options.country - ] + self.options.data = get_default_dataset( + country=self.options.country, + region=self.options.region, + ) - if isinstance(self.options.data, str): + elif isinstance(self.options.data, str): filename = self.options.data if "://" in self.options.data: bucket = None @@ -129,6 +125,7 @@ def _set_data(self): bucket, filename = self.options.data.split("://")[ -1 ].split("/") + hf_org = "policyengine" elif "hf://" in self.options.data: hf_org, hf_repo, filename = self.options.data.split("://")[ -1 @@ -221,6 +218,8 @@ def _initialise_simulation( if subsample is not None: simulation = simulation.subsample(subsample) + simulation.default_calculation_period = time_period + return simulation def _apply_region_to_simulation( From dbd849f465201d2edfc4626270cc9c3640575942 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 30 Apr 2025 07:33:08 -0400 Subject: [PATCH 2/2] Format --- policyengine/constants.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/policyengine/constants.py b/policyengine/constants.py index 8bd67713..c6bca554 100644 --- a/policyengine/constants.py +++ b/policyengine/constants.py @@ -11,6 +11,7 @@ CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5" POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5" + def get_default_dataset(country: str, region: str): if country == "uk": data_file = download( @@ -34,7 +35,7 @@ def get_default_dataset(country: str, region: str): gcs_bucket="policyengine-us-data", ) time_period = 2023 - + return Dataset.from_file( file_path=data_file, time_period=time_period,