Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/lmms_engine/datasets/iterable/multimodal_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(self, config) -> None:
self.storage_client = BlobServiceClient(account_url=SAS_URL, retry_policy=RETRY_POLICY)
self.bucket_name = self.config.bucket_name
self.cur_idx = 0
# Set to True by load_state_dict; consumed (and cleared) by __iter__ so
# that the next iteration pass continues from the saved cur_idx instead
# of being reset to 0.
self._resuming = False
if not dist.is_initialized():
logger.info("Distributed environment not initialized, setting rank and world size to 0 and 1")
self.rank = 0
Expand Down Expand Up @@ -166,8 +170,12 @@ def __iter__(self):
curr_data_folder = curr_data_folder[iter_start:iter_end]

if self.config.packing:
# Reset index at the start of each iteration pass
self.cur_idx = 0
# Reset index only when starting a fresh pass. When resuming from a
# checkpoint via StatefulDataLoader, load_state_dict() has set
# self.cur_idx and self._resuming=True, so we keep cur_idx as-is.
if not self._resuming:
self.cur_idx = 0
self._resuming = False
packing_length = self.config.packing_length
packing_kwargs = self.config.extra_kwargs.get("packing_kwargs", {})
packer = build_online_packer(
Expand Down Expand Up @@ -204,7 +212,9 @@ def __iter__(self):
# Flush remaining buffer
yield from packer.flush()
else:
self.cur_idx = 0
if not self._resuming:
self.cur_idx = 0
self._resuming = False
while self.cur_idx < len(curr_data_list):
try:
yield self.get_one_sample(self.cur_idx, curr_data_folder[self.cur_idx], curr_data_list)
Expand All @@ -215,6 +225,24 @@ def __iter__(self):
continue
self.cur_idx += 1

def state_dict(self):
"""Stateful protocol for torchdata.StatefulDataLoader.

Called by StatefulDataLoader inside each worker process. We only need
to persist the per-worker iteration cursor; rank/worker sharding is
deterministic given the data_seed + world_size + num_workers.
"""
return {"cur_idx": self.cur_idx}

def load_state_dict(self, state_dict):
"""Restore the per-worker iteration cursor saved by state_dict().

Sets ``_resuming=True`` so that the next ``__iter__`` pass does NOT
reset ``cur_idx`` to 0 and instead continues from the restored position.
"""
self.cur_idx = int(state_dict.get("cur_idx", 0))
self._resuming = True

def load_from_json(self, data, data_folder=None):
"""
Default implementation for loading from JSON format.
Expand Down
Loading