Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit c8aa89a

Browse files
authored
[cherry-pick] release fixes 6/7/22 (#852)
* bug fix for image classificaiton eval whith no recipe provided (#850) * add click as pytorch dep (#851) for running image_classification scripts
1 parent 8cf5aa9 commit c8aa89a

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
"tensorboard>=1.0",
6363
"tensorboardX>=1.0",
6464
"gputils",
65+
"click<8.1",
6566
]
6667
_pytorch_vision_deps = _pytorch_deps + ["torchvision>=0.3.0,<=0.10.1"]
6768
_tensorflow_v1_deps = ["tensorflow<2.0.0", "tensorboard<2.0.0", "tf2onnx>=1.0.0,<1.6"]

src/sparseml/pytorch/image_classification/train.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -545,18 +545,22 @@ def main(
545545
train_batch_size = train_batch_size // world_size
546546
helpers.set_seeds(local_rank=local_rank)
547547

548-
train_dataset, train_loader, = helpers.get_dataset_and_dataloader(
549-
dataset_name=dataset,
550-
dataset_path=dataset_path,
551-
batch_size=train_batch_size,
552-
image_size=image_size,
553-
dataset_kwargs=dataset_kwargs,
554-
training=True,
555-
loader_num_workers=loader_num_workers,
556-
loader_pin_memory=loader_pin_memory,
557-
ffcv=ffcv,
558-
device=device,
559-
)
548+
if not eval_mode:
549+
train_dataset, train_loader, = helpers.get_dataset_and_dataloader(
550+
dataset_name=dataset,
551+
dataset_path=dataset_path,
552+
batch_size=train_batch_size,
553+
image_size=image_size,
554+
dataset_kwargs=dataset_kwargs,
555+
training=True,
556+
loader_num_workers=loader_num_workers,
557+
loader_pin_memory=loader_pin_memory,
558+
ffcv=ffcv,
559+
device=device,
560+
)
561+
else:
562+
train_dataset = None
563+
train_loader = None
560564

561565
val_dataset, val_loader = (
562566
helpers.get_dataset_and_dataloader(

src/sparseml/pytorch/image_classification/utils/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@ def run_one_epoch(
195195
if not (train_mode or validation_mode):
196196
raise ValueError(f"Invalid train mode '{mode}', must be 'train' or 'val'")
197197

198-
if torch.__version__ < "1.9" and self.manager.qat_active(epoch=self.epoch):
198+
if (
199+
torch.__version__ < "1.9"
200+
and self.manager
201+
and (self.manager.qat_active(epoch=self.epoch))
202+
):
199203
# switch off fp16
200204
self._device_context.use_mixed_precision = False
201205

0 commit comments

Comments
 (0)