Skip to content

fix(iterable-dataset): implement stateful protocol for resume#172

Merged
kcz358 merged 1 commit into
mainfrom
fix/iterable-dataset-stateful
May 18, 2026
Merged

fix(iterable-dataset): implement stateful protocol for resume#172
kcz358 merged 1 commit into
mainfrom
fix/iterable-dataset-stateful

Conversation

@kcz358
Copy link
Copy Markdown
Collaborator

@kcz358 kcz358 commented May 18, 2026

Problem

MultiModalIterableDataset did not implement state_dict() / load_state_dict(), so StatefulDataLoader had no way to persist the per-worker iteration cursor. On resume, __iter__ unconditionally reset self.cur_idx = 0, so even though global_step and dataloader-level fetch state were restored, the dataset re-walked from the start of the shard — the trainer effectively naive-forwarded a chunk of data before producing the first useful batch.

Fix

  • Add state_dict() / load_state_dict() that persist {cur_idx} (called per worker by StatefulDataLoader).
  • Guard the cur_idx = 0 resets in both __iter__ branches (packing / non-packing) behind a _resuming flag that load_state_dict sets and __iter__ consumes once.

Per-worker sharding is deterministic given data_seed + world_size + num_workers, so persisting cur_idx alone is sufficient as long as those don't change across resume.

Verification

Two-phase test with scripts/launch/qwen3_vl_test.sh on 4 GPUs / 2 workers per rank, packing enabled, save_steps=2:

Phase Command rank0 worker_0 cur_idx rank0 worker_1 cur_idx
1. Fresh train 2 steps max_steps=2 save_steps=2 43 24
2. Resume + 2 more steps max_steps=4 save_steps=2 68 (+25) 66 (+42)

Read directly from checkpoint-*/dataloader_state/*.pt_snapshot._worker_snapshots[*].dataset_state. Workers continue from the saved position instead of restarting at 0.

Notes

  • Packer internal buffer is not persisted, so the very first packed batch after resume may differ slightly from a non-interrupted run. If exact reproducibility is needed, a follow-up can add state_dict to the online packer too.
  • No trainer-side changes needed.

MultiModalIterableDataset did not implement state_dict / load_state_dict,
so StatefulDataLoader could not persist the per-worker iteration cursor.
On resume, __iter__ unconditionally reset cur_idx=0 and the trainer
naive-forwarded from the start of the data even though global_step was
restored, wasting compute.

- Add state_dict() / load_state_dict() returning {cur_idx}.
- Guard the cur_idx=0 resets in __iter__ behind a _resuming flag that
  load_state_dict sets and __iter__ consumes once.

Verified by checkpointing at step 2 (workers at cur_idx=43/24) and
resuming for 2 more steps: workers continue to 68/66 instead of
restarting from 0.
@kcz358 kcz358 merged commit 4ed3da8 into main May 18, 2026
3 checks passed
@kcz358 kcz358 deleted the fix/iterable-dataset-stateful branch May 18, 2026 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant