@@ -104,57 +104,68 @@ def train_model(df, epochs=50, batch_size=2048, lr=1e-3, device='mps'):
104104 print ("Preparing training data..." )
105105 print ("=" * 60 )
106106
107- # Identify variable types
107+ # Predictor variables - demographics + lags for transition modeling
108+ # Core predictors (all surveys)
109+ predictors = ['age' , 'is_male' ]
110+
111+ # Add lag predictors (SIPP only, but that's where transitions matter)
112+ lag_predictors = ['job1_income_lag1' , 'total_income_lag1' ]
113+ for lag in lag_predictors :
114+ if lag in df .columns and df [lag ].notna ().sum () > 10000 :
115+ predictors .append (lag )
116+
117+ # Filter to rows with valid predictors FIRST
118+ mask = df [predictors ].notna ().all (axis = 1 )
119+ train_subset = df [mask ]
120+ print (f" Training subset: { len (train_subset ):,} rows (have all predictors)" )
121+
122+ # Identify variable types - only use vars observed in training subset
108123 # Continuous (ZI-QDNN) - income variables that can be zero or positive
109124 zi_vars = []
110125 for col in ['wage_income' , 'self_employment_income' , 'interest_income' ,
111126 'dividend_income' , 'rental_income' , 'farm_income' ,
112127 'total_income' , 'job1_income' , 'job2_income' , 'job3_income' ,
113128 'tip_income' , 'social_security' , 'total_family_income' ]:
114- if col in df .columns and df [col ].notna ().sum () > 1000 :
129+ if col in train_subset .columns and train_subset [col ].notna ().sum () > 1000 :
115130 zi_vars .append (col )
116131
117- # Categorical - discrete with >2 classes
132+ # Categorical - discrete with >2 classes (use train_subset)
118133 cat_vars = {}
119134 for col in ['education' , 'race' , 'marital_status' , 'relationship' ,
120135 'state_fips' , 'job1_occ' , 'job1_ind' , 'job2_occ' , 'job2_ind' ]:
121- if col in df .columns and df [col ].notna ().sum () > 1000 :
122- n_classes = int (df [col ].max ()) + 1
136+ if col in train_subset .columns and train_subset [col ].notna ().sum () > 1000 :
137+ n_classes = int (train_subset [col ].max ()) + 1
123138 if n_classes > 2 and n_classes < 100 : # Reasonable number of classes
124139 cat_vars [col ] = n_classes
125140
126- # Binary
141+ # Binary (use train_subset) - exclude is_male since it's a predictor
127142 binary_vars = []
128- for col in ['is_male' , ' hispanic' , 'job_loss' , 'job_gain' ]:
129- if col in df .columns and df [col ].notna ().sum () > 1000 :
143+ for col in ['hispanic' , 'job_loss' , 'job_gain' ]:
144+ if col in train_subset .columns and train_subset [col ].notna ().sum () > 1000 :
130145 binary_vars .append (col )
131146
132147 print (f"\n Variable types:" )
133148 print (f" ZI-QDNN (continuous): { zi_vars } " )
134149 print (f" Categorical: { list (cat_vars .keys ())} " )
135150 print (f" Binary: { binary_vars } " )
151+ print (f" Predictors: { predictors } " )
136152
137- # Condition variables - use only those available across ALL surveys
138- # age and is_male are in all surveys
139- cond_vars = ['age' , 'is_male' ]
140-
141- print (f" Conditioning: { cond_vars } " )
142-
143- # Build tensors
144- all_vars = cond_vars + zi_vars + list (cat_vars .keys ()) + binary_vars
145-
146- # Use only rows with valid conditioning vars
147- mask = df [cond_vars ].notna ().all (axis = 1 )
148- train_df = df [mask ].copy ()
153+ # Use the pre-filtered train_subset
154+ train_df = train_subset .copy ()
155+ all_vars = predictors + zi_vars + list (cat_vars .keys ()) + binary_vars
149156 print (f"\n Training rows: { len (train_df ):,} (of { len (df ):,} )" )
150157
151- # Normalize conditioning variables
152- cond_data = train_df [cond_vars ].values .astype (np .float32 )
153- cond_means = np .nanmean (cond_data , axis = 0 )
154- cond_stds = np .nanstd (cond_data , axis = 0 ) + 1e-6
155- cond_data = (cond_data - cond_means ) / cond_stds
158+ # Normalize predictor variables (including lags)
159+ pred_data = train_df [predictors ].values .astype (np .float32 )
160+ # Log-transform income lags before normalizing
161+ for i , p in enumerate (predictors ):
162+ if 'income' in p :
163+ pred_data [:, i ] = np .log1p (np .maximum (pred_data [:, i ], 0 ))
164+ pred_means = np .nanmean (pred_data , axis = 0 )
165+ pred_stds = np .nanstd (pred_data , axis = 0 ) + 1e-6
166+ pred_data = (pred_data - pred_means ) / pred_stds
156167
157- X = torch .tensor (cond_data , dtype = torch .float32 ).to (device )
168+ X = torch .tensor (pred_data , dtype = torch .float32 ).to (device )
158169
159170 # Build target tensors
160171 targets = {}
@@ -185,7 +196,7 @@ def train_model(df, epochs=50, batch_size=2048, lr=1e-3, device='mps'):
185196
186197 # Model
187198 model = MixedSynth (
188- n_cond = len (cond_vars ),
199+ n_cond = len (predictors ),
189200 zi_vars = zi_vars ,
190201 cat_vars = cat_vars ,
191202 binary_vars = binary_vars ,
@@ -301,9 +312,9 @@ def train_model(df, epochs=50, batch_size=2048, lr=1e-3, device='mps'):
301312 'zi_vars' : zi_vars ,
302313 'cat_vars' : cat_vars ,
303314 'binary_vars' : binary_vars ,
304- 'cond_vars ' : cond_vars ,
305- 'cond_means ' : cond_means ,
306- 'cond_stds ' : cond_stds ,
315+ 'predictors ' : predictors ,
316+ 'pred_means ' : pred_means ,
317+ 'pred_stds ' : pred_stds ,
307318 'quantiles' : quantiles ,
308319 }
309320
@@ -316,30 +327,34 @@ def train_model(df, epochs=50, batch_size=2048, lr=1e-3, device='mps'):
316327
317328
318329def generate_synthetic (model , model_info , n_samples , train_df , device = 'mps' ):
319- """Generate synthetic samples by sampling conditioning from training data."""
330+ """Generate synthetic samples by sampling predictors from training data."""
320331 print (f"\n Generating { n_samples :,} synthetic samples..." )
321332 gen_start = time .time ()
322333
323334 model .eval ()
324335
325- # Sample conditioning values from training data (not random!)
326- cond_vars = model_info ['cond_vars ' ]
327- train_cond = train_df [cond_vars ].dropna ()
328- sample_idx = np .random .choice (len (train_cond ), n_samples , replace = True )
329- cond_data_raw = train_cond .iloc [sample_idx ].values .astype (np .float32 )
336+ # Sample predictor values from training data
337+ predictors = model_info ['predictors ' ]
338+ train_pred = train_df [predictors ].dropna ()
339+ sample_idx = np .random .choice (len (train_pred ), n_samples , replace = True )
340+ pred_data_raw = train_pred .iloc [sample_idx ].values .astype (np .float32 )
330341
331- # Normalize
332- cond_data = (cond_data_raw - model_info ['cond_means' ]) / model_info ['cond_stds' ]
333- X = torch .tensor (cond_data , dtype = torch .float32 ).to (device )
342+ # Log-transform income predictors, then normalize
343+ pred_data = pred_data_raw .copy ()
344+ for i , p in enumerate (predictors ):
345+ if 'income' in p :
346+ pred_data [:, i ] = np .log1p (np .maximum (pred_data [:, i ], 0 ))
347+ pred_data = (pred_data - model_info ['pred_means' ]) / model_info ['pred_stds' ]
348+ X = torch .tensor (pred_data , dtype = torch .float32 ).to (device )
334349
335350 with torch .no_grad ():
336351 out = model (X )
337352
338353 result = {}
339354
340- # Keep original conditioning values
341- for i , var in enumerate (cond_vars ):
342- result [var ] = cond_data_raw [:, i ]
355+ # Keep original predictor values (for reference)
356+ for i , var in enumerate (predictors ):
357+ result [var ] = pred_data_raw [:, i ]
343358
344359 # Sample ZI vars
345360 quantiles = model_info ['quantiles' ]
0 commit comments