Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 28ec07f

Browse files
authored
[Cherry Pick] M-FAC fixes (#586)
* Fix mfac gradcheck (#573) * Fix: proper handeling of dictionary type num_grads for MFAC (#579) * M-FAC Indexing fix (#583) * Update: docs to use correct mfac options format (#584)
1 parent e18b8d3 commit 28ec07f

File tree

6 files changed

+45
-39
lines changed

6 files changed

+45
-39
lines changed

research/mfac/README.md

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ techniques on a variety of one-shot and gradual pruning tasks.
3131
SparseML makes it easy to use the M-FAC pruning algorithm as part of sparsification
3232
recipes to improve pruning recovery by providing an `MFACPruningModifier`.
3333
The `MFACPruningModifier` contains the same settings as the magnitude
34-
pruning modifiers and contains extra settings for the M-FAC algorithm under the
35-
`mfac_options` parameter. `mfac_options` should be provided as a YAML dictionary and
36-
details of the main options are provided below.
34+
pruning modifiers and contains extra settings for the M-FAC algorithm including
35+
`num_grads`, `fisher_block_size`, and `available_gpus`. Ideal values will depend
36+
on the system available to run on and model to be pruned.
3737

3838
### Example M-FAC Recipe
3939
The following is an example `MFACPruningModifier` to be used in place of other
@@ -48,17 +48,11 @@ pruning_modifiers:
4848
start_epoch: 1.0
4949
end_epoch: 61.0
5050
update_frequency: 4.0
51-
mfac_options:
52-
num_grads: {0.0: 256, 0.5: 512, 0.75: 1024, 0.83: 1400}
53-
fisher_block_size: 10000
54-
available_gpus: ["cuda:0"]
51+
num_grads: {0.0: 256, 0.5: 512, 0.75: 1024, 0.83: 1400}
52+
fisher_block_size: 10000
53+
available_gpus: ["cuda:0"]
5554
```
5655
57-
### mfac_options Parameters
58-
The following parameters can be specified under the `mfac_options` parameter to control
59-
how the M-FAC calculations are made. Ideal values will depend on the system
60-
available to run on and model to be pruned.
61-
6256
#### num_grads
6357
To approximate the second order information in the M-FAC algorithm, first order
6458
gradients are used. `num_grads` specifies the number of recent gradient samples to store

research/mfac/recipes/pruning-mnistnet-one_shot-mfac.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@ pruning_modifiers:
2323
start_epoch: 0.0
2424
end_epoch: 1.0
2525
update_frequency: 1.0
26-
mfac_options:
27-
num_grads: 512
28-
fisher_block_size: 2000
26+
num_grads: 512
27+
fisher_block_size: 2000
2928
---
3029

3130
# Pruning MNISTNet with M-FAC
3231
This recipe prunes a model to 35% sparsity using the M-FAC pruning algorithm.
3332
It is intended for use with MNISTNet but could be used to prune other models
34-
in one shot, however the `final_sparsity` and `mfac_options` should be adjusted
35-
accordingly.
33+
in one shot, however the `final_sparsity`, `num_grads`, and `fisher_block_size`
34+
should be adjusted accordingly.

research/mfac/recipes/pruning-mobilenet-imagenette-mfac-short-95.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,9 @@ pruning_modifiers:
5050
end_epoch: *pruning_end_epoch
5151
update_frequency: *pruning_update_frequency
5252
mask_type: *pruning_mask_type
53-
mfac_options:
54-
num_grads: 256
55-
fisher_block_size: 2000
56-
available_gpus: ["cuda:0"]
53+
num_grads: 256
54+
fisher_block_size: 2000
55+
available_gpus: ["cuda:0"]
5756

5857
- !SetWeightDecayModifier
5958
weight_decay: 0.0

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_mfac.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def __init__(
404404
self._grad_buffer = None # type: Tensor
405405
self._grads = None # placeholder for all grads across buffers
406406
self._buffer_idx = 0
407+
self._grads_collected = 0
407408
self._latest_h_inv_diag = None # type: tuple
408409

409410
# scale num_grads by number of DDP processes
@@ -434,13 +435,13 @@ def score_parameters(self) -> List[Tensor]:
434435
H^-1, scores will be W^2 / (2 * diag(H^-1))
435436
"""
436437

437-
if self._grad_buffer is None or torch.any(
438-
torch.all(self._grad_buffer == 0.0, dim=1)
438+
if self._grads_collected < _get_num_grads_for_sparsity(
439+
self._num_grads, self._last_applied_sparsity
439440
):
440441
# raise Exception if grad buffer is not full
441442
raise RuntimeError(
442-
"MFAC pruning step called, but not enough gradient samples have been "
443-
f"collected. Expected {self._num_grads} samples"
443+
f"MFAC pruning step called, but only {self._grads_collected} were "
444+
f"collected from the expected {self._num_grads}."
444445
)
445446

446447
if self._is_ddp:
@@ -519,6 +520,7 @@ def pre_optim_step_update(self, masks: List[Tensor]):
519520
# update buffer idx
520521
self._buffer_idx += 1
521522
self._buffer_idx %= self._grad_buffer.size(0)
523+
self._grads_collected += 1
522524

523525
@torch.no_grad()
524526
def mask_update(self, masks: List[Tensor], mask_diffs: List[Tensor]):
@@ -635,6 +637,7 @@ def _setup_grad_buffer(self, masks: Tensor):
635637
device="cpu",
636638
)
637639
self._buffer_idx = 0
640+
self._grads_collected = 0
638641

639642

640643
"""
@@ -1260,25 +1263,26 @@ def mul_blocked(self, x: Tensor, call_idx: int, device: str) -> Tensor:
12601263

12611264
# Get the H^-1 values corresponding to the number of blocks used here.
12621265
# It's clunky compared to torch.cat()[idx], but avoids duplicating
1263-
# the memory of H^-1
1264-
start_block = sum(self._num_blocks_per_device_call[:call_idx])
1265-
end_block = sum(self._num_blocks_per_device_call[: call_idx + 1])
1266+
# the memory of H^-1. Most of the logic deals with indexing into a list of
1267+
# tensors as one continuous tensor, to grab slices that may span separate
1268+
# tensors in the list
1269+
block_start = sum(self._num_blocks_per_device_call[:call_idx])
1270+
block_end = sum(self._num_blocks_per_device_call[: call_idx + 1])
12661271
t_hinv = []
1267-
tensor_start = 0
1268-
tensor_end = 0
1272+
cont_end_idx = 0
12691273
for tensor in self._hinvs:
1270-
tensor_end += len(tensor)
1271-
if start_block > tensor_end:
1274+
cont_start_idx = cont_end_idx
1275+
cont_end_idx += len(tensor)
1276+
if block_start > cont_end_idx:
12721277
continue
1273-
if end_block < tensor_end:
1278+
if block_end < cont_end_idx:
12741279
t_hinv.append(
1275-
tensor[start_block - tensor_start : end_block - tensor_start]
1280+
tensor[block_start - cont_start_idx : block_end - cont_start_idx]
12761281
)
12771282
break
12781283
else:
1279-
t_hinv.append(tensor[start_block - tensor_start :])
1280-
start_block = tensor_end
1281-
tensor_start = tensor_end
1284+
t_hinv.append(tensor[block_start - cont_start_idx :])
1285+
block_start = cont_end_idx
12821286

12831287
mul_slice = (
12841288
torch.bmm(torch.cat(t_hinv).to(device), x_slice)

tests/sparseml/pytorch/sparsification/pruning/test_mfac_inverse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@
4747
pytest.param(
4848
["cuda:0"],
4949
marks=pytest.mark.skipif(
50-
"CUDA_VISIBLE_DEVICES" not in os.environ
51-
or not os.getenv("CUDA_VISIBLE_DEVICES"),
50+
not torch.cuda.is_available(),
5251
reason="No CUDA devices available",
5352
),
5453
),

tests/sparseml/pytorch/sparsification/pruning/test_modifier_pruning_mfac.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ def _build_gradient_sampler(
9696
num_grads=8,
9797
available_devices=["cpu"],
9898
),
99+
lambda: MFACPruningModifier(
100+
params=["seq.fc1.weight", "seq.fc2.weight"],
101+
init_sparsity=0.5,
102+
final_sparsity=0.95,
103+
start_epoch=2.0,
104+
end_epoch=5.0,
105+
update_frequency=1.0,
106+
inter_func="cubic",
107+
num_grads=8,
108+
global_sparsity=True,
109+
),
99110
],
100111
scope="function",
101112
)

0 commit comments

Comments
 (0)