Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit eb898d0

Browse files
rahul-tulicorey-nm
authored andcommitted
Distillation support for torchvision script (#1310)
* Add support for `self` distillation and `disable` * Pull out model creation into a method * Add support to distill with another model * Add modifier loss update before backward pass * bugfix, set loss * Update src/sparseml/pytorch/torchvision/train.py Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
1 parent a83a67a commit eb898d0

File tree

1 file changed

+103
-23
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+103
-23
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 103 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import warnings
2424
from functools import update_wrapper
2525
from types import SimpleNamespace
26-
from typing import Callable
26+
from typing import Callable, Optional
2727

2828
import torch
2929
import torch.utils.data
@@ -64,6 +64,7 @@ def train_one_epoch(
6464
epoch: int,
6565
args,
6666
log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None],
67+
manager=None,
6768
model_ema=None,
6869
scaler=None,
6970
) -> utils.MetricLogger:
@@ -92,13 +93,24 @@ def train_one_epoch(
9293
start_time = time.time()
9394
image, target = image.to(device), target.to(device)
9495
with torch.cuda.amp.autocast(enabled=scaler is not None):
95-
output = model(image)
96+
outputs = output = model(image)
9697
if isinstance(output, tuple):
9798
# NOTE: sparseml models return two things (logits & probs)
9899
output = output[0]
99100
loss = criterion(output, target)
100101

101102
if steps_accumulated % accum_steps == 0:
103+
if manager is not None:
104+
loss = manager.loss_update(
105+
loss=loss,
106+
module=model,
107+
optimizer=optimizer,
108+
epoch=epoch,
109+
steps_per_epoch=len(data_loader) / accum_steps,
110+
student_outputs=outputs,
111+
student_inputs=image,
112+
)
113+
102114
# first: do training to consume gradients
103115
if scaler is not None:
104116
scaler.scale(loss).backward()
@@ -355,27 +367,28 @@ def collate_fn(batch):
355367
)
356368

357369
_LOGGER.info("Creating model")
358-
if args.arch_key in ModelRegistry.available_keys():
359-
with torch_distributed_zero_first(args.rank if args.distributed else None):
360-
model = ModelRegistry.create(
361-
key=args.arch_key,
362-
pretrained=args.pretrained,
363-
pretrained_path=args.checkpoint_path,
364-
pretrained_dataset=args.pretrained_dataset,
365-
num_classes=num_classes,
366-
)
367-
elif args.arch_key in torchvision.models.__dict__:
368-
# fall back to torchvision
369-
model = torchvision.models.__dict__[args.arch_key](
370-
pretrained=args.pretrained, num_classes=num_classes
371-
)
372-
if args.checkpoint_path is not None:
373-
load_model(args.checkpoint_path, model, strict=True)
374-
else:
375-
raise ValueError(
376-
f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models"
370+
local_rank = args.rank if args.distributed else None
371+
model = _create_model(
372+
arch_key=args.arch_key,
373+
local_rank=local_rank,
374+
pretrained=args.pretrained,
375+
checkpoint_path=args.checkpoint_path,
376+
pretrained_dataset=args.pretrained_dataset,
377+
device=device,
378+
num_classes=num_classes,
379+
)
380+
381+
if args.distill_teacher not in ["self", "disable", None]:
382+
_LOGGER.info("Instantiating teacher")
383+
args.distill_teacher = _create_model(
384+
arch_key=args.teacher_arch_key,
385+
local_rank=local_rank,
386+
pretrained=True, # teacher is always pretrained
387+
pretrained_dataset=args.pretrained_teacher_dataset,
388+
checkpoint_path=args.distill_teacher,
389+
device=device,
390+
num_classes=num_classes,
377391
)
378-
model.to(device)
379392

380393
if args.distributed and args.sync_bn:
381394
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -549,7 +562,12 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
549562
)
550563

551564
if manager is not None:
552-
manager.initialize(model, epoch=args.start_epoch, loggers=logger)
565+
manager.initialize(
566+
model,
567+
epoch=args.start_epoch,
568+
loggers=logger,
569+
distillation_teacher=args.distill_teacher,
570+
)
553571
optimizer = manager.modify(
554572
model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch
555573
)
@@ -586,6 +604,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
586604
epoch,
587605
args,
588606
log_metrics,
607+
manager=manager,
589608
model_ema=model_ema,
590609
scaler=scaler,
591610
)
@@ -658,6 +677,39 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
658677
_LOGGER.info(f"Training time {total_time_str}")
659678

660679

680+
def _create_model(
681+
arch_key: Optional[str] = None,
682+
local_rank=None,
683+
pretrained: Optional[bool] = False,
684+
checkpoint_path: Optional[str] = None,
685+
pretrained_dataset: Optional[str] = None,
686+
device=None,
687+
num_classes=None,
688+
):
689+
if arch_key in ModelRegistry.available_keys():
690+
with torch_distributed_zero_first(local_rank):
691+
model = ModelRegistry.create(
692+
key=arch_key,
693+
pretrained=pretrained,
694+
pretrained_path=checkpoint_path,
695+
pretrained_dataset=pretrained_dataset,
696+
num_classes=num_classes,
697+
)
698+
elif arch_key in torchvision.models.__dict__:
699+
# fall back to torchvision
700+
model = torchvision.models.__dict__[arch_key](
701+
pretrained=pretrained, num_classes=num_classes
702+
)
703+
if checkpoint_path is not None:
704+
load_model(checkpoint_path, model, strict=True)
705+
else:
706+
raise ValueError(
707+
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
708+
)
709+
model.to(device)
710+
return model
711+
712+
661713
def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
662714
lr_scheduler = None
663715

@@ -1040,6 +1092,34 @@ def new_func(*args, **kwargs):
10401092
help="Save the best validation result after the given "
10411093
"epoch completes until the end of training",
10421094
)
1095+
@click.option(
1096+
"--distill-teacher",
1097+
default=None,
1098+
type=str,
1099+
help="Teacher model for distillation (a trained image classification model)"
1100+
" can be set to 'self' for self-distillation and 'disable' to switch-off"
1101+
" distillation, additionally can also take in a SparseZoo stub",
1102+
)
1103+
@click.option(
1104+
"--pretrained-teacher-dataset",
1105+
default=None,
1106+
type=str,
1107+
help=(
1108+
"The dataset to load pretrained weights for the teacher"
1109+
"Load the default dataset for the architecture if set to None. "
1110+
"examples:`imagenet`, `cifar10`, etc..."
1111+
),
1112+
)
1113+
@click.option(
1114+
"--teacher-arch-key",
1115+
default=None,
1116+
type=str,
1117+
help=(
1118+
"The architecture key for teacher image classification model; "
1119+
"example: `resnet50`, `mobilenet`. "
1120+
"Note: Will be read from the checkpoint if not specified"
1121+
),
1122+
)
10431123
@click.pass_context
10441124
def cli(ctx, **kwargs):
10451125
"""

0 commit comments

Comments
 (0)