diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index 7909534d1..f298f118c 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -343,6 +343,9 @@ def inference_forward( "l2_to_target": [], "cosine_to_target": [], "c_skip": [], + "d_cur_norm": [], + "d_cur_step_norm": [], + "residual_std": [], "x": [x.cpu()], } @@ -371,7 +374,7 @@ def inference_forward( denoised = self.denoise(x=x_hat, c=c, sigma=t_hat, fstep=fstep, coords=coords) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur - + # Apply 2nd order correction. if i < num_steps - 1: denoised = self.denoise(x=x_next, c=c, sigma=t_next, fstep=fstep, coords=coords) @@ -385,6 +388,9 @@ def inference_forward( track["c_skip"].append(self.sigma_data**2 / (s**2 + self.sigma_data**2)) track["x_std"].append(x_next.std().item()) track["denoised_std"].append(denoised.std().item()) + track["d_cur_norm"].append(d_cur.norm().item()) + track["d_cur_step_norm"].append(((t_next - t_hat) * d_cur).norm().item()) + track["residual_std"].append((x_hat - denoised).std().item()) track["x"].append(x_next.cpu()) if self.cur_token is not None: track["l2_to_target"].append((x_next - self.cur_token).norm().item()) @@ -407,7 +413,7 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: steps = list(range(len(track["sigma"]))) has_target = len(track["l2_to_target"]) > 0 - n_plots = 3 + n_plots = 7 fig, axes = plt.subplots(n_plots, 1, figsize=(10, 3 * n_plots), sharex=True) @@ -442,6 +448,33 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None: axes[2].set_ylabel("L2 error to target") axes[2].grid(True, alpha=0.3) + # 4) d_cur norm and step norm + axes[3].semilogy(steps, track["d_cur_norm"], "o-", markersize=3, label="||d_cur||") + axes[3].semilogy(steps, track["d_cur_step_norm"], "^-", markersize=3, label="||(t_next - t_hat) * d_cur||") + axes[3].set_ylabel("norm (log scale)") + axes[3].set_title("ODE drift norms") + axes[3].legend(fontsize=8) + axes[3].grid(True, alpha=0.3) + + # 5) Residual std: Std(x_hat - denoised) + axes[4].semilogy(steps, track["residual_std"], "s-", markersize=3, color="tab:orange") + axes[4].set_ylabel("std (log scale)") + axes[4].set_title("Std(x_hat - denoised)") + axes[4].grid(True, alpha=0.3) + + # 6) Residual std zoomed to [0, 1] + axes[5].plot(steps, track["residual_std"], "s-", markersize=3, color="tab:orange") + axes[5].set_ylim(0, 1) + axes[5].set_ylabel("std (clipped to 1)") + axes[5].set_title("Std(x_hat - denoised) [y ≤ 1]") + axes[5].grid(True, alpha=0.3) + + # 7) Std of x_next over sampling steps + axes[6].semilogy(steps, track["x_std"], "o-", markersize=3, color="tab:blue") + axes[6].set_ylabel("std (log scale)") + axes[6].set_title("Std of x_next over denoising steps") + axes[6].grid(True, alpha=0.3) + axes[-1].set_xlabel("sampling step") fig.tight_layout() diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 20590e5bf..537c93983 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -758,6 +758,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: step, meta_info=batch.samples[0].meta_info, coords=model_params.rope_coords, + # num_steps=30 ) # Diffusion inference returns the per-ODE-step intermediate denoised tokens as a