Skip to content

Commit 39b5af3

Browse files
MaxGhenisclaude
andcommitted
Extend Pareto comparison to joint and continuous targets
Added scenarios: - Joint state×age cross-tabulated targets (69 categorical) - Multiple continuous targets (income, wages, benefits) Results: Cross-Category + IPF still achieves ~0% error even with 72 total targets and extreme sparsity (0.7 records per target). Hard Concrete has ~1% error but offers: - End-to-end differentiability - Custom loss functions - Multi-objective optimization 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c7c1ac3 commit 39b5af3

2 files changed

Lines changed: 70 additions & 11 deletions

File tree

scripts/sparse_calibration_pareto.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,54 @@ def generate_synthetic_population(n_records: int = 10000, seed: int = 42) -> pd.
3333
return data
3434

3535

36-
def compute_targets(data: pd.DataFrame) -> tuple:
37-
"""Compute calibration targets from data."""
36+
def compute_targets(data: pd.DataFrame, include_joint: bool = False, n_continuous: int = 1) -> tuple:
37+
"""Compute calibration targets from data.
38+
39+
Args:
40+
data: Population dataframe
41+
include_joint: If True, include state × age cross-tabulated targets
42+
n_continuous: Number of continuous targets (1=income, 3=income+wages+benefits)
43+
"""
3844
marginal_targets = {}
3945

4046
for var in ["state", "age_group", "income_bracket"]:
4147
marginal_targets[var] = {}
4248
for val in data[var].unique():
4349
marginal_targets[var][val] = float((data[var] == val).sum())
4450

51+
# Add joint state × age targets
52+
if include_joint:
53+
# Create derived column for joint distribution
54+
data["state_age"] = data["state"] + "_" + data["age_group"]
55+
marginal_targets["state_age"] = {}
56+
for val in data["state_age"].unique():
57+
count = (data["state_age"] == val).sum()
58+
marginal_targets["state_age"][val] = float(count)
59+
60+
# Continuous targets
4561
continuous_targets = {"income": float(data["income"].sum())}
4662

63+
if n_continuous >= 2:
64+
# Simulate wages as portion of income
65+
data["wages"] = data["income"] * np.random.uniform(0.5, 0.9, len(data))
66+
continuous_targets["wages"] = float(data["wages"].sum())
67+
68+
if n_continuous >= 3:
69+
# Simulate benefits
70+
data["benefits"] = np.random.lognormal(8, 1.5, len(data))
71+
continuous_targets["benefits"] = float(data["benefits"].sum())
72+
4773
return marginal_targets, continuous_targets
4874

4975

50-
def run_comparison(n_records: int = 5000):
76+
def run_comparison(n_records: int = 5000, include_joint: bool = False, n_continuous: int = 1):
5177
"""Run both methods across sparsity range and collect results."""
5278
print(f"Generating {n_records} synthetic records...")
5379
pop = generate_synthetic_population(n_records=n_records)
54-
marginal_targets, continuous_targets = compute_targets(pop)
80+
marginal_targets, continuous_targets = compute_targets(pop, include_joint=include_joint, n_continuous=n_continuous)
81+
82+
n_cat_targets = sum(len(v) for v in marginal_targets.values())
83+
print(f"Targets: {n_cat_targets} categorical, {len(continuous_targets)} continuous")
5584

5685
# Cross-category results
5786
cc_results = []
@@ -156,13 +185,43 @@ def plot_pareto(cc_df: pd.DataFrame, hc_df: pd.DataFrame, output_path: str = "sp
156185
return fig
157186

158187

188+
def run_scenarios():
189+
"""Run comparison across different target complexities."""
190+
scenarios = [
191+
{"name": "simple", "include_joint": False, "n_continuous": 1},
192+
{"name": "joint", "include_joint": True, "n_continuous": 1},
193+
{"name": "multi_continuous", "include_joint": False, "n_continuous": 3},
194+
{"name": "complex", "include_joint": True, "n_continuous": 3},
195+
]
196+
197+
all_results = {}
198+
for scenario in scenarios:
199+
print(f"\n{'='*60}")
200+
print(f"Scenario: {scenario['name']}")
201+
print(f"{'='*60}")
202+
cc_df, hc_df = run_comparison(
203+
n_records=5000,
204+
include_joint=scenario["include_joint"],
205+
n_continuous=scenario["n_continuous"],
206+
)
207+
all_results[scenario["name"]] = {"cc": cc_df, "hc": hc_df}
208+
plot_pareto(cc_df, hc_df, f"sparse_calibration_pareto_{scenario['name']}.png")
209+
210+
return all_results
211+
212+
159213
if __name__ == "__main__":
160-
cc_df, hc_df = run_comparison(n_records=5000)
214+
import sys
215+
216+
if len(sys.argv) > 1 and sys.argv[1] == "--all":
217+
run_scenarios()
218+
else:
219+
cc_df, hc_df = run_comparison(n_records=5000)
161220

162-
print("\n" + "=" * 60)
163-
print("Cross-Category Results:")
164-
print(cc_df.to_string(index=False))
165-
print("\nHard Concrete Results:")
166-
print(hc_df.to_string(index=False))
221+
print("\n" + "=" * 60)
222+
print("Cross-Category Results:")
223+
print(cc_df.to_string(index=False))
224+
print("\nHard Concrete Results:")
225+
print(hc_df.to_string(index=False))
167226

168-
plot_pareto(cc_df, hc_df)
227+
plot_pareto(cc_df, hc_df)
63.1 KB
Loading

0 commit comments

Comments
 (0)