@@ -33,25 +33,54 @@ def generate_synthetic_population(n_records: int = 10000, seed: int = 42) -> pd.
3333 return data
3434
3535
36- def compute_targets (data : pd .DataFrame ) -> tuple :
37- """Compute calibration targets from data."""
36+ def compute_targets (data : pd .DataFrame , include_joint : bool = False , n_continuous : int = 1 ) -> tuple :
37+ """Compute calibration targets from data.
38+
39+ Args:
40+ data: Population dataframe
41+ include_joint: If True, include state × age cross-tabulated targets
42+ n_continuous: Number of continuous targets (1=income, 3=income+wages+benefits)
43+ """
3844 marginal_targets = {}
3945
4046 for var in ["state" , "age_group" , "income_bracket" ]:
4147 marginal_targets [var ] = {}
4248 for val in data [var ].unique ():
4349 marginal_targets [var ][val ] = float ((data [var ] == val ).sum ())
4450
51+ # Add joint state × age targets
52+ if include_joint :
53+ # Create derived column for joint distribution
54+ data ["state_age" ] = data ["state" ] + "_" + data ["age_group" ]
55+ marginal_targets ["state_age" ] = {}
56+ for val in data ["state_age" ].unique ():
57+ count = (data ["state_age" ] == val ).sum ()
58+ marginal_targets ["state_age" ][val ] = float (count )
59+
60+ # Continuous targets
4561 continuous_targets = {"income" : float (data ["income" ].sum ())}
4662
63+ if n_continuous >= 2 :
64+ # Simulate wages as portion of income
65+ data ["wages" ] = data ["income" ] * np .random .uniform (0.5 , 0.9 , len (data ))
66+ continuous_targets ["wages" ] = float (data ["wages" ].sum ())
67+
68+ if n_continuous >= 3 :
69+ # Simulate benefits
70+ data ["benefits" ] = np .random .lognormal (8 , 1.5 , len (data ))
71+ continuous_targets ["benefits" ] = float (data ["benefits" ].sum ())
72+
4773 return marginal_targets , continuous_targets
4874
4975
50- def run_comparison (n_records : int = 5000 ):
76+ def run_comparison (n_records : int = 5000 , include_joint : bool = False , n_continuous : int = 1 ):
5177 """Run both methods across sparsity range and collect results."""
5278 print (f"Generating { n_records } synthetic records..." )
5379 pop = generate_synthetic_population (n_records = n_records )
54- marginal_targets , continuous_targets = compute_targets (pop )
80+ marginal_targets , continuous_targets = compute_targets (pop , include_joint = include_joint , n_continuous = n_continuous )
81+
82+ n_cat_targets = sum (len (v ) for v in marginal_targets .values ())
83+ print (f"Targets: { n_cat_targets } categorical, { len (continuous_targets )} continuous" )
5584
5685 # Cross-category results
5786 cc_results = []
@@ -156,13 +185,43 @@ def plot_pareto(cc_df: pd.DataFrame, hc_df: pd.DataFrame, output_path: str = "sp
156185 return fig
157186
158187
188+ def run_scenarios ():
189+ """Run comparison across different target complexities."""
190+ scenarios = [
191+ {"name" : "simple" , "include_joint" : False , "n_continuous" : 1 },
192+ {"name" : "joint" , "include_joint" : True , "n_continuous" : 1 },
193+ {"name" : "multi_continuous" , "include_joint" : False , "n_continuous" : 3 },
194+ {"name" : "complex" , "include_joint" : True , "n_continuous" : 3 },
195+ ]
196+
197+ all_results = {}
198+ for scenario in scenarios :
199+ print (f"\n { '=' * 60 } " )
200+ print (f"Scenario: { scenario ['name' ]} " )
201+ print (f"{ '=' * 60 } " )
202+ cc_df , hc_df = run_comparison (
203+ n_records = 5000 ,
204+ include_joint = scenario ["include_joint" ],
205+ n_continuous = scenario ["n_continuous" ],
206+ )
207+ all_results [scenario ["name" ]] = {"cc" : cc_df , "hc" : hc_df }
208+ plot_pareto (cc_df , hc_df , f"sparse_calibration_pareto_{ scenario ['name' ]} .png" )
209+
210+ return all_results
211+
212+
159213if __name__ == "__main__" :
160- cc_df , hc_df = run_comparison (n_records = 5000 )
214+ import sys
215+
216+ if len (sys .argv ) > 1 and sys .argv [1 ] == "--all" :
217+ run_scenarios ()
218+ else :
219+ cc_df , hc_df = run_comparison (n_records = 5000 )
161220
162- print ("\n " + "=" * 60 )
163- print ("Cross-Category Results:" )
164- print (cc_df .to_string (index = False ))
165- print ("\n Hard Concrete Results:" )
166- print (hc_df .to_string (index = False ))
221+ print ("\n " + "=" * 60 )
222+ print ("Cross-Category Results:" )
223+ print (cc_df .to_string (index = False ))
224+ print ("\n Hard Concrete Results:" )
225+ print (hc_df .to_string (index = False ))
167226
168- plot_pareto (cc_df , hc_df )
227+ plot_pareto (cc_df , hc_df )
0 commit comments