1818
1919
2020import logging
21- from collections .abc import Mapping
22- from copy import deepcopy
23- from typing import Any , Dict , Iterable , List , Optional , Union
21+ from typing import Any , List
2422
25- import torch
26- import torch .nn .functional as TF
27- from torch import Tensor
28- from torch .nn import Module
29- from torch .optim import Optimizer
30-
31- from sparseml .optim import BaseModifier , ModifierProp
32- from sparseml .pytorch .sparsification .modifier import (
33- PyTorchModifierYAML ,
34- ScheduledModifier ,
35- ScheduledUpdateModifier ,
23+ from sparseml .optim import ModifierProp
24+ from sparseml .pytorch .sparsification .distillation .modifier_distillation_base import (
25+ BaseDistillationModifier ,
26+ kldiv_loss ,
3627)
37- from sparseml .pytorch .utils import BaseLogger , device_of , tensors_module_forward
38- from sparseml .sparsification import SparsificationTypes
28+ from sparseml .pytorch .sparsification .modifier import PyTorchModifierYAML
3929
4030
4131__all__ = [
4232 "DistillationModifier" ,
4333]
4434
45-
4635_LOGGER = logging .getLogger (__name__ )
4736
4837
4938@PyTorchModifierYAML ()
50- class DistillationModifier (ScheduledUpdateModifier ):
39+ class DistillationModifier (BaseDistillationModifier ):
5140 """
5241 Adds a knowledge distillation loss based on a teacher model during the
5342 loss_update phase of the SparseML lifecycle. A distillation_teacher
@@ -64,50 +53,39 @@ class DistillationModifier(ScheduledUpdateModifier):
6453 | distill_output_keys: [0]
6554
6655 :param start_epoch: The epoch to start the modifier at
67- :param hardness: how much to weight the distillation loss vs the base loss
68- (e.g. hardness of 0.6 will return 0.6 * distill_loss + 0.4 * base_loss).
69- Default is 0.5
70- :param temperature: temperature applied to teacher and student softmax for
71- distillation
56+ :param end_epoch: The epoch to end the modifier at
7257 :param distill_output_keys: list of keys for the module outputs to use for
7358 distillation if multiple outputs are present. None or empty list defaults
7459 to using all available outputs
7560 :param teacher_input_keys: list of keys to filter the inputs by before
7661 passing into the teacher. None or empty list defaults to using
7762 all available inputs
63+ :param hardness: how much to weight the distillation loss vs the base loss
64+ (e.g. hardness of 0.6 will return 0.6 * distill_loss + 0.4 * base_loss).
65+ Default is 0.5
66+ :param temperature: temperature applied to teacher and student softmax for
67+ distillation
7868 """
7969
8070 def __init__ (
8171 self ,
8272 start_epoch : float = - 1.0 ,
8373 end_epoch : float = - 1.0 ,
84- hardness : float = 0.5 ,
85- temperature : float = 2.0 ,
8674 distill_output_keys : List [Any ] = None ,
8775 teacher_input_keys : List [Any ] = None ,
8876 update_frequency : float = - 1.0 ,
77+ hardness : float = 0.5 ,
78+ temperature : float = 2.0 ,
8979 ):
9080 super ().__init__ (
9181 start_epoch = start_epoch ,
9282 end_epoch = end_epoch ,
93- end_comparator = - 1 ,
83+ distill_output_keys = distill_output_keys ,
84+ teacher_input_keys = teacher_input_keys ,
85+ update_frequency = update_frequency ,
9486 )
9587 self ._hardness = hardness
9688 self ._temperature = temperature
97- self ._distill_output_keys = distill_output_keys
98- self ._teacher_input_keys = teacher_input_keys
99-
100- self ._teacher = None
101- self ._distillation_enabled = False
102-
103- self ._logged_loss_terms = {}
104-
105- @BaseModifier .sparsification_types .getter
106- def sparsification_types (self ) -> List [SparsificationTypes ]:
107- """
108- :return: the sparsification types this modifier instance will apply
109- """
110- return [SparsificationTypes .distillation ]
11189
11290 @ModifierProp ()
11391 def hardness (self ) -> float :
@@ -141,270 +119,13 @@ def temperature(self, value: float):
141119 """
142120 self ._temperature = value
143121
144- @ModifierProp ()
145- def distill_output_keys (self ) -> Optional [List [Any ]]:
146- """
147- :return: list of keys for the module outputs to use for
148- distillation if multiple outputs are present. None or empty list defaults
149- to using all available outputs
150- """
151- return self ._distill_output_keys
152-
153- @distill_output_keys .setter
154- def distill_output_keys (self , value : Optional [List [Any ]]):
155- """
156- :params value: list of keys for the module outputs to use for
157- distillation if multiple outputs are present. None or empty list defaults
158- to using all available outputs
159- """
160- self ._distill_output_keys = value
161-
162- @ModifierProp ()
163- def teacher_input_keys (self ) -> Optional [List [Any ]]:
164- """
165- :return: list of keys to filter the inputs by before
166- passing into the teacher. None or empty list defaults to using
167- all available inputs
168- """
169- return self ._teacher_input_keys
170-
171- @teacher_input_keys .setter
172- def teacher_input_keys (self , value : Optional [List [Any ]]):
173- """
174- :params value: list of keys to filter the inputs by before
175- passing into the teacher. None or empty list defaults to using
176- all available inputs
177- """
178- self ._teacher_input_keys = value
179-
180- def initialize (
181- self ,
182- module : Module ,
183- epoch : float = 0 ,
184- loggers : Optional [List [BaseLogger ]] = None ,
185- distillation_teacher : Module = "disable" ,
186- ** kwargs ,
187- ):
188- """
189- Store the teacher model for distillation if provided
190-
191- :param module: the PyTorch model/module to modify
192- :param epoch: The epoch to initialize the modifier and module at.
193- Defaults to 0 (start of the training process)
194- :param loggers: Optional list of loggers to log the modification process to
195- :param distillation_teacher: teacher module to perform knowledge distillation
196- with. If not provided, self distillation will be used with a teacher
197- from a copy of the given module at the start epoch. If given string
198- "disable" this modifier will not apply distillation of any kind,
199- even in the active epoch range
200- :param kwargs: Optional kwargs to support specific arguments
201- for individual modifiers.
202- """
203- super ().initialize (module , epoch , loggers , ** kwargs )
204-
205- if distillation_teacher == "disable" :
206- _LOGGER .warning (
207- "distillation_teacher set to disable, disabling distillation modifier"
208- )
209- self ._distillation_enabled = False
210- elif distillation_teacher == "self" :
211- self ._distillation_enabled = True
212- _LOGGER .info (
213- "distillation_teacher set to self attention, "
214- "instantiating self distillation at start_epoch"
215- )
216- elif callable (distillation_teacher ):
217- self ._teacher = distillation_teacher
218- self ._distillation_enabled = True
219- _LOGGER .info ("distillation modifier using distillation_teacher object" )
220- else :
221- raise ValueError (
222- "unrecognized value for distillation_modifier given of "
223- f"{ distillation_teacher } . "
224- "To disable set to 'disable' and for self attention set to 'self'"
225- )
226- self ._latest_student_loss = None
227- self ._latest_teacher_loss = None
228- self ._latest_distillation_loss = None
229-
230- def update_ready (self , epoch : float , steps_per_epoch : int ) -> bool :
231- """
232- :param epoch: current epoch and progress within the current epoch
233- :param steps_per_epoch: number of steps taken within each epoch
234- (calculate batch number using this and epoch)
235- :return: True if the modifier is pending an update and update() should be called
236- """
237- if not self ._initialized :
238- raise RuntimeError ("modifier must be initialized first" )
239-
240- return self ._distillation_enabled and super ().update_ready (
241- epoch , steps_per_epoch
242- )
243-
244- @ScheduledModifier .log_call
245- def loss_update (
246- self ,
247- loss : Tensor ,
248- module : Module ,
249- optimizer : Optimizer ,
250- epoch : float ,
251- steps_per_epoch : int ,
252- student_outputs : Union [Tensor , Dict , Iterable ] = None ,
253- student_inputs : Union [Tensor , Iterable [Tensor ], Dict [Any , Tensor ]] = None ,
254- teacher_inputs : Union [Tensor , Iterable [Tensor ], Dict [Any , Tensor ]] = None ,
255- ** kwargs ,
256- ) -> Tensor :
257- """
258- Updates the loss with the distillation loss
259-
260- :param loss: The calculated loss tensor
261- :param module: module to modify
262- :param optimizer: optimizer to modify
263- :param epoch: current epoch and progress within the current epoch
264- :param steps_per_epoch: number of steps taken within each epoch
265- (calculate batch number using this and epoch)
266- :return: loss tensor with knowledge distillation loss added
267- """
268- loss = super ().loss_update (
269- loss , module , optimizer , epoch , steps_per_epoch , ** kwargs
122+ def compute_distillation_loss (self , student_outputs , teacher_outputs , ** kwargs ):
123+ return kldiv_loss (
124+ student_outputs ,
125+ teacher_outputs ,
126+ self .temperature ,
127+ self ._distill_output_keys ,
270128 )
271- self ._logged_loss_terms ["task_loss" ] = loss
272-
273- if not self .update_ready (epoch , steps_per_epoch ):
274- return loss
275-
276- if student_outputs is None or student_inputs is None :
277- raise ValueError (
278- "Student outputs and student inputs are required for "
279- "distillation loss update"
280- )
281-
282- if teacher_inputs is None :
283- teacher_inputs = (
284- student_inputs
285- if not self ._teacher_input_keys
286- else {key : student_inputs [key ] for key in self ._teacher_input_keys }
287- )
288-
289- # copy to keep from updating student's inputs
290- teacher_inputs = deepcopy (teacher_inputs )
291-
292- if self ._teacher == "self" :
293- _LOGGER .info ("Copying current models state for self distillation" )
294- self ._teacher = deepcopy (module )
295-
296- # ensure that teacher model is in eval mode and on correct device
297- self ._teacher .eval ()
298- teacher_device = next (self ._teacher .parameters ()).device
299- inputs_device = device_of (teacher_inputs )
300-
301- if teacher_device != inputs_device :
302- _LOGGER .info (
303- f"Teacher device { teacher_device } does not match "
304- f"inputs device { inputs_device } , moving teacher to correct device"
305- )
306- self ._teacher .to (inputs_device )
307-
308- with torch .no_grad ():
309- teacher_outputs = tensors_module_forward (
310- teacher_inputs , self ._teacher , check_feat_lab_inp = False
311- )
312-
313- if type (student_outputs ) != type (teacher_outputs ):
314- raise ValueError (
315- f"Student output type of { type (student_outputs )} must match "
316- f"teacher output type of { type (teacher_outputs )} "
317- )
318129
319- teacher_loss = self ._kldiv_output_loss (student_outputs , teacher_outputs )
320- total_loss = ((1.0 - self ._hardness ) * loss ) + (self ._hardness * teacher_loss )
321- self ._logged_loss_terms .update (
322- {"teacher_loss" : teacher_loss , "total_loss" : total_loss }
323- )
324-
325- return total_loss
326-
327- def log_update (
328- self ,
329- module : Module ,
330- optimizer : Optimizer ,
331- epoch : float ,
332- steps_per_epoch : int ,
333- ):
334- """
335- log the latest set of losses
336-
337- :param module: module to modify
338- :param optimizer: optimizer to modify
339- :param epoch: current epoch and progress within the current epoch
340- :param steps_per_epoch: number of steps taken within each epoch
341- (calculate batch number using this and epoch)
342- """
343- super ().log_update (module , optimizer , epoch , steps_per_epoch )
344-
345- self .log_named_scalars (
346- name_value_pairs = self ._logged_loss_terms .items (),
347- epoch = epoch ,
348- steps_per_epoch = steps_per_epoch ,
349- )
350-
351- def finalize (
352- self , module : Optional [Module ] = None , reset_loggers : bool = True , ** kwargs
353- ):
354- """
355- Cleans up any state and hooks
356-
357- :param module: The model/module to finalize the modifier for.
358- Marked optional so state can still be cleaned up on delete,
359- but generally should always be passed in.
360- :param reset_loggers: True to remove any currently attached loggers (default),
361- False to keep the loggers attached.
362- :param kwargs: Optional kwargs to support specific arguments
363- for individual modifiers.
364- """
365- super ().finalize (module , reset_loggers , ** kwargs )
366- self ._teacher = None
367- self ._distillation_enabled = False
368-
369- def _calc_distill_head_output_loss (
370- self , student_val : Tensor , teacher_val : Tensor
371- ) -> Tensor :
372- v = (
373- TF .kl_div (
374- input = TF .log_softmax (student_val / self ._temperature , dim = - 1 ),
375- target = TF .log_softmax (teacher_val / self ._temperature , dim = - 1 ),
376- log_target = True ,
377- reduction = "sum" ,
378- )
379- * (self ._temperature ** 2 )
380- / (student_val .numel () / student_val .shape [- 1 ])
381- )
382- return v
383-
384- def _kldiv_output_loss (self , student_outputs , teacher_outputs ):
385- # Distillation loss from the head outputs
386- distill_head_output_losses = []
387- if isinstance (student_outputs , Tensor ):
388- distill_head_output_losses .append (
389- self ._calc_distill_head_output_loss (student_outputs , teacher_outputs )
390- )
391- elif isinstance (student_outputs , Mapping ):
392- for key in self ._distill_output_keys or student_outputs :
393- distill_head_output_losses .append (
394- self ._calc_distill_head_output_loss (
395- student_outputs [key ], teacher_outputs [key ]
396- )
397- )
398- elif isinstance (student_outputs , Iterable ):
399- for idx in self ._distill_output_keys or range (len (student_outputs )):
400- distill_head_output_losses .append (
401- self ._calc_distill_head_output_loss (
402- student_outputs [idx ], teacher_outputs [idx ]
403- )
404- )
405- kldiv_output_loss = (
406- sum (distill_head_output_losses ) / len (distill_head_output_losses )
407- if distill_head_output_losses
408- else 0.0
409- )
410- return kldiv_output_loss
130+ def compute_total_loss (self , loss , distillation_loss ):
131+ return ((1.0 - self .hardness ) * loss ) + (self .hardness * distillation_loss )
0 commit comments