Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 33d764a

Browse files
authored
mask creator refactor (#599) (#601)
1 parent 5ffa7d7 commit 33d764a

File tree

7 files changed

+213
-50
lines changed

7 files changed

+213
-50
lines changed

src/sparseml/pytorch/sparsification/pruning/mask_creator.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"GroupedPruningMaskCreator",
3232
"UnstructuredPruningMaskCreator",
3333
"FourBlockMaskCreator",
34+
"BlockMaskCreator",
3435
]
3536

3637

@@ -422,3 +423,148 @@ def _map_mask_to_tensor(
422423
permute_val.insert(1, len(permute_val))
423424
block_mask = block_mask.permute(*permute_val)
424425
return block_mask
426+
427+
428+
class BlockMaskCreator(GroupedPruningMaskCreator):
429+
"""
430+
Structured sparsity mask creator that groups the input tensor into blocks of
431+
shape block_shape.
432+
433+
:param block_shape: The shape in and out channel should take in blocks. Should be
434+
a list of exactly two integers that divide the input tensors evenly on the
435+
channel dimensions. -1 for a dimension blocks across the entire dimension
436+
:param grouping_fn_name: The name of the torch grouping function to reduce
437+
dimensions by
438+
"""
439+
440+
def __init__(
441+
self,
442+
block_shape: List[int],
443+
grouping_fn_name: str = "mean",
444+
):
445+
if len(block_shape) < 2:
446+
raise ValueError(
447+
(
448+
"Invalid block_shape: {}, "
449+
"block_shape must have length == 2 for in and out channels"
450+
).format(block_shape)
451+
)
452+
453+
if len(block_shape) > 2 and not all([shape == 1 for shape in block_shape[2:]]):
454+
# after in and out channels, only 1 can be used for other dimensions
455+
raise ValueError(
456+
(
457+
"Invalid block_shape: {}, "
458+
"block_shape for indices not in [0, 1] must be equal to 1"
459+
).format(block_shape)
460+
)
461+
462+
self._block_shape = block_shape
463+
self._grouping_fn_name = grouping_fn_name
464+
465+
def group_tensor(self, tensor: Tensor) -> Tensor:
466+
"""
467+
:param tensor: The tensor to transform
468+
:return: The mean values of the tensor grouped by blocks of shape
469+
self._block_shape
470+
"""
471+
blocked_tens_shape = self._get_blocked_tens_shape_and_validate(tensor.shape)
472+
blocked_tensor = tensor.reshape(blocked_tens_shape)
473+
reduced_blocks = GroupedPruningMaskCreator.reduce_tensor(
474+
blocked_tensor, 1, self._grouping_fn_name
475+
)
476+
return reduced_blocks.type(tensor.type())
477+
478+
def _map_mask_to_tensor(
479+
self,
480+
grouped_mask: Tensor,
481+
original_tensor_shape: torch.Size,
482+
tensor_idx: Optional[int] = None,
483+
) -> Tensor:
484+
"""
485+
:param grouped_mask: A binary mask the size of a tensor from group_tensor
486+
:param original_tensor_shape: Shape of the original tensor grouped_mask
487+
derives from
488+
:param tensor_idx: optional index this tensor was passed into a tensor
489+
list for mask creation
490+
:return: The values from grouped_mask mapped to a tensor of size
491+
original_tensor_shape
492+
"""
493+
blocked_tens_shape = self._get_blocked_tens_shape_and_validate(
494+
original_tensor_shape
495+
)
496+
# expand so every element has a corresponding value in the original tensor
497+
block_mask = grouped_mask.reshape(blocked_tens_shape[0], blocked_tens_shape[2])
498+
block_mask = block_mask.unsqueeze(1)
499+
block_mask = block_mask.expand(*blocked_tens_shape).contiguous()
500+
return block_mask.reshape(original_tensor_shape)
501+
502+
def _get_blocked_tens_shape_and_validate(
503+
self,
504+
tens_shape: torch.Size,
505+
) -> List[int]:
506+
"""
507+
:param tens_shape: The shape of the tensor to group in blocks
508+
:return: shape of tens when blocked by block_shape
509+
:raise: ValueError if we are unable to block tens by shape block_shape
510+
"""
511+
block_shape = self._block_shape
512+
n_dims = len(tens_shape)
513+
while len(block_shape) < n_dims: # Conv will have block shape [X, Y, 1, ..., 1]
514+
block_shape.append(1)
515+
for idx, shape in enumerate(block_shape):
516+
if shape == -1:
517+
block_shape[idx] = tens_shape[idx]
518+
# Validate
519+
if n_dims < 2:
520+
raise ValueError(
521+
"Invalid tensor shape {}."
522+
" BlockSparsityMaskCreator can only create masks from tensors with 2 or"
523+
" more dimensions, tensor has {}.".format(tens_shape, n_dims)
524+
)
525+
for tens_dim, block_dim in zip(tens_shape, block_shape):
526+
if tens_dim % block_dim != 0:
527+
raise ValueError(
528+
f"Invalid block_shape {block_shape} for parameter shape "
529+
f"{tens_shape}. Elements of block_shape must divide parameter "
530+
f"shape evenly"
531+
)
532+
# Compute blocked tensor shape
533+
if len(block_shape) > 1 and block_shape[1] > 1:
534+
return [
535+
tens_shape[0] * tens_shape[1] // (block_shape[0] * block_shape[1]),
536+
block_shape[0] * block_shape[1],
537+
-1,
538+
]
539+
else:
540+
return [tens_shape[0] // block_shape[0], block_shape[0], -1]
541+
542+
543+
def get_mask_creator_default(mask_type: Union[str, List[int]]) -> PruningMaskCreator:
544+
"""
545+
:param mask_type: type of mask creator to use, can be 'unstructured', for
546+
unstructured mask creator, 'block4' for 1x4 block pruning, or a list of two
547+
integers for custom block pruning (does not support padding)
548+
:return: mask creator object created from the mask type
549+
"""
550+
if mask_type == "unstructured":
551+
return UnstructuredPruningMaskCreator()
552+
elif mask_type == "block4":
553+
return FourBlockMaskCreator()
554+
elif isinstance(mask_type, List):
555+
if not all(isinstance(val, int) for val in mask_type):
556+
raise ValueError(
557+
"all values in list specification of BlockMaskCreator must be integers "
558+
f"found {mask_type}"
559+
)
560+
if len(mask_type) != 2:
561+
raise ValueError(
562+
"expected list of length 2 for specification of BlockMaskCreator, "
563+
f"got list with length {len(mask_type)}, mask_type={mask_type}"
564+
)
565+
return BlockMaskCreator(mask_type)
566+
else:
567+
raise ValueError(
568+
f"Unknown mask_type {mask_type}. Supported mask types include "
569+
"'unstructured' and 'block'"
570+
)

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_acdc.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020

2121
from sparseml.pytorch.optim.modifier import ModifierProp, PyTorchModifierYAML
2222
from sparseml.pytorch.sparsification.pruning.mask_creator import (
23-
FourBlockMaskCreator,
2423
PruningMaskCreator,
25-
UnstructuredPruningMaskCreator,
24+
get_mask_creator_default,
2625
)
2726
from sparseml.pytorch.sparsification.pruning.modifier_pruning_base import (
2827
BasePruningModifier,
@@ -70,9 +69,9 @@ class ACDCPruningModifier(BasePruningModifier):
7069
immediately after or doing some other prune. Default is True
7170
:param log_types: The loggers to allow the learning rate to be logged to,
7271
default is __ALL__
73-
:param mask_type: String to define type of sparsity (options: ['unstructured',
74-
'channel', 'filter']), List to define block shape of a parameters in and out
75-
channels, or a SparsityMaskCreator object. default is 'unstructured'
72+
:param mask_type: String to define type of sparsity to apply. May be 'unstructred'
73+
for unstructured pruning or 'block4' for four block pruning or a list of two
74+
integers for a custom block shape. Default is 'unstructured'
7675
:param momentum_buffer_reset: set True to reset momentum buffer
7776
before algorithm enters a consecutive decompression phase.
7877
According to the paper:
@@ -225,15 +224,7 @@ def _get_mask_creator(
225224
:param params: list of Parameters to be masked
226225
:return: mask creator object to be used by this pruning algorithm
227226
"""
228-
if self._mask_type == "unstructured":
229-
return UnstructuredPruningMaskCreator()
230-
elif self._mask_type == "block":
231-
return FourBlockMaskCreator()
232-
else:
233-
raise ValueError(
234-
f"Unknown mask_type {self._mask_type}. Supported mask types include "
235-
"'unstructured' and 'block'"
236-
)
227+
return get_mask_creator_default(self.mask_type)
237228

238229
@staticmethod
239230
def _reset_momentum_buffer(optimizer):

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_magnitude.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424

2525
from sparseml.pytorch.optim.modifier import PyTorchModifierYAML
2626
from sparseml.pytorch.sparsification.pruning.mask_creator import (
27-
FourBlockMaskCreator,
2827
PruningMaskCreator,
29-
UnstructuredPruningMaskCreator,
28+
get_mask_creator_default,
3029
)
3130
from sparseml.pytorch.sparsification.pruning.modifier_pruning_base import (
3231
BaseGradualPruningModifier,
@@ -104,8 +103,9 @@ class GMPruningModifier(BaseGradualPruningModifier, BaseGMPruningModifier):
104103
[linear, cubic, inverse_cubic]
105104
:param log_types: The loggers to allow the learning rate to be logged to,
106105
default is __ALL__
107-
:param mask_type: String to define type of sparsity to apply. May be 'unstructured'
108-
for unstructured pruning or 'block' for four block pruning
106+
:param mask_type: String to define type of sparsity to apply. May be 'unstructred'
107+
for unstructured pruning or 'block4' for four block pruning or a list of two
108+
integers for a custom block shape. Default is 'unstructured'
109109
"""
110110

111111
def __init__(
@@ -152,15 +152,7 @@ def _get_mask_creator(
152152
:param params: list of parameters to be masked
153153
:return: mask creator object to be used by this pruning algorithm
154154
"""
155-
if self.mask_type == "unstructured":
156-
return UnstructuredPruningMaskCreator()
157-
elif self.mask_type == "block":
158-
return FourBlockMaskCreator()
159-
else:
160-
raise ValueError(
161-
f"Unknown mask_type {self.mask_type}. Supported mask types include "
162-
"'unstructured' and 'block'"
163-
)
155+
return get_mask_creator_default(self.mask_type)
164156

165157
def _get_scorer(self, params: List[Parameter]) -> PruningParamsScorer:
166158
"""
@@ -221,7 +213,8 @@ class MagnitudePruningModifier(GMPruningModifier):
221213
:param log_types: The loggers to allow the learning rate to be logged to,
222214
default is __ALL__
223215
:param mask_type: String to define type of sparsity to apply. May be 'unstructred'
224-
for unstructured pruning or 'block' for four block pruning
216+
for unstructured pruning or 'block4' for four block pruning or a list of two
217+
integers for a custom block shape. Default is 'unstructured'
225218
"""
226219

227220
# just an alias for GMPruningModifier
@@ -274,7 +267,8 @@ class GlobalMagnitudePruningModifier(GMPruningModifier):
274267
:param log_types: The loggers to allow the learning rate to be logged to,
275268
default is __ALL__
276269
:param mask_type: String to define type of sparsity to apply. May be 'unstructred'
277-
for unstructured pruning or 'block' for four block pruning
270+
for unstructured pruning or 'block4' for four block pruning or a list of two
271+
integers for a custom block shape. Default is 'unstructured'
278272
"""
279273

280274
def __init__(

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_mfac.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232
import GPUtil
3333
from sparseml.pytorch.optim.modifier import ModifierProp, PyTorchModifierYAML
3434
from sparseml.pytorch.sparsification.pruning.mask_creator import (
35-
FourBlockMaskCreator,
3635
PruningMaskCreator,
37-
UnstructuredPruningMaskCreator,
36+
get_mask_creator_default,
3837
)
3938
from sparseml.pytorch.sparsification.pruning.modifier_pruning_base import (
4039
BaseGradualPruningModifier,
@@ -109,9 +108,9 @@ class MFACPruningModifier(BaseGradualPruningModifier):
109108
[linear, cubic, inverse_cubic]
110109
:param log_types: The loggers to allow the learning rate to be logged to,
111110
default is __ALL__
112-
:param mask_type: String to define type of sparsity (options: ['unstructured',
113-
'channel', 'filter']), List to define block shape of a parameters in and out
114-
channels, or a SparsityMaskCreator object. default is 'unstructured'
111+
:param mask_type: String to define type of sparsity to apply. May be 'unstructred'
112+
for unstructured pruning or 'block4' for four block pruning or a list of two
113+
integers for a custom block shape. Default is 'unstructured'
115114
:param global_sparsity: set True to enable global pruning. if False, pruning will
116115
be layer-wise. Default is False
117116
:param use_gradient_buffering: Optional bool to use gradient buffering instead of
@@ -132,9 +131,9 @@ class MFACPruningModifier(BaseGradualPruningModifier):
132131
Default is 1
133132
:param available_devices: list of device names to perform computation on. Default
134133
is empty
135-
:param mask_type: String to define type of sparsity (options: ['unstructured',
136-
'block']), List to define block shape of a parameters in and out
137-
channels, or a SparsityMaskCreator object. default is 'unstructured'
134+
:param mask_type: String to define type of sparsity to apply. May be 'unstructred'
135+
for unstructured pruning or 'block4' for four block pruning or a list of two
136+
integers for a custom block shape. Default is 'unstructured'
138137
"""
139138

140139
def __init__(
@@ -296,15 +295,7 @@ def _get_mask_creator(
296295
:param params: list of Parameters to be masked
297296
:return: mask creator object to be used by this pruning algorithm
298297
"""
299-
if self._mask_type == "unstructured":
300-
return UnstructuredPruningMaskCreator()
301-
elif self._mask_type == "block":
302-
return FourBlockMaskCreator()
303-
else:
304-
raise ValueError(
305-
f"Unknown mask_type {self._mask_type}. Supported mask types include "
306-
"'unstructured' and 'block'"
307-
)
298+
return get_mask_creator_default(self.mask_type)
308299

309300
def _get_scorer(self, params: List[Parameter]) -> PruningParamsGradScorer:
310301
"""

tests/sparseml/pytorch/sparsification/pruning/test_mask_creator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717

1818
from sparseml.pytorch.sparsification.pruning import (
19+
BlockMaskCreator,
1920
FourBlockMaskCreator,
2021
GroupedPruningMaskCreator,
2122
UnstructuredPruningMaskCreator,
@@ -35,6 +36,7 @@
3536
([[64, 513]], FourBlockMaskCreator()),
3637
([[64, 512, 3, 3]], FourBlockMaskCreator()),
3738
([[63, 513, 3, 3]], FourBlockMaskCreator()),
39+
([[64, 512, 3, 3]], BlockMaskCreator([1, 4])),
3840
],
3941
)
4042
@pytest.mark.parametrize("sparsity_val", [0.0, 0.4, 0.6, 0.9, 0.99, 1.0])
@@ -47,6 +49,7 @@ def test_sparsity_mask_creator(tensor_shape, mask_creator, sparsity_val):
4749
[
4850
([[64, 64, 3, 3]], UnstructuredPruningMaskCreator()),
4951
([[64, 512, 3, 3]], FourBlockMaskCreator()),
52+
([[64, 512, 3, 3]], BlockMaskCreator([1, 4])),
5053
],
5154
)
5255
@pytest.mark.parametrize("sparsity_val", [0.0, 0.4, 0.6, 0.9, 0.99, 1.0])
@@ -74,6 +77,14 @@ def test_sparsity_mask_creator_cuda(tensor_shape, mask_creator, sparsity_val):
7477
[i * torch.randn(64, 64, 3, 3) for i in range(1, 6)],
7578
FourBlockMaskCreator(),
7679
),
80+
(
81+
[torch.randn(128, 128, 3, 3), 3 * torch.randn(64, 512)],
82+
BlockMaskCreator([1, 4]),
83+
),
84+
(
85+
[i * torch.randn(64, 64, 3, 3) for i in range(1, 6)],
86+
BlockMaskCreator([1, 4]),
87+
),
7788
],
7889
)
7990
@pytest.mark.parametrize("sparsity_val", [0.0, 0.4, 0.6, 0.9, 0.99, 1.0])
@@ -116,7 +127,37 @@ def test_global_sparsity_mask_creator(tensors, mask_creator, sparsity_val):
116127
FourBlockMaskCreator(),
117128
[0.4, 0.6, 0.8, 0.9, 0.95, 0.99],
118129
),
130+
(
131+
[[128, 128, 3, 3], [64, 512]],
132+
BlockMaskCreator([1, 4]),
133+
[0.8, 0.9],
134+
),
135+
(
136+
[[64, 64, 3, 3]] * 6,
137+
BlockMaskCreator([1, 4]),
138+
[0.4, 0.6, 0.8, 0.9, 0.95, 0.99],
139+
),
119140
],
120141
)
121142
def test_sparsity_mask_creator_mult_tensor(tensor_shapes, mask_creator, sparsity_val):
122143
sparsity_mask_creator_test(tensor_shapes, mask_creator, sparsity_val, "cpu")
144+
145+
146+
@pytest.mark.parametrize(
147+
("tensors"),
148+
[
149+
[torch.randn(128, 128), torch.randn(128, 512, 3, 3)],
150+
[torch.randn(5, 64, 3, 3)],
151+
],
152+
)
153+
@pytest.mark.parametrize("sparsity_val", [0.0, 0.4, 0.6, 0.9, 0.99, 1.0])
154+
def test_four_block_mask_creator_matches_block(tensors, sparsity_val):
155+
mask_creator_1 = FourBlockMaskCreator()
156+
mask_creator_2 = BlockMaskCreator([1, 4])
157+
158+
masks_1 = mask_creator_1.create_sparsity_masks(tensors, sparsity_val)
159+
masks_2 = mask_creator_2.create_sparsity_masks(tensors, sparsity_val)
160+
161+
for mask_1, mask_2 in zip(masks_1, masks_2):
162+
assert mask_1.shape == mask_2.shape
163+
assert torch.all(mask_1 == mask_2)

0 commit comments

Comments
 (0)