|
| 1 | +"""Comprehensive comparison of synthesis methods on CPS-like data.""" |
| 2 | + |
| 3 | +import sys |
| 4 | +import time |
| 5 | +import warnings |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | + |
| 11 | +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) |
| 12 | +warnings.filterwarnings("ignore") |
| 13 | + |
| 14 | +from run_cps_benchmark import generate_cps_like_data |
| 15 | +from compare_qrf import SequentialQRFWithZeroInflation |
| 16 | +from multivariate_metrics import ( |
| 17 | + compute_mmd, compute_energy_distance, normalize_data, |
| 18 | + compute_authenticity_distance, compute_coverage_distance |
| 19 | +) |
| 20 | +from microplex import Synthesizer |
| 21 | + |
| 22 | +# Generate data |
| 23 | +print("Generating CPS-like data (20k train, 5k test)...") |
| 24 | +full_data = generate_cps_like_data(25000, seed=42) |
| 25 | +train_data = full_data.iloc[:20000].copy() |
| 26 | +test_data = full_data.iloc[20000:].copy() |
| 27 | + |
| 28 | +target_vars = ["wage_income", "self_emp_income", "ssi_income", "uc_income", |
| 29 | + "snap_benefit", "eitc", "agi", "federal_tax"] |
| 30 | +condition_vars = ["age", "education", "is_employed", "marital_status"] |
| 31 | +all_vars = target_vars + condition_vars |
| 32 | + |
| 33 | +test_conditions = test_data[condition_vars].copy() |
| 34 | + |
| 35 | +def evaluate(synthetic, name): |
| 36 | + """Compute all metrics for a synthetic dataset.""" |
| 37 | + train_norm, test_norm, synth_norm, _ = normalize_data( |
| 38 | + train_data, test_data, synthetic, target_vars |
| 39 | + ) |
| 40 | + mmd = compute_mmd(test_norm, synth_norm) |
| 41 | + energy = compute_energy_distance(test_norm, synth_norm) |
| 42 | + auth = compute_authenticity_distance(synth_norm, test_norm) |
| 43 | + cov = compute_coverage_distance(test_norm, synth_norm) |
| 44 | + return { |
| 45 | + "method": name, |
| 46 | + "mmd": mmd, |
| 47 | + "energy_dist": energy, |
| 48 | + "authenticity": auth["mean"], |
| 49 | + "coverage": cov["mean"], |
| 50 | + } |
| 51 | + |
| 52 | +results = [] |
| 53 | + |
| 54 | +# 1. microplex (tuned) |
| 55 | +print("\n[1/7] microplex (tuned: L=8, H=128, epochs=100)...") |
| 56 | +try: |
| 57 | + start = time.time() |
| 58 | + model = Synthesizer( |
| 59 | + target_vars=target_vars, |
| 60 | + condition_vars=condition_vars, |
| 61 | + n_layers=8, hidden_dim=128, zero_inflated=True, |
| 62 | + ) |
| 63 | + model.fit(train_data, epochs=100, batch_size=256, verbose=False) |
| 64 | + synthetic = model.generate(test_conditions) |
| 65 | + train_time = time.time() - start |
| 66 | + res = evaluate(synthetic, "microplex (tuned)") |
| 67 | + res["time"] = train_time |
| 68 | + results.append(res) |
| 69 | + print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}") |
| 70 | +except Exception as e: |
| 71 | + print(f" ✗ Failed: {e}") |
| 72 | + |
| 73 | +# 2. QRF+ZI (tuned) |
| 74 | +print("\n[2/7] QRF+ZI (tuned: n=200, depth=15)...") |
| 75 | +try: |
| 76 | + start = time.time() |
| 77 | + model = SequentialQRFWithZeroInflation( |
| 78 | + target_vars, condition_vars, |
| 79 | + n_estimators=200, max_depth=15 |
| 80 | + ) |
| 81 | + model.fit(train_data, verbose=False) |
| 82 | + synthetic = model.generate(test_conditions) |
| 83 | + train_time = time.time() - start |
| 84 | + res = evaluate(synthetic, "QRF+ZI (tuned)") |
| 85 | + res["time"] = train_time |
| 86 | + results.append(res) |
| 87 | + print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}") |
| 88 | +except Exception as e: |
| 89 | + print(f" ✗ Failed: {e}") |
| 90 | + |
| 91 | +# 3. TabPFN-based |
| 92 | +print("\n[3/7] TabPFN-based synthesis...") |
| 93 | +try: |
| 94 | + from tabpfn import TabPFNRegressor |
| 95 | + |
| 96 | + start = time.time() |
| 97 | + synthetic_rows = [] |
| 98 | + |
| 99 | + # TabPFN works on small context, so we sample training data |
| 100 | + train_sample = train_data.sample(n=min(1000, len(train_data)), random_state=42) |
| 101 | + |
| 102 | + for target in target_vars: |
| 103 | + X_train = train_sample[condition_vars].values |
| 104 | + y_train = train_sample[target].values |
| 105 | + X_test = test_conditions.values |
| 106 | + |
| 107 | + model = TabPFNRegressor(device="cpu", n_estimators=4) |
| 108 | + model.fit(X_train, y_train) |
| 109 | + pred = model.predict(X_test) |
| 110 | + synthetic_rows.append(pd.Series(pred, name=target)) |
| 111 | + |
| 112 | + synthetic = pd.concat(synthetic_rows, axis=1) |
| 113 | + synthetic[condition_vars] = test_conditions.values |
| 114 | + train_time = time.time() - start |
| 115 | + |
| 116 | + res = evaluate(synthetic, "TabPFN") |
| 117 | + res["time"] = train_time |
| 118 | + results.append(res) |
| 119 | + print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}") |
| 120 | +except Exception as e: |
| 121 | + print(f" ✗ Failed: {e}") |
| 122 | + |
| 123 | +# 4. CT-GAN |
| 124 | +print("\n[4/7] CT-GAN...") |
| 125 | +try: |
| 126 | + from sdv.single_table import CTGANSynthesizer |
| 127 | + from sdv.metadata import SingleTableMetadata |
| 128 | + |
| 129 | + start = time.time() |
| 130 | + |
| 131 | + # Create metadata |
| 132 | + metadata = SingleTableMetadata() |
| 133 | + metadata.detect_from_dataframe(train_data[all_vars]) |
| 134 | + |
| 135 | + model = CTGANSynthesizer(metadata, epochs=50, verbose=False) |
| 136 | + model.fit(train_data[all_vars]) |
| 137 | + |
| 138 | + # Generate and filter to match test conditions |
| 139 | + synthetic = model.sample(len(test_conditions)) |
| 140 | + train_time = time.time() - start |
| 141 | + |
| 142 | + res = evaluate(synthetic, "CT-GAN") |
| 143 | + res["time"] = train_time |
| 144 | + results.append(res) |
| 145 | + print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}") |
| 146 | +except Exception as e: |
| 147 | + print(f" ✗ Failed: {e}") |
| 148 | + |
| 149 | +# 5. TVAE |
| 150 | +print("\n[5/7] TVAE...") |
| 151 | +try: |
| 152 | + from sdv.single_table import TVAESynthesizer |
| 153 | + from sdv.metadata import SingleTableMetadata |
| 154 | + |
| 155 | + start = time.time() |
| 156 | + |
| 157 | + metadata = SingleTableMetadata() |
| 158 | + metadata.detect_from_dataframe(train_data[all_vars]) |
| 159 | + |
| 160 | + model = TVAESynthesizer(metadata, epochs=50, verbose=False) |
| 161 | + model.fit(train_data[all_vars]) |
| 162 | + |
| 163 | + synthetic = model.sample(len(test_conditions)) |
| 164 | + train_time = time.time() - start |
| 165 | + |
| 166 | + res = evaluate(synthetic, "TVAE") |
| 167 | + res["time"] = train_time |
| 168 | + results.append(res) |
| 169 | + print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}") |
| 170 | +except Exception as e: |
| 171 | + print(f" ✗ Failed: {e}") |
| 172 | + |
| 173 | +# 6. Gaussian Copula |
| 174 | +print("\n[6/7] Gaussian Copula...") |
| 175 | +try: |
| 176 | + from sdv.single_table import GaussianCopulaSynthesizer |
| 177 | + from sdv.metadata import SingleTableMetadata |
| 178 | + |
| 179 | + start = time.time() |
| 180 | + |
| 181 | + metadata = SingleTableMetadata() |
| 182 | + metadata.detect_from_dataframe(train_data[all_vars]) |
| 183 | + |
| 184 | + model = GaussianCopulaSynthesizer(metadata) |
| 185 | + model.fit(train_data[all_vars]) |
| 186 | + |
| 187 | + synthetic = model.sample(len(test_conditions)) |
| 188 | + train_time = time.time() - start |
| 189 | + |
| 190 | + res = evaluate(synthetic, "Gaussian Copula") |
| 191 | + res["time"] = train_time |
| 192 | + results.append(res) |
| 193 | + print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}") |
| 194 | +except Exception as e: |
| 195 | + print(f" ✗ Failed: {e}") |
| 196 | + |
| 197 | +# 7. XGBoost Sequential (similar to QRF but with XGBoost) |
| 198 | +print("\n[7/7] XGBoost Sequential + Zero-Inflation...") |
| 199 | +try: |
| 200 | + import xgboost as xgb |
| 201 | + from sklearn.linear_model import LogisticRegression |
| 202 | + |
| 203 | + start = time.time() |
| 204 | + |
| 205 | + synthetic_data = test_conditions.copy() |
| 206 | + available_features = list(condition_vars) |
| 207 | + |
| 208 | + for target in target_vars: |
| 209 | + X_train = train_data[available_features].values |
| 210 | + y_train = train_data[target].values |
| 211 | + X_test = synthetic_data[available_features].values |
| 212 | + |
| 213 | + # Zero classifier |
| 214 | + y_binary = (y_train > 0).astype(int) |
| 215 | + clf = LogisticRegression(max_iter=1000, random_state=42) |
| 216 | + clf.fit(X_train, y_binary) |
| 217 | + p_positive = clf.predict_proba(X_test)[:, 1] |
| 218 | + |
| 219 | + # Regressor for positive values |
| 220 | + mask = y_train > 0 |
| 221 | + if mask.sum() > 10: |
| 222 | + reg = xgb.XGBRegressor( |
| 223 | + n_estimators=100, max_depth=6, |
| 224 | + learning_rate=0.1, random_state=42, verbosity=0 |
| 225 | + ) |
| 226 | + reg.fit(X_train[mask], y_train[mask]) |
| 227 | + pred_positive = reg.predict(X_test) |
| 228 | + else: |
| 229 | + pred_positive = np.full(len(X_test), y_train[mask].mean() if mask.sum() > 0 else 0) |
| 230 | + |
| 231 | + # Sample zeros |
| 232 | + is_positive = np.random.random(len(X_test)) < p_positive |
| 233 | + predictions = np.where(is_positive, pred_positive, 0) |
| 234 | + predictions = np.maximum(predictions, 0) # Clip negatives |
| 235 | + |
| 236 | + synthetic_data[target] = predictions |
| 237 | + available_features.append(target) |
| 238 | + |
| 239 | + train_time = time.time() - start |
| 240 | + |
| 241 | + res = evaluate(synthetic_data, "XGBoost+ZI") |
| 242 | + res["time"] = train_time |
| 243 | + results.append(res) |
| 244 | + print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}") |
| 245 | +except Exception as e: |
| 246 | + print(f" ✗ Failed: {e}") |
| 247 | + |
| 248 | +# Summary |
| 249 | +print("\n" + "=" * 90) |
| 250 | +print("COMPREHENSIVE COMPARISON RESULTS") |
| 251 | +print("=" * 90) |
| 252 | +df = pd.DataFrame(results) |
| 253 | +df = df.sort_values("mmd") |
| 254 | + |
| 255 | +# Format nicely |
| 256 | +print(f"\n{'Method':<25} {'MMD':>10} {'Energy':>10} {'Auth':>10} {'Coverage':>10} {'Time(s)':>10}") |
| 257 | +print("-" * 75) |
| 258 | +for _, row in df.iterrows(): |
| 259 | + print(f"{row['method']:<25} {row['mmd']:>10.4f} {row['energy_dist']:>10.4f} " |
| 260 | + f"{row['authenticity']:>10.4f} {row['coverage']:>10.4f} {row['time']:>10.1f}") |
| 261 | + |
| 262 | +print("\n" + "=" * 90) |
| 263 | +print("WINNER BY METRIC (lower is better except for context)") |
| 264 | +print("=" * 90) |
| 265 | +print(f"Best MMD (joint dist): {df.loc[df['mmd'].idxmin(), 'method']}") |
| 266 | +print(f"Best Energy Distance: {df.loc[df['energy_dist'].idxmin(), 'method']}") |
| 267 | +print(f"Best Authenticity: {df.loc[df['authenticity'].idxmin(), 'method']}") |
| 268 | +print(f"Best Coverage: {df.loc[df['coverage'].idxmin(), 'method']}") |
| 269 | +print(f"Fastest: {df.loc[df['time'].idxmin(), 'method']}") |
| 270 | + |
| 271 | +# Save results |
| 272 | +output_path = Path(__file__).parent / "results" / "comprehensive_comparison.csv" |
| 273 | +df.to_csv(output_path, index=False) |
| 274 | +print(f"\nResults saved to: {output_path}") |
0 commit comments