This repository contains the metric implementations and released CSV results used in Diffusing in the Right Space: A Systematic Study of Latent Diffusability.
import torch
from metrics.norm import per_channel_norm
from metrics.viv import velocity_irreducible_variance
latents = torch.randn(50000, 32, 16, 16) # (B, C, H, W)
labels = torch.randint(0, 1000, (50000,)) # class labels
latents = per_channel_norm(latents)
viv = velocity_irreducible_variance(
latents,
labels,
output_folder="outputs/viv_params",
)If output_folder is provided, the per-class spectra
are saved to latents_params.pt.
import torch
from metrics.norm import per_channel_norm
from metrics.lnc import latent_neighbor_consistency
latents = torch.randn(5000, 32, 16, 16) # (N, C, h, w)
masks = torch.randint(0, 2, (5000, 256, 256)) # foreground masks
labels = torch.randint(0, 1000, (5000,)) # class labels
latents = per_channel_norm(latents)
lnc = latent_neighbor_consistency(
latents,
masks,
labels,
num_images_per_class=50,
sim_metric="cosine",
balance_classes=False,
)import torch
from metrics.norm import per_channel_norm
from metrics.sec import spectral_energy_concentration
latents = torch.randn(16, 32, 64, 64) # (B, C, H, W)
latents = per_channel_norm(latents)
sec = spectral_energy_concentration(
latents,
thresholds=[0.25, 0.5],
dist_type="Manhattan",
)For CDS, LDS, and SRSS, we follow the official iREPA implementation and apply per_channel_norm to the latent features before computing the metrics.
For iFID, we follow the official implementation. The values reported in the paper are computed on the 50K ImageNet validation images.
For Normalized Entropy, Density CV, and Gini Coefficient, we follow the latent visualization implementation from VAVAE and apply per_channel_norm before computing the statistics.
All raw data we used in the paper are organized in data folder.
- Latent property:
data/metrics.csv - Generation quality:
data/generate.csv - ODE straightness:
data/straightness.csv
plot_correlation.py merges CSV files, plots one metric against another,
and prints/saves the correlation scatter plot.
Example: plot generation quality against VIV for the SiT-B, CFG = 1.0, and
n_samples = 50000 rows.
python plot_correlation.py \
--m1_csv data/generate.csv \
--m1_col FID --m1_name gFID --inv_m1 \
--m2_csv data/metrics.csv \
--m2_col VIV \
--id_col VAE \
--cls_col Cluster --cls_color_csv misc/cls_color.csv \
--where "Diffusion='SiT-B' AND CFG = 1.0 AND n_samples = 50000" \
-o results/fid_vs_viv.png