Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/weathergen/model/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ def inference_forward(
"l2_to_target": [],
"cosine_to_target": [],
"c_skip": [],
"d_cur_norm": [],
"d_cur_step_norm": [],
"residual_std": [],
"x": [x.cpu()],
}

Expand Down Expand Up @@ -371,7 +374,7 @@ def inference_forward(
denoised = self.denoise(x=x_hat, c=c, sigma=t_hat, fstep=fstep, coords=coords)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur

# Apply 2nd order correction.
if i < num_steps - 1:
denoised = self.denoise(x=x_next, c=c, sigma=t_next, fstep=fstep, coords=coords)
Expand All @@ -385,6 +388,9 @@ def inference_forward(
track["c_skip"].append(self.sigma_data**2 / (s**2 + self.sigma_data**2))
track["x_std"].append(x_next.std().item())
track["denoised_std"].append(denoised.std().item())
track["d_cur_norm"].append(d_cur.norm().item())
track["d_cur_step_norm"].append(((t_next - t_hat) * d_cur).norm().item())
track["residual_std"].append((x_hat - denoised).std().item())
track["x"].append(x_next.cpu())
if self.cur_token is not None:
track["l2_to_target"].append((x_next - self.cur_token).norm().item())
Expand All @@ -407,7 +413,7 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None:

steps = list(range(len(track["sigma"])))
has_target = len(track["l2_to_target"]) > 0
n_plots = 3
n_plots = 7

fig, axes = plt.subplots(n_plots, 1, figsize=(10, 3 * n_plots), sharex=True)

Expand Down Expand Up @@ -442,6 +448,33 @@ def _plot_sampling_diagnostics(self, track: dict, num_steps: int) -> None:
axes[2].set_ylabel("L2 error to target")
axes[2].grid(True, alpha=0.3)

# 4) d_cur norm and step norm
axes[3].semilogy(steps, track["d_cur_norm"], "o-", markersize=3, label="||d_cur||")
axes[3].semilogy(steps, track["d_cur_step_norm"], "^-", markersize=3, label="||(t_next - t_hat) * d_cur||")
axes[3].set_ylabel("norm (log scale)")
axes[3].set_title("ODE drift norms")
axes[3].legend(fontsize=8)
axes[3].grid(True, alpha=0.3)

# 5) Residual std: Std(x_hat - denoised)
axes[4].semilogy(steps, track["residual_std"], "s-", markersize=3, color="tab:orange")
axes[4].set_ylabel("std (log scale)")
axes[4].set_title("Std(x_hat - denoised)")
axes[4].grid(True, alpha=0.3)

# 6) Residual std zoomed to [0, 1]
axes[5].plot(steps, track["residual_std"], "s-", markersize=3, color="tab:orange")
axes[5].set_ylim(0, 1)
axes[5].set_ylabel("std (clipped to 1)")
axes[5].set_title("Std(x_hat - denoised) [y ≤ 1]")
axes[5].grid(True, alpha=0.3)

# 7) Std of x_next over sampling steps
axes[6].semilogy(steps, track["x_std"], "o-", markersize=3, color="tab:blue")
axes[6].set_ylabel("std (log scale)")
axes[6].set_title("Std of x_next over denoising steps")
axes[6].grid(True, alpha=0.3)

axes[-1].set_xlabel("sampling step")
fig.tight_layout()

Expand Down
1 change: 1 addition & 0 deletions src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput:
step,
meta_info=batch.samples[0].meta_info,
coords=model_params.rope_coords,
# num_steps=30
)

# Diffusion inference returns the per-ODE-step intermediate denoised tokens as a
Expand Down
Loading