From e279471b448a0402b65ecf91a06a6b1e9c32b9b2 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 28 May 2026 01:55:20 -0400 Subject: [PATCH] Add Arch target core helpers --- pyproject.toml | 2 +- src/microplex/__init__.py | 10 +- src/microplex/targets/__init__.py | 40 +++ src/microplex/targets/arch.py | 171 +++++++++++++ src/microplex/targets/rollups.py | 353 ++++++++++++++++++++++++++ src/microplex/tax_units.py | 403 ++++++++++++++++++++++++++++++ tests/targets/test_arch.py | 111 ++++++++ tests/targets/test_rollups.py | 137 ++++++++++ tests/test_tax_units.py | 162 ++++++++++++ 9 files changed, 1386 insertions(+), 3 deletions(-) create mode 100644 src/microplex/targets/arch.py create mode 100644 src/microplex/targets/rollups.py create mode 100644 src/microplex/tax_units.py create mode 100644 tests/targets/test_arch.py create mode 100644 tests/targets/test_rollups.py create mode 100644 tests/test_tax_units.py diff --git a/pyproject.toml b/pyproject.toml index ef78ec0..70a07fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "microplex" -version = "0.1.0" +version = "0.2.0" description = "Microdata synthesis and reweighting using normalizing flows" readme = "README.md" license = "MIT" diff --git a/src/microplex/__init__.py b/src/microplex/__init__.py index 65fad0e..2c278b8 100644 --- a/src/microplex/__init__.py +++ b/src/microplex/__init__.py @@ -65,7 +65,7 @@ CategoricalModel, DiscreteModelCollection, ) -from microplex.flows import AffineCouplingLayer, ConditionalMAF, MADE +from microplex.flows import MADE, AffineCouplingLayer, ConditionalMAF from microplex.fusion import ( FusionConfig, FusionResult, @@ -84,6 +84,10 @@ create_synthesizer, ) from microplex.synthesizer import Synthesizer +from microplex.tax_units import ( + PreservedTaxUnitTables, + build_preserved_tax_unit_tables, +) from microplex.transforms import ( LogTransform, MultiVariableTransformer, @@ -105,13 +109,15 @@ # pipelines or custom loss functions. DefaultSparseCalibrator = SparseCalibrator -__version__ = "0.1.0" +__version__ = "0.2.0" __all__ = [ # Core synthesis "Synthesizer", "HierarchicalSynthesizer", "HouseholdSchema", + "PreservedTaxUnitTables", + "build_preserved_tax_unit_tables", # Calibration "Reweighter", "Calibrator", diff --git a/src/microplex/targets/__init__.py b/src/microplex/targets/__init__.py index a904989..9964a88 100644 --- a/src/microplex/targets/__init__.py +++ b/src/microplex/targets/__init__.py @@ -1,5 +1,17 @@ """Target primitives for microplex.""" +from microplex.targets.arch import ( + ARCH_CONSUMER_FACT_SCHEMA_VERSION, + ArchConsumerFact, + arch_consumer_fact_concept, + arch_consumer_fact_numeric_value, + arch_consumer_fact_period, + arch_consumer_fact_source_record_id, + iter_arch_consumer_facts, + load_arch_consumer_fact_jsonl_rows, + load_arch_consumer_facts, + mapping_value, +) from microplex.targets.artifacts import ( BenchmarkArtifactValidationResult, assert_valid_benchmark_artifact_manifest, @@ -65,6 +77,16 @@ reweight_to_target_constraints, sparse_constraint_abs_rel_error, ) +from microplex.targets.rollups import ( + TabularRollupSpec, + TabularRollupTargetProvider, + as_string_tuple, + build_tabular_rollup_targets, + normalize_rollup_id, + resolve_rollup_keys, + tabular_rollup_target, + target_name_fragment, +) from microplex.targets.spec import ( FilterOperator, TargetAggregation, @@ -83,6 +105,16 @@ "BenchmarkArtifactValidationResult", "validate_benchmark_artifact_manifest", "assert_valid_benchmark_artifact_manifest", + "ARCH_CONSUMER_FACT_SCHEMA_VERSION", + "ArchConsumerFact", + "arch_consumer_fact_concept", + "arch_consumer_fact_numeric_value", + "arch_consumer_fact_period", + "arch_consumer_fact_source_record_id", + "iter_arch_consumer_facts", + "load_arch_consumer_fact_jsonl_rows", + "load_arch_consumer_facts", + "mapping_value", "TargetProvider", "TargetQuery", "StaticTargetProvider", @@ -136,4 +168,12 @@ "numeric_series", "TargetSet", "TargetSpec", + "TabularRollupSpec", + "TabularRollupTargetProvider", + "as_string_tuple", + "build_tabular_rollup_targets", + "normalize_rollup_id", + "resolve_rollup_keys", + "tabular_rollup_target", + "target_name_fragment", ] diff --git a/src/microplex/targets/arch.py b/src/microplex/targets/arch.py new file mode 100644 index 0000000..7392241 --- /dev/null +++ b/src/microplex/targets/arch.py @@ -0,0 +1,171 @@ +"""Neutral helpers for Arch target artifacts.""" + +from __future__ import annotations + +import json +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +ARCH_CONSUMER_FACT_SCHEMA_VERSION = "arch.consumer_fact.v1" + + +@dataclass(frozen=True) +class ArchConsumerFact: + """Neutral view over one Arch consumer-contract fact row.""" + + row: Mapping[str, Any] + path: str | None = None + line_number: int | None = None + + @property + def concept(self) -> str | None: + """Return the canonical or observed concept for this fact.""" + return arch_consumer_fact_concept(self.row) + + @property + def period(self) -> int: + """Return the calendar/model year represented by this fact.""" + return arch_consumer_fact_period(self.row) + + @property + def value(self) -> float: + """Return the fact's numeric value.""" + return arch_consumer_fact_numeric_value(self.row.get("value")) + + @property + def geography(self) -> Mapping[str, Any]: + """Return the fact geography payload.""" + return mapping_value(self.row.get("geography")) + + @property + def source(self) -> Mapping[str, Any]: + """Return the source metadata payload.""" + return mapping_value(self.row.get("source")) + + @property + def source_record_id(self) -> str | None: + """Return the source record ID from lineage metadata, when present.""" + return arch_consumer_fact_source_record_id(self.row) + + +def load_arch_consumer_fact_jsonl_rows( + paths: Iterable[str | Path], + *, + period: int | None = None, + schema_version: str = ARCH_CONSUMER_FACT_SCHEMA_VERSION, +) -> tuple[dict[str, Any], ...]: + """Load validated Arch consumer fact JSONL rows from one or more files.""" + rows: list[dict[str, Any]] = [] + for pathlike in paths: + path = Path(pathlike) + for fact in iter_arch_consumer_facts( + path, + period=period, + schema_version=schema_version, + ): + rows.append(dict(fact.row)) + return tuple(rows) + + +def load_arch_consumer_facts( + paths: Iterable[str | Path], + *, + period: int | None = None, + schema_version: str = ARCH_CONSUMER_FACT_SCHEMA_VERSION, +) -> tuple[ArchConsumerFact, ...]: + """Load validated Arch consumer facts from one or more JSONL files.""" + facts: list[ArchConsumerFact] = [] + for path in paths: + facts.extend( + iter_arch_consumer_facts( + path, + period=period, + schema_version=schema_version, + ) + ) + return tuple(facts) + + +def iter_arch_consumer_facts( + pathlike: str | Path, + *, + period: int | None = None, + schema_version: str = ARCH_CONSUMER_FACT_SCHEMA_VERSION, +) -> Iterable[ArchConsumerFact]: + """Yield validated Arch consumer facts from one JSONL file.""" + path = Path(pathlike) + with path.open() as file: + for line_number, line in enumerate(file, start=1): + if not line.strip(): + continue + row = json.loads(line) + observed_schema_version = row.get("schema_version") + if observed_schema_version != schema_version: + raise ValueError( + "Unsupported Arch consumer fact schema " + f"{observed_schema_version!r} in {path} line {line_number}; " + f"expected {schema_version!r}." + ) + if period is not None and arch_consumer_fact_period(row) != int(period): + continue + yield ArchConsumerFact( + row=row, + path=str(path), + line_number=line_number, + ) + + +def arch_consumer_fact_concept(row: Mapping[str, Any]) -> str | None: + """Return a row's canonical concept, falling back to source concept.""" + concept_alignment = mapping_value(row.get("concept_alignment")) + observed_measure = mapping_value(row.get("observed_measure")) + concept = concept_alignment.get("canonical_concept") or observed_measure.get( + "source_concept" + ) + return str(concept) if concept is not None else None + + +def arch_consumer_fact_period(row: Mapping[str, Any]) -> int: + """Return a consumer fact period as an integer year.""" + period = mapping_value(row.get("period")) + value = period["value"] + if period.get("type") == "month" and isinstance(value, str): + return int(value.split("-", maxsplit=1)[0]) + return int(value) + + +def arch_consumer_fact_source_record_id(row: Mapping[str, Any]) -> str | None: + """Return a source record ID from a consumer fact lineage payload.""" + lineage = mapping_value(row.get("lineage")) + source_record_id = lineage.get("source_record_id") + return str(source_record_id) if source_record_id is not None else None + + +def arch_consumer_fact_numeric_value(value: Any) -> float: + """Return a numeric consumer fact value.""" + if isinstance(value, bool) or value is None: + raise ValueError(f"Arch consumer fact value is not numeric: {value!r}") + if isinstance(value, (int, float, str)): + return float(value) + raise ValueError(f"Arch consumer fact value is not numeric: {value!r}") + + +def mapping_value(value: Any) -> Mapping[str, Any]: + """Return a mapping payload, or an empty mapping for malformed/empty values.""" + return value if isinstance(value, Mapping) else {} + + +__all__ = [ + "ARCH_CONSUMER_FACT_SCHEMA_VERSION", + "ArchConsumerFact", + "arch_consumer_fact_concept", + "arch_consumer_fact_numeric_value", + "arch_consumer_fact_period", + "arch_consumer_fact_source_record_id", + "iter_arch_consumer_facts", + "load_arch_consumer_fact_jsonl_rows", + "load_arch_consumer_facts", + "mapping_value", +] diff --git a/src/microplex/targets/rollups.py b/src/microplex/targets/rollups.py new file mode 100644 index 0000000..b74c21d --- /dev/null +++ b/src/microplex/targets/rollups.py @@ -0,0 +1,353 @@ +"""Generic target providers for tabular rollup artifacts.""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable, Mapping +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import pandas as pd + +from microplex.core import EntityType +from microplex.targets.provider import TargetQuery, apply_target_query +from microplex.targets.spec import ( + FilterOperator, + TargetAggregation, + TargetFilter, + TargetSet, + TargetSpec, +) + + +@dataclass(frozen=True) +class TabularRollupSpec: + """One rollup target family available from a tabular artifact.""" + + geo_level: str + source_column: str | None + filter_feature: str | None + group_name: str + name_prefix: str + metadata: Mapping[str, Any] = field(default_factory=dict) + + +class TabularRollupTargetProvider: + """Build target specs by grouping an input table to configured rollups.""" + + def __init__( + self, + data: pd.DataFrame | None = None, + *, + data_path: str | Path | None = None, + data_loader: Callable[[str | Path | None], pd.DataFrame] | None = None, + prepare_data: Callable[[pd.DataFrame], pd.DataFrame] | None = None, + rollups: Mapping[str, TabularRollupSpec], + value_column: str, + variable: str, + entity: EntityType | str, + period: int | str, + source: str | None = None, + units: str | None = None, + aggregation: TargetAggregation | str = TargetAggregation.COUNT, + measure: str | None = None, + default_geo_levels: Iterable[str] | None = None, + variable_aliases: Iterable[str] = (), + base_metadata: Mapping[str, Any] | None = None, + min_value: float | None = None, + normalize_geographic_id: Callable[[Any], str] | None = None, + ) -> None: + self._data = data + self.data_path = Path(data_path) if data_path is not None else None + self.data_loader = data_loader or _read_parquet + self.prepare_data = prepare_data + self.rollups = dict(rollups) + self.value_column = value_column + self.variable = variable + self.entity = entity + self.period = period + self.source = source + self.units = units + self.aggregation = aggregation + self.measure = measure + self.default_geo_levels = tuple(default_geo_levels or self.rollups) + self.variable_aliases = tuple(variable_aliases) + self.base_metadata = dict(base_metadata or {}) + self.min_value = min_value + self.normalize_geographic_id = normalize_geographic_id or normalize_rollup_id + + def load_target_set(self, query: TargetQuery | None = None) -> TargetSet: + """Load tabular rollup targets for the requested provider filters.""" + query = query or TargetQuery() + provider_filters = dict(query.provider_filters) + variables = as_string_tuple(provider_filters.get("variables")) + allowed_variables = {self.variable, *self.variable_aliases} + if variables and not any( + variable in allowed_variables for variable in variables + ): + return TargetSet([]) + + geo_levels = resolve_rollup_keys( + provider_filters.get("geo_levels") + if "geo_levels" in provider_filters + else provider_filters.get("geographic_levels"), + rollups=self.rollups, + default_geo_levels=self.default_geo_levels, + ) + geographic_ids = as_string_tuple(provider_filters.get("geographic_ids")) + targets = build_tabular_rollup_targets( + self._load_data(), + rollups=self.rollups, + value_column=self.value_column, + variable=self.variable, + entity=self.entity, + period=self.period, + source=self.source, + units=self.units, + aggregation=self.aggregation, + measure=self.measure, + geo_levels=geo_levels, + geographic_ids=geographic_ids or None, + base_metadata=self.base_metadata, + min_value=self.min_value, + normalize_geographic_id=self.normalize_geographic_id, + ) + return apply_target_query( + TargetSet(targets), + TargetQuery( + period=query.period, + entity=query.entity, + names=query.names, + metadata_filters=query.metadata_filters, + ), + ) + + def _load_data(self) -> pd.DataFrame: + data = ( + self._data.copy() + if self._data is not None + else self.data_loader(self.data_path) + ) + if self.prepare_data is not None: + data = self.prepare_data(data) + return data + + +def build_tabular_rollup_targets( + data: pd.DataFrame, + *, + rollups: Mapping[str, TabularRollupSpec], + value_column: str, + variable: str, + entity: EntityType | str, + period: int | str, + source: str | None = None, + units: str | None = None, + aggregation: TargetAggregation | str = TargetAggregation.COUNT, + measure: str | None = None, + geo_levels: Iterable[str] | None = None, + geographic_ids: Iterable[str] | None = None, + base_metadata: Mapping[str, Any] | None = None, + min_value: float | None = None, + normalize_geographic_id: Callable[[Any], str] | None = None, +) -> list[TargetSpec]: + """Roll a tabular artifact into canonical target specs.""" + if value_column not in data.columns: + raise ValueError(f"Tabular rollup data must include {value_column!r}") + normalize_id = normalize_geographic_id or normalize_rollup_id + rollup_keys = resolve_rollup_keys( + geo_levels, + rollups=rollups, + default_geo_levels=tuple(rollups), + ) + selected_geographic_ids = ( + {normalize_id(value) for value in geographic_ids} + if geographic_ids is not None + else None + ) + values = data.copy() + values[value_column] = pd.to_numeric(values[value_column], errors="coerce") + targets: list[TargetSpec] = [] + for rollup_key in rollup_keys: + rollup = rollups[rollup_key] + if rollup.source_column is None: + if selected_geographic_ids: + continue + value = float(values[value_column].sum()) + if min_value is not None and value <= min_value: + continue + targets.append( + tabular_rollup_target( + rollup=rollup, + geographic_id=None, + value=value, + variable=variable, + entity=entity, + period=period, + source=source, + units=units, + aggregation=aggregation, + measure=measure, + base_metadata=base_metadata, + ) + ) + continue + if rollup.source_column not in values.columns: + continue + rollup_values = values[[rollup.source_column, value_column]].dropna( + subset=[rollup.source_column, value_column] + ) + rollup_values[rollup.source_column] = rollup_values[rollup.source_column].map( + normalize_id + ) + rollup_values = rollup_values[rollup_values[rollup.source_column].astype(bool)] + if selected_geographic_ids is not None: + rollup_values = rollup_values[ + rollup_values[rollup.source_column].isin(selected_geographic_ids) + ] + grouped = rollup_values.groupby(rollup.source_column, dropna=True)[ + value_column + ].sum() + for geographic_id, value in grouped.items(): + target_value = float(value) + if min_value is not None and target_value <= min_value: + continue + targets.append( + tabular_rollup_target( + rollup=rollup, + geographic_id=str(geographic_id), + value=target_value, + variable=variable, + entity=entity, + period=period, + source=source, + units=units, + aggregation=aggregation, + measure=measure, + base_metadata=base_metadata, + ) + ) + return targets + + +def tabular_rollup_target( + *, + rollup: TabularRollupSpec, + geographic_id: str | None, + value: float, + variable: str, + entity: EntityType | str, + period: int | str, + source: str | None = None, + units: str | None = None, + aggregation: TargetAggregation | str = TargetAggregation.COUNT, + measure: str | None = None, + base_metadata: Mapping[str, Any] | None = None, +) -> TargetSpec: + """Build one canonical target from a tabular rollup cell.""" + filters = ( + () + if geographic_id is None or rollup.filter_feature is None + else ( + TargetFilter( + feature=rollup.filter_feature, + operator=FilterOperator.EQ, + value=geographic_id, + ), + ) + ) + name = rollup.name_prefix + if geographic_id is not None: + name = f"{name}_{target_name_fragment(geographic_id)}" + return TargetSpec( + name=name, + entity=entity, + value=value, + period=period, + measure=measure, + aggregation=aggregation, + filters=filters, + source=source, + units=units, + metadata={ + "variable": variable, + "geo_level": rollup.geo_level, + "geographic_id": geographic_id, + "target_group": rollup.group_name, + "tabular_rollup": True, + **dict(base_metadata or {}), + **dict(rollup.metadata), + }, + ) + + +def resolve_rollup_keys( + geo_levels: Iterable[str] | Any | None, + *, + rollups: Mapping[str, TabularRollupSpec], + default_geo_levels: Iterable[str], +) -> tuple[str, ...]: + """Resolve requested rollup keys and validate them against available specs.""" + requested = ( + as_string_tuple(geo_levels) + if geo_levels is not None + else tuple(default_geo_levels) + ) + if requested == ("all",): + requested = tuple(rollups) + unknown = sorted(set(requested) - set(rollups)) + if unknown: + raise ValueError(f"Unsupported tabular rollup geo levels: {unknown}") + return requested + + +def as_string_tuple(value: Any) -> tuple[str, ...]: + """Coerce a scalar or iterable provider filter value to a string tuple.""" + if value is None: + return () + if isinstance(value, str): + return (value,) + try: + return tuple(str(item) for item in value) + except TypeError: + return (str(value),) + + +def normalize_rollup_id(value: Any) -> str: + """Normalize common pandas scalar values into stable target geography IDs.""" + if pd.isna(value): + return "" + text = str(value).strip() + if not text: + return "" + if text.endswith(".0") and text[:-2].isdigit(): + return text[:-2] + return text + + +def target_name_fragment(geographic_id: str) -> str: + """Return a stable target-name fragment for one geography ID.""" + return ( + geographic_id.replace(" ", "_") + .replace("-", "_") + .replace("/", "_") + .replace(".", "_") + ) + + +def _read_parquet(path: str | Path | None) -> pd.DataFrame: + if path is None: + raise ValueError("A data_path or custom data_loader is required") + return pd.read_parquet(path) + + +__all__ = [ + "TabularRollupSpec", + "TabularRollupTargetProvider", + "as_string_tuple", + "build_tabular_rollup_targets", + "normalize_rollup_id", + "resolve_rollup_keys", + "tabular_rollup_target", + "target_name_fragment", +] diff --git a/src/microplex/tax_units.py b/src/microplex/tax_units.py new file mode 100644 index 0000000..f670ed2 --- /dev/null +++ b/src/microplex/tax_units.py @@ -0,0 +1,403 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +import numpy as np +import pandas as pd + + +@dataclass(frozen=True) +class PreservedTaxUnitTables: + tax_units: pd.DataFrame + persons: pd.DataFrame + preserved_households: set[Any] + + +def build_preserved_tax_unit_tables( + persons: pd.DataFrame, + *, + household_id_col: str = "household_id", + tax_unit_id_col: str = "tax_unit_id", + person_id_col: str = "person_id", + income_col: str = "income", + complete_households_only: bool = True, + deduplicated_sum_columns: Sequence[str] = ( + "health_savings_account_ald", + "unrecaptured_section_1250_gain", + ), +) -> PreservedTaxUnitTables: + """Build tax-unit rows from existing person-level tax-unit IDs. + + Use this when tax-unit membership already exists and should be preserved. + It normalizes household/tax-unit pairs to dense integer IDs, then + aggregates with NumPy reductions instead of pandas groupby/MultiIndex + construction. Use :class:`microplex.hierarchical.TaxUnitOptimizer` when + memberships need to be inferred from household composition. + """ + + _require_columns( + persons, + [person_id_col, household_id_col, tax_unit_id_col], + ) + if persons.empty: + return PreservedTaxUnitTables( + tax_units=_empty_tax_units(), + persons=persons.copy(), + preserved_households=set(), + ) + + has_tax_unit_id = _non_missing_id_mask(persons[tax_unit_id_col]) + if complete_households_only: + row_mask, preserved_households = _complete_household_row_mask( + persons[household_id_col], + has_tax_unit_id, + ) + else: + if not bool(has_tax_unit_id.all()): + raise ValueError( + "tax_unit_id contains missing values; pass " + "complete_households_only=True to drop incomplete households" + ) + row_mask = np.ones(len(persons), dtype=bool) + preserved_households = set(persons[household_id_col].tolist()) + + person_rows = persons.loc[row_mask].copy() + if person_rows.empty: + return PreservedTaxUnitTables( + tax_units=_empty_tax_units(), + persons=person_rows, + preserved_households=set(), + ) + + unit_codes = _factorize_household_entity_pairs( + person_rows[household_id_col], + person_rows[tax_unit_id_col], + ) + n_units = int(unit_codes.max()) + 1 + first_pos = _first_positions(unit_codes, n_units) + tax_unit_ids = np.arange(1, n_units + 1, dtype=np.int64) + + person_rows[tax_unit_id_col] = tax_unit_ids[unit_codes] + household_ids = person_rows[household_id_col].to_numpy() + tax_units = pd.DataFrame( + { + "tax_unit_id": tax_unit_ids, + "household_id": household_ids[first_pos], + "total_income": _group_sum( + _numeric_column(person_rows, income_col), + unit_codes, + n_units, + ), + "tax_liability": np.zeros(n_units, dtype=float), + "n_dependents": _tax_unit_dependent_count( + person_rows, + unit_codes, + n_units, + ), + "filing_status": _tax_unit_filing_status( + person_rows, + unit_codes, + n_units, + first_pos, + ), + } + ) + + for column in deduplicated_sum_columns: + if column in person_rows.columns: + tax_units[column] = _deduplicating_group_sum( + _numeric_column(person_rows, column), + unit_codes, + n_units, + ) + + _sync_joint_separated_flags(tax_units, person_rows, unit_codes) + _add_mortgage_columns(tax_units, person_rows, unit_codes, n_units) + return PreservedTaxUnitTables( + tax_units=tax_units, + persons=person_rows, + preserved_households=preserved_households, + ) + + +def _require_columns(df: pd.DataFrame, columns: Sequence[str]) -> None: + missing = [column for column in columns if column not in df.columns] + if missing: + joined = ", ".join(sorted(missing)) + raise KeyError(f"Missing required column(s): {joined}") + + +def _empty_tax_units() -> pd.DataFrame: + return pd.DataFrame( + { + "tax_unit_id": np.array([], dtype=np.int64), + "household_id": np.array([], dtype=np.int64), + "total_income": np.array([], dtype=float), + "tax_liability": np.array([], dtype=float), + "n_dependents": np.array([], dtype=np.int64), + "filing_status": np.array([], dtype=object), + } + ) + + +def _non_missing_id_mask(values: pd.Series) -> np.ndarray: + mask = values.notna().to_numpy(dtype=bool, copy=True) + if ( + pd.api.types.is_object_dtype(values.dtype) + or pd.api.types.is_string_dtype(values.dtype) + ): + text = values.astype("string") + mask &= text.str.strip().ne("").fillna(False).to_numpy(dtype=bool) + return mask + + +def _complete_household_row_mask( + household_ids: pd.Series, + has_entity_id: np.ndarray, +) -> tuple[np.ndarray, set[Any]]: + household_codes, household_uniques = pd.factorize( + household_ids, + sort=False, + ) + valid_household = household_codes >= 0 + complete_by_household = np.ones(len(household_uniques), dtype=bool) + np.logical_and.at( + complete_by_household, + household_codes[valid_household], + has_entity_id[valid_household], + ) + + row_mask = np.zeros(len(household_ids), dtype=bool) + row_mask[valid_household] = complete_by_household[ + household_codes[valid_household] + ] + preserved = set(pd.Series(household_uniques[complete_by_household]).tolist()) + return row_mask, preserved + + +def _factorize_household_entity_pairs( + household_ids: pd.Series, + entity_ids: pd.Series, +) -> np.ndarray: + household_codes, household_uniques = pd.factorize( + household_ids, + sort=False, + ) + entity_codes, entity_uniques = pd.factorize(entity_ids, sort=False) + if bool((household_codes < 0).any() or (entity_codes < 0).any()): + raise ValueError("Cannot factorize missing household/entity IDs") + + base = int(len(entity_uniques)) + if base == 0: + return np.array([], dtype=np.int64) + if len(household_uniques) > np.iinfo(np.int64).max // base: + raise OverflowError("Too many household/entity pairs to encode") + + pair_keys = household_codes.astype(np.int64) * base + entity_codes + pair_codes, _ = pd.factorize(pair_keys, sort=False) + return pair_codes.astype(np.int64, copy=False) + + +def _first_positions(group_codes: np.ndarray, n_groups: int) -> np.ndarray: + _, first_pos = np.unique(group_codes, return_index=True) + if len(first_pos) != n_groups: + raise ValueError("Group codes are not dense") + return first_pos + + +def _numeric_column( + df: pd.DataFrame, + column: str, + *, + default: float = 0.0, +) -> np.ndarray: + if column not in df.columns: + return np.full(len(df), default, dtype=float) + return ( + pd.to_numeric(df[column], errors="coerce") + .fillna(default) + .to_numpy(dtype=float) + ) + + +def _group_sum( + values: np.ndarray, + group_codes: np.ndarray, + n_groups: int, +) -> np.ndarray: + return np.bincount(group_codes, weights=values, minlength=n_groups) + + +def _group_max( + values: np.ndarray, + group_codes: np.ndarray, + n_groups: int, +) -> np.ndarray: + result = np.full(n_groups, -np.inf, dtype=float) + np.maximum.at(result, group_codes, values) + result[~np.isfinite(result)] = 0.0 + return result + + +def _tax_unit_dependent_count( + person_rows: pd.DataFrame, + unit_codes: np.ndarray, + n_units: int, +) -> np.ndarray: + if "tax_unit_count_dependents" in person_rows.columns: + return _group_max( + _numeric_column(person_rows, "tax_unit_count_dependents"), + unit_codes, + n_units, + ).astype(np.int64) + if "is_tax_unit_dependent" in person_rows.columns: + dependents = _numeric_column( + person_rows, + "is_tax_unit_dependent", + ) > 0 + return np.bincount( + unit_codes, + weights=dependents.astype(np.int64), + minlength=n_units, + ).astype(np.int64) + return np.zeros(n_units, dtype=np.int64) + + +def _tax_unit_filing_status( + person_rows: pd.DataFrame, + unit_codes: np.ndarray, + n_units: int, + first_pos: np.ndarray, +) -> np.ndarray: + if "filing_status" in person_rows.columns: + raw = person_rows["filing_status"].to_numpy(dtype=object)[first_pos] + return _normalize_filing_status(raw) + if "tax_unit_is_joint" in person_rows.columns: + joint = _group_max( + _numeric_column(person_rows, "tax_unit_is_joint"), + unit_codes, + n_units, + ) > 0 + return np.where(joint, "JOINT", "SINGLE") + return np.full(n_units, "SINGLE", dtype=object) + + +def _normalize_filing_status(values: Sequence[Any]) -> np.ndarray: + result = [] + for value in values: + text = str(value).strip().upper() if value is not None else "" + text = text.replace(" ", "_") + if text in {"JOINT", "MARRIED_FILING_JOINTLY"}: + result.append("JOINT") + elif text in {"SEPARATE", "MARRIED_FILING_SEPARATELY"}: + result.append("SEPARATE") + elif text in {"HEAD_OF_HOUSEHOLD", "HEAD"}: + result.append("HEAD_OF_HOUSEHOLD") + elif text in {"SURVIVING_SPOUSE", "WIDOW", "WIDOWER"}: + result.append("SURVIVING_SPOUSE") + else: + result.append("SINGLE") + return np.asarray(result, dtype=object) + + +def _deduplicating_group_sum( + values: np.ndarray, + group_codes: np.ndarray, + n_groups: int, +) -> np.ndarray: + summed = _group_sum(values, group_codes, n_groups) + nonzero = ~np.isclose(values, 0.0) + nonzero_count = np.bincount( + group_codes, + weights=nonzero.astype(np.int64), + minlength=n_groups, + ) + nonzero_codes = group_codes[nonzero] + nonzero_values = values[nonzero] + minimum = np.full(n_groups, np.inf, dtype=float) + maximum = np.full(n_groups, -np.inf, dtype=float) + if len(nonzero_values): + np.minimum.at(minimum, nonzero_codes, nonzero_values) + np.maximum.at(maximum, nonzero_codes, nonzero_values) + + repeated_same_value = ( + nonzero_count > 1 + ) & np.isfinite(minimum) & np.isclose(minimum, maximum) + result = summed.copy() + result[repeated_same_value] = maximum[repeated_same_value] + return result + + +def _sync_joint_separated_flags( + tax_units: pd.DataFrame, + person_rows: pd.DataFrame, + unit_codes: np.ndarray, +) -> None: + if "is_separated" not in person_rows.columns: + return + if not ( + "is_tax_unit_head" in person_rows.columns + or "is_tax_unit_spouse" in person_rows.columns + ): + return + + joint_units = ( + tax_units["filing_status"].astype("string").str.upper().eq("JOINT") + ).to_numpy(dtype=bool) + if not bool(np.any(joint_units)): + return + + head = _numeric_column(person_rows, "is_tax_unit_head") + spouse = _numeric_column(person_rows, "is_tax_unit_spouse") + filer_rows = (head > 0.0) | (spouse > 0.0) + person_rows.loc[filer_rows & joint_units[unit_codes], "is_separated"] = False + + +def _add_mortgage_columns( + tax_units: pd.DataFrame, + person_rows: pd.DataFrame, + unit_codes: np.ndarray, + n_units: int, +) -> None: + interest_column = None + if "first_home_mortgage_interest" in person_rows.columns: + interest_column = "first_home_mortgage_interest" + elif "home_mortgage_interest" in person_rows.columns: + interest_column = "home_mortgage_interest" + if interest_column is None: + return + + interest = _group_sum( + _numeric_column(person_rows, interest_column), + unit_codes, + n_units, + ) + if not bool(np.any(interest > 0)): + return + + tax_units["first_home_mortgage_interest"] = interest + balance = _group_sum( + _numeric_column(person_rows, "first_home_mortgage_balance"), + unit_codes, + n_units, + ) + tax_units["first_home_mortgage_balance"] = np.where( + balance > 0.0, + balance, + np.maximum(interest, 1.0), + ) + + if "first_home_mortgage_origination_year" not in person_rows.columns: + return + years = _numeric_column(person_rows, "first_home_mortgage_origination_year") + positive_pos = np.flatnonzero(years > 0) + if not len(positive_pos): + return + first_positive_pos = np.full(n_units, len(years), dtype=np.int64) + np.minimum.at(first_positive_pos, unit_codes[positive_pos], positive_pos) + has_year = first_positive_pos < len(years) + result = np.zeros(n_units, dtype=float) + result[has_year] = years[first_positive_pos[has_year]] + tax_units["first_home_mortgage_origination_year"] = result diff --git a/tests/targets/test_arch.py b/tests/targets/test_arch.py new file mode 100644 index 0000000..9b02d38 --- /dev/null +++ b/tests/targets/test_arch.py @@ -0,0 +1,111 @@ +"""Tests for neutral Arch target artifact helpers.""" + +from __future__ import annotations + +import json + +import pytest + +from microplex.targets import ( + ArchConsumerFact, + arch_consumer_fact_concept, + arch_consumer_fact_numeric_value, + arch_consumer_fact_period, + arch_consumer_fact_source_record_id, + load_arch_consumer_fact_jsonl_rows, + load_arch_consumer_facts, +) + + +def _consumer_fact( + key: str, + *, + period: dict | None = None, + concept_alignment: dict | None = None, +) -> dict: + return { + "schema_version": "arch.consumer_fact.v1", + "aggregate_fact_key": f"arch.aggregate_fact.v2:{key}", + "semantic_fact_key": f"arch.semantic_fact.v2:{key}", + "value": "123.5", + "period": period or {"type": "calendar_year", "value": 2024}, + "geography": {"level": "country", "id": "0100000US", "name": "US"}, + "observed_measure": { + "source_concept": "publisher.population", + "source_name": "publisher", + "source_table": "Table 1", + "unit": "count", + }, + "concept_alignment": concept_alignment or {}, + "source": {"source_name": "publisher", "source_table": "Table 1"}, + "lineage": { + "source_record_id": f"publisher.{key}", + "source_cell_keys": [f"arch.source_cell.v1:{key}"], + }, + } + + +def test_load_arch_consumer_facts_validates_and_parses_rows(tmp_path) -> None: + path = tmp_path / "consumer_facts.jsonl" + path.write_text( + "\n".join( + [ + json.dumps( + _consumer_fact( + "state", + concept_alignment={ + "canonical_concept": "canonical.population", + }, + ), + sort_keys=True, + ), + "", + json.dumps( + _consumer_fact( + "month", + period={"type": "month", "value": "2025-01"}, + ), + sort_keys=True, + ), + ] + ) + + "\n" + ) + + rows = load_arch_consumer_fact_jsonl_rows((path,), period=2024) + facts = load_arch_consumer_facts((path,)) + + assert len(rows) == 1 + assert len(facts) == 2 + assert isinstance(facts[0], ArchConsumerFact) + assert facts[0].concept == "canonical.population" + assert facts[0].period == 2024 + assert facts[0].value == 123.5 + assert facts[0].source_record_id == "publisher.state" + assert facts[0].path == str(path) + assert facts[0].line_number == 1 + assert facts[1].period == 2025 + + +def test_arch_consumer_fact_accessors_fall_back_to_observed_concept() -> None: + row = _consumer_fact("fallback") + + assert arch_consumer_fact_concept(row) == "publisher.population" + assert arch_consumer_fact_period(row) == 2024 + assert arch_consumer_fact_source_record_id(row) == "publisher.fallback" + assert arch_consumer_fact_numeric_value("42") == 42 + + +def test_load_arch_consumer_facts_rejects_wrong_schema(tmp_path) -> None: + path = tmp_path / "consumer_facts.jsonl" + row = _consumer_fact("bad") + row["schema_version"] = "arch.consumer_fact.v0" + path.write_text(json.dumps(row) + "\n") + + with pytest.raises(ValueError, match="line 1"): + load_arch_consumer_fact_jsonl_rows((path,)) + + +def test_arch_consumer_fact_numeric_value_rejects_bool() -> None: + with pytest.raises(ValueError, match="not numeric"): + arch_consumer_fact_numeric_value(True) diff --git a/tests/targets/test_rollups.py b/tests/targets/test_rollups.py new file mode 100644 index 0000000..064c028 --- /dev/null +++ b/tests/targets/test_rollups.py @@ -0,0 +1,137 @@ +"""Tests for generic tabular rollup target providers.""" + +from __future__ import annotations + +import pandas as pd + +from microplex.core import EntityType +from microplex.targets import ( + TabularRollupSpec, + TabularRollupTargetProvider, + TargetAggregation, + TargetFilter, + TargetQuery, + as_string_tuple, + build_tabular_rollup_targets, +) + +ROLLUPS = { + "national": TabularRollupSpec( + geo_level="national", + source_column=None, + filter_feature=None, + group_name="people_national", + name_prefix="people_national", + ), + "region": TabularRollupSpec( + geo_level="region", + source_column="region_code", + filter_feature="region", + group_name="people_region", + name_prefix="people_region", + ), +} + + +def _rows() -> pd.DataFrame: + return pd.DataFrame( + { + "region_code": ["01", "01", "02", None], + "population": [10, 5, 7, 3], + } + ) + + +def test_build_tabular_rollup_targets_groups_and_filters() -> None: + targets = build_tabular_rollup_targets( + _rows(), + rollups=ROLLUPS, + value_column="population", + variable="person_count", + entity=EntityType.PERSON, + period=2026, + source="publisher", + units="persons", + geo_levels=("national", "region"), + geographic_ids=("01",), + base_metadata={"source_year": 2024}, + ) + + assert [target.name for target in targets] == ["people_region_01"] + target = targets[0] + assert target.value == 15 + assert target.entity is EntityType.PERSON + assert target.aggregation is TargetAggregation.COUNT + assert target.filters == ( + TargetFilter(feature="region", operator="==", value="01"), + ) + assert target.metadata["variable"] == "person_count" + assert target.metadata["geo_level"] == "region" + assert target.metadata["geographic_id"] == "01" + assert target.metadata["target_group"] == "people_region" + assert target.metadata["tabular_rollup"] is True + assert target.metadata["source_year"] == 2024 + + +def test_tabular_rollup_provider_honors_query_and_variable_aliases() -> None: + provider = TabularRollupTargetProvider( + _rows(), + rollups=ROLLUPS, + value_column="population", + variable="person_count", + variable_aliases=("population",), + entity=EntityType.PERSON, + period=2026, + default_geo_levels=("region",), + ) + + target_set = provider.load_target_set( + TargetQuery( + provider_filters={ + "variables": ["population"], + "geographic_levels": ["region"], + "geographic_ids": ["02"], + } + ) + ) + + assert [target.name for target in target_set.targets] == ["people_region_02"] + assert target_set.targets[0].value == 7 + + +def test_tabular_rollup_provider_returns_empty_for_unrelated_variable() -> None: + provider = TabularRollupTargetProvider( + _rows(), + rollups=ROLLUPS, + value_column="population", + variable="person_count", + entity=EntityType.PERSON, + period=2026, + ) + + target_set = provider.load_target_set( + TargetQuery(provider_filters={"variables": ["household_count"]}) + ) + + assert target_set.targets == [] + + +def test_tabular_rollup_targets_keep_zero_values_by_default() -> None: + targets = build_tabular_rollup_targets( + pd.DataFrame({"region_code": ["03"], "population": [0]}), + rollups=ROLLUPS, + value_column="population", + variable="person_count", + entity=EntityType.PERSON, + period=2026, + geo_levels=("region",), + ) + + assert [target.value for target in targets] == [0] + + +def test_as_string_tuple_accepts_scalar_provider_filters() -> None: + assert as_string_tuple(None) == () + assert as_string_tuple("state") == ("state",) + assert as_string_tuple(6) == ("6",) + assert as_string_tuple(["01", 2]) == ("01", "2") diff --git a/tests/test_tax_units.py b/tests/test_tax_units.py new file mode 100644 index 0000000..32f0459 --- /dev/null +++ b/tests/test_tax_units.py @@ -0,0 +1,162 @@ +import numpy as np +import pandas as pd +import pytest + +from microplex.tax_units import build_preserved_tax_unit_tables + + +def test_preserved_tax_unit_tables_split_reused_string_ids_by_household(): + persons = pd.DataFrame( + { + "person_id": [1, 2, 3], + "household_id": [10, 10, 20], + "tax_unit_id": ["100:clone", "100:clone", "100:clone"], + "income": [60_000.0, 15_000.0, 25_000.0], + "tax_unit_is_joint": [1.0, 1.0, 0.0], + "is_tax_unit_dependent": [0.0, 0.0, 0.0], + "health_savings_account_ald": [60.0, 15.0, 5.0], + } + ) + + result = build_preserved_tax_unit_tables(persons) + + assert result.persons["tax_unit_id"].tolist() == [1, 1, 2] + assert result.preserved_households == {10, 20} + tax_units = result.tax_units.sort_values("tax_unit_id").reset_index( + drop=True + ) + assert tax_units["household_id"].tolist() == [10, 20] + assert tax_units["filing_status"].tolist() == ["JOINT", "SINGLE"] + assert tax_units["total_income"].tolist() == [75_000.0, 25_000.0] + assert tax_units["n_dependents"].tolist() == [0, 0] + assert tax_units["health_savings_account_ald"].tolist() == [75.0, 5.0] + + +def test_preserved_tax_unit_tables_drop_incomplete_households(): + persons = pd.DataFrame( + { + "person_id": [1, 2, 3, 4], + "household_id": [10, 10, 20, 20], + "tax_unit_id": ["a", None, "b", "b"], + "income": [10.0, 20.0, 30.0, 40.0], + } + ) + + result = build_preserved_tax_unit_tables(persons) + + assert result.preserved_households == {20} + assert result.persons["person_id"].tolist() == [3, 4] + assert result.tax_units["total_income"].tolist() == [70.0] + + +def test_preserved_tax_unit_tables_use_dependency_count_when_available(): + persons = pd.DataFrame( + { + "person_id": [1, 2, 3], + "household_id": [10, 10, 10], + "tax_unit_id": [100, 100, 100], + "income": [10.0, 0.0, 0.0], + "tax_unit_count_dependents": [2, 2, 2], + } + ) + + result = build_preserved_tax_unit_tables(persons) + + assert result.tax_units["n_dependents"].tolist() == [2] + + +def test_preserved_tax_unit_tables_deduplicate_repeated_unit_values(): + persons = pd.DataFrame( + { + "person_id": [1, 2], + "household_id": [10, 10], + "tax_unit_id": [100, 100], + "income": [10.0, 20.0], + "health_savings_account_ald": [500.0, 500.0], + "unrecaptured_section_1250_gain": [100.0, 25.0], + } + ) + + result = build_preserved_tax_unit_tables(persons) + + row = result.tax_units.iloc[0] + assert row["health_savings_account_ald"] == 500.0 + assert row["unrecaptured_section_1250_gain"] == 125.0 + + +def test_preserved_tax_unit_tables_clear_separated_flag_for_joint_filers(): + persons = pd.DataFrame( + { + "person_id": [1, 2], + "household_id": [10, 10], + "tax_unit_id": ["joint", "joint"], + "income": [10.0, 20.0], + "tax_unit_is_joint": [1.0, 1.0], + "is_tax_unit_head": [1.0, 0.0], + "is_tax_unit_spouse": [0.0, 1.0], + "is_separated": [True, True], + } + ) + + result = build_preserved_tax_unit_tables(persons) + + assert result.tax_units["filing_status"].tolist() == ["JOINT"] + assert result.persons["is_separated"].tolist() == [False, False] + + +def test_preserved_tax_unit_tables_add_mortgage_balance_floor(): + persons = pd.DataFrame( + { + "person_id": [1, 2], + "household_id": [10, 10], + "tax_unit_id": [100, 100], + "income": [10.0, 20.0], + "home_mortgage_interest": [250.0, 50.0], + "first_home_mortgage_origination_year": [0.0, 2020.0], + } + ) + + result = build_preserved_tax_unit_tables(persons) + + row = result.tax_units.iloc[0] + assert row["first_home_mortgage_interest"] == 300.0 + assert row["first_home_mortgage_balance"] == 300.0 + assert row["first_home_mortgage_origination_year"] == 2020.0 + + +def test_preserved_tax_unit_tables_do_not_call_pandas_groupby(monkeypatch): + def fail_groupby(self, *args, **kwargs): + raise AssertionError("groupby should not be used in this hot path") + + monkeypatch.setattr(pd.DataFrame, "groupby", fail_groupby) + monkeypatch.setattr(pd.Series, "groupby", fail_groupby) + persons = pd.DataFrame( + { + "person_id": np.arange(6), + "household_id": [1, 1, 2, 2, 3, 3], + "tax_unit_id": ["a", "a", "a", "a", "b", "b"], + "income": [1, 2, 3, 4, 5, 6], + "is_tax_unit_dependent": [0, 1, 0, 0, 0, 1], + } + ) + + result = build_preserved_tax_unit_tables(persons) + + assert result.persons["tax_unit_id"].tolist() == [1, 1, 2, 2, 3, 3] + assert result.tax_units["total_income"].tolist() == [3.0, 7.0, 11.0] + + +def test_preserved_tax_unit_tables_require_complete_ids_when_not_dropping(): + persons = pd.DataFrame( + { + "person_id": [1], + "household_id": [10], + "tax_unit_id": [None], + } + ) + + with pytest.raises(ValueError): + build_preserved_tax_unit_tables( + persons, + complete_households_only=False, + )