From 0b76a3f8c6483c4f20867c52e2702adffd44f0ff Mon Sep 17 00:00:00 2001 From: kcz358 Date: Mon, 18 May 2026 04:22:41 -0700 Subject: [PATCH] fix(iterable-dataset): implement stateful protocol for resume MultiModalIterableDataset did not implement state_dict / load_state_dict, so StatefulDataLoader could not persist the per-worker iteration cursor. On resume, __iter__ unconditionally reset cur_idx=0 and the trainer naive-forwarded from the start of the data even though global_step was restored, wasting compute. - Add state_dict() / load_state_dict() returning {cur_idx}. - Guard the cur_idx=0 resets in __iter__ behind a _resuming flag that load_state_dict sets and __iter__ consumes once. Verified by checkpointing at step 2 (workers at cur_idx=43/24) and resuming for 2 more steps: workers continue to 68/66 instead of restarting from 0. --- .../iterable/multimodal_iterable_dataset.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/lmms_engine/datasets/iterable/multimodal_iterable_dataset.py b/src/lmms_engine/datasets/iterable/multimodal_iterable_dataset.py index 63e98408..48a4b516 100644 --- a/src/lmms_engine/datasets/iterable/multimodal_iterable_dataset.py +++ b/src/lmms_engine/datasets/iterable/multimodal_iterable_dataset.py @@ -51,6 +51,10 @@ def __init__(self, config) -> None: self.storage_client = BlobServiceClient(account_url=SAS_URL, retry_policy=RETRY_POLICY) self.bucket_name = self.config.bucket_name self.cur_idx = 0 + # Set to True by load_state_dict; consumed (and cleared) by __iter__ so + # that the next iteration pass continues from the saved cur_idx instead + # of being reset to 0. + self._resuming = False if not dist.is_initialized(): logger.info("Distributed environment not initialized, setting rank and world size to 0 and 1") self.rank = 0 @@ -166,8 +170,12 @@ def __iter__(self): curr_data_folder = curr_data_folder[iter_start:iter_end] if self.config.packing: - # Reset index at the start of each iteration pass - self.cur_idx = 0 + # Reset index only when starting a fresh pass. When resuming from a + # checkpoint via StatefulDataLoader, load_state_dict() has set + # self.cur_idx and self._resuming=True, so we keep cur_idx as-is. + if not self._resuming: + self.cur_idx = 0 + self._resuming = False packing_length = self.config.packing_length packing_kwargs = self.config.extra_kwargs.get("packing_kwargs", {}) packer = build_online_packer( @@ -204,7 +212,9 @@ def __iter__(self): # Flush remaining buffer yield from packer.flush() else: - self.cur_idx = 0 + if not self._resuming: + self.cur_idx = 0 + self._resuming = False while self.cur_idx < len(curr_data_list): try: yield self.get_one_sample(self.cur_idx, curr_data_folder[self.cur_idx], curr_data_list) @@ -215,6 +225,24 @@ def __iter__(self): continue self.cur_idx += 1 + def state_dict(self): + """Stateful protocol for torchdata.StatefulDataLoader. + + Called by StatefulDataLoader inside each worker process. We only need + to persist the per-worker iteration cursor; rank/worker sharding is + deterministic given the data_seed + world_size + num_workers. + """ + return {"cur_idx": self.cur_idx} + + def load_state_dict(self, state_dict): + """Restore the per-worker iteration cursor saved by state_dict(). + + Sets ``_resuming=True`` so that the next ``__iter__`` pass does NOT + reset ``cur_idx`` to 0 and instead continues from the restored position. + """ + self.cur_idx = int(state_dict.get("cur_idx", 0)) + self._resuming = True + def load_from_json(self, data, data_folder=None): """ Default implementation for loading from JSON format.