Skip to content

Conversation

@vijk777
Copy link
Collaborator

@vijk777 vijk777 commented Jan 22, 2026

  • fix post_run_analyze

  • add a new baseline for linear interpolation

  • add a plot of intervening mse along with constant baseline
    time_aligned_mse_latent

  • add an early stop if the intervening mse doesn't improve too much (and we have some roll out stability)

vijk777 and others added 12 commits January 22, 2026 08:31
add early stopping criterion that monitors max intervening mse between
steps 0 and tu-1. stops training early if:
- first divergence > min_divergence (default 1000 steps)
- max intervening mse doesn't improve by 10% in patience epochs (default 10)

this prevents wasted training time when the model cannot learn to
predict intermediate steps between observation points.

config parameters:
- early_stop_intervening_mse: enable/disable (default false)
- early_stop_patience_epochs: patience for improvement (default 10)
- early_stop_min_divergence: min divergence threshold (default 1000)

new metric tracked in diagnostics:
- time_aligned_mse_{rollout_type}_max_intervening_0_to_tu: max mse in
  steps 0 to tu-2 (worst case over starts, avg over neurons)

early stopping checkpoint saved as checkpoint_early_stop.pt when triggered.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
use load_val_only() instead of load_dataset() and load neuron_data
separately using load_metadata() + NeuronData.from_metadata().

fixes ValueError: not enough values to unpack (expected 7, got 5)
neuron_data is guaranteed to be the same across configs, so just
reuse the already-loaded neuron_data instead of loading it again
the alpha for linear interpolation should be relative to t_start and t_end,
not include start_idx. start_idx is the absolute time index while t_start
and t_end are relative time points (0, tu, 2*tu, etc).

fixes incorrect linear interpolation MSE (~1e4 instead of reasonable values)
…rvations

interpolate between observations at t_start and t_end using:
alpha = (actual_time - t_start) / (t_end - t_start)

this ensures MSE=0 at observation points (tu, 2*tu, 3*tu, ...)
fixes:
- use correct observation indices: x_gt[i*tu] instead of x_gt[i]
- only compute mse on predictions (first total_steps), not final observation
- convert to numpy array before accumulating

ensures mse is 0 at observation points (tu, 2*tu, 3*tu, ...)
adds small dots (markersize=3) to the model mse line to show all
individual time steps, while keeping large markers (s=100) for
observation points where loss is applied
increase markersize from 3 to 5 and use circle markers for better visibility
use scatter with s=50 for all data points and s=100 for training points,
making it clearer which points are observation points vs training points
- add s=30 markers to linear interpolation baseline
- reduce model mse markers from s=50 to s=30
- remove separate legend entry for data points
- keep training points at s=100

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@vijk777 vijk777 merged commit 2feccc8 into main Jan 22, 2026
2 checks passed
@vijk777 vijk777 deleted the vj/mse_metrics branch January 22, 2026 17:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants