Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 0031883

Browse files
corey-nmbfineran
authored andcommitted
Torchvision scaler (#1312)
* Passing scaler into torchvision manager.modify * Properly wrapping optim/scaler for amp * Properly disabling scaler
1 parent 401af9f commit 0031883

File tree

1 file changed

+11
-3
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+11
-3
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,17 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
574574
loggers=logger,
575575
distillation_teacher=args.distill_teacher,
576576
)
577-
optimizer = manager.modify(
578-
model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch
577+
step_wrapper = manager.modify(
578+
model,
579+
optimizer,
580+
steps_per_epoch=steps_per_epoch,
581+
epoch=args.start_epoch,
582+
wrap_optim=scaler,
579583
)
584+
if scaler is None:
585+
optimizer = step_wrapper
586+
else:
587+
scaler = step_wrapper
580588

581589
lr_scheduler = _get_lr_scheduler(
582590
args, optimizer, checkpoint=checkpoint, manager=manager
@@ -597,7 +605,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
597605
if args.distributed:
598606
train_sampler.set_epoch(epoch)
599607
if manager is not None and manager.qat_active(epoch=epoch):
600-
scaler = None
608+
scaler._enabled = False
601609
model_ema = None
602610

603611
train_metrics = train_one_epoch(

0 commit comments

Comments
 (0)