Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 5660cea

Browse files
authored
Fixing bugs with resume & checkpointing in torchvision (#1190) (#1191)
1 parent 219e668 commit 5660cea

File tree

1 file changed

+2
-2
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+2
-2
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def collate_fn(batch):
440440

441441
# load params
442442
if checkpoint is not None:
443-
model.load_state_dict(checkpoint["model"])
443+
model.load_state_dict(checkpoint["state_dict"])
444444
optimizer.load_state_dict(checkpoint["optimizer"])
445445
if model_ema and "model_ema" in checkpoint:
446446
model_ema.load_state_dict(checkpoint["model_ema"])
@@ -782,7 +782,7 @@ def new_func(*args, **kwargs):
782782
)
783783
@click.option("--print-freq", default=10, type=int, help="print frequency")
784784
@click.option("--output-dir", default=".", type=str, help="path to save outputs")
785-
@click.option("--resume", default="", type=str, help="path of checkpoint")
785+
@click.option("--resume", default=None, type=str, help="path of checkpoint")
786786
@click.option(
787787
"--checkpoint-path",
788788
default=None,

0 commit comments

Comments
 (0)