|
23 | 23 | import warnings |
24 | 24 | from functools import update_wrapper |
25 | 25 | from types import SimpleNamespace |
26 | | -from typing import Callable |
| 26 | +from typing import Callable, Optional |
27 | 27 |
|
28 | 28 | import torch |
29 | 29 | import torch.utils.data |
@@ -64,6 +64,7 @@ def train_one_epoch( |
64 | 64 | epoch: int, |
65 | 65 | args, |
66 | 66 | log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None], |
| 67 | + manager=None, |
67 | 68 | model_ema=None, |
68 | 69 | scaler=None, |
69 | 70 | ) -> utils.MetricLogger: |
@@ -92,13 +93,24 @@ def train_one_epoch( |
92 | 93 | start_time = time.time() |
93 | 94 | image, target = image.to(device), target.to(device) |
94 | 95 | with torch.cuda.amp.autocast(enabled=scaler is not None): |
95 | | - output = model(image) |
| 96 | + outputs = output = model(image) |
96 | 97 | if isinstance(output, tuple): |
97 | 98 | # NOTE: sparseml models return two things (logits & probs) |
98 | 99 | output = output[0] |
99 | 100 | loss = criterion(output, target) |
100 | 101 |
|
101 | 102 | if steps_accumulated % accum_steps == 0: |
| 103 | + if manager is not None: |
| 104 | + loss = manager.loss_update( |
| 105 | + loss=loss, |
| 106 | + module=model, |
| 107 | + optimizer=optimizer, |
| 108 | + epoch=epoch, |
| 109 | + steps_per_epoch=len(data_loader) / accum_steps, |
| 110 | + student_outputs=outputs, |
| 111 | + student_inputs=image, |
| 112 | + ) |
| 113 | + |
102 | 114 | # first: do training to consume gradients |
103 | 115 | if scaler is not None: |
104 | 116 | scaler.scale(loss).backward() |
@@ -355,27 +367,28 @@ def collate_fn(batch): |
355 | 367 | ) |
356 | 368 |
|
357 | 369 | _LOGGER.info("Creating model") |
358 | | - if args.arch_key in ModelRegistry.available_keys(): |
359 | | - with torch_distributed_zero_first(args.rank if args.distributed else None): |
360 | | - model = ModelRegistry.create( |
361 | | - key=args.arch_key, |
362 | | - pretrained=args.pretrained, |
363 | | - pretrained_path=args.checkpoint_path, |
364 | | - pretrained_dataset=args.pretrained_dataset, |
365 | | - num_classes=num_classes, |
366 | | - ) |
367 | | - elif args.arch_key in torchvision.models.__dict__: |
368 | | - # fall back to torchvision |
369 | | - model = torchvision.models.__dict__[args.arch_key]( |
370 | | - pretrained=args.pretrained, num_classes=num_classes |
371 | | - ) |
372 | | - if args.checkpoint_path is not None: |
373 | | - load_model(args.checkpoint_path, model, strict=True) |
374 | | - else: |
375 | | - raise ValueError( |
376 | | - f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models" |
| 370 | + local_rank = args.rank if args.distributed else None |
| 371 | + model = _create_model( |
| 372 | + arch_key=args.arch_key, |
| 373 | + local_rank=local_rank, |
| 374 | + pretrained=args.pretrained, |
| 375 | + checkpoint_path=args.checkpoint_path, |
| 376 | + pretrained_dataset=args.pretrained_dataset, |
| 377 | + device=device, |
| 378 | + num_classes=num_classes, |
| 379 | + ) |
| 380 | + |
| 381 | + if args.distill_teacher not in ["self", "disable", None]: |
| 382 | + _LOGGER.info("Instantiating teacher") |
| 383 | + args.distill_teacher = _create_model( |
| 384 | + arch_key=args.teacher_arch_key, |
| 385 | + local_rank=local_rank, |
| 386 | + pretrained=True, # teacher is always pretrained |
| 387 | + pretrained_dataset=args.pretrained_teacher_dataset, |
| 388 | + checkpoint_path=args.distill_teacher, |
| 389 | + device=device, |
| 390 | + num_classes=num_classes, |
377 | 391 | ) |
378 | | - model.to(device) |
379 | 392 |
|
380 | 393 | if args.distributed and args.sync_bn: |
381 | 394 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
@@ -549,7 +562,12 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i |
549 | 562 | ) |
550 | 563 |
|
551 | 564 | if manager is not None: |
552 | | - manager.initialize(model, epoch=args.start_epoch, loggers=logger) |
| 565 | + manager.initialize( |
| 566 | + model, |
| 567 | + epoch=args.start_epoch, |
| 568 | + loggers=logger, |
| 569 | + distillation_teacher=args.distill_teacher, |
| 570 | + ) |
553 | 571 | optimizer = manager.modify( |
554 | 572 | model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch |
555 | 573 | ) |
@@ -586,6 +604,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i |
586 | 604 | epoch, |
587 | 605 | args, |
588 | 606 | log_metrics, |
| 607 | + manager=manager, |
589 | 608 | model_ema=model_ema, |
590 | 609 | scaler=scaler, |
591 | 610 | ) |
@@ -658,6 +677,39 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i |
658 | 677 | _LOGGER.info(f"Training time {total_time_str}") |
659 | 678 |
|
660 | 679 |
|
| 680 | +def _create_model( |
| 681 | + arch_key: Optional[str] = None, |
| 682 | + local_rank=None, |
| 683 | + pretrained: Optional[bool] = False, |
| 684 | + checkpoint_path: Optional[str] = None, |
| 685 | + pretrained_dataset: Optional[str] = None, |
| 686 | + device=None, |
| 687 | + num_classes=None, |
| 688 | +): |
| 689 | + if arch_key in ModelRegistry.available_keys(): |
| 690 | + with torch_distributed_zero_first(local_rank): |
| 691 | + model = ModelRegistry.create( |
| 692 | + key=arch_key, |
| 693 | + pretrained=pretrained, |
| 694 | + pretrained_path=checkpoint_path, |
| 695 | + pretrained_dataset=pretrained_dataset, |
| 696 | + num_classes=num_classes, |
| 697 | + ) |
| 698 | + elif arch_key in torchvision.models.__dict__: |
| 699 | + # fall back to torchvision |
| 700 | + model = torchvision.models.__dict__[arch_key]( |
| 701 | + pretrained=pretrained, num_classes=num_classes |
| 702 | + ) |
| 703 | + if checkpoint_path is not None: |
| 704 | + load_model(checkpoint_path, model, strict=True) |
| 705 | + else: |
| 706 | + raise ValueError( |
| 707 | + f"Unable to find {arch_key} in ModelRegistry or in torchvision.models" |
| 708 | + ) |
| 709 | + model.to(device) |
| 710 | + return model |
| 711 | + |
| 712 | + |
661 | 713 | def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None): |
662 | 714 | lr_scheduler = None |
663 | 715 |
|
@@ -1040,6 +1092,34 @@ def new_func(*args, **kwargs): |
1040 | 1092 | help="Save the best validation result after the given " |
1041 | 1093 | "epoch completes until the end of training", |
1042 | 1094 | ) |
| 1095 | +@click.option( |
| 1096 | + "--distill-teacher", |
| 1097 | + default=None, |
| 1098 | + type=str, |
| 1099 | + help="Teacher model for distillation (a trained image classification model)" |
| 1100 | + " can be set to 'self' for self-distillation and 'disable' to switch-off" |
| 1101 | + " distillation, additionally can also take in a SparseZoo stub", |
| 1102 | +) |
| 1103 | +@click.option( |
| 1104 | + "--pretrained-teacher-dataset", |
| 1105 | + default=None, |
| 1106 | + type=str, |
| 1107 | + help=( |
| 1108 | + "The dataset to load pretrained weights for the teacher" |
| 1109 | + "Load the default dataset for the architecture if set to None. " |
| 1110 | + "examples:`imagenet`, `cifar10`, etc..." |
| 1111 | + ), |
| 1112 | +) |
| 1113 | +@click.option( |
| 1114 | + "--teacher-arch-key", |
| 1115 | + default=None, |
| 1116 | + type=str, |
| 1117 | + help=( |
| 1118 | + "The architecture key for teacher image classification model; " |
| 1119 | + "example: `resnet50`, `mobilenet`. " |
| 1120 | + "Note: Will be read from the checkpoint if not specified" |
| 1121 | + ), |
| 1122 | +) |
1043 | 1123 | @click.pass_context |
1044 | 1124 | def cli(ctx, **kwargs): |
1045 | 1125 | """ |
|
0 commit comments