Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 4bf5d02

Browse files
authored
Created base class for knowledge distillation (#854)
* Created base class for knowledge distillation * Quality and style fixes * Moved kl computing methods out of the base class * Update modifier_distillation.py * Style and quality fixes
1 parent 6366d07 commit 4bf5d02

File tree

2 files changed

+417
-305
lines changed

2 files changed

+417
-305
lines changed

src/sparseml/pytorch/sparsification/distillation/modifier_distillation.py

Lines changed: 26 additions & 305 deletions
Original file line numberDiff line numberDiff line change
@@ -18,36 +18,25 @@
1818

1919

2020
import logging
21-
from collections.abc import Mapping
22-
from copy import deepcopy
23-
from typing import Any, Dict, Iterable, List, Optional, Union
21+
from typing import Any, List
2422

25-
import torch
26-
import torch.nn.functional as TF
27-
from torch import Tensor
28-
from torch.nn import Module
29-
from torch.optim import Optimizer
30-
31-
from sparseml.optim import BaseModifier, ModifierProp
32-
from sparseml.pytorch.sparsification.modifier import (
33-
PyTorchModifierYAML,
34-
ScheduledModifier,
35-
ScheduledUpdateModifier,
23+
from sparseml.optim import ModifierProp
24+
from sparseml.pytorch.sparsification.distillation.modifier_distillation_base import (
25+
BaseDistillationModifier,
26+
kldiv_loss,
3627
)
37-
from sparseml.pytorch.utils import BaseLogger, device_of, tensors_module_forward
38-
from sparseml.sparsification import SparsificationTypes
28+
from sparseml.pytorch.sparsification.modifier import PyTorchModifierYAML
3929

4030

4131
__all__ = [
4232
"DistillationModifier",
4333
]
4434

45-
4635
_LOGGER = logging.getLogger(__name__)
4736

4837

4938
@PyTorchModifierYAML()
50-
class DistillationModifier(ScheduledUpdateModifier):
39+
class DistillationModifier(BaseDistillationModifier):
5140
"""
5241
Adds a knowledge distillation loss based on a teacher model during the
5342
loss_update phase of the SparseML lifecycle. A distillation_teacher
@@ -64,50 +53,39 @@ class DistillationModifier(ScheduledUpdateModifier):
6453
| distill_output_keys: [0]
6554
6655
:param start_epoch: The epoch to start the modifier at
67-
:param hardness: how much to weight the distillation loss vs the base loss
68-
(e.g. hardness of 0.6 will return 0.6 * distill_loss + 0.4 * base_loss).
69-
Default is 0.5
70-
:param temperature: temperature applied to teacher and student softmax for
71-
distillation
56+
:param end_epoch: The epoch to end the modifier at
7257
:param distill_output_keys: list of keys for the module outputs to use for
7358
distillation if multiple outputs are present. None or empty list defaults
7459
to using all available outputs
7560
:param teacher_input_keys: list of keys to filter the inputs by before
7661
passing into the teacher. None or empty list defaults to using
7762
all available inputs
63+
:param hardness: how much to weight the distillation loss vs the base loss
64+
(e.g. hardness of 0.6 will return 0.6 * distill_loss + 0.4 * base_loss).
65+
Default is 0.5
66+
:param temperature: temperature applied to teacher and student softmax for
67+
distillation
7868
"""
7969

8070
def __init__(
8171
self,
8272
start_epoch: float = -1.0,
8373
end_epoch: float = -1.0,
84-
hardness: float = 0.5,
85-
temperature: float = 2.0,
8674
distill_output_keys: List[Any] = None,
8775
teacher_input_keys: List[Any] = None,
8876
update_frequency: float = -1.0,
77+
hardness: float = 0.5,
78+
temperature: float = 2.0,
8979
):
9080
super().__init__(
9181
start_epoch=start_epoch,
9282
end_epoch=end_epoch,
93-
end_comparator=-1,
83+
distill_output_keys=distill_output_keys,
84+
teacher_input_keys=teacher_input_keys,
85+
update_frequency=update_frequency,
9486
)
9587
self._hardness = hardness
9688
self._temperature = temperature
97-
self._distill_output_keys = distill_output_keys
98-
self._teacher_input_keys = teacher_input_keys
99-
100-
self._teacher = None
101-
self._distillation_enabled = False
102-
103-
self._logged_loss_terms = {}
104-
105-
@BaseModifier.sparsification_types.getter
106-
def sparsification_types(self) -> List[SparsificationTypes]:
107-
"""
108-
:return: the sparsification types this modifier instance will apply
109-
"""
110-
return [SparsificationTypes.distillation]
11189

11290
@ModifierProp()
11391
def hardness(self) -> float:
@@ -141,270 +119,13 @@ def temperature(self, value: float):
141119
"""
142120
self._temperature = value
143121

144-
@ModifierProp()
145-
def distill_output_keys(self) -> Optional[List[Any]]:
146-
"""
147-
:return: list of keys for the module outputs to use for
148-
distillation if multiple outputs are present. None or empty list defaults
149-
to using all available outputs
150-
"""
151-
return self._distill_output_keys
152-
153-
@distill_output_keys.setter
154-
def distill_output_keys(self, value: Optional[List[Any]]):
155-
"""
156-
:params value: list of keys for the module outputs to use for
157-
distillation if multiple outputs are present. None or empty list defaults
158-
to using all available outputs
159-
"""
160-
self._distill_output_keys = value
161-
162-
@ModifierProp()
163-
def teacher_input_keys(self) -> Optional[List[Any]]:
164-
"""
165-
:return: list of keys to filter the inputs by before
166-
passing into the teacher. None or empty list defaults to using
167-
all available inputs
168-
"""
169-
return self._teacher_input_keys
170-
171-
@teacher_input_keys.setter
172-
def teacher_input_keys(self, value: Optional[List[Any]]):
173-
"""
174-
:params value: list of keys to filter the inputs by before
175-
passing into the teacher. None or empty list defaults to using
176-
all available inputs
177-
"""
178-
self._teacher_input_keys = value
179-
180-
def initialize(
181-
self,
182-
module: Module,
183-
epoch: float = 0,
184-
loggers: Optional[List[BaseLogger]] = None,
185-
distillation_teacher: Module = "disable",
186-
**kwargs,
187-
):
188-
"""
189-
Store the teacher model for distillation if provided
190-
191-
:param module: the PyTorch model/module to modify
192-
:param epoch: The epoch to initialize the modifier and module at.
193-
Defaults to 0 (start of the training process)
194-
:param loggers: Optional list of loggers to log the modification process to
195-
:param distillation_teacher: teacher module to perform knowledge distillation
196-
with. If not provided, self distillation will be used with a teacher
197-
from a copy of the given module at the start epoch. If given string
198-
"disable" this modifier will not apply distillation of any kind,
199-
even in the active epoch range
200-
:param kwargs: Optional kwargs to support specific arguments
201-
for individual modifiers.
202-
"""
203-
super().initialize(module, epoch, loggers, **kwargs)
204-
205-
if distillation_teacher == "disable":
206-
_LOGGER.warning(
207-
"distillation_teacher set to disable, disabling distillation modifier"
208-
)
209-
self._distillation_enabled = False
210-
elif distillation_teacher == "self":
211-
self._distillation_enabled = True
212-
_LOGGER.info(
213-
"distillation_teacher set to self attention, "
214-
"instantiating self distillation at start_epoch"
215-
)
216-
elif callable(distillation_teacher):
217-
self._teacher = distillation_teacher
218-
self._distillation_enabled = True
219-
_LOGGER.info("distillation modifier using distillation_teacher object")
220-
else:
221-
raise ValueError(
222-
"unrecognized value for distillation_modifier given of "
223-
f"{distillation_teacher}. "
224-
"To disable set to 'disable' and for self attention set to 'self'"
225-
)
226-
self._latest_student_loss = None
227-
self._latest_teacher_loss = None
228-
self._latest_distillation_loss = None
229-
230-
def update_ready(self, epoch: float, steps_per_epoch: int) -> bool:
231-
"""
232-
:param epoch: current epoch and progress within the current epoch
233-
:param steps_per_epoch: number of steps taken within each epoch
234-
(calculate batch number using this and epoch)
235-
:return: True if the modifier is pending an update and update() should be called
236-
"""
237-
if not self._initialized:
238-
raise RuntimeError("modifier must be initialized first")
239-
240-
return self._distillation_enabled and super().update_ready(
241-
epoch, steps_per_epoch
242-
)
243-
244-
@ScheduledModifier.log_call
245-
def loss_update(
246-
self,
247-
loss: Tensor,
248-
module: Module,
249-
optimizer: Optimizer,
250-
epoch: float,
251-
steps_per_epoch: int,
252-
student_outputs: Union[Tensor, Dict, Iterable] = None,
253-
student_inputs: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]] = None,
254-
teacher_inputs: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]] = None,
255-
**kwargs,
256-
) -> Tensor:
257-
"""
258-
Updates the loss with the distillation loss
259-
260-
:param loss: The calculated loss tensor
261-
:param module: module to modify
262-
:param optimizer: optimizer to modify
263-
:param epoch: current epoch and progress within the current epoch
264-
:param steps_per_epoch: number of steps taken within each epoch
265-
(calculate batch number using this and epoch)
266-
:return: loss tensor with knowledge distillation loss added
267-
"""
268-
loss = super().loss_update(
269-
loss, module, optimizer, epoch, steps_per_epoch, **kwargs
122+
def compute_distillation_loss(self, student_outputs, teacher_outputs, **kwargs):
123+
return kldiv_loss(
124+
student_outputs,
125+
teacher_outputs,
126+
self.temperature,
127+
self._distill_output_keys,
270128
)
271-
self._logged_loss_terms["task_loss"] = loss
272-
273-
if not self.update_ready(epoch, steps_per_epoch):
274-
return loss
275-
276-
if student_outputs is None or student_inputs is None:
277-
raise ValueError(
278-
"Student outputs and student inputs are required for "
279-
"distillation loss update"
280-
)
281-
282-
if teacher_inputs is None:
283-
teacher_inputs = (
284-
student_inputs
285-
if not self._teacher_input_keys
286-
else {key: student_inputs[key] for key in self._teacher_input_keys}
287-
)
288-
289-
# copy to keep from updating student's inputs
290-
teacher_inputs = deepcopy(teacher_inputs)
291-
292-
if self._teacher == "self":
293-
_LOGGER.info("Copying current models state for self distillation")
294-
self._teacher = deepcopy(module)
295-
296-
# ensure that teacher model is in eval mode and on correct device
297-
self._teacher.eval()
298-
teacher_device = next(self._teacher.parameters()).device
299-
inputs_device = device_of(teacher_inputs)
300-
301-
if teacher_device != inputs_device:
302-
_LOGGER.info(
303-
f"Teacher device {teacher_device} does not match "
304-
f"inputs device {inputs_device}, moving teacher to correct device"
305-
)
306-
self._teacher.to(inputs_device)
307-
308-
with torch.no_grad():
309-
teacher_outputs = tensors_module_forward(
310-
teacher_inputs, self._teacher, check_feat_lab_inp=False
311-
)
312-
313-
if type(student_outputs) != type(teacher_outputs):
314-
raise ValueError(
315-
f"Student output type of {type(student_outputs)} must match "
316-
f"teacher output type of {type(teacher_outputs)}"
317-
)
318129

319-
teacher_loss = self._kldiv_output_loss(student_outputs, teacher_outputs)
320-
total_loss = ((1.0 - self._hardness) * loss) + (self._hardness * teacher_loss)
321-
self._logged_loss_terms.update(
322-
{"teacher_loss": teacher_loss, "total_loss": total_loss}
323-
)
324-
325-
return total_loss
326-
327-
def log_update(
328-
self,
329-
module: Module,
330-
optimizer: Optimizer,
331-
epoch: float,
332-
steps_per_epoch: int,
333-
):
334-
"""
335-
log the latest set of losses
336-
337-
:param module: module to modify
338-
:param optimizer: optimizer to modify
339-
:param epoch: current epoch and progress within the current epoch
340-
:param steps_per_epoch: number of steps taken within each epoch
341-
(calculate batch number using this and epoch)
342-
"""
343-
super().log_update(module, optimizer, epoch, steps_per_epoch)
344-
345-
self.log_named_scalars(
346-
name_value_pairs=self._logged_loss_terms.items(),
347-
epoch=epoch,
348-
steps_per_epoch=steps_per_epoch,
349-
)
350-
351-
def finalize(
352-
self, module: Optional[Module] = None, reset_loggers: bool = True, **kwargs
353-
):
354-
"""
355-
Cleans up any state and hooks
356-
357-
:param module: The model/module to finalize the modifier for.
358-
Marked optional so state can still be cleaned up on delete,
359-
but generally should always be passed in.
360-
:param reset_loggers: True to remove any currently attached loggers (default),
361-
False to keep the loggers attached.
362-
:param kwargs: Optional kwargs to support specific arguments
363-
for individual modifiers.
364-
"""
365-
super().finalize(module, reset_loggers, **kwargs)
366-
self._teacher = None
367-
self._distillation_enabled = False
368-
369-
def _calc_distill_head_output_loss(
370-
self, student_val: Tensor, teacher_val: Tensor
371-
) -> Tensor:
372-
v = (
373-
TF.kl_div(
374-
input=TF.log_softmax(student_val / self._temperature, dim=-1),
375-
target=TF.log_softmax(teacher_val / self._temperature, dim=-1),
376-
log_target=True,
377-
reduction="sum",
378-
)
379-
* (self._temperature ** 2)
380-
/ (student_val.numel() / student_val.shape[-1])
381-
)
382-
return v
383-
384-
def _kldiv_output_loss(self, student_outputs, teacher_outputs):
385-
# Distillation loss from the head outputs
386-
distill_head_output_losses = []
387-
if isinstance(student_outputs, Tensor):
388-
distill_head_output_losses.append(
389-
self._calc_distill_head_output_loss(student_outputs, teacher_outputs)
390-
)
391-
elif isinstance(student_outputs, Mapping):
392-
for key in self._distill_output_keys or student_outputs:
393-
distill_head_output_losses.append(
394-
self._calc_distill_head_output_loss(
395-
student_outputs[key], teacher_outputs[key]
396-
)
397-
)
398-
elif isinstance(student_outputs, Iterable):
399-
for idx in self._distill_output_keys or range(len(student_outputs)):
400-
distill_head_output_losses.append(
401-
self._calc_distill_head_output_loss(
402-
student_outputs[idx], teacher_outputs[idx]
403-
)
404-
)
405-
kldiv_output_loss = (
406-
sum(distill_head_output_losses) / len(distill_head_output_losses)
407-
if distill_head_output_losses
408-
else 0.0
409-
)
410-
return kldiv_output_loss
130+
def compute_total_loss(self, loss, distillation_loss):
131+
return ((1.0 - self.hardness) * loss) + (self.hardness * distillation_loss)

0 commit comments

Comments
 (0)