Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit be4a6a7

Browse files
authored
Torchvision eval fixes (#1212) (#1217)
* Changing checkpoint_recipe to recipe, and --recipe no longer required * Changing recipe order for checkpoint path * Fixing model loading * Styling
1 parent 2099d33 commit be4a6a7

File tree

1 file changed

+31
-26
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+31
-26
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
download_framework_model_by_recipe_type,
4040
torch_distributed_zero_first,
4141
)
42+
from sparseml.pytorch.utils.model import load_model
4243
from sparsezoo import Model
4344

4445

@@ -332,6 +333,8 @@ def collate_fn(batch):
332333
model = torchvision.models.__dict__[args.arch_key](
333334
pretrained=args.pretrained, num_classes=num_classes
334335
)
336+
if args.checkpoint_path is not None:
337+
load_model(args.checkpoint_path, model, strict=True)
335338
else:
336339
raise ValueError(
337340
f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models"
@@ -418,35 +421,54 @@ def collate_fn(batch):
418421
checkpoint = _load_checkpoint(args.checkpoint_path)
419422

420423
# restore state from prior recipe
421-
manager = ScheduledModifierManager.from_yaml(args.recipe)
422-
checkpoint_manager = ScheduledModifierManager.from_yaml(
423-
checkpoint["checkpoint_recipe"]
424+
manager = (
425+
ScheduledModifierManager.from_yaml(args.recipe)
426+
if args.recipe is not None
427+
else None
424428
)
425-
checkpoint_manager.apply_structure(model, epoch=checkpoint["epoch"])
429+
checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint["recipe"])
426430
elif args.resume:
427431
checkpoint = _load_checkpoint(args.resume)
428432

429433
# NOTE: override manager with the checkpoint's manager
430-
manager = ScheduledModifierManager.from_yaml(checkpoint["checkpoint_recipe"])
434+
manager = ScheduledModifierManager.from_yaml(checkpoint["recipe"])
431435
checkpoint_manager = None
432436
manager.initialize(model, epoch=checkpoint["epoch"])
433437

434438
# NOTE: override start epoch
435439
args.start_epoch = checkpoint["epoch"] + 1
436440
else:
441+
if args.recipe is None:
442+
raise ValueError("Must specify --recipe if not loading from a checkpoint")
437443
checkpoint = None
438444
manager = ScheduledModifierManager.from_yaml(args.recipe)
439445
checkpoint_manager = None
440446

441447
# load params
442448
if checkpoint is not None:
443-
model.load_state_dict(checkpoint["state_dict"])
444449
optimizer.load_state_dict(checkpoint["optimizer"])
445450
if model_ema and "model_ema" in checkpoint:
446451
model_ema.load_state_dict(checkpoint["model_ema"])
447452
if scaler and "scaler" in checkpoint:
448453
scaler.load_state_dict(checkpoint["scaler"])
449454

455+
if args.test_only:
456+
# We disable the cudnn benchmarking because it can
457+
# noticeably affect the accuracy
458+
torch.backends.cudnn.benchmark = False
459+
torch.backends.cudnn.deterministic = True
460+
if model_ema:
461+
evaluate(
462+
model_ema,
463+
criterion,
464+
data_loader_test,
465+
device,
466+
log_suffix="EMA",
467+
)
468+
else:
469+
evaluate(model, criterion, data_loader_test, device)
470+
return
471+
450472
optimizer = manager.modify(model, optimizer, len(data_loader))
451473

452474
if manager.learning_rate_modifiers:
@@ -503,23 +525,6 @@ def collate_fn(batch):
503525
if args.resume and checkpoint:
504526
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
505527

506-
if args.test_only:
507-
# We disable the cudnn benchmarking because it can
508-
# noticeably affect the accuracy
509-
torch.backends.cudnn.benchmark = False
510-
torch.backends.cudnn.deterministic = True
511-
if model_ema:
512-
evaluate(
513-
model_ema,
514-
criterion,
515-
data_loader_test,
516-
device,
517-
log_suffix="EMA",
518-
)
519-
else:
520-
evaluate(model, criterion, data_loader_test, device)
521-
return
522-
523528
model_without_ddp = model
524529
if args.distributed:
525530
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
@@ -580,12 +585,12 @@ def collate_fn(batch):
580585
if epoch == manager.max_epochs - 1
581586
else epoch + checkpoint_manager.max_epochs
582587
)
583-
checkpoint["checkpoint_recipe"] = str(
588+
checkpoint["recipe"] = str(
584589
ScheduledModifierManager.compose_staged(checkpoint_manager, manager)
585590
)
586591
else:
587592
checkpoint["epoch"] = -1 if epoch == manager.max_epochs - 1 else epoch
588-
checkpoint["checkpoint_recipe"] = str(manager)
593+
checkpoint["recipe"] = str(manager)
589594

590595
file_names = ["checkpoint.pth"]
591596
if is_new_best:
@@ -658,7 +663,7 @@ def new_func(*args, **kwargs):
658663
allow_extra_args=True,
659664
)
660665
)
661-
@click.option("--recipe", required=True, type=str, help="Path to recipe")
666+
@click.option("--recipe", default=None, type=str, help="Path to recipe")
662667
@click.option("--dataset-path", required=True, type=str, help="dataset path")
663668
@click.option(
664669
"--arch-key",

0 commit comments

Comments
 (0)