Skip to content

Commit 7bb6a39

Browse files
MaxGhenisclaude
andcommitted
Add comprehensive synthesis method comparison
- New benchmarks: comprehensive_comparison.py (7 methods), parallel_benchmark.py (multiprocessing), tune_comparison.py (hyperparameter search) - Added multivariate metrics on CPS: MMD, energy distance, coverage, authenticity - Results: microplex wins on MMD (0.1145) and coverage (0.2627) - New modules: Calibrator, CPSSyntheticGenerator with tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent e012acb commit 7bb6a39

12 files changed

Lines changed: 3044 additions & 0 deletions
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
"""Comprehensive comparison of synthesis methods on CPS-like data."""
2+
3+
import sys
4+
import time
5+
import warnings
6+
from pathlib import Path
7+
8+
import numpy as np
9+
import pandas as pd
10+
11+
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
12+
warnings.filterwarnings("ignore")
13+
14+
from run_cps_benchmark import generate_cps_like_data
15+
from compare_qrf import SequentialQRFWithZeroInflation
16+
from multivariate_metrics import (
17+
compute_mmd, compute_energy_distance, normalize_data,
18+
compute_authenticity_distance, compute_coverage_distance
19+
)
20+
from microplex import Synthesizer
21+
22+
# Generate data
23+
print("Generating CPS-like data (20k train, 5k test)...")
24+
full_data = generate_cps_like_data(25000, seed=42)
25+
train_data = full_data.iloc[:20000].copy()
26+
test_data = full_data.iloc[20000:].copy()
27+
28+
target_vars = ["wage_income", "self_emp_income", "ssi_income", "uc_income",
29+
"snap_benefit", "eitc", "agi", "federal_tax"]
30+
condition_vars = ["age", "education", "is_employed", "marital_status"]
31+
all_vars = target_vars + condition_vars
32+
33+
test_conditions = test_data[condition_vars].copy()
34+
35+
def evaluate(synthetic, name):
36+
"""Compute all metrics for a synthetic dataset."""
37+
train_norm, test_norm, synth_norm, _ = normalize_data(
38+
train_data, test_data, synthetic, target_vars
39+
)
40+
mmd = compute_mmd(test_norm, synth_norm)
41+
energy = compute_energy_distance(test_norm, synth_norm)
42+
auth = compute_authenticity_distance(synth_norm, test_norm)
43+
cov = compute_coverage_distance(test_norm, synth_norm)
44+
return {
45+
"method": name,
46+
"mmd": mmd,
47+
"energy_dist": energy,
48+
"authenticity": auth["mean"],
49+
"coverage": cov["mean"],
50+
}
51+
52+
results = []
53+
54+
# 1. microplex (tuned)
55+
print("\n[1/7] microplex (tuned: L=8, H=128, epochs=100)...")
56+
try:
57+
start = time.time()
58+
model = Synthesizer(
59+
target_vars=target_vars,
60+
condition_vars=condition_vars,
61+
n_layers=8, hidden_dim=128, zero_inflated=True,
62+
)
63+
model.fit(train_data, epochs=100, batch_size=256, verbose=False)
64+
synthetic = model.generate(test_conditions)
65+
train_time = time.time() - start
66+
res = evaluate(synthetic, "microplex (tuned)")
67+
res["time"] = train_time
68+
results.append(res)
69+
print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}")
70+
except Exception as e:
71+
print(f" ✗ Failed: {e}")
72+
73+
# 2. QRF+ZI (tuned)
74+
print("\n[2/7] QRF+ZI (tuned: n=200, depth=15)...")
75+
try:
76+
start = time.time()
77+
model = SequentialQRFWithZeroInflation(
78+
target_vars, condition_vars,
79+
n_estimators=200, max_depth=15
80+
)
81+
model.fit(train_data, verbose=False)
82+
synthetic = model.generate(test_conditions)
83+
train_time = time.time() - start
84+
res = evaluate(synthetic, "QRF+ZI (tuned)")
85+
res["time"] = train_time
86+
results.append(res)
87+
print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}")
88+
except Exception as e:
89+
print(f" ✗ Failed: {e}")
90+
91+
# 3. TabPFN-based
92+
print("\n[3/7] TabPFN-based synthesis...")
93+
try:
94+
from tabpfn import TabPFNRegressor
95+
96+
start = time.time()
97+
synthetic_rows = []
98+
99+
# TabPFN works on small context, so we sample training data
100+
train_sample = train_data.sample(n=min(1000, len(train_data)), random_state=42)
101+
102+
for target in target_vars:
103+
X_train = train_sample[condition_vars].values
104+
y_train = train_sample[target].values
105+
X_test = test_conditions.values
106+
107+
model = TabPFNRegressor(device="cpu", n_estimators=4)
108+
model.fit(X_train, y_train)
109+
pred = model.predict(X_test)
110+
synthetic_rows.append(pd.Series(pred, name=target))
111+
112+
synthetic = pd.concat(synthetic_rows, axis=1)
113+
synthetic[condition_vars] = test_conditions.values
114+
train_time = time.time() - start
115+
116+
res = evaluate(synthetic, "TabPFN")
117+
res["time"] = train_time
118+
results.append(res)
119+
print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}")
120+
except Exception as e:
121+
print(f" ✗ Failed: {e}")
122+
123+
# 4. CT-GAN
124+
print("\n[4/7] CT-GAN...")
125+
try:
126+
from sdv.single_table import CTGANSynthesizer
127+
from sdv.metadata import SingleTableMetadata
128+
129+
start = time.time()
130+
131+
# Create metadata
132+
metadata = SingleTableMetadata()
133+
metadata.detect_from_dataframe(train_data[all_vars])
134+
135+
model = CTGANSynthesizer(metadata, epochs=50, verbose=False)
136+
model.fit(train_data[all_vars])
137+
138+
# Generate and filter to match test conditions
139+
synthetic = model.sample(len(test_conditions))
140+
train_time = time.time() - start
141+
142+
res = evaluate(synthetic, "CT-GAN")
143+
res["time"] = train_time
144+
results.append(res)
145+
print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}")
146+
except Exception as e:
147+
print(f" ✗ Failed: {e}")
148+
149+
# 5. TVAE
150+
print("\n[5/7] TVAE...")
151+
try:
152+
from sdv.single_table import TVAESynthesizer
153+
from sdv.metadata import SingleTableMetadata
154+
155+
start = time.time()
156+
157+
metadata = SingleTableMetadata()
158+
metadata.detect_from_dataframe(train_data[all_vars])
159+
160+
model = TVAESynthesizer(metadata, epochs=50, verbose=False)
161+
model.fit(train_data[all_vars])
162+
163+
synthetic = model.sample(len(test_conditions))
164+
train_time = time.time() - start
165+
166+
res = evaluate(synthetic, "TVAE")
167+
res["time"] = train_time
168+
results.append(res)
169+
print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}")
170+
except Exception as e:
171+
print(f" ✗ Failed: {e}")
172+
173+
# 6. Gaussian Copula
174+
print("\n[6/7] Gaussian Copula...")
175+
try:
176+
from sdv.single_table import GaussianCopulaSynthesizer
177+
from sdv.metadata import SingleTableMetadata
178+
179+
start = time.time()
180+
181+
metadata = SingleTableMetadata()
182+
metadata.detect_from_dataframe(train_data[all_vars])
183+
184+
model = GaussianCopulaSynthesizer(metadata)
185+
model.fit(train_data[all_vars])
186+
187+
synthetic = model.sample(len(test_conditions))
188+
train_time = time.time() - start
189+
190+
res = evaluate(synthetic, "Gaussian Copula")
191+
res["time"] = train_time
192+
results.append(res)
193+
print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}")
194+
except Exception as e:
195+
print(f" ✗ Failed: {e}")
196+
197+
# 7. XGBoost Sequential (similar to QRF but with XGBoost)
198+
print("\n[7/7] XGBoost Sequential + Zero-Inflation...")
199+
try:
200+
import xgboost as xgb
201+
from sklearn.linear_model import LogisticRegression
202+
203+
start = time.time()
204+
205+
synthetic_data = test_conditions.copy()
206+
available_features = list(condition_vars)
207+
208+
for target in target_vars:
209+
X_train = train_data[available_features].values
210+
y_train = train_data[target].values
211+
X_test = synthetic_data[available_features].values
212+
213+
# Zero classifier
214+
y_binary = (y_train > 0).astype(int)
215+
clf = LogisticRegression(max_iter=1000, random_state=42)
216+
clf.fit(X_train, y_binary)
217+
p_positive = clf.predict_proba(X_test)[:, 1]
218+
219+
# Regressor for positive values
220+
mask = y_train > 0
221+
if mask.sum() > 10:
222+
reg = xgb.XGBRegressor(
223+
n_estimators=100, max_depth=6,
224+
learning_rate=0.1, random_state=42, verbosity=0
225+
)
226+
reg.fit(X_train[mask], y_train[mask])
227+
pred_positive = reg.predict(X_test)
228+
else:
229+
pred_positive = np.full(len(X_test), y_train[mask].mean() if mask.sum() > 0 else 0)
230+
231+
# Sample zeros
232+
is_positive = np.random.random(len(X_test)) < p_positive
233+
predictions = np.where(is_positive, pred_positive, 0)
234+
predictions = np.maximum(predictions, 0) # Clip negatives
235+
236+
synthetic_data[target] = predictions
237+
available_features.append(target)
238+
239+
train_time = time.time() - start
240+
241+
res = evaluate(synthetic_data, "XGBoost+ZI")
242+
res["time"] = train_time
243+
results.append(res)
244+
print(f" ✓ MMD={res['mmd']:.4f}, Energy={res['energy_dist']:.4f}, Coverage={res['coverage']:.4f}")
245+
except Exception as e:
246+
print(f" ✗ Failed: {e}")
247+
248+
# Summary
249+
print("\n" + "=" * 90)
250+
print("COMPREHENSIVE COMPARISON RESULTS")
251+
print("=" * 90)
252+
df = pd.DataFrame(results)
253+
df = df.sort_values("mmd")
254+
255+
# Format nicely
256+
print(f"\n{'Method':<25} {'MMD':>10} {'Energy':>10} {'Auth':>10} {'Coverage':>10} {'Time(s)':>10}")
257+
print("-" * 75)
258+
for _, row in df.iterrows():
259+
print(f"{row['method']:<25} {row['mmd']:>10.4f} {row['energy_dist']:>10.4f} "
260+
f"{row['authenticity']:>10.4f} {row['coverage']:>10.4f} {row['time']:>10.1f}")
261+
262+
print("\n" + "=" * 90)
263+
print("WINNER BY METRIC (lower is better except for context)")
264+
print("=" * 90)
265+
print(f"Best MMD (joint dist): {df.loc[df['mmd'].idxmin(), 'method']}")
266+
print(f"Best Energy Distance: {df.loc[df['energy_dist'].idxmin(), 'method']}")
267+
print(f"Best Authenticity: {df.loc[df['authenticity'].idxmin(), 'method']}")
268+
print(f"Best Coverage: {df.loc[df['coverage'].idxmin(), 'method']}")
269+
print(f"Fastest: {df.loc[df['time'].idxmin(), 'method']}")
270+
271+
# Save results
272+
output_path = Path(__file__).parent / "results" / "comprehensive_comparison.csv"
273+
df.to_csv(output_path, index=False)
274+
print(f"\nResults saved to: {output_path}")

0 commit comments

Comments
 (0)