Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit b176894

Browse files
authored
One shot MFAC pruning support (#305)
* One shot MFAC prunign support * GradSampler class * using tensors_module_forward * fix extra grad collection issue * perturbation bug fix * updates from review, rebasing onto latest
1 parent 2b84085 commit b176894

File tree

2 files changed

+137
-7
lines changed

2 files changed

+137
-7
lines changed

src/sparseml/pytorch/optim/modifier_pruning.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ScheduledUpdateModifier,
4040
)
4141
from sparseml.pytorch.utils import (
42+
GradSampler,
4243
MFACOptions,
4344
NamedLayerParam,
4445
get_layer,
@@ -255,7 +256,7 @@ def initialize(
255256
)
256257
)
257258

258-
self._check_mask_update(module, epoch, steps_per_epoch=1)
259+
self._check_mask_update(module, epoch, steps_per_epoch=1, **kwargs)
259260

260261
def finalize(
261262
self, module: Optional[Module] = None, reset_loggers: bool = True, **kwargs
@@ -332,7 +333,9 @@ def optimizer_post_step(
332333
self._module_masks.apply()
333334

334335
@abstractmethod
335-
def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int):
336+
def _check_mask_update(
337+
self, module: Module, epoch: float, steps_per_epoch: int, **kwargs
338+
):
336339
raise NotImplementedError()
337340

338341
def _should_log(
@@ -444,7 +447,9 @@ def __init__(
444447
log_types=log_types,
445448
)
446449

447-
def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int):
450+
def _check_mask_update(
451+
self, module: Module, epoch: float, steps_per_epoch: int, **kwargs
452+
):
448453
if self.start_pending(epoch, steps_per_epoch):
449454
self._module_masks.set_param_masks_from_weights()
450455
self._module_masks.enabled = True
@@ -807,7 +812,9 @@ def validate(self):
807812
).format(self._inter_func, INTERPOLATION_FUNCS, self.__class__.__name__)
808813
)
809814

810-
def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int):
815+
def _check_mask_update(
816+
self, module: Module, epoch: float, steps_per_epoch: int, **kwargs
817+
):
811818
"""
812819
Check for updating the pruning masks at the given epoch.
813820
Called from both initialize and update.
@@ -822,8 +829,9 @@ def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int)
822829
self._module_masks.enabled = True
823830
started = True
824831

825-
self._module_masks.pre_optim_step_update()
826-
self._pre_step_completed = True
832+
if not self._pre_step_completed:
833+
self._module_masks.pre_optim_step_update()
834+
self._pre_step_completed = True
827835

828836
if started:
829837
# set the mask tensors according to the new sparsity
@@ -1359,6 +1367,16 @@ def mfac_options(self) -> Dict[str, Any]:
13591367
"""
13601368
return self._mfac_options
13611369

1370+
def _check_mask_update(
1371+
self, module: Module, epoch: float, steps_per_epoch: int, **kwargs
1372+
):
1373+
# create grads for pne-shot pruning
1374+
if "grad_sampler" in kwargs:
1375+
self._collect_grad_samples(module, kwargs["grad_sampler"])
1376+
self._pre_step_completed = True
1377+
1378+
super()._check_mask_update(module, epoch, steps_per_epoch, **kwargs)
1379+
13621380
def _create_pruning_mask(
13631381
self, layers: List[Module], layer_names: List[str], param_names: List[str]
13641382
) -> ModuleParamPruningMask:
@@ -1371,6 +1389,23 @@ def _create_pruning_mask(
13711389
score_type=MFACOptions(**self._mfac_options),
13721390
)
13731391

1392+
def _collect_grad_samples(
1393+
self,
1394+
module: Module,
1395+
grad_sampler: GradSampler,
1396+
):
1397+
if not isinstance(grad_sampler, GradSampler):
1398+
raise ValueError(
1399+
"One-shot MFAC pruning requires a GradSampler object given by the "
1400+
f"grad_sampler kwarg. Given an object of type {type(grad_sampler)}"
1401+
)
1402+
num_grads = MFACOptions(**self._mfac_options).get_num_grads_for_sparsity(
1403+
self._applied_sparsity or 0.0
1404+
)
1405+
1406+
for _ in grad_sampler.iter_module_backwards(module, num_grads):
1407+
self._module_masks.pre_optim_step_update()
1408+
13741409

13751410
@PyTorchModifierYAML()
13761411
class MFACGlobalPruningModifier(MFACPruningModifier):

src/sparseml/pytorch/utils/mfac_helpers.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,29 @@
2121
import threading
2222
from abc import ABC, abstractmethod
2323
from dataclasses import dataclass, field
24-
from typing import Dict, List, Optional, Union
24+
from typing import (
25+
Any,
26+
Callable,
27+
Dict,
28+
Generator,
29+
Iterable,
30+
Iterator,
31+
List,
32+
Optional,
33+
Tuple,
34+
Union,
35+
)
2536

2637
import torch
2738
from torch import Tensor
39+
from torch.nn import Module
2840
from torch.nn.parallel.parallel_apply import parallel_apply
2941

42+
from sparseml.pytorch.utils import tensors_module_forward
43+
3044

3145
__all__ = [
46+
"GradSampler",
3247
"MFACOptions",
3348
"FisherInverse",
3449
"FisherInverseFast",
@@ -38,6 +53,86 @@
3853
]
3954

4055

56+
class GradSampler:
57+
"""
58+
Class for computing gradient samples for a Model given a sample data loader and
59+
loss function.
60+
61+
:param data_loader: iterator of data samples to use as model inputs and their loss
62+
targets. Samples can either be single tensors as model input or a list of
63+
inputs and should be iterated in tuples with their targets
64+
:param loss_fn: function to be called on model outputs to compute the loss at
65+
each step
66+
"""
67+
68+
def __init__(
69+
self,
70+
data_loader: Iterator[Tuple[Union[Tensor, List[Tensor]], Any]],
71+
loss_fn: Callable[[Tensor], Tensor],
72+
):
73+
if not isinstance(data_loader, Iterable):
74+
raise ValueError(
75+
"data_loader for GradSampler must be Iterable, received object of "
76+
f"type {type(data_loader)}"
77+
)
78+
if not callable(loss_fn):
79+
raise ValueError(
80+
"loss_fn for GradSampler must be callable, given input "
81+
f"with type {type(loss_fn)}"
82+
)
83+
84+
self._data_loader = data_loader
85+
self._loss_fn = loss_fn
86+
87+
def module_forward(self, module: Module, data: Union[Tensor, List[Tensor]]) -> Any:
88+
"""
89+
:param module: module to perform forward pass with
90+
:param data: single data sample to pass to module
91+
:return: output(s) of the module forward pass
92+
"""
93+
if isinstance(data, Tensor):
94+
data = [data]
95+
96+
return tensors_module_forward(*data, module)
97+
98+
def module_backward(self, module_outputs: Any, targets: Any):
99+
"""
100+
Computes module loss based on the given module outputs, target data and loss
101+
function
102+
103+
:param module_outputs: outputs of a forward pass from a module
104+
:param targets: target outputs for the module to be used for the loss function
105+
"""
106+
loss = self._loss_fn(module_outputs, targets)
107+
loss.backward()
108+
109+
def iter_module_backwards(
110+
self, module: Module, num_grads: int
111+
) -> Generator[int, None, None]:
112+
"""
113+
114+
:param module: module to compute gradients for
115+
:param num_grads: number of gradient samples to compute
116+
:return: generator that yields after every gradient is computed with the index
117+
of the gradient sample number
118+
"""
119+
computed_grads = 0
120+
121+
while computed_grads < num_grads:
122+
for sample, target in self._data_loader:
123+
# run sample forward and backwards pass
124+
model_outputs = self.module_forward(module, sample)
125+
self.module_backward(model_outputs, target)
126+
127+
# yield so gradients can be collected
128+
computed_grads += 1
129+
yield computed_grads
130+
131+
if computed_grads >= num_grads:
132+
break
133+
module.zero_grad()
134+
135+
41136
@dataclass
42137
class MFACOptions:
43138
"""

0 commit comments

Comments
 (0)