diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index ced49b97..7fe94afc 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -464,10 +464,12 @@ def __getitem__(self, index: ChunkedIndex | int | slice) -> Any: if isinstance(self.transform, list): for transform_fn in self.transform: item = transform_fn(item) + if item is None: + break else: item = self.transform(item) - return item + return item if item else self.__next__() def __next__(self) -> Any: # check if we have reached the end of the dataset (i.e., all the chunks have been processed) diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 323b7626..1e8dec4b 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -591,11 +591,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: FileLock(local_filepath + ".lock", timeout=0), tempfile.TemporaryDirectory() as tmpdir, ): - _, _, _, repo_org, repo_name, path = remote_filepath.split("/", 5) - repo_id = f"{repo_org}/{repo_name}" + _, _, _, repo_org, repo_name_revision, path = remote_filepath.split("/", 5) + splits = repo_name_revision.split("@", 2) + repo_id = f"{repo_org}/{splits[0]}" downloaded_path = hf_hub_download( repo_id, path, + revision=splits[1] if len(splits) == 2 else None, cache_dir=tmpdir, repo_type="dataset", **self._storage_options,