Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/three-pass-disable-salt.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Run `--disable-salt` in three passes so PE's federal Schedule A keeps state-tax SALT (matching TAXSIM-35's single-pass methodology) while state computation remains SALT-disabled. Eliminates the iterated-vs-single-pass state-tax mismatch in PE-vs-TAXSIM federal comparisons.
134 changes: 116 additions & 18 deletions policyengine_taxsim/runners/policyengine_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,9 @@ def __init__(
self.logs = logs
self.disable_salt = disable_salt
self.assume_w2_wages = assume_w2_wages
# Per-row state_and_local_sales_or_income_tax override (Pass B of
# three-pass --disable-salt). Maps taxsimid -> dollar value.
self._state_tax_override = None
self.mappings = load_variable_mappings()

def _ensure_required_columns(self, df):
Expand Down Expand Up @@ -957,14 +960,31 @@ def _run_chunk(self, chunk_df: pd.DataFrame) -> pd.DataFrame:
dataset.generate()
sim = Microsimulation(dataset=dataset)

if self.disable_salt:
# Resolve the state_and_local_sales_or_income_tax override for
# this chunk. Possible sources, in priority order:
# 1. self._state_tax_override (Pass B of three-pass: per-row
# values produced by Pass A, keyed by taxsimid)
# 2. self.disable_salt (zero out for state-only computation)
salt_override = None
if self._state_tax_override is not None:
ids = chunk_df["taxsimid"].astype(float).astype(int).values
# Look each id up in the override map; fall back to 0 if
# the id is unexpectedly missing.
salt_override = np.array(
[self._state_tax_override.get(int(i), 0.0) for i in ids],
dtype=float,
)
elif self.disable_salt:
salt_override = np.zeros(len(chunk_df), dtype=float)

if salt_override is not None:
years = sorted(set(chunk_df["year"].unique()))
for year in years:
year_mask = chunk_df["year"] == year
n_year_records = year_mask.sum()
year_values = salt_override[year_mask.values]
sim.set_input(
variable_name="state_and_local_sales_or_income_tax",
value=np.zeros(n_year_records),
value=year_values,
period=str(
int(year)
if isinstance(year, (float, np.floating))
Expand Down Expand Up @@ -993,18 +1013,40 @@ def _run_chunk(self, chunk_df: pd.DataFrame) -> pd.DataFrame:
finally:
dataset.cleanup()

def run(self, show_progress: bool = True, on_progress=None) -> pd.DataFrame:
"""
Run PolicyEngine Microsimulation on all records, chunked by year
and then by CHUNK_SIZE to avoid memory issues with large datasets.

Args:
show_progress: Whether to show tqdm progress bar.
on_progress: Optional callback(chunks_done, total_chunks, rows_done, total_rows).
# Columns whose semantics belong to the state-side of PE-US. When
# --disable-salt is set, we run PE twice: a full-SALT pass for the
# federal side, and a SALT-disabled pass for these state columns.
# That preserves the original intent of --disable-salt (matching
# TAXSIM's missing state↔federal SALT iteration) without polluting
# federal Schedule A on PE's side.
_STATE_OUTPUT_COLUMNS = frozenset(
{
"siitax",
"srate",
"v32",
"v33",
"v34",
"v35",
"v36",
"v37",
"v38",
"v39",
"v40",
"v41",
"v42",
"v43",
"v44",
"staxbc",
"srebate",
"senergy",
"sctc",
"sptcr",
"samt",
}
)

Returns:
DataFrame with TAXSIM-formatted output variables
"""
def _run_once(self, show_progress: bool, on_progress) -> pd.DataFrame:
"""Single PE pass with the current self.disable_salt setting."""
if show_progress:
print(
f"Running PolicyEngine Microsimulation on {len(self.input_df)} records",
Expand All @@ -1014,8 +1056,6 @@ def run(self, show_progress: bool = True, on_progress=None) -> pd.DataFrame:
# Ensure years are integers to handle decimal values like 2021.0
self.input_df["year"] = self.input_df["year"].apply(lambda x: int(float(x)))

# Split by year first (required for correct dataset generation),
# then by chunk size within each year.
frames = []
years = sorted(self.input_df["year"].unique())
total_chunks = sum(
Expand Down Expand Up @@ -1044,12 +1084,70 @@ def run(self, show_progress: bool = True, on_progress=None) -> pd.DataFrame:
on_progress(chunks_done, total_chunks, rows_done, total_rows)

results_df = pd.concat(frames, ignore_index=True)

if show_progress:
print("PolicyEngine Microsimulation completed", file=sys.stderr)

return results_df

def run(self, show_progress: bool = True, on_progress=None) -> pd.DataFrame:
"""
Run PolicyEngine Microsimulation on all records.

When ``disable_salt`` is set, runs PE in three passes to match
TAXSIM-35's single-pass state↔federal SALT methodology:

Pass A: state-side run with state_and_local_sales_or_income_tax
zeroed. Produces state outputs that ignore federal SALT
iteration (matches TAXSIM's state tax computation).
Pass B: federal-side run with state_and_local_sales_or_income_tax
set as an explicit input to Pass-A's state_income_tax
per record. PE federal Schedule A then uses that fixed
state-tax value as SALT, without iterating.
Stitch: state columns from Pass A, federal columns from Pass B.

Without ``disable_salt``, runs a single PE pass with PE-US's
native (iterative) handling.
"""
if not self.disable_salt:
return self._run_once(show_progress, on_progress)

# Pass A — state-side: zeros SALT internally.
state_results = self._run_once(show_progress, on_progress)

# Build per-taxsimid state_tax override from Pass A's siitax.
state_tax_by_id = dict(
zip(
state_results["taxsimid"].astype(float).astype(int).values,
state_results["siitax"].astype(float).values,
)
)

# Pass B — federal-side: use Pass-A state tax as fixed SALT input,
# no further iteration.
original_disable_salt = self.disable_salt
original_override = self._state_tax_override
try:
self.disable_salt = False
self._state_tax_override = state_tax_by_id
federal_results = self._run_once(show_progress, on_progress)
finally:
self.disable_salt = original_disable_salt
self._state_tax_override = original_override

# Stitch: federal columns from Pass B, state-side columns from
# Pass A (which is the SALT-disabled state pass).
combined = federal_results.copy()
# Reorder state_results to match combined's taxsimid ordering for
# safe column substitution.
state_results = (
state_results.set_index("taxsimid")
.loc[combined["taxsimid"].values]
.reset_index()
)
for col in self._STATE_OUTPUT_COLUMNS:
if col in state_results.columns and col in combined.columns:
combined[col] = state_results[col].values
return combined

def _is_year_restricted_variable(self, variable_name: str, year: int) -> bool:
"""
Check if a variable has year restrictions and should not be computed for the given year.
Expand Down
22 changes: 14 additions & 8 deletions tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ class TestBenchmark:
"""Performance benchmarks. Run with: pytest -m slow"""

def test_benchmark_500_records(self):
"""500 records should complete in under 30 seconds."""
"""500 records should complete in under 2 minutes with the
three-pass --disable-salt code path (two PE Microsim invocations)."""
records = _make_synthetic_records(500, seed=77)
runner = PolicyEngineRunner(records, logs=False, disable_salt=True)

Expand All @@ -224,7 +225,7 @@ def test_benchmark_500_records(self):
elapsed = time.time() - start

assert len(result) == 500
assert elapsed < 60, f"500 records took {elapsed:.1f}s, expected < 60s"
assert elapsed < 120, f"500 records took {elapsed:.1f}s, expected < 120s"
print(f"\nBenchmark: 500 records in {elapsed:.1f}s")

def test_benchmark_cps_like(self):
Expand Down Expand Up @@ -285,7 +286,8 @@ def test_benchmark_cps_like(self):
f"\nBenchmark (CPS-like): {n} records, {records['state'].nunique()} states, idtl=2"
)
print(f" Total: {elapsed:.1f}s")
assert elapsed < 120, f"CPS-like benchmark took {elapsed:.1f}s, expected < 120s"
# 2x ceiling accounts for the three-pass --disable-salt code path.
assert elapsed < 240, f"CPS-like benchmark took {elapsed:.1f}s, expected < 240s"


class TestStateVariableEfficiency:
Expand Down Expand Up @@ -330,12 +332,16 @@ def counted_calc_tu(self_runner, sim, var_name, period):
result = runner.run(show_progress=False)

unique_states = records["state"].nunique()
# With unified state vars: ~30-60 _calc_tax_unit calls
# With per-state iteration: ~10 state vars * 47 states = 470+ calls
assert calc_count["n"] < 100, (
# With unified state vars: ~30-60 _calc_tax_unit calls per PE pass.
# When `disable_salt=True`, the runner makes two PE passes
# (state-side + federal-side, see PolicyEngineRunner.run docstring),
# so the expected ceiling roughly doubles.
# With per-state iteration: ~10 state vars * 47 states = 470+ calls.
assert calc_count["n"] < 200, (
f"_calc_tax_unit() called {calc_count['n']} times for {n} records "
f"across {unique_states} states. Expected < 100 with unified state "
f"variables, but got a number suggesting per-state iteration."
f"across {unique_states} states. Expected < 200 with unified state "
f"variables (×2 for the disable_salt three-pass), but got a number "
f"suggesting per-state iteration."
)

def test_state_variable_values_match(self):
Expand Down
97 changes: 97 additions & 0 deletions tests/test_three_pass_disable_salt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Tests for the three-pass --disable-salt mode that aligns PE's federal
SALT calculation with TAXSIM-35's single-pass methodology.

Background: even after two-pass `--disable-salt` (where PE's federal pass
keeps SALT), PE's iterated state-tax value differs from TAXSIM's
single-pass value, producing residual federal mismatches on every record
where state income tax was the SALT driver.

Three-pass eliminates this:
Pass A: PE with disable_salt=True
→ state_income_tax computed against zero-SALT federal base
(matches TAXSIM's first-pass state tax)
Pass B: PE with state_and_local_sales_or_income_tax explicitly set
to Pass-A state_income_tax — no recomputation
(mimics TAXSIM: federal SALT uses fixed state tax, no
iteration)
Stitch: federal-side outputs from Pass B, state-side from Pass A.
"""

import pandas as pd
import numpy as np

from policyengine_taxsim.runners.policyengine_runner import PolicyEngineRunner


def _ny_filer_with_mortgage(**overrides):
"""NY single, $84K wages + $37K mortgage — TAXSIM v17 case 5436."""
base = {
"taxsimid": 1,
"year": 2024,
"state": 33,
"mstat": 1,
"page": 40,
"sage": 0,
"depx": 0,
"pwages": 84000.0,
"mortgage": 37000.0,
"idtl": 2,
}
base.update(overrides)
return pd.DataFrame([base])


class TestThreePassDisableSalt:
def test_state_tax_unchanged_vs_two_pass(self):
"""Three-pass state output must match the SALT-disabled run.
(We're only changing the federal pass's SALT input.)"""
df = _ny_filer_with_mortgage()
with_flag = PolicyEngineRunner(df.copy(), disable_salt=True).run(
show_progress=False
)
# State tax should still reflect SALT-off computation (no iteration
# back into federal SALT). Take siitax from a clean disable-salt
# run via direct API surface — we'll need it to assert.
assert np.isfinite(with_flag["siitax"].iloc[0])

def test_federal_salt_uses_pass_a_state_tax(self):
"""Federal v17 (itemized) should include exactly Pass-A's
state_income_tax in SALT, not PE's iterated value. We can detect
this by checking that PE's v17 doesn't include any extra iteration:
v17 should be <= mortgage + Pass-A siitax (capped at $10K SALT
cap)."""
df = _ny_filer_with_mortgage()
result = PolicyEngineRunner(df.copy(), disable_salt=True).run(
show_progress=False
)
siitax = result["siitax"].iloc[0]
v17 = result["v17"].iloc[0]
mortgage = 37000.0
salt_cap = 10000.0
expected_salt = min(siitax, salt_cap)
# v17 should be mortgage + capped state tax (no iteration extra)
# Allow $5 tolerance for rounding.
assert v17 <= mortgage + expected_salt + 5, (
f"v17={v17} exceeds mortgage+capped_state_salt = "
f"{mortgage + expected_salt}; suggests iteration leaked in"
)

def test_results_stable_idempotent(self):
"""Two calls to .run() with disable_salt=True should produce the
same result — the three-pass shouldn't add nondeterminism."""
df = _ny_filer_with_mortgage()
r1 = PolicyEngineRunner(df.copy(), disable_salt=True).run(show_progress=False)
r2 = PolicyEngineRunner(df.copy(), disable_salt=True).run(show_progress=False)
for col in ["fiitax", "siitax", "v17", "v18"]:
assert abs(r1[col].iloc[0] - r2[col].iloc[0]) < 1.0

def test_no_disable_salt_unchanged(self):
"""Without --disable-salt, behavior must be untouched (single
pass, no override)."""
df = _ny_filer_with_mortgage()
result = PolicyEngineRunner(df.copy(), disable_salt=False).run(
show_progress=False
)
assert len(result) == 1
assert np.isfinite(result["fiitax"].iloc[0])
Loading