Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 5131c84

Browse files
Extend arguments of ImageNet (#1040)
* Add resize_scale and resize_mode arguments to allow for pre-processing flow of EfficientNets * Update src/sparseml/pytorch/datasets/classification/imagenet.py Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> * Inserted import * Added type check to eval_mode Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
1 parent 3c042f3 commit 5131c84

File tree

1 file changed

+13
-3
lines changed
  • src/sparseml/pytorch/datasets/classification

1 file changed

+13
-3
lines changed

src/sparseml/pytorch/datasets/classification/imagenet.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
ImageFolder = object # default for constructor
3232
torchvision_import_error = torchvision_error
3333

34+
from typing import Union
35+
3436
from sparseml.pytorch.datasets.image_classification.ffcv_dataset import (
3537
FFCVImageNetDataset,
3638
)
@@ -72,20 +74,28 @@ def __init__(
7274
train: bool = True,
7375
rand_trans: bool = False,
7476
image_size: int = 224,
77+
resize_scale: float = 1.143,
78+
resize_mode: Union[str, "transforms.InterpolationMode"] = "bilinear",
7579
):
7680
if torchvision_import_error is not None:
7781
raise torchvision_import_error
7882

7983
root = clean_path(root)
80-
non_rand_resize_scale = 256.0 / 224.0 # standard used
84+
if type(resize_mode) is str and resize_mode.lower() in ["linear", "bilinear"]:
85+
interpolation = transforms.InterpolationMode.BILINEAR
86+
elif type(resize_mode) is str and resize_mode.lower() in ["cubic", "bicubic"]:
87+
interpolation = transforms.InterpolationMode.BICUBIC
88+
8189
init_trans = (
8290
[
83-
transforms.RandomResizedCrop(image_size),
91+
transforms.RandomResizedCrop(image_size, interpolation=interpolation),
8492
transforms.RandomHorizontalFlip(),
8593
]
8694
if rand_trans
8795
else [
88-
transforms.Resize(round(non_rand_resize_scale * image_size)),
96+
transforms.Resize(
97+
round(resize_scale * image_size), interpolation=interpolation
98+
),
8999
transforms.CenterCrop(image_size),
90100
]
91101
)

0 commit comments

Comments
 (0)