|
39 | 39 | download_framework_model_by_recipe_type, |
40 | 40 | torch_distributed_zero_first, |
41 | 41 | ) |
| 42 | +from sparseml.pytorch.utils.model import load_model |
42 | 43 | from sparsezoo import Model |
43 | 44 |
|
44 | 45 |
|
@@ -332,6 +333,8 @@ def collate_fn(batch): |
332 | 333 | model = torchvision.models.__dict__[args.arch_key]( |
333 | 334 | pretrained=args.pretrained, num_classes=num_classes |
334 | 335 | ) |
| 336 | + if args.checkpoint_path is not None: |
| 337 | + load_model(args.checkpoint_path, model, strict=True) |
335 | 338 | else: |
336 | 339 | raise ValueError( |
337 | 340 | f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models" |
@@ -418,35 +421,54 @@ def collate_fn(batch): |
418 | 421 | checkpoint = _load_checkpoint(args.checkpoint_path) |
419 | 422 |
|
420 | 423 | # restore state from prior recipe |
421 | | - manager = ScheduledModifierManager.from_yaml(args.recipe) |
422 | | - checkpoint_manager = ScheduledModifierManager.from_yaml( |
423 | | - checkpoint["checkpoint_recipe"] |
| 424 | + manager = ( |
| 425 | + ScheduledModifierManager.from_yaml(args.recipe) |
| 426 | + if args.recipe is not None |
| 427 | + else None |
424 | 428 | ) |
425 | | - checkpoint_manager.apply_structure(model, epoch=checkpoint["epoch"]) |
| 429 | + checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint["recipe"]) |
426 | 430 | elif args.resume: |
427 | 431 | checkpoint = _load_checkpoint(args.resume) |
428 | 432 |
|
429 | 433 | # NOTE: override manager with the checkpoint's manager |
430 | | - manager = ScheduledModifierManager.from_yaml(checkpoint["checkpoint_recipe"]) |
| 434 | + manager = ScheduledModifierManager.from_yaml(checkpoint["recipe"]) |
431 | 435 | checkpoint_manager = None |
432 | 436 | manager.initialize(model, epoch=checkpoint["epoch"]) |
433 | 437 |
|
434 | 438 | # NOTE: override start epoch |
435 | 439 | args.start_epoch = checkpoint["epoch"] + 1 |
436 | 440 | else: |
| 441 | + if args.recipe is None: |
| 442 | + raise ValueError("Must specify --recipe if not loading from a checkpoint") |
437 | 443 | checkpoint = None |
438 | 444 | manager = ScheduledModifierManager.from_yaml(args.recipe) |
439 | 445 | checkpoint_manager = None |
440 | 446 |
|
441 | 447 | # load params |
442 | 448 | if checkpoint is not None: |
443 | | - model.load_state_dict(checkpoint["state_dict"]) |
444 | 449 | optimizer.load_state_dict(checkpoint["optimizer"]) |
445 | 450 | if model_ema and "model_ema" in checkpoint: |
446 | 451 | model_ema.load_state_dict(checkpoint["model_ema"]) |
447 | 452 | if scaler and "scaler" in checkpoint: |
448 | 453 | scaler.load_state_dict(checkpoint["scaler"]) |
449 | 454 |
|
| 455 | + if args.test_only: |
| 456 | + # We disable the cudnn benchmarking because it can |
| 457 | + # noticeably affect the accuracy |
| 458 | + torch.backends.cudnn.benchmark = False |
| 459 | + torch.backends.cudnn.deterministic = True |
| 460 | + if model_ema: |
| 461 | + evaluate( |
| 462 | + model_ema, |
| 463 | + criterion, |
| 464 | + data_loader_test, |
| 465 | + device, |
| 466 | + log_suffix="EMA", |
| 467 | + ) |
| 468 | + else: |
| 469 | + evaluate(model, criterion, data_loader_test, device) |
| 470 | + return |
| 471 | + |
450 | 472 | optimizer = manager.modify(model, optimizer, len(data_loader)) |
451 | 473 |
|
452 | 474 | if manager.learning_rate_modifiers: |
@@ -503,23 +525,6 @@ def collate_fn(batch): |
503 | 525 | if args.resume and checkpoint: |
504 | 526 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
505 | 527 |
|
506 | | - if args.test_only: |
507 | | - # We disable the cudnn benchmarking because it can |
508 | | - # noticeably affect the accuracy |
509 | | - torch.backends.cudnn.benchmark = False |
510 | | - torch.backends.cudnn.deterministic = True |
511 | | - if model_ema: |
512 | | - evaluate( |
513 | | - model_ema, |
514 | | - criterion, |
515 | | - data_loader_test, |
516 | | - device, |
517 | | - log_suffix="EMA", |
518 | | - ) |
519 | | - else: |
520 | | - evaluate(model, criterion, data_loader_test, device) |
521 | | - return |
522 | | - |
523 | 528 | model_without_ddp = model |
524 | 529 | if args.distributed: |
525 | 530 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) |
@@ -580,12 +585,12 @@ def collate_fn(batch): |
580 | 585 | if epoch == manager.max_epochs - 1 |
581 | 586 | else epoch + checkpoint_manager.max_epochs |
582 | 587 | ) |
583 | | - checkpoint["checkpoint_recipe"] = str( |
| 588 | + checkpoint["recipe"] = str( |
584 | 589 | ScheduledModifierManager.compose_staged(checkpoint_manager, manager) |
585 | 590 | ) |
586 | 591 | else: |
587 | 592 | checkpoint["epoch"] = -1 if epoch == manager.max_epochs - 1 else epoch |
588 | | - checkpoint["checkpoint_recipe"] = str(manager) |
| 593 | + checkpoint["recipe"] = str(manager) |
589 | 594 |
|
590 | 595 | file_names = ["checkpoint.pth"] |
591 | 596 | if is_new_best: |
@@ -658,7 +663,7 @@ def new_func(*args, **kwargs): |
658 | 663 | allow_extra_args=True, |
659 | 664 | ) |
660 | 665 | ) |
661 | | -@click.option("--recipe", required=True, type=str, help="Path to recipe") |
| 666 | +@click.option("--recipe", default=None, type=str, help="Path to recipe") |
662 | 667 | @click.option("--dataset-path", required=True, type=str, help="dataset path") |
663 | 668 | @click.option( |
664 | 669 | "--arch-key", |
|
0 commit comments