Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d1303cc

Browse files
rahul-tulibfineran
authored andcommitted
Refactor: torchvision save checkpoint code, to include the checkpoint (#1334)
Refactor: Model creation code, to try `ModelRegistry.create` even when `arch_key` is `None`; This tries to read the arch_key from the model checkpoint, if key is not present, the relevant function still raise the respective errors
1 parent f7cc6a6 commit d1303cc

File tree

1 file changed

+7
-3
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+7
-3
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def collate_fn(batch):
368368

369369
_LOGGER.info("Creating model")
370370
local_rank = args.rank if args.distributed else None
371-
model = _create_model(
371+
model, arch_key = _create_model(
372372
arch_key=args.arch_key,
373373
local_rank=local_rank,
374374
pretrained=args.pretrained,
@@ -652,6 +652,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
652652
"state_dict": model_without_ddp.state_dict(),
653653
"optimizer": optimizer.state_dict(),
654654
"args": args,
655+
"arch_key": arch_key,
655656
}
656657
if lr_scheduler:
657658
checkpoint["lr_scheduler"] = lr_scheduler.state_dict()
@@ -703,7 +704,7 @@ def _create_model(
703704
device=None,
704705
num_classes=None,
705706
):
706-
if arch_key in ModelRegistry.available_keys():
707+
if not arch_key or arch_key in ModelRegistry.available_keys():
707708
with torch_distributed_zero_first(local_rank):
708709
model = ModelRegistry.create(
709710
key=arch_key,
@@ -712,6 +713,9 @@ def _create_model(
712713
pretrained_dataset=pretrained_dataset,
713714
num_classes=num_classes,
714715
)
716+
717+
if isinstance(model, tuple):
718+
model, arch_key = model
715719
elif arch_key in torchvision.models.__dict__:
716720
# fall back to torchvision
717721
model = torchvision.models.__dict__[arch_key](
@@ -724,7 +728,7 @@ def _create_model(
724728
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
725729
)
726730
model.to(device)
727-
return model
731+
return model, arch_key
728732

729733

730734
def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):

0 commit comments

Comments
 (0)