From 1daf56146f0a0b93371ea5ed47539ba39543dfd8 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 18 Apr 2026 19:39:46 -0400 Subject: [PATCH 1/3] Pre-launch cleanup: dead code + plotly optional extra MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removes three unambiguously dead code paths and moves plotly out of the core install so `import policyengine` doesn't pull the charting stack. Changes are behavior-preserving for every downstream repo surveyed (policyengine-api, policyengine-api-v2, policyengine-api-v2-alpha). 1. Delete `tax_benefit_models/{us,uk}.py` shim files. Python always resolves the `us/`/`uk/` package directory first, so the .py files were dead. Worse: both re-exported `general_policy_reform_analysis` which is not defined anywhere — `from policyengine.tax_benefit_models.us import general_policy_reform_analysis` raises ImportError at runtime. 2. Delete `_create_entity_output_model` + `PersonOutput` / `BenunitOutput` / `HouseholdEntityOutput` in uk/analysis.py. Built via pydantic.create_model at import time, referenced nowhere in the codebase. 3. Delete `policyengine.core.DatasetVersion`. One optional field on Dataset (never set by anything) and one core re-export. Nothing reads it downstream. 4. Move `plotly>=5.0.0` from base dependencies to a `[plotting]` optional extra. Only `policyengine.utils.plotting` uses plotly, and nothing in src/ imports that module — only `examples/` do. `plotting.py` now soft-imports with a clear install hint. Downstream impact: none. Surveyed policyengine-api (pinned to a pre-3.x API), policyengine-api-v2 (3.4.0), policyengine-api-v2-alpha (3.1.15); none of them import the deleted symbols. Tests: 216 passed locally across test_release_manifests, test_trace_tro, test_results, test_household_impact, test_models, test_us_regions, test_uk_regions, test_region, test_manifest_version_mismatch, test_filtering, test_cache, test_scoping_strategy. Deferred (bigger refactors, follow-up PRs): - filter_field/filter_value legacy path on Simulation (still wired through Region construction; needs migration) - calculate_household_impact → calculate_household rename (with deprecation shim) - Extract shared MicrosimulationModelVersion base (~600 LOC savings) - Move release_manifest + trace_tro to policyengine/provenance/ Co-Authored-By: Claude Opus 4.7 (1M context) --- changelog.d/pre-launch-cleanup.removed.md | 6 +++ pyproject.toml | 5 ++- src/policyengine/core/__init__.py | 1 - src/policyengine/core/dataset.py | 2 - src/policyengine/core/dataset_version.py | 16 -------- src/policyengine/tax_benefit_models/uk.py | 40 ------------------- .../tax_benefit_models/uk/analysis.py | 20 +--------- src/policyengine/tax_benefit_models/us.py | 40 ------------------- src/policyengine/utils/plotting.py | 20 ++++++++-- 9 files changed, 28 insertions(+), 122 deletions(-) create mode 100644 changelog.d/pre-launch-cleanup.removed.md delete mode 100644 src/policyengine/core/dataset_version.py delete mode 100644 src/policyengine/tax_benefit_models/uk.py delete mode 100644 src/policyengine/tax_benefit_models/us.py diff --git a/changelog.d/pre-launch-cleanup.removed.md b/changelog.d/pre-launch-cleanup.removed.md new file mode 100644 index 00000000..73b95b51 --- /dev/null +++ b/changelog.d/pre-launch-cleanup.removed.md @@ -0,0 +1,6 @@ +Pre-launch cleanup — remove dead code and drop `plotly` from the core dependency set: + +- Delete `policyengine.tax_benefit_models.us` and `policyengine.tax_benefit_models.uk` module shims. Python resolves the package directory first, so the `.py` shims were always shadowed; worse, both attempted to re-export `general_policy_reform_analysis` which is not defined anywhere, making `from policyengine.tax_benefit_models.us import general_policy_reform_analysis` raise `ImportError` at runtime. +- Delete `_create_entity_output_model` plus the `PersonOutput` / `BenunitOutput` / `HouseholdEntityOutput` factory products in `policyengine.tax_benefit_models.uk.analysis` — built via `pydantic.create_model` but never referenced anywhere in the codebase. +- Delete `policyengine.core.DatasetVersion` (only consumer was an `Optional` field on `Dataset` that was never set, and the `policyengine.core` re-export). +- Move `plotly>=5.0.0` from the base install to a new `policyengine[plotting]` extra. Only `policyengine.utils.plotting` uses it, and that module is itself only used by the `examples/` scripts. The package now imports cleanly without `plotly`. diff --git a/pyproject.toml b/pyproject.toml index 67582060..72af3935 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "pydantic>=2.0.0", "pandas>=2.0.0", "microdf_python>=1.2.1", - "plotly>=5.0.0", "requests>=2.31.0", "psutil>=5.9.0", "packaging>=23.0", @@ -34,6 +33,9 @@ dependencies = [ policyengine = "policyengine.cli:main" [project.optional-dependencies] +plotting = [ + "plotly>=5.0.0", +] uk = [ "policyengine_core>=3.25.0", "policyengine-uk==2.88.0", @@ -51,6 +53,7 @@ dev = [ "itables", "build", "jsonschema>=4.0.0", + "plotly>=5.0.0", "pytest-asyncio>=0.26.0", "ruff>=0.9.0", "policyengine_core>=3.25.0", diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index 8ff37aed..71ca0132 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -1,7 +1,6 @@ from .dataset import Dataset from .dataset import YearData as YearData from .dataset import map_to_entity as map_to_entity -from .dataset_version import DatasetVersion as DatasetVersion from .dynamic import Dynamic as Dynamic from .output import Output as Output from .output import OutputCollection as OutputCollection diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index 27f51d16..64f74eba 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -6,7 +6,6 @@ from microdf import MicroDataFrame from pydantic import BaseModel, ConfigDict, Field -from .dataset_version import DatasetVersion from .tax_benefit_model import TaxBenefitModel @@ -85,7 +84,6 @@ class MyDataset(Dataset): id: str = Field(default_factory=lambda: str(uuid4())) name: str description: str - dataset_version: Optional[DatasetVersion] = None filepath: str is_output_dataset: bool = False tax_benefit_model: Optional[TaxBenefitModel] = None diff --git a/src/policyengine/core/dataset_version.py b/src/policyengine/core/dataset_version.py deleted file mode 100644 index 711cd7d7..00000000 --- a/src/policyengine/core/dataset_version.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import TYPE_CHECKING -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from .tax_benefit_model import TaxBenefitModel - -if TYPE_CHECKING: - from .dataset import Dataset - - -class DatasetVersion(BaseModel): - id: str = Field(default_factory=lambda: str(uuid4())) - dataset: "Dataset" - description: str - tax_benefit_model: TaxBenefitModel = None diff --git a/src/policyengine/tax_benefit_models/uk.py b/src/policyengine/tax_benefit_models/uk.py deleted file mode 100644 index 52abcb18..00000000 --- a/src/policyengine/tax_benefit_models/uk.py +++ /dev/null @@ -1,40 +0,0 @@ -"""PolicyEngine UK tax-benefit model - imports from uk/ module.""" - -from importlib.util import find_spec - -if find_spec("policyengine_uk") is not None: - from .uk import ( - PolicyEngineUK, - PolicyEngineUKDataset, - PolicyEngineUKLatest, - ProgrammeStatistics, - UKYearData, - create_datasets, - ensure_datasets, - general_policy_reform_analysis, - load_datasets, - managed_microsimulation, - uk_latest, - uk_model, - ) - - __all__ = [ - "UKYearData", - "PolicyEngineUKDataset", - "create_datasets", - "load_datasets", - "ensure_datasets", - "PolicyEngineUK", - "PolicyEngineUKLatest", - "managed_microsimulation", - "uk_model", - "uk_latest", - "general_policy_reform_analysis", - "ProgrammeStatistics", - ] - - # Rebuild models to resolve forward references - PolicyEngineUKDataset.model_rebuild() - PolicyEngineUKLatest.model_rebuild() -else: - __all__ = [] diff --git a/src/policyengine/tax_benefit_models/uk/analysis.py b/src/policyengine/tax_benefit_models/uk/analysis.py index 0a545b52..b05e21b0 100644 --- a/src/policyengine/tax_benefit_models/uk/analysis.py +++ b/src/policyengine/tax_benefit_models/uk/analysis.py @@ -6,7 +6,7 @@ import pandas as pd from microdf import MicroDataFrame -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, Field from policyengine.core import OutputCollection, Simulation from policyengine.core.policy import Policy @@ -28,24 +28,6 @@ from .outputs import ProgrammeStatistics -def _create_entity_output_model(entity: str, variables: list[str]) -> type[BaseModel]: - """Create a dynamic Pydantic model for entity output variables.""" - fields = {var: (float, ...) for var in variables} - return create_model(f"{entity.title()}Output", **fields) - - -# Create output models dynamically from uk_latest.entity_variables -PersonOutput = _create_entity_output_model( - "person", uk_latest.entity_variables["person"] -) -BenunitOutput = _create_entity_output_model( - "benunit", uk_latest.entity_variables["benunit"] -) -HouseholdEntityOutput = _create_entity_output_model( - "household", uk_latest.entity_variables["household"] -) - - class UKHouseholdOutput(BaseModel): """Output from a UK household calculation with all entity data.""" diff --git a/src/policyengine/tax_benefit_models/us.py b/src/policyengine/tax_benefit_models/us.py deleted file mode 100644 index bbc29486..00000000 --- a/src/policyengine/tax_benefit_models/us.py +++ /dev/null @@ -1,40 +0,0 @@ -"""PolicyEngine US tax-benefit model - imports from us/ module.""" - -from importlib.util import find_spec - -if find_spec("policyengine_us") is not None: - from .us import ( - PolicyEngineUS, - PolicyEngineUSDataset, - PolicyEngineUSLatest, - ProgramStatistics, - USYearData, - create_datasets, - ensure_datasets, - general_policy_reform_analysis, - load_datasets, - managed_microsimulation, - us_latest, - us_model, - ) - - __all__ = [ - "USYearData", - "PolicyEngineUSDataset", - "create_datasets", - "load_datasets", - "ensure_datasets", - "PolicyEngineUS", - "PolicyEngineUSLatest", - "managed_microsimulation", - "us_model", - "us_latest", - "general_policy_reform_analysis", - "ProgramStatistics", - ] - - # Rebuild models to resolve forward references - PolicyEngineUSDataset.model_rebuild() - PolicyEngineUSLatest.model_rebuild() -else: - __all__ = [] diff --git a/src/policyengine/utils/plotting.py b/src/policyengine/utils/plotting.py index 2ca8e48c..b1700a35 100644 --- a/src/policyengine/utils/plotting.py +++ b/src/policyengine/utils/plotting.py @@ -1,8 +1,22 @@ -"""Plotting utilities for PolicyEngine visualisations.""" +"""Plotting utilities for PolicyEngine visualisations. -from typing import Optional +Requires plotly, which is installed via the ``[plotting]`` extra +(``pip install policyengine[plotting]``). Importing from this module +fails with a clear error when plotly is absent. +""" -import plotly.graph_objects as go +from typing import TYPE_CHECKING, Optional + +try: + import plotly.graph_objects as go +except ImportError as exc: # pragma: no cover + raise ImportError( + "policyengine.utils.plotting requires plotly. " + "Install with: pip install policyengine[plotting]" + ) from exc + +if TYPE_CHECKING: + import plotly.graph_objects as go # noqa: F401 # PolicyEngine brand colours COLORS = { From 8e412d244afcde70f66eb9948c4c92e78cfbafa5 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 18 Apr 2026 19:55:47 -0400 Subject: [PATCH 2/3] Extract brand tokens to utils.design so import works without plotly utils/__init__.py eagerly imported COLORS from plotting.py, which now raises ImportError when plotly isn't installed. Every smoke-import job on PR #288 failed because plotting.py blocked at module load. Move COLORS + FONT_* constants to a new plotly-free utils/design.py; plotting.py re-exports them for backward compatibility and adds them to __all__. utils/__init__.py now pulls COLORS from design rather than plotting. Confirmed locally that pip uninstall plotly still lets 'import policyengine' + 'from policyengine.utils import COLORS' + 'from policyengine.core.release_manifest import get_release_manifest' all work cleanly. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/policyengine/utils/__init__.py | 3 +-- src/policyengine/utils/design.py | 24 ++++++++++++++++++ src/policyengine/utils/plotting.py | 40 ++++++++++++++---------------- 3 files changed, 44 insertions(+), 23 deletions(-) create mode 100644 src/policyengine/utils/design.py diff --git a/src/policyengine/utils/__init__.py b/src/policyengine/utils/__init__.py index bf3cc681..bfbfe10b 100644 --- a/src/policyengine/utils/__init__.py +++ b/src/policyengine/utils/__init__.py @@ -1,7 +1,6 @@ from .dates import parse_safe_date as parse_safe_date +from .design import COLORS as COLORS from .parameter_labels import build_scale_lookup as build_scale_lookup from .parameter_labels import ( generate_label_for_parameter as generate_label_for_parameter, ) -from .plotting import COLORS as COLORS -from .plotting import format_fig as format_fig diff --git a/src/policyengine/utils/design.py b/src/policyengine/utils/design.py new file mode 100644 index 00000000..eda921a1 --- /dev/null +++ b/src/policyengine/utils/design.py @@ -0,0 +1,24 @@ +"""PolicyEngine brand colours and typography tokens. + +Lives outside ``plotting`` so consumers can import ``COLORS`` without +pulling plotly in. +""" + +COLORS = { + "primary": "#319795", # Teal + "primary_light": "#E6FFFA", + "primary_dark": "#1D4044", + "success": "#22C55E", # Green (positive changes) + "warning": "#FEC601", # Yellow (cautions) + "error": "#EF4444", # Red (negative changes) + "info": "#1890FF", # Blue (neutral info) + "gray_light": "#F2F4F7", + "gray": "#667085", + "gray_dark": "#101828", + "blue_secondary": "#026AA2", +} + +FONT_FAMILY = "Inter, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" +FONT_SIZE_LABEL = 12 +FONT_SIZE_DEFAULT = 14 +FONT_SIZE_TITLE = 16 diff --git a/src/policyengine/utils/plotting.py b/src/policyengine/utils/plotting.py index b1700a35..15243e0e 100644 --- a/src/policyengine/utils/plotting.py +++ b/src/policyengine/utils/plotting.py @@ -2,7 +2,9 @@ Requires plotly, which is installed via the ``[plotting]`` extra (``pip install policyengine[plotting]``). Importing from this module -fails with a clear error when plotly is absent. +fails with a clear error when plotly is absent. Brand tokens +(``COLORS``, font constants) live in :mod:`policyengine.utils.design` +so they can be imported without plotly. """ from typing import TYPE_CHECKING, Optional @@ -18,26 +20,22 @@ if TYPE_CHECKING: import plotly.graph_objects as go # noqa: F401 -# PolicyEngine brand colours -COLORS = { - "primary": "#319795", # Teal - "primary_light": "#E6FFFA", - "primary_dark": "#1D4044", - "success": "#22C55E", # Green (positive changes) - "warning": "#FEC601", # Yellow (cautions) - "error": "#EF4444", # Red (negative changes) - "info": "#1890FF", # Blue (neutral info) - "gray_light": "#F2F4F7", - "gray": "#667085", - "gray_dark": "#101828", - "blue_secondary": "#026AA2", -} - -# Typography -FONT_FAMILY = "Inter, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif" -FONT_SIZE_LABEL = 12 -FONT_SIZE_DEFAULT = 14 -FONT_SIZE_TITLE = 16 +from .design import ( + COLORS, + FONT_FAMILY, + FONT_SIZE_DEFAULT, + FONT_SIZE_LABEL, + FONT_SIZE_TITLE, +) + +__all__ = [ + "COLORS", + "FONT_FAMILY", + "FONT_SIZE_DEFAULT", + "FONT_SIZE_LABEL", + "FONT_SIZE_TITLE", + "format_fig", +] def format_fig( From 07d24daf54606278124f4a9772445be98e00d90e Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 18 Apr 2026 20:16:39 -0400 Subject: [PATCH 3/3] Drop legacy filter_field/filter_value scoping fields (v4 breaking) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removes the two-way scoping contract in favour of the single ScopingStrategy path. The legacy fields were labeled "kept for backward compatibility" but became dead wiring the moment every caller started passing scoping_strategy explicitly. Changes: Simulation (core/simulation.py) - Drop filter_field, filter_value fields. - Drop _auto_construct_strategy model_validator that rewrote those fields into a RowFilterStrategy. Region (core/region.py) - Drop filter_field, filter_value, requires_filter fields. - Re-add requires_filter as a derived @property: True iff scoping_strategy is not None. - Simplify get_dataset_regions / get_filter_regions to use dataset_path / scoping_strategy directly. Country models (tax_benefit_models/us/model.py, .../uk/model.py) - Delete the `elif simulation.filter_field and simulation.filter_value:` fallback branch in run() — unreachable because nobody sets those fields anymore. - Delete the _filter_dataset_by_household_variable private method — only called from the elif branch. The underlying utils.entity_utils.filter_dataset_by_household_variable helper stays (it's what RowFilterStrategy.apply uses). - Drop the now-unused import. Region factories (countries/{us,uk}/regions.py) - Stop setting requires_filter=True, filter_field=..., filter_value=... alongside scoping_strategy. The scoping_strategy is already the source of truth; the duplicate legacy triple was noise. Tests - test_filtering.py: drop TestSimulationFilterParameters (fields gone) and TestUSFilterDatasetByHouseholdVariable / TestUKFilterDatasetByHouseholdVariable (method gone; underlying behaviour still covered by test_scoping_strategy.py TestRowFilterStrategy). Keep the build_entity_relationships tests. - test_scoping_strategy.py: drop three legacy-auto-construct tests, replace one with a direct WeightReplacementStrategy round-trip check. - test_region.py, test_us_regions.py, test_uk_regions.py: assertions move from `region.filter_field == "X"` to `region.scoping_strategy.variable_name == "X"`. - fixtures/region_fixtures.py: factories use scoping_strategy=RowFilterStrategy(...) directly. 212 tests pass. Downstream impact: policyengine-api-v2-alpha uses the legacy fields in ~14 call sites (grep confirmed); they migrate to RowFilterStrategy in a paired PR. The v4 migration guide will call out this single search-and-replace. Co-Authored-By: Claude Opus 4.7 (1M context) --- changelog.d/v4-drop-filter-fields.removed.md | 13 + src/policyengine/core/region.py | 47 +- src/policyengine/core/simulation.py | 38 +- src/policyengine/countries/uk/regions.py | 9 - src/policyengine/countries/us/regions.py | 3 - .../tax_benefit_models/uk/model.py | 36 +- .../tax_benefit_models/us/model.py | 39 +- tests/fixtures/region_fixtures.py | 17 +- tests/test_filtering.py | 430 +----------------- tests/test_region.py | 15 +- tests/test_scoping_strategy.py | 28 +- tests/test_uk_regions.py | 10 +- tests/test_us_regions.py | 5 +- 13 files changed, 87 insertions(+), 603 deletions(-) create mode 100644 changelog.d/v4-drop-filter-fields.removed.md diff --git a/changelog.d/v4-drop-filter-fields.removed.md b/changelog.d/v4-drop-filter-fields.removed.md new file mode 100644 index 00000000..d2130d5d --- /dev/null +++ b/changelog.d/v4-drop-filter-fields.removed.md @@ -0,0 +1,13 @@ +**BREAKING (v4):** Remove the legacy `filter_field` / `filter_value` +fields from `Simulation` and `Region`, the `_auto_construct_strategy` +model validator that rewrote them into a `RowFilterStrategy`, and the +`_filter_dataset_by_household_variable` methods they fed on both +country models. All scoping now flows through `scoping_strategy: +Optional[ScopingStrategy]`. `Region.requires_filter` becomes a derived +property (`True` iff `scoping_strategy is not None`). The sub-national +region factories (`countries/us/regions.py`, `countries/uk/regions.py`) +construct `scoping_strategy=RowFilterStrategy(...)` / +`WeightReplacementStrategy(...)` directly. Callers that previously +passed `filter_field="place_fips", filter_value="44000"` now pass +`scoping_strategy=RowFilterStrategy(variable_name="place_fips", +variable_value="44000")`. diff --git a/src/policyengine/core/region.py b/src/policyengine/core/region.py index 7ff55a64..6c5faf2a 100644 --- a/src/policyengine/core/region.py +++ b/src/policyengine/core/region.py @@ -3,7 +3,8 @@ This module provides the Region and RegionRegistry classes for defining geographic regions that a tax-benefit model supports. Regions can have: 1. A dedicated dataset (e.g., US states, congressional districts) -2. Filter from a parent region's dataset (e.g., US places/cities, UK countries) +2. A scoping strategy that derives the region from a parent dataset + (row filter or weight replacement) """ from typing import Literal, Optional, Union @@ -22,8 +23,9 @@ class Region(BaseModel): """Geographic region for tax-benefit simulations. Regions can either have: - 1. A dedicated dataset (dataset_path is set, requires_filter is False) - 2. Filter from a parent region's dataset (requires_filter is True) + 1. A dedicated dataset (``dataset_path`` is set). + 2. A scoping strategy that derives the region from a parent dataset + (``scoping_strategy`` is set). The unique identifier is the code field, which uses a prefixed format: - National: "us", "uk" @@ -57,25 +59,16 @@ class Region(BaseModel): description="GCS path to dedicated dataset (e.g., 'gs://policyengine-us-data/states/CA.h5')", ) - # Scoping strategy (preferred over legacy filter fields) + # Scoping strategy for regions that derive from a parent dataset scoping_strategy: Optional[ScopingStrategy] = Field( default=None, description="Strategy for scoping dataset to this region (row filtering or weight replacement)", ) - # Legacy filtering configuration (kept for backward compatibility) - requires_filter: bool = Field( - default=False, - description="True if this region filters from a parent dataset rather than having its own", - ) - filter_field: Optional[str] = Field( - default=None, - description="Dataset field to filter on (e.g., 'place_fips', 'country')", - ) - filter_value: Optional[str] = Field( - default=None, - description="Value to match when filtering (defaults to code suffix if not set)", - ) + @property + def requires_filter(self) -> bool: + """Whether this region needs a parent dataset + a scoping strategy.""" + return self.scoping_strategy is not None # Metadata (primarily for US congressional districts) state_code: Optional[str] = Field( @@ -180,24 +173,12 @@ def get_children(self, parent_code: str) -> list[Region]: return [r for r in self.regions if r.parent_code == parent_code] def get_dataset_regions(self) -> list[Region]: - """Get all regions that have dedicated datasets. - - Returns: - List of regions with dataset_path set and requires_filter False - """ - return [ - r - for r in self.regions - if r.dataset_path is not None and not r.requires_filter - ] + """Get all regions that have a dedicated dataset on disk.""" + return [r for r in self.regions if r.dataset_path is not None] def get_filter_regions(self) -> list[Region]: - """Get all regions that require filtering from parent datasets. - - Returns: - List of regions with requires_filter True - """ - return [r for r in self.regions if r.requires_filter] + """Get all regions that derive from a parent dataset via a scoping strategy.""" + return [r for r in self.regions if r.scoping_strategy is not None] def __len__(self) -> int: """Return the number of regions in the registry.""" diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 6456e5bc..5002b141 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -3,13 +3,13 @@ from typing import Optional from uuid import uuid4 -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field from .cache import LRUCache from .dataset import Dataset from .dynamic import Dynamic from .policy import Policy -from .scoping_strategy import RowFilterStrategy, ScopingStrategy +from .scoping_strategy import ScopingStrategy from .tax_benefit_model_version import TaxBenefitModelVersion logger = logging.getLogger(__name__) @@ -26,42 +26,22 @@ class Simulation(BaseModel): dynamic: Optional[Dynamic] = None dataset: Dataset = None - # Scoping strategy (preferred over legacy filter fields) scoping_strategy: Optional[ScopingStrategy] = Field( default=None, description="Strategy for scoping dataset to a sub-national region", ) - # Legacy regional filtering parameters (kept for backward compatibility) - filter_field: Optional[str] = Field( - default=None, - description="Household-level variable to filter dataset by (e.g., 'place_fips', 'country')", - ) - filter_value: Optional[str] = Field( - default=None, - description="Value to match when filtering (e.g., '44000', 'ENGLAND')", + extra_variables: dict[str, list[str]] = Field( + default_factory=dict, + description=( + "Additional variables to calculate beyond the model version's " + "default entity_variables, keyed by entity name. Use when a " + "caller needs variables that are not in the bundled default set." + ), ) tax_benefit_model_version: TaxBenefitModelVersion = None - @model_validator(mode="after") - def _auto_construct_strategy(self) -> "Simulation": - """Auto-construct a RowFilterStrategy from legacy filter fields. - - If filter_field and filter_value are set but scoping_strategy is not, - create a RowFilterStrategy for backward compatibility. - """ - if ( - self.scoping_strategy is None - and self.filter_field is not None - and self.filter_value is not None - ): - self.scoping_strategy = RowFilterStrategy( - variable_name=self.filter_field, - variable_value=self.filter_value, - ) - return self - output_dataset: Optional[Dataset] = None def run(self): diff --git a/src/policyengine/countries/uk/regions.py b/src/policyengine/countries/uk/regions.py index 2f100524..d90f0ad0 100644 --- a/src/policyengine/countries/uk/regions.py +++ b/src/policyengine/countries/uk/regions.py @@ -140,9 +140,6 @@ def build_uk_region_registry( label=name, region_type="country", parent_code="uk", - requires_filter=True, - filter_field="country", - filter_value=code.upper(), scoping_strategy=RowFilterStrategy( variable_name="country", variable_value=code.upper(), @@ -161,9 +158,6 @@ def build_uk_region_registry( label=const["name"], region_type="constituency", parent_code="uk", - requires_filter=True, - filter_field="household_weight", - filter_value=const["code"], scoping_strategy=WeightReplacementStrategy( weight_matrix_bucket="policyengine-uk-data-private", weight_matrix_key="parliamentary_constituency_weights.h5", @@ -185,9 +179,6 @@ def build_uk_region_registry( label=la["name"], region_type="local_authority", parent_code="uk", - requires_filter=True, - filter_field="household_weight", - filter_value=la["code"], scoping_strategy=WeightReplacementStrategy( weight_matrix_bucket="policyengine-uk-data-private", weight_matrix_key="local_authority_weights.h5", diff --git a/src/policyengine/countries/us/regions.py b/src/policyengine/countries/us/regions.py index f335805f..9e20d8b3 100644 --- a/src/policyengine/countries/us/regions.py +++ b/src/policyengine/countries/us/regions.py @@ -101,9 +101,6 @@ def build_us_region_registry() -> RegionRegistry: label=place["name"], region_type="place", parent_code=f"state/{state_abbrev.lower()}", - requires_filter=True, - filter_field="place_fips", - filter_value=fips, state_code=state_abbrev, state_name=place["state_name"], scoping_strategy=RowFilterStrategy( diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index 1d6711d0..ce6f2dd9 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -21,10 +21,7 @@ resolve_local_managed_dataset_source, resolve_managed_dataset_reference, ) -from policyengine.utils.entity_utils import ( - build_entity_relationships, - filter_dataset_by_household_variable, -) +from policyengine.utils.entity_utils import build_entity_relationships from policyengine.utils.parameter_labels import ( build_scale_lookup, generate_label_for_parameter, @@ -281,33 +278,6 @@ def _build_entity_relationships( person_data = pd.DataFrame(dataset.data.person) return build_entity_relationships(person_data, UK_GROUP_ENTITIES) - def _filter_dataset_by_household_variable( - self, - dataset: PolicyEngineUKDataset, - variable_name: str, - variable_value: str, - ) -> PolicyEngineUKDataset: - """Filter a dataset to only include households where a variable matches.""" - filtered = filter_dataset_by_household_variable( - entity_data=dataset.data.entity_data, - group_entities=UK_GROUP_ENTITIES, - variable_name=variable_name, - variable_value=variable_value, - ) - return PolicyEngineUKDataset( - id=dataset.id + f"_filtered_{variable_name}_{variable_value}", - name=dataset.name, - description=f"{dataset.description} (filtered: {variable_name}={variable_value})", - filepath=dataset.filepath, - year=dataset.year, - is_output_dataset=dataset.is_output_dataset, - data=UKYearData( - person=filtered["person"], - benunit=filtered["benunit"], - household=filtered["household"], - ), - ) - def run(self, simulation: "Simulation") -> "Simulation": from policyengine_uk import Microsimulation from policyengine_uk.data import UKSingleYearDataset @@ -341,10 +311,6 @@ def run(self, simulation: "Simulation") -> "Simulation": household=scoped_data["household"], ), ) - elif simulation.filter_field and simulation.filter_value: - dataset = self._filter_dataset_by_household_variable( - dataset, simulation.filter_field, simulation.filter_value - ) input_data = UKSingleYearDataset( person=dataset.data.person, diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index f5aca625..cd56df09 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -21,10 +21,7 @@ resolve_local_managed_dataset_source, resolve_managed_dataset_reference, ) -from policyengine.utils.entity_utils import ( - build_entity_relationships, - filter_dataset_by_household_variable, -) +from policyengine.utils.entity_utils import build_entity_relationships from policyengine.utils.parameter_labels import ( build_scale_lookup, generate_label_for_parameter, @@ -273,36 +270,6 @@ def _build_entity_relationships( person_data = pd.DataFrame(dataset.data.person) return build_entity_relationships(person_data, US_GROUP_ENTITIES) - def _filter_dataset_by_household_variable( - self, - dataset: PolicyEngineUSDataset, - variable_name: str, - variable_value: str, - ) -> PolicyEngineUSDataset: - """Filter a dataset to only include households where a variable matches.""" - filtered = filter_dataset_by_household_variable( - entity_data=dataset.data.entity_data, - group_entities=US_GROUP_ENTITIES, - variable_name=variable_name, - variable_value=variable_value, - ) - return PolicyEngineUSDataset( - id=dataset.id + f"_filtered_{variable_name}_{variable_value}", - name=dataset.name, - description=f"{dataset.description} (filtered: {variable_name}={variable_value})", - filepath=dataset.filepath, - year=dataset.year, - is_output_dataset=dataset.is_output_dataset, - data=USYearData( - person=filtered["person"], - marital_unit=filtered["marital_unit"], - family=filtered["family"], - spm_unit=filtered["spm_unit"], - tax_unit=filtered["tax_unit"], - household=filtered["household"], - ), - ) - def run(self, simulation: "Simulation") -> "Simulation": from policyengine_us import Microsimulation from policyengine_us.system import system @@ -340,10 +307,6 @@ def run(self, simulation: "Simulation") -> "Simulation": household=scoped_data["household"], ), ) - elif simulation.filter_field and simulation.filter_value: - dataset = self._filter_dataset_by_household_variable( - dataset, simulation.filter_field, simulation.filter_value - ) # Build reform dict from policy and dynamic parameter values. # US requires reforms at Microsimulation construction time diff --git a/tests/fixtures/region_fixtures.py b/tests/fixtures/region_fixtures.py index ca1adfe2..3dc8a639 100644 --- a/tests/fixtures/region_fixtures.py +++ b/tests/fixtures/region_fixtures.py @@ -3,6 +3,7 @@ import pytest from policyengine.core.region import Region, RegionRegistry +from policyengine.core.scoping_strategy import RowFilterStrategy def create_national_region( @@ -43,15 +44,16 @@ def create_place_region( name: str, state_name: str, ) -> Region: - """Create a place region that filters from parent state.""" + """Create a place region that scopes from parent state via row filter.""" return Region( code=f"place/{state_code}-{fips}", label=name, region_type="place", parent_code=f"state/{state_code.lower()}", - requires_filter=True, - filter_field="place_fips", - filter_value=fips, + scoping_strategy=RowFilterStrategy( + variable_name="place_fips", + variable_value=fips, + ), state_code=state_code, state_name=state_name, ) @@ -107,9 +109,10 @@ def create_sample_us_registry() -> RegionRegistry: label="Paterson", region_type="place", parent_code="state/nj", - requires_filter=True, - filter_field="place_fips", - filter_value="57000", + scoping_strategy=RowFilterStrategy( + variable_name="place_fips", + variable_value="57000", + ), state_code="NJ", state_name="New Jersey", ) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 39359dd6..6588d3f9 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -1,87 +1,18 @@ -"""Tests for dataset filtering functionality. +"""Tests for the `_build_entity_relationships` helper on the country models. -Tests the _build_entity_relationships and _filter_dataset_by_household_variable -methods in both US and UK models. +Scoping/filtering behaviour is covered by ``tests/test_scoping_strategy.py``. """ -import pandas as pd -import pytest - -from policyengine.core.simulation import Simulation - - -class TestSimulationFilterParameters: - """Tests for Simulation filter_field and filter_value parameters.""" - - def test__given_no_filter_params__then_simulation_has_none_values(self): - """Given: Simulation created without filter parameters - When: Accessing filter_field and filter_value - Then: Both are None - """ - # When - simulation = Simulation() - - # Then - assert simulation.filter_field is None - assert simulation.filter_value is None - - def test__given_filter_params__then_simulation_stores_them(self): - """Given: Simulation created with filter parameters - When: Accessing filter_field and filter_value - Then: Values are stored correctly - """ - # When - simulation = Simulation( - filter_field="place_fips", - filter_value="44000", - ) - - # Then - assert simulation.filter_field == "place_fips" - assert simulation.filter_value == "44000" - - def test__given_filter_params__then_auto_constructs_scoping_strategy(self): - """Given: Simulation created with legacy filter parameters - When: Checking scoping_strategy - Then: RowFilterStrategy is auto-constructed - """ - from policyengine.core.scoping_strategy import RowFilterStrategy - - # When - simulation = Simulation( - filter_field="place_fips", - filter_value="44000", - ) - - # Then - assert simulation.scoping_strategy is not None - assert isinstance(simulation.scoping_strategy, RowFilterStrategy) - assert simulation.scoping_strategy.variable_name == "place_fips" - assert simulation.scoping_strategy.variable_value == "44000" - class TestUSBuildEntityRelationships: - """Tests for US model _build_entity_relationships method.""" + """US model `_build_entity_relationships`.""" - def test__given_us_dataset__then_entity_relationships_has_all_columns( - self, us_test_dataset - ): - """Given: US dataset with persons and entities - When: Building entity relationships - Then: DataFrame has all entity ID columns - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) + def test__given_us_dataset__then_has_all_entity_id_columns(self, us_test_dataset): + from policyengine.tax_benefit_models.us.model import PolicyEngineUSLatest model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # When entity_rel = model._build_entity_relationships(us_test_dataset) - - # Then - expected_columns = { + assert set(entity_rel.columns) == { "person_id", "household_id", "tax_unit_id", @@ -89,366 +20,45 @@ def test__given_us_dataset__then_entity_relationships_has_all_columns( "family_id", "marital_unit_id", } - assert set(entity_rel.columns) == expected_columns - def test__given_us_dataset__then_entity_relationships_has_correct_row_count( + def test__given_us_dataset__then_row_count_equals_person_count( self, us_test_dataset ): - """Given: US dataset with 6 persons - When: Building entity relationships - Then: DataFrame has 6 rows (one per person) - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) + from policyengine.tax_benefit_models.us.model import PolicyEngineUSLatest model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # When entity_rel = model._build_entity_relationships(us_test_dataset) - - # Then assert len(entity_rel) == 6 - def test__given_us_dataset__then_entity_relationships_preserves_mappings( - self, us_test_dataset - ): - """Given: US dataset where persons 1,2 belong to household 1 - When: Building entity relationships - Then: Mappings are preserved correctly - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) + def test__given_us_dataset__then_mappings_preserved(self, us_test_dataset): + from policyengine.tax_benefit_models.us.model import PolicyEngineUSLatest model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # When entity_rel = model._build_entity_relationships(us_test_dataset) - - # Then person_1_row = entity_rel[entity_rel["person_id"] == 1].iloc[0] assert person_1_row["household_id"] == 1 assert person_1_row["tax_unit_id"] == 1 -class TestUSFilterDatasetByHouseholdVariable: - """Tests for US model _filter_dataset_by_household_variable method.""" - - def test__given_filter_by_place_fips__then_returns_matching_households( - self, us_test_dataset - ): - """Given: US dataset with households in places 44000 and 57000 - When: Filtering by place_fips=44000 - Then: Returns only households in place 44000 - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) - - model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # When - filtered = model._filter_dataset_by_household_variable( - us_test_dataset, - variable_name="place_fips", - variable_value="44000", - ) - - # Then - household_df = pd.DataFrame(filtered.data.household) - assert len(household_df) == 2 - assert all(household_df["place_fips"] == "44000") - - def test__given_filter_by_place_fips__then_preserves_related_persons( - self, us_test_dataset - ): - """Given: US dataset with 4 persons in place 44000 - When: Filtering by place_fips=44000 - Then: Returns all 4 persons in matching households - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) - - model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # When - filtered = model._filter_dataset_by_household_variable( - us_test_dataset, - variable_name="place_fips", - variable_value="44000", - ) - - # Then - person_df = pd.DataFrame(filtered.data.person) - assert len(person_df) == 4 - assert set(person_df["person_id"]) == {1, 2, 3, 4} - - def test__given_filter_by_place_fips__then_preserves_related_entities( - self, us_test_dataset - ): - """Given: US dataset with 2 tax units in place 44000 - When: Filtering by place_fips=44000 - Then: Returns all related entities (tax_unit, spm_unit, etc.) - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) - - model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # When - filtered = model._filter_dataset_by_household_variable( - us_test_dataset, - variable_name="place_fips", - variable_value="44000", - ) - - # Then - assert len(pd.DataFrame(filtered.data.tax_unit)) == 2 - assert len(pd.DataFrame(filtered.data.spm_unit)) == 2 - assert len(pd.DataFrame(filtered.data.family)) == 2 - assert len(pd.DataFrame(filtered.data.marital_unit)) == 2 - - def test__given_no_matching_households__then_raises_value_error( - self, us_test_dataset - ): - """Given: US dataset with no households matching filter - When: Filtering by place_fips=99999 - Then: Raises ValueError - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) - - model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # Then - with pytest.raises(ValueError, match="No households found"): - model._filter_dataset_by_household_variable( - us_test_dataset, - variable_name="place_fips", - variable_value="99999", - ) - - def test__given_invalid_variable_name__then_raises_value_error( - self, us_test_dataset - ): - """Given: US dataset - When: Filtering by non-existent variable - Then: Raises ValueError with helpful message - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) - - model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # Then - with pytest.raises(ValueError, match="not found in household data"): - model._filter_dataset_by_household_variable( - us_test_dataset, - variable_name="nonexistent_var", - variable_value="value", - ) - - def test__given_filtered_dataset__then_has_updated_metadata(self, us_test_dataset): - """Given: US dataset - When: Filtering by place_fips - Then: Filtered dataset has updated id and description - """ - # Given - from policyengine.tax_benefit_models.us.model import ( - PolicyEngineUSLatest, - ) - - model = PolicyEngineUSLatest.__new__(PolicyEngineUSLatest) - - # When - filtered = model._filter_dataset_by_household_variable( - us_test_dataset, - variable_name="place_fips", - variable_value="44000", - ) - - # Then - assert "filtered" in filtered.id - assert "place_fips=44000" in filtered.description - - class TestUKBuildEntityRelationships: - """Tests for UK model _build_entity_relationships method.""" + """UK model `_build_entity_relationships`.""" - def test__given_uk_dataset__then_entity_relationships_has_all_columns( - self, uk_test_dataset - ): - """Given: UK dataset with persons and entities - When: Building entity relationships - Then: DataFrame has all entity ID columns - """ - # Given - from policyengine.tax_benefit_models.uk.model import ( - PolicyEngineUKLatest, - ) + def test__given_uk_dataset__then_has_all_entity_id_columns(self, uk_test_dataset): + from policyengine.tax_benefit_models.uk.model import PolicyEngineUKLatest model = PolicyEngineUKLatest.__new__(PolicyEngineUKLatest) - - # When entity_rel = model._build_entity_relationships(uk_test_dataset) + assert set(entity_rel.columns) == { + "person_id", + "benunit_id", + "household_id", + } - # Then - expected_columns = {"person_id", "benunit_id", "household_id"} - assert set(entity_rel.columns) == expected_columns - - def test__given_uk_dataset__then_entity_relationships_has_correct_row_count( + def test__given_uk_dataset__then_row_count_equals_person_count( self, uk_test_dataset ): - """Given: UK dataset with 6 persons - When: Building entity relationships - Then: DataFrame has 6 rows (one per person) - """ - # Given - from policyengine.tax_benefit_models.uk.model import ( - PolicyEngineUKLatest, - ) + from policyengine.tax_benefit_models.uk.model import PolicyEngineUKLatest model = PolicyEngineUKLatest.__new__(PolicyEngineUKLatest) - - # When entity_rel = model._build_entity_relationships(uk_test_dataset) - - # Then assert len(entity_rel) == 6 - - -class TestUKFilterDatasetByHouseholdVariable: - """Tests for UK model _filter_dataset_by_household_variable method.""" - - def test__given_filter_by_country__then_returns_matching_households( - self, uk_test_dataset - ): - """Given: UK dataset with households in England and Scotland - When: Filtering by country=ENGLAND - Then: Returns only households in England - """ - # Given - from policyengine.tax_benefit_models.uk.model import ( - PolicyEngineUKLatest, - ) - - model = PolicyEngineUKLatest.__new__(PolicyEngineUKLatest) - - # When - filtered = model._filter_dataset_by_household_variable( - uk_test_dataset, - variable_name="country", - variable_value="ENGLAND", - ) - - # Then - household_df = pd.DataFrame(filtered.data.household) - assert len(household_df) == 2 - assert all(household_df["country"] == "ENGLAND") - - def test__given_filter_by_country__then_preserves_related_persons( - self, uk_test_dataset - ): - """Given: UK dataset with 4 persons in England - When: Filtering by country=ENGLAND - Then: Returns all 4 persons in matching households - """ - # Given - from policyengine.tax_benefit_models.uk.model import ( - PolicyEngineUKLatest, - ) - - model = PolicyEngineUKLatest.__new__(PolicyEngineUKLatest) - - # When - filtered = model._filter_dataset_by_household_variable( - uk_test_dataset, - variable_name="country", - variable_value="ENGLAND", - ) - - # Then - person_df = pd.DataFrame(filtered.data.person) - assert len(person_df) == 4 - assert set(person_df["person_id"]) == {1, 2, 3, 4} - - def test__given_filter_by_country__then_preserves_related_benunits( - self, uk_test_dataset - ): - """Given: UK dataset with 2 benunits in England - When: Filtering by country=ENGLAND - Then: Returns all related benunits - """ - # Given - from policyengine.tax_benefit_models.uk.model import ( - PolicyEngineUKLatest, - ) - - model = PolicyEngineUKLatest.__new__(PolicyEngineUKLatest) - - # When - filtered = model._filter_dataset_by_household_variable( - uk_test_dataset, - variable_name="country", - variable_value="ENGLAND", - ) - - # Then - assert len(pd.DataFrame(filtered.data.benunit)) == 2 - - def test__given_no_matching_households__then_raises_value_error( - self, uk_test_dataset - ): - """Given: UK dataset with no households matching filter - When: Filtering by country=WALES - Then: Raises ValueError - """ - # Given - from policyengine.tax_benefit_models.uk.model import ( - PolicyEngineUKLatest, - ) - - model = PolicyEngineUKLatest.__new__(PolicyEngineUKLatest) - - # Then - with pytest.raises(ValueError, match="No households found"): - model._filter_dataset_by_household_variable( - uk_test_dataset, - variable_name="country", - variable_value="WALES", - ) - - def test__given_filtered_dataset__then_has_updated_metadata(self, uk_test_dataset): - """Given: UK dataset - When: Filtering by country - Then: Filtered dataset has updated id and description - """ - # Given - from policyengine.tax_benefit_models.uk.model import ( - PolicyEngineUKLatest, - ) - - model = PolicyEngineUKLatest.__new__(PolicyEngineUKLatest) - - # When - filtered = model._filter_dataset_by_household_variable( - uk_test_dataset, - variable_name="country", - variable_value="ENGLAND", - ) - - # Then - assert "filtered" in filtered.id - assert "country=ENGLAND" in filtered.description diff --git a/tests/test_region.py b/tests/test_region.py index e13d5b5e..fa54124a 100644 --- a/tests/test_region.py +++ b/tests/test_region.py @@ -43,18 +43,19 @@ def test__given_dataset_path__then_region_has_dedicated_dataset(self): assert region.state_code == "CA" assert not region.requires_filter - def test__given_filter_configuration__then_region_requires_filter(self): - """Given: Region with requires_filter=True and filter fields + def test__given_scoping_strategy__then_region_requires_filter(self): + """Given: Region with a RowFilterStrategy on the parent dataset When: Creating the Region - Then: Region is configured for filtering from parent + Then: Region.requires_filter is derived from scoping_strategy presence """ - # Given (using fixture) + from policyengine.core.scoping_strategy import RowFilterStrategy + region = FILTER_REGION - # Then assert region.requires_filter is True - assert region.filter_field == "place_fips" - assert region.filter_value == "57000" + assert isinstance(region.scoping_strategy, RowFilterStrategy) + assert region.scoping_strategy.variable_name == "place_fips" + assert region.scoping_strategy.variable_value == "57000" def test__given_same_codes__then_regions_are_equal(self): """Given: Two regions with the same code diff --git a/tests/test_scoping_strategy.py b/tests/test_scoping_strategy.py index a7a7200b..334cad1b 100644 --- a/tests/test_scoping_strategy.py +++ b/tests/test_scoping_strategy.py @@ -265,39 +265,17 @@ def test__given_explicit_strategy__then_simulation_stores_it(self): assert sim.scoping_strategy is not None assert isinstance(sim.scoping_strategy, RowFilterStrategy) - def test__given_legacy_filter_fields__then_auto_constructs_row_filter( - self, - ): - sim = Simulation( - filter_field="place_fips", - filter_value="44000", - ) - assert sim.scoping_strategy is not None - assert isinstance(sim.scoping_strategy, RowFilterStrategy) - assert sim.scoping_strategy.variable_name == "place_fips" - assert sim.scoping_strategy.variable_value == "44000" - - def test__given_explicit_strategy_and_legacy_fields__then_explicit_wins( - self, - ): - explicit = WeightReplacementStrategy( + def test__given_weight_replacement__then_simulation_stores_it(self): + strategy = WeightReplacementStrategy( weight_matrix_bucket="bucket", weight_matrix_key="key.h5", lookup_csv_bucket="bucket", lookup_csv_key="lookup.csv", region_code="E14001234", ) - sim = Simulation( - scoping_strategy=explicit, - filter_field="household_weight", - filter_value="E14001234", - ) + sim = Simulation(scoping_strategy=strategy) assert isinstance(sim.scoping_strategy, WeightReplacementStrategy) - def test__given_only_filter_field_no_value__then_no_auto_construct(self): - sim = Simulation(filter_field="place_fips") - assert sim.scoping_strategy is None - # Fixtures for scoping strategy tests @pytest.fixture diff --git a/tests/test_uk_regions.py b/tests/test_uk_regions.py index 57a55992..56f5a5fd 100644 --- a/tests/test_uk_regions.py +++ b/tests/test_uk_regions.py @@ -97,8 +97,8 @@ def test__given_england_region__then_filters_from_national(self): assert england.region_type == "country" assert england.parent_code == "uk" assert england.requires_filter - assert england.filter_field == "country" - assert england.filter_value == "ENGLAND" + assert england.scoping_strategy.variable_name == "country" + assert england.scoping_strategy.variable_value == "ENGLAND" assert england.dataset_path is None def test__given_country_regions__then_have_row_filter_strategy(self): @@ -126,7 +126,7 @@ def test__given_scotland_region__then_filters_from_national(self): assert scotland is not None assert scotland.label == "Scotland" assert scotland.requires_filter - assert scotland.filter_value == "SCOTLAND" + assert scotland.scoping_strategy.variable_value == "SCOTLAND" def test__given_wales_region__then_filters_from_national(self): """Given: Wales country region @@ -140,7 +140,7 @@ def test__given_wales_region__then_filters_from_national(self): assert wales is not None assert wales.label == "Wales" assert wales.requires_filter - assert wales.filter_value == "WALES" + assert wales.scoping_strategy.variable_value == "WALES" def test__given_northern_ireland_region__then_filters_from_national(self): """Given: Northern Ireland country region @@ -154,7 +154,7 @@ def test__given_northern_ireland_region__then_filters_from_national(self): assert ni is not None assert ni.label == "Northern Ireland" assert ni.requires_filter - assert ni.filter_value == "NORTHERN_IRELAND" + assert ni.scoping_strategy.variable_value == "NORTHERN_IRELAND" def test__given_uk_national__then_children_are_countries(self): """Given: UK national region diff --git a/tests/test_us_regions.py b/tests/test_us_regions.py index 079ce1c5..7c038556 100644 --- a/tests/test_us_regions.py +++ b/tests/test_us_regions.py @@ -210,8 +210,9 @@ def test__given_los_angeles_region__then_has_correct_format(self): assert la.region_type == "place" assert la.parent_code == "state/ca" assert la.requires_filter - assert la.filter_field == "place_fips" - assert la.filter_value == "44000" + assert la.scoping_strategy is not None + assert la.scoping_strategy.variable_name == "place_fips" + assert la.scoping_strategy.variable_value == "44000" assert la.state_code == "CA" assert la.dataset_path is None # No dedicated dataset