diff --git a/src/policyengine/tax_benefit_models/common/__init__.py b/src/policyengine/tax_benefit_models/common/__init__.py index 654f350d..744bf21d 100644 --- a/src/policyengine/tax_benefit_models/common/__init__.py +++ b/src/policyengine/tax_benefit_models/common/__init__.py @@ -6,6 +6,9 @@ """ from .extra_variables import dispatch_extra_variables as dispatch_extra_variables +from .household import ( + validate_annual_household_inputs as validate_annual_household_inputs, +) from .model_version import ( MicrosimulationModelVersion as MicrosimulationModelVersion, ) diff --git a/src/policyengine/tax_benefit_models/common/household.py b/src/policyengine/tax_benefit_models/common/household.py new file mode 100644 index 00000000..e3aced33 --- /dev/null +++ b/src/policyengine/tax_benefit_models/common/household.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any + + +def validate_annual_household_inputs( + *, + year: Any, + entities: Mapping[str, Sequence[Mapping[str, Any]]], +) -> int: + """Validate annual-only household calculator inputs.""" + validated_year = _validate_annual_year(year) + _validate_unperiodized_values(entities) + return validated_year + + +def _validate_annual_year(year: Any) -> int: + if isinstance(year, bool): + raise _annual_period_error(year) + if isinstance(year, int): + return year + if isinstance(year, str) and year.isdecimal() and len(year) == 4: + return int(year) + raise _annual_period_error(year) + + +def _annual_period_error(year: Any) -> ValueError: + return ValueError( + "Household calculations require a calendar year as an integer, " + "for example year=2026. " + "Monthly periods are not supported by calculate_household. " + f"Received year={year!r}." + ) + + +def _validate_unperiodized_values( + entities: Mapping[str, Sequence[Mapping[str, Any]]], +) -> None: + for entity, records in entities.items(): + for index, record in enumerate(records): + for variable, value in record.items(): + if variable != "id" and isinstance(value, Mapping): + raise ValueError( + "Periodized household inputs are not supported by " + "calculate_household. Pass annual scalar input values " + f"only; received a periodized value for " + f"{_input_location(entity, index, len(records), variable)}." + ) + + +def _input_location( + entity: str, + index: int, + record_count: int, + variable: str, +) -> str: + if record_count == 1 and entity != "people": + return f"{entity}.{variable}" + return f"{entity}[{index}].{variable}" diff --git a/src/policyengine/tax_benefit_models/uk/household.py b/src/policyengine/tax_benefit_models/uk/household.py index 5dbd71bb..53c386d1 100644 --- a/src/policyengine/tax_benefit_models/uk/household.py +++ b/src/policyengine/tax_benefit_models/uk/household.py @@ -28,6 +28,7 @@ HouseholdResult, compile_reform, dispatch_extra_variables, + validate_annual_household_inputs, ) from policyengine.utils.household_validation import validate_household_input @@ -131,19 +132,28 @@ def calculate_household( :class:`HouseholdResult` with dot-accessible entity results. Raises: - ValueError: on unknown or mis-placed variable names, or - unknown reform parameter paths. + ValueError: on unknown or mis-placed variable names, + unknown reform parameter paths, non-annual ``year`` values, + or periodized household input values. TypeError: on US-only kwargs (``tax_unit``, etc.) or other unsupported keyword arguments. """ if unexpected: _raise_unexpected_kwargs(unexpected) - from policyengine_uk import Simulation - people = list(people) benunit_dict = dict(benunit or {}) household_dict = dict(household or {}) + year = validate_annual_household_inputs( + year=year, + entities={ + "people": people, + "benunit": [benunit_dict], + "household": [household_dict], + }, + ) + + from policyengine_uk import Simulation validate_household_input( model_version=uk_latest, diff --git a/src/policyengine/tax_benefit_models/us/household.py b/src/policyengine/tax_benefit_models/us/household.py index 5258043a..76902f00 100644 --- a/src/policyengine/tax_benefit_models/us/household.py +++ b/src/policyengine/tax_benefit_models/us/household.py @@ -45,6 +45,7 @@ HouseholdResult, compile_reform, dispatch_extra_variables, + validate_annual_household_inputs, ) from policyengine.utils.household_validation import validate_household_input @@ -181,13 +182,12 @@ def calculate_household( if a variable is placed on the wrong entity (e.g. ``filing_status`` on ``people``), or if ``extra_variables`` / ``reform`` names a variable or parameter path not defined - on the US model. + on the US model. Raises if ``year`` is not an annual calendar + year or if household input values are already periodized. """ if unexpected: _raise_unexpected_kwargs(unexpected) - from policyengine_us import Simulation - people = list(people) entities = { "marital_unit": dict(marital_unit or {}), @@ -196,6 +196,15 @@ def calculate_household( "tax_unit": dict(tax_unit or {}), "household": dict(household or {}), } + year = validate_annual_household_inputs( + year=year, + entities={ + "people": people, + **{name: [value] for name, value in entities.items()}, + }, + ) + + from policyengine_us import Simulation validate_household_input( model_version=us_latest, diff --git a/tests/test_household_impact.py b/tests/test_household_impact.py index d99d144b..88444ebc 100644 --- a/tests/test_household_impact.py +++ b/tests/test_household_impact.py @@ -9,7 +9,11 @@ import pytest import policyengine as pe -from policyengine.tax_benefit_models.common import EntityResult, HouseholdResult +from policyengine.tax_benefit_models.common import ( + EntityResult, + HouseholdResult, + validate_annual_household_inputs, +) class TestUKCalculateHousehold: @@ -65,6 +69,34 @@ def test__reform_changes_child_benefit__then_dict_compiles_and_applies(self): assert isinstance(reformed.benunit.child_benefit, float) assert isinstance(baseline.benunit.child_benefit, float) + def test__monthly_year_period__then_raises_before_calculation(self): + with pytest.raises(ValueError, match="Monthly periods are not supported"): + pe.uk.calculate_household( + people=[{"age": 30}], + year="2026-01", + ) + + def test__periodized_person_input__then_raises_before_calculation(self): + with pytest.raises( + ValueError, + match=r"Periodized household inputs.*people\[0\]\.employment_income", + ): + pe.uk.calculate_household( + people=[{"age": 30, "employment_income": {"2026-01": 1_000}}], + year=2026, + ) + + def test__periodized_group_input__then_raises_before_calculation(self): + with pytest.raises( + ValueError, + match=r"Periodized household inputs.*benunit\.would_claim_child_benefit", + ): + pe.uk.calculate_household( + people=[{"age": 30}], + benunit={"would_claim_child_benefit": {"2026-01": True}}, + year=2026, + ) + class TestUSCalculateHousehold: def test__single_adult__then_returns_result_with_net_income(self): @@ -119,8 +151,55 @@ def test__reform_compiles_effective_date_form(self): ) assert result.tax_unit.ctc >= 0 + def test__monthly_year_period__then_raises_before_calculation(self): + with pytest.raises(ValueError, match="Monthly periods are not supported"): + pe.us.calculate_household( + people=[{"age": 30, "is_tax_unit_head": True}], + year="2026-01", + ) + + def test__periodized_person_input__then_raises_before_calculation(self): + with pytest.raises( + ValueError, + match=r"Periodized household inputs.*people\[0\]\.employment_income", + ): + pe.us.calculate_household( + people=[ + { + "age": 30, + "is_tax_unit_head": True, + "employment_income": {"2026-01": 1_000}, + } + ], + year=2026, + ) + + def test__periodized_group_input__then_raises_before_calculation(self): + with pytest.raises( + ValueError, + match=r"Periodized household inputs.*household\.state_code", + ): + pe.us.calculate_household( + people=[{"age": 30, "is_tax_unit_head": True}], + household={"state_code": {"2026-01": "CA"}}, + year=2026, + ) + class TestHouseholdInputValidation: + def test__annual_year_string__then_normalizes_to_int(self): + assert ( + validate_annual_household_inputs(year="2026", entities={"people": []}) + == 2026 + ) + + def test__non_annual_year__then_error_includes_received_year(self): + with pytest.raises(ValueError, match=r"Received year='2026-01'"): + validate_annual_household_inputs( + year="2026-01", + entities={"people": []}, + ) + def test__unknown_person_variable__then_raises_with_suggestion(self): with pytest.raises(ValueError, match="employment_incme"): pe.us.calculate_household(