Skip to content

Commit 31b6e6f

Browse files
MaxGhenisclaude
andcommitted
Add sample() method for full synthesis in microplex
microplex.sample(n) generates fully synthetic records: - Samples conditions from training distribution - Generates targets conditioned on sampled conditions Full synthesis results: - microplex: Best condition match (0.012 MMD) - samples from real - CT-GAN: Best joint distribution (0.096 MMD) - microplex competitive at 0.108 joint MMD, 3x faster than CT-GAN 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c05b2df commit 31b6e6f

4 files changed

Lines changed: 74 additions & 12 deletions

File tree

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
method,cond_mmd,target_mmd,joint_mmd,time
2-
CT-GAN,0.11858314294642136,0.13823647660564373,0.1390349789821181,56.425487756729126
3-
TVAE,0.18734142598531364,0.16595096545339236,0.19125623040798687,24.558475971221924
4-
Gaussian Copula,0.11279594598807205,0.347873117998337,0.23121743396787886,2.4372987747192383
2+
microplex,0.012205728981483213,0.0801667374165846,0.10837475669093859,20.113948822021484
3+
CT-GAN,0.10443139822475064,0.05203332641526303,0.09616782962293756,55.82759714126587
4+
TVAE,0.17223485929557655,0.1421805932438464,0.16974297462602855,25.058035135269165
5+
Gaussian Copula,0.11279594598807205,0.347873117998337,0.23121743396787886,2.3078439235687256
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
method,target_mmd,target_energy,time,type
2-
NND.hotdeck (FILTER),0.0,0.00010076622397425794,0.7841610908508301,filtering
3-
Binning (FILTER),0.0,2.3532192952480102e-05,2.4988253116607666,filtering
4-
QRF+ZI (PREDICT),0.07725531470412278,0.03508081808408736,13.898022890090942,prediction
5-
microplex (JOINT),0.18174467283765103,0.1365223085269447,19.701925039291382,joint
6-
CT-GAN (JOINT*),0.04333717023742819,0.010986144843117884,36.72522306442261,joint
2+
NND.hotdeck (FILTER),0.0,0.00010076622397425794,0.6831917762756348,filtering
3+
Binning (FILTER),0.0,2.3532192952480102e-05,2.6791610717773438,filtering
4+
QRF+ZI (PREDICT),0.07725531470412278,0.03508081808408736,12.860822916030884,prediction
5+
microplex (JOINT),0.06617186281731197,0.02525602165985763,19.281394958496094,joint
6+
CT-GAN (JOINT*),0.04802335274033213,0.012547827654338928,35.70978116989136,joint

benchmarks/synthesis_modes_comparison.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,25 @@ def evaluate_full(synthetic, test_data, name):
259259

260260
full_synthesis_results = []
261261

262-
# microplex (can only do imputation, not full synthesis)
263-
print("\n[1] microplex - N/A for full synthesis (requires conditions)")
262+
# microplex (now supports full synthesis via sample())
263+
print("\n[1] microplex (JOINT - full synthesis via sample())...")
264+
try:
265+
start = time.time()
266+
model = Synthesizer(
267+
target_vars=target_vars,
268+
condition_vars=condition_vars,
269+
n_layers=6, hidden_dim=64, zero_inflated=True,
270+
)
271+
model.fit(train_data, epochs=50, batch_size=256, verbose=False)
272+
synthetic = model.sample(len(test_data), seed=42)
273+
274+
res = evaluate_full(synthetic, test_data, "microplex")
275+
res["time"] = time.time() - start
276+
full_synthesis_results.append(res)
277+
print(f" ✓ Cond MMD={res['cond_mmd']:.4f}, Target MMD={res['target_mmd']:.4f}, Joint MMD={res['joint_mmd']:.4f}")
278+
except Exception as e:
279+
print(f" ✗ {e}")
280+
import traceback; traceback.print_exc()
264281

265282
# CT-GAN
266283
print("\n[2] CT-GAN (JOINT - true full synthesis)...")
@@ -373,10 +390,12 @@ def evaluate_full(synthetic, test_data, name):
373390
SYNTHESIS MODES:
374391
IMPUTATION: Use when you have real demographics, need synthetic targets
375392
→ Filtering methods (NND.hotdeck) excel here
393+
→ microplex.generate(conditions) for model-based
376394
377395
FULL SYNTHESIS: Use when you need entirely synthetic microdata
378-
→ Joint methods (CT-GAN, TVAE) required
379-
→ microplex currently imputation-only
396+
→ microplex.sample(n) - samples conditions from training, generates targets
397+
→ CT-GAN/TVAE - generate both from scratch
398+
→ microplex has best condition match (samples real), CT-GAN best joint
380399
""")
381400

382401
# Save results

src/microplex/synthesizer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
self._train_target_std: Optional[torch.Tensor] = None # Store target std for variance reg
110110
self._train_target_max: Optional[torch.Tensor] = None # Store max for clipping calibration
111111
self._original_scale_stats: Optional[Dict[str, Dict[str, float]]] = None # Original scale stats for clipping
112+
self._training_data: Optional[pd.DataFrame] = None # Store for full synthesis
112113

113114
def fit(
114115
self,
@@ -137,6 +138,9 @@ def fit(
137138
Returns:
138139
self
139140
"""
141+
# Store training data for full synthesis mode
142+
self._training_data = data[self.condition_vars + self.target_vars].copy()
143+
140144
# Prepare data dict for transforms
141145
data_dict = {col: data[col].values for col in data.columns}
142146

@@ -496,6 +500,44 @@ def generate(
496500

497501
return result
498502

503+
def sample(
504+
self,
505+
n: int,
506+
seed: Optional[int] = None,
507+
) -> pd.DataFrame:
508+
"""
509+
Generate fully synthetic records (both conditions and targets).
510+
511+
For full synthesis mode - samples conditions from training distribution,
512+
then generates targets conditioned on those.
513+
514+
Args:
515+
n: Number of synthetic records to generate
516+
seed: Random seed for reproducibility
517+
518+
Returns:
519+
DataFrame with all variables (conditions + targets)
520+
"""
521+
if not self.is_fitted_:
522+
raise ValueError("Synthesizer not fitted. Call fit() first.")
523+
524+
if self._training_data is None:
525+
raise ValueError(
526+
"Full synthesis requires training data. "
527+
"Re-fit with store_training_data=True or use generate() with conditions."
528+
)
529+
530+
if seed is not None:
531+
np.random.seed(seed)
532+
533+
# Sample conditions from training distribution (with replacement)
534+
train_conditions = self._training_data[self.condition_vars]
535+
sampled_idx = np.random.choice(len(train_conditions), size=n, replace=True)
536+
conditions = train_conditions.iloc[sampled_idx].reset_index(drop=True)
537+
538+
# Generate targets conditioned on sampled conditions
539+
return self.generate(conditions, seed=seed)
540+
499541
def save(self, path: Union[str, Path]) -> None:
500542
"""Save fitted model to disk."""
501543
if not self.is_fitted_:

0 commit comments

Comments
 (0)