Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 6a47673

Browse files
authored
Revised distillation modifier (#398)
* Revised distillation modifier * Move teacher model's logic back to modifier
1 parent a877698 commit 6a47673

File tree

5 files changed

+75
-84
lines changed

5 files changed

+75
-84
lines changed

src/sparseml/pytorch/optim/manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def loss_update(
455455
optimizer: Optimizer,
456456
epoch: float,
457457
steps_per_epoch: int,
458+
**kwargs,
458459
) -> Tensor:
459460
"""
460461
Optional call that can be made on the optimizer to update the contained
@@ -468,13 +469,15 @@ def loss_update(
468469
(calculate batch number using this and epoch)
469470
:return: the modified loss tensor
470471
"""
471-
super().loss_update(loss, module, optimizer, epoch, steps_per_epoch)
472+
super().loss_update(loss, module, optimizer, epoch, steps_per_epoch, **kwargs)
472473

473474
for mod in self._modifiers:
474475
if not mod.enabled:
475476
continue
476477

477-
loss = mod.loss_update(loss, module, optimizer, epoch, steps_per_epoch)
478+
loss = mod.loss_update(
479+
loss, module, optimizer, epoch, steps_per_epoch, **kwargs
480+
)
478481

479482
return loss
480483

src/sparseml/pytorch/optim/modifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def loss_update(
287287
optimizer: Optimizer,
288288
epoch: float,
289289
steps_per_epoch: int,
290+
**kwargs,
290291
):
291292
"""
292293
Optional call that can be made on the optimizer to update the modifiers

src/sparseml/pytorch/optim/modifier_distillation.py

Lines changed: 35 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919

2020
import logging
2121
from copy import deepcopy
22-
from typing import Any, Dict, Iterable, List, Optional
22+
from typing import Any, Dict, Iterable, List, Optional, Union
2323

24+
import torch
2425
import torch.nn.functional as TF
2526
from torch import Tensor
2627
from torch.nn import Module
2728
from torch.optim import Optimizer
2829

2930
from sparseml.optim import ModifierProp
3031
from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier
31-
from sparseml.pytorch.utils import BaseLogger, tensors_module_forward
32+
from sparseml.pytorch.utils import BaseLogger, device_of, tensors_module_forward
3233

3334

3435
__all__ = [
@@ -221,6 +222,9 @@ def loss_update(
221222
optimizer: Optimizer,
222223
epoch: float,
223224
steps_per_epoch: int,
225+
student_outputs: Union[Tensor, Dict, Iterable] = None,
226+
teacher_inputs: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]] = None,
227+
**kwargs,
224228
) -> Tensor:
225229
"""
226230
Updates the bass loss with the distillation loss
@@ -233,69 +237,57 @@ def loss_update(
233237
(calculate batch number using this and epoch)
234238
:return: loss tensor with knowledge distillation loss added
235239
"""
236-
loss = super().loss_update(loss, module, optimizer, epoch, steps_per_epoch)
240+
loss = super().loss_update(
241+
loss, module, optimizer, epoch, steps_per_epoch, **kwargs
242+
)
237243

238244
if not self._distillation_enabled or self._disable_distillation:
239245
return loss
240246

241-
if self._student_outputs is None or self._student_inputs is None:
242-
raise RuntimeError(
243-
"A forward pass of the module must be run before calling loss_update "
244-
"with a DistillationModifier"
247+
if student_outputs is None or teacher_inputs is None:
248+
raise ValueError(
249+
"Student outputs and teacher inputs are required for "
250+
"distillation loss update"
245251
)
246252

247253
# ensure that teacher model is in eval mode and on correct device
248254
self._teacher.eval()
249-
target_device = (
250-
self._student_inputs.device
251-
if isinstance(self._student_inputs, Tensor)
252-
else self._student_inputs[0].device
253-
if isinstance(self._student_inputs, Iterable)
254-
else [
255-
tens.device
256-
for tens in self._student_inputs.values()
257-
if isinstance(tens, Tensor)
258-
][0]
259-
)
255+
target_device = device_of(teacher_inputs)
260256
self._teacher.to(target_device)
257+
with torch.no_grad():
258+
teacher_outputs = tensors_module_forward(
259+
teacher_inputs, self._teacher, check_feat_lab_inp=False
260+
)
261261

262-
teacher_outputs = tensors_module_forward(
263-
self._student_inputs, self._teacher, check_feat_lab_inp=False
264-
)
265-
266-
assert type(self._student_outputs) == type(
267-
teacher_outputs
268-
), "Student and teacher models must have the same output type"
262+
if type(student_outputs) != type(teacher_outputs):
263+
raise ValueError(
264+
"Student and teacher models must have the same output type"
265+
)
269266

270267
distill_losses = []
271-
if isinstance(self._student_outputs, Tensor):
268+
if isinstance(student_outputs, Tensor):
272269
distill_losses.append(
273-
self._calc_distill_loss(self._student_outputs, teacher_outputs)
270+
self._calc_distill_loss(student_outputs, teacher_outputs)
274271
)
275-
elif isinstance(self._student_outputs, Dict):
276-
for key in self._distill_output_keys or self._student_outputs:
272+
elif isinstance(student_outputs, Dict):
273+
for key in self._distill_output_keys or student_outputs:
277274
distill_losses.append(
278-
self._calc_distill_loss(
279-
self._student_outputs[key], teacher_outputs[key]
280-
)
275+
self._calc_distill_loss(student_outputs[key], teacher_outputs[key])
281276
)
282-
elif isinstance(self._student_outputs, Iterable):
283-
for idx in self._distill_output_keys or range(len(self._student_outputs)):
277+
elif isinstance(student_outputs, Iterable):
278+
for idx in self._distill_output_keys or range(len(student_outputs)):
284279
distill_losses.append(
285-
self._calc_distill_loss(
286-
self._student_outputs[idx], teacher_outputs[idx]
287-
)
280+
self._calc_distill_loss(student_outputs[idx], teacher_outputs[idx])
288281
)
289282

290283
# get distillation loss as average of individual output distillation loss values
291284
teacher_loss = sum(distill_losses) / len(distill_losses)
292285
distillation_loss = ((1.0 - self._hardness) * loss) + (
293286
self._hardness * teacher_loss
294287
)
295-
296-
_log_losses(
297-
self.loggers, epoch, steps_per_epoch, loss, teacher_loss, distillation_loss
298-
)
288+
global_step = kwargs.get("global_step")
289+
global_step = epoch * steps_per_epoch if global_step is None else global_step
290+
_log_losses(self.loggers, global_step, loss, teacher_loss, distillation_loss)
299291
return distillation_loss
300292

301293
def finalize(
@@ -340,46 +332,22 @@ def _check_distillation_update(
340332
"Using self distillation with copy of the module's current state"
341333
)
342334
self._teacher = deepcopy(module)
343-
self._set_student_hook(module)
344335
self._distillation_enabled = True
345336

346337
if self.end_pending(epoch, steps_per_epoch):
347-
self._disable_student_hook()
348338
self._distillation_enabled = False
349339

350-
def _set_student_hook(self, module: Module):
351-
# delete hook if already exists
352-
self._disable_student_hook()
353-
354-
def _track_inputs_and_outputs_hook(mod, inputs, outputs):
355-
self._student_inputs = inputs
356-
self._student_outputs = outputs
357-
358-
self._track_student_hook = module.register_forward_hook(
359-
_track_inputs_and_outputs_hook
360-
)
361-
362-
def _disable_student_hook(self):
363-
if self._track_student_hook is not None:
364-
self._track_student_hook.remove()
365-
self._track_student_hook = None
366-
self._student_inputs = None
367-
self._student_outputs = None
368-
369340
def _is_distillation_epoch(self, epoch):
370341
return self.start_epoch <= epoch < self.end_epoch
371342

372343

373344
def _log_losses(
374345
loggers: List[BaseLogger],
375-
epoch: float,
376-
steps_per_epoch: int,
346+
global_step: int,
377347
original_loss: float,
378348
teacher_loss: float,
379349
distillation_loss: float,
380350
):
381-
step = round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch)
382-
383351
losses = {
384352
"original_loss": original_loss,
385353
"teacher_loss": teacher_loss,
@@ -388,4 +356,4 @@ def _log_losses(
388356

389357
for logger in loggers:
390358
for (name, loss) in losses.items():
391-
logger.log_scalar(f"DistillationModifier/{name}", loss.item(), step)
359+
logger.log_scalar(f"DistillationModifier/{name}", loss.item(), global_step)

src/sparseml/pytorch/utils/helpers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
__all__ = [
4444
"default_device",
45+
"device_of",
4546
"get_optim_learning_rate",
4647
"get_optim_groups_learning_rates",
4748
"set_optim_learning_rate",
@@ -98,6 +99,19 @@ def default_device() -> str:
9899
return "cuda:{}".format(",".join(device_ids))
99100

100101

102+
def device_of(inputs: Any):
103+
if isinstance(inputs, Tensor):
104+
return inputs.device
105+
elif isinstance(inputs, Dict):
106+
for tens in inputs.values():
107+
return device_of(tens)
108+
elif isinstance(inputs, Iterable):
109+
return device_of(inputs[0])
110+
else:
111+
raise RuntimeError("Unknown type of inputs to device_of function")
112+
return default_device()
113+
114+
101115
##############################
102116
#
103117
# pytorch optim helpers

tests/sparseml/pytorch/optim/test_modifier_distillation.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,18 @@ def test_lifecycle(
7070

7171
# test distillation has been applied
7272
# fake forward pass
73-
fake_loss = model(self._get_fake_batch(model_lambda)).mean()
73+
student_inputs = self._get_fake_batch(model_lambda)
74+
student_outputs = model(student_inputs)
75+
teacher_outputs = student_outputs + 0.5 # fake teacher model's outputs
76+
fake_loss = student_outputs.mean()
7477
updated_loss = modifier.loss_update(
75-
fake_loss, model, optimizer, -1, test_steps_per_epoch
78+
fake_loss,
79+
model,
80+
optimizer,
81+
-1,
82+
test_steps_per_epoch,
83+
student_outputs,
84+
teacher_outputs,
7685
)
7786

7887
assert isinstance(updated_loss, torch.Tensor)
@@ -98,23 +107,19 @@ def test_loss_update(
98107
model = model_lambda()
99108
optimizer = optim_lambda(model)
100109

101-
with pytest.raises(RuntimeError):
102-
modifier.loss_update(
103-
test_loss, model, optimizer, test_epoch, test_steps_per_epoch
104-
)
105-
106110
self.initialize_helper(modifier, model)
107111

108-
# should fail until a forward pass is run
109-
with pytest.raises(RuntimeError):
110-
modifier.loss_update(
111-
test_loss, model, optimizer, test_epoch, test_steps_per_epoch
112-
)
113-
114112
# run fake forward pass and try updating the loss
115-
_ = model(self._get_fake_batch(model_lambda))
113+
inputs = self._get_fake_batch(model_lambda)
114+
student_outputs = model(inputs)
116115
new_loss = modifier.loss_update(
117-
test_loss, model, optimizer, test_epoch, test_steps_per_epoch
116+
test_loss,
117+
model,
118+
optimizer,
119+
test_epoch,
120+
test_steps_per_epoch,
121+
student_outputs,
122+
inputs,
118123
)
119124

120125
assert isinstance(new_loss, Tensor)

0 commit comments

Comments
 (0)