1919
2020import logging
2121from copy import deepcopy
22- from typing import Any , Dict , Iterable , List , Optional
22+ from typing import Any , Dict , Iterable , List , Optional , Union
2323
24+ import torch
2425import torch .nn .functional as TF
2526from torch import Tensor
2627from torch .nn import Module
2728from torch .optim import Optimizer
2829
2930from sparseml .optim import ModifierProp
3031from sparseml .pytorch .optim .modifier import PyTorchModifierYAML , ScheduledModifier
31- from sparseml .pytorch .utils import BaseLogger , tensors_module_forward
32+ from sparseml .pytorch .utils import BaseLogger , device_of , tensors_module_forward
3233
3334
3435__all__ = [
@@ -221,6 +222,9 @@ def loss_update(
221222 optimizer : Optimizer ,
222223 epoch : float ,
223224 steps_per_epoch : int ,
225+ student_outputs : Union [Tensor , Dict , Iterable ] = None ,
226+ teacher_inputs : Union [Tensor , Iterable [Tensor ], Dict [Any , Tensor ]] = None ,
227+ ** kwargs ,
224228 ) -> Tensor :
225229 """
226230 Updates the bass loss with the distillation loss
@@ -233,69 +237,57 @@ def loss_update(
233237 (calculate batch number using this and epoch)
234238 :return: loss tensor with knowledge distillation loss added
235239 """
236- loss = super ().loss_update (loss , module , optimizer , epoch , steps_per_epoch )
240+ loss = super ().loss_update (
241+ loss , module , optimizer , epoch , steps_per_epoch , ** kwargs
242+ )
237243
238244 if not self ._distillation_enabled or self ._disable_distillation :
239245 return loss
240246
241- if self . _student_outputs is None or self . _student_inputs is None :
242- raise RuntimeError (
243- "A forward pass of the module must be run before calling loss_update "
244- "with a DistillationModifier "
247+ if student_outputs is None or teacher_inputs is None :
248+ raise ValueError (
249+ "Student outputs and teacher inputs are required for "
250+ "distillation loss update "
245251 )
246252
247253 # ensure that teacher model is in eval mode and on correct device
248254 self ._teacher .eval ()
249- target_device = (
250- self ._student_inputs .device
251- if isinstance (self ._student_inputs , Tensor )
252- else self ._student_inputs [0 ].device
253- if isinstance (self ._student_inputs , Iterable )
254- else [
255- tens .device
256- for tens in self ._student_inputs .values ()
257- if isinstance (tens , Tensor )
258- ][0 ]
259- )
255+ target_device = device_of (teacher_inputs )
260256 self ._teacher .to (target_device )
257+ with torch .no_grad ():
258+ teacher_outputs = tensors_module_forward (
259+ teacher_inputs , self ._teacher , check_feat_lab_inp = False
260+ )
261261
262- teacher_outputs = tensors_module_forward (
263- self ._student_inputs , self ._teacher , check_feat_lab_inp = False
264- )
265-
266- assert type (self ._student_outputs ) == type (
267- teacher_outputs
268- ), "Student and teacher models must have the same output type"
262+ if type (student_outputs ) != type (teacher_outputs ):
263+ raise ValueError (
264+ "Student and teacher models must have the same output type"
265+ )
269266
270267 distill_losses = []
271- if isinstance (self . _student_outputs , Tensor ):
268+ if isinstance (student_outputs , Tensor ):
272269 distill_losses .append (
273- self ._calc_distill_loss (self . _student_outputs , teacher_outputs )
270+ self ._calc_distill_loss (student_outputs , teacher_outputs )
274271 )
275- elif isinstance (self . _student_outputs , Dict ):
276- for key in self ._distill_output_keys or self . _student_outputs :
272+ elif isinstance (student_outputs , Dict ):
273+ for key in self ._distill_output_keys or student_outputs :
277274 distill_losses .append (
278- self ._calc_distill_loss (
279- self ._student_outputs [key ], teacher_outputs [key ]
280- )
275+ self ._calc_distill_loss (student_outputs [key ], teacher_outputs [key ])
281276 )
282- elif isinstance (self . _student_outputs , Iterable ):
283- for idx in self ._distill_output_keys or range (len (self . _student_outputs )):
277+ elif isinstance (student_outputs , Iterable ):
278+ for idx in self ._distill_output_keys or range (len (student_outputs )):
284279 distill_losses .append (
285- self ._calc_distill_loss (
286- self ._student_outputs [idx ], teacher_outputs [idx ]
287- )
280+ self ._calc_distill_loss (student_outputs [idx ], teacher_outputs [idx ])
288281 )
289282
290283 # get distillation loss as average of individual output distillation loss values
291284 teacher_loss = sum (distill_losses ) / len (distill_losses )
292285 distillation_loss = ((1.0 - self ._hardness ) * loss ) + (
293286 self ._hardness * teacher_loss
294287 )
295-
296- _log_losses (
297- self .loggers , epoch , steps_per_epoch , loss , teacher_loss , distillation_loss
298- )
288+ global_step = kwargs .get ("global_step" )
289+ global_step = epoch * steps_per_epoch if global_step is None else global_step
290+ _log_losses (self .loggers , global_step , loss , teacher_loss , distillation_loss )
299291 return distillation_loss
300292
301293 def finalize (
@@ -340,46 +332,22 @@ def _check_distillation_update(
340332 "Using self distillation with copy of the module's current state"
341333 )
342334 self ._teacher = deepcopy (module )
343- self ._set_student_hook (module )
344335 self ._distillation_enabled = True
345336
346337 if self .end_pending (epoch , steps_per_epoch ):
347- self ._disable_student_hook ()
348338 self ._distillation_enabled = False
349339
350- def _set_student_hook (self , module : Module ):
351- # delete hook if already exists
352- self ._disable_student_hook ()
353-
354- def _track_inputs_and_outputs_hook (mod , inputs , outputs ):
355- self ._student_inputs = inputs
356- self ._student_outputs = outputs
357-
358- self ._track_student_hook = module .register_forward_hook (
359- _track_inputs_and_outputs_hook
360- )
361-
362- def _disable_student_hook (self ):
363- if self ._track_student_hook is not None :
364- self ._track_student_hook .remove ()
365- self ._track_student_hook = None
366- self ._student_inputs = None
367- self ._student_outputs = None
368-
369340 def _is_distillation_epoch (self , epoch ):
370341 return self .start_epoch <= epoch < self .end_epoch
371342
372343
373344def _log_losses (
374345 loggers : List [BaseLogger ],
375- epoch : float ,
376- steps_per_epoch : int ,
346+ global_step : int ,
377347 original_loss : float ,
378348 teacher_loss : float ,
379349 distillation_loss : float ,
380350):
381- step = round (epoch ) if steps_per_epoch <= 0 else round (epoch * steps_per_epoch )
382-
383351 losses = {
384352 "original_loss" : original_loss ,
385353 "teacher_loss" : teacher_loss ,
@@ -388,4 +356,4 @@ def _log_losses(
388356
389357 for logger in loggers :
390358 for (name , loss ) in losses .items ():
391- logger .log_scalar (f"DistillationModifier/{ name } " , loss .item (), step )
359+ logger .log_scalar (f"DistillationModifier/{ name } " , loss .item (), global_step )
0 commit comments