Skip to content

Conversation

@vijk777
Copy link
Collaborator

@vijk777 vijk777 commented Jan 22, 2026

We add two modes:

  • time_aligned: data available at 0, tu, 2tu, ...
  • staggered_random: each neuron has data at an interval of tu like above. But the neurons have random acquisition offsets 0, 1, ..., tu-1. This is like the zapbench data.
  time_aligned mode (tu=5)
  ============================
  All neurons observed simultaneously at regular intervals

  Time:     0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
           ┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┐
  Neuron 1 │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│
  Neuron 2 │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│
  Neuron 3 │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│
  Neuron 4 │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│
  Neuron 5 │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│
           └──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┘
           ↑              ↑              ↑              ↑              ↑
         t=0            t=5           t=10           t=15           t=20


  staggered_random mode (tu=5, different phases per neuron)
  ==========================================================
  Each neuron observed at different phase + k*tu

  Time:     0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
           ┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┐
  Neuron 1 │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  phase=0
  Neuron 2 │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │  phase=1
  Neuron 3 │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  phase=3
  Neuron 4 │  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  phase=4
  Neuron 5 │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  │ X│  │  │  │  phase=2
           └──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┘
           ↑  ↑  ↑  ↑  ↑                                                
        phases 0-4 spread across first tu timesteps
        (then repeat every tu steps)

Deprecate intermediate loss points.

Intermediate results are quite promising for time-aligned acquisition (middle), while naively applying the current model to staggered acquisition fails (right). The left column is the baseline where we use all the data. And I've selected epoch 20/100. We see errors of ~ 1e-2 and I am also showing the constant baseline and we are 100x better than that.
image

This is different from what we saw on the old DAVIS data - where using time-aligned acquisition resulted in poor generalization and a divergence.

vijk777 and others added 11 commits January 22, 2026 00:31
- mark intermediate_loss_steps as deprecated with validator
- remove intermediate loss computation logic from train_step
- loss now only computed at time_units multiples
- simplifies training loop
- add acquisition.py module with acquisition mode types (preparation for future work)
- add time_units parameter to RandomChunkLoader
- chunk starts are now aligned to time_units multiples
- get_next_chunk now returns (chunk_start, chunk_data, chunk_stim)
- update all call sites in latent.py
- update all tests for new 3-tuple return signature
- add acquisition_mode config field to TrainingConfig
- verifies chunk starts are multiples of time_units
- ensures randomness is preserved with alignment constraint
- test_acquisition_mode_batch_sampling_bounds: verifies observation indices stay within chunk bounds for both time_aligned and staggered_random modes, includes np.diff check for time_units spacing
- test_staggered_complete_coverage: validates phase distribution
- test_acquisition_with_chunk_boundaries: stress tests boundary conditions
- Move compute_neuron_phases() call before reconstruction warmup
- Remove duplicate neuron_phases computation after warmup
- Remove unused sample_batch_within_chunk import
- Ensures neuron_phases is available for both warmup and main training

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Fix tensor indexing in train_step functions where observation_indices
is used. The old indexing `train_data[observation_indices]` with shapes
(T, N) and (b, N) created intermediate tensors of shape (b, N, N) due to
broadcasting, causing OOM errors (~24 GB per indexing operation with
N=13741).

Use proper 2D advanced indexing with neuron_indices to get the correct
(b, N) output:
- train_data[observation_indices, neuron_indices]

This fixes OOM errors when running acquisition modes with realistic
neuron counts.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Add test assertion to verify that all observation indices in time_aligned
mode are divisible by time_units, not just that spacing is correct.

This ensures time_aligned mode correctly enforces observations at
0, tu, 2*tu, 3*tu, ... as specified.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Add test assertion to verify that for each neuron in staggered_random
mode, when you subtract its phase from its observation indices, the
result is divisible by time_units.

This ensures staggered_random mode correctly enforces observations at
phase_n + 0*tu, phase_n + 1*tu, phase_n + 2*tu, ... for each neuron n.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@vijk777 vijk777 merged commit 5521cad into main Jan 22, 2026
4 checks passed
@vijk777 vijk777 deleted the vj/acquisition branch January 22, 2026 13:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants