From 70a61bb2ad7b110b802e1980a8339bebbfc64b6a Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Thu, 28 May 2026 17:31:42 +0200 Subject: [PATCH 1/2] added sampling diagnostics --- src/weathergen/model/diffusion.py | 47 ++++++++++++++++++++++++++----- src/weathergen/model/model.py | 1 + 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 7909534d1..a05d6e0dd 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -343,6 +343,9 @@ def inference_forward( "l2_to_target": [], "cosine_to_target": [], "c_skip": [], + "d_cur_norm": [], + "d_cur_step_norm": [], + "residual_std": [], "x": [x.cpu()], } @@ -371,12 +374,12 @@ def inference_forward( denoised = self.denoise(x=x_hat, c=c, sigma=t_hat, fstep=fstep, coords=coords) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur - - # Apply 2nd order correction. - if i < num_steps - 1: - denoised = self.denoise(x=x_next, c=c, sigma=t_next, fstep=fstep, coords=coords) - d_prime = (x_next - denoised) / t_next - x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + # # Apply 2nd order correction. + # if i < num_steps - 1: + # denoised = self.denoise(x=x_next, c=c, sigma=t_next, fstep=fstep, coords=coords) + # d_prime = (x_next - denoised) / t_next + # x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) # --- Record diagnostics --- with torch.no_grad(): @@ -385,6 +388,9 @@ def inference_forward( track["c_skip"].append(self.sigma_data**2 / (s**2 + self.sigma_data**2)) track["x_std"].append(x_next.std().item()) track["denoised_std"].append(denoised.std().item()) + track["d_cur_norm"].append(d_cur.norm().item()) + track["d_cur_step_norm"].append(((t_next - t_hat) * d_cur).norm().item()) + track["residual_std"].append((x_hat - denoised).std().item()) track["x"].append(x_next.cpu()) if self.cur_token is not None: track["l2_to_target"].append((x_next - self.cur_token).norm().item()) @@ -407,7 +413,7 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: steps = list(range(len(track["sigma"]))) has_target = len(track["l2_to_target"]) > 0 - n_plots = 3 + n_plots = 7 fig, axes = plt.subplots(n_plots, 1, figsize=(10, 3 * n_plots), sharex=True) @@ -442,6 +448,33 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: axes[2].set_ylabel("L2 error to target") axes[2].grid(True, alpha=0.3) + # 4) d_cur norm and step norm + axes[3].semilogy(steps, track["d_cur_norm"], "o-", markersize=3, label="||d_cur||") + axes[3].semilogy(steps, track["d_cur_step_norm"], "^-", markersize=3, label="||(t_next - t_hat) * d_cur||") + axes[3].set_ylabel("norm (log scale)") + axes[3].set_title("ODE drift norms") + axes[3].legend(fontsize=8) + axes[3].grid(True, alpha=0.3) + + # 5) Residual std: Std(x_hat - denoised) + axes[4].semilogy(steps, track["residual_std"], "s-", markersize=3, color="tab:orange") + axes[4].set_ylabel("std (log scale)") + axes[4].set_title("Std(x_hat - denoised)") + axes[4].grid(True, alpha=0.3) + + # 6) Residual std zoomed to [0, 1] + axes[5].plot(steps, track["residual_std"], "s-", markersize=3, color="tab:orange") + axes[5].set_ylim(0, 1) + axes[5].set_ylabel("std (clipped to 1)") + axes[5].set_title("Std(x_hat - denoised) [y ≤ 1]") + axes[5].grid(True, alpha=0.3) + + # 7) Std of x_next over sampling steps + axes[6].semilogy(steps, track["x_std"], "o-", markersize=3, color="tab:blue") + axes[6].set_ylabel("std (log scale)") + axes[6].set_title("Std of x_next over denoising steps") + axes[6].grid(True, alpha=0.3) + axes[-1].set_xlabel("sampling step") fig.tight_layout() diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 20590e5bf..537c93983 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -758,6 +758,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: step, meta_info=batch.samples[0].meta_info, coords=model_params.rope_coords, + # num_steps=30 ) # Diffusion inference returns the per-ODE-step intermediate denoised tokens as a From e4f4892809ca3652ee36e7071cbc0f043b563686 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Thu, 28 May 2026 17:34:04 +0200 Subject: [PATCH 2/2] put back Heun correction --- src/weathergen/model/diffusion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index a05d6e0dd..f298f118c 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -375,11 +375,11 @@ def inference_forward( d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur - # # Apply 2nd order correction. - # if i < num_steps - 1: - # denoised = self.denoise(x=x_next, c=c, sigma=t_next, fstep=fstep, coords=coords) - # d_prime = (x_next - denoised) / t_next - # x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = self.denoise(x=x_next, c=c, sigma=t_next, fstep=fstep, coords=coords) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) # --- Record diagnostics --- with torch.no_grad():