Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit fe5523f

Browse files
authored
PyTorch Global Magnitude Pruning (#142)
* refactor mask creators and mask objects to accept lists of parameters instead of single * global magnitude option for mask creators * add global pruning option to GMPruningModifier * create PruningModifier child classes
1 parent 36799e9 commit fe5523f

File tree

7 files changed

+1036
-432
lines changed

7 files changed

+1036
-432
lines changed

src/sparseml/pytorch/optim/mask_creator_pruning.py

Lines changed: 208 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -40,33 +40,41 @@ class PruningMaskCreator(ABC):
4040
Subclasses should define all methods for creating masks
4141
"""
4242

43-
def create_sparsity_mask_from_tensor(self, tensor: Tensor) -> Tensor:
43+
def create_sparsity_masks_from_tensor(self, tensors: List[Tensor]) -> List[Tensor]:
4444
"""
45-
:param tensor: the tensor to calculate a mask based on its values
46-
:return: a mask derived from the values of tensor
45+
:param tensors: list of tensors to calculate a masks based on their values
46+
:return: list of masks derived from each of the given tensors
4747
"""
48-
return torch.ne(tensor, 0.0).type(tensor.type())
48+
return [torch.ne(tensor, 0.0).type(tensor.type()) for tensor in tensors]
4949

5050
@abstractmethod
51-
def create_sparsity_mask_from_abs_threshold(
52-
self, tensor: Tensor, threshold: Union[float, Tensor]
53-
) -> Tensor:
51+
def create_sparsity_masks_from_abs_threshold(
52+
self, tensors: List[Tensor], threshold: Union[float, Tensor]
53+
) -> List[Tensor]:
5454
"""
55-
:param tensor: the tensor to calculate a mask from based on the contained
55+
:param tensors: list of tensors to calculate a masks based on their contained
5656
values
5757
:param threshold: a threshold to determine cutoff for sparsification
58-
:return: a mask derived from the tensor and the threshold
58+
:return: list of masks derived from each of the given tensors and the threshold
5959
"""
6060
raise NotImplementedError()
6161

6262
@abstractmethod
63-
def create_sparsity_mask(self, tensor: Tensor, sparsity: float) -> Tensor:
63+
def create_sparsity_masks(
64+
self, tensors: List[Tensor], sparsity: float, global_sparsity: bool = False
65+
) -> List[Tensor]:
6466
"""
65-
:param tensor: the tensor to calculate a mask from based on the contained values
67+
:param tensors: list of tensors to calculate a masks based on their contained
68+
values
6669
:param sparsity: the desired sparsity to reach within the mask
6770
(decimal fraction of zeros)
68-
:return: a mask (0.0 for values that are masked, 1.0 for values that are
69-
unmasked) calculated from the tens such that the desired number of zeros
71+
:param global_sparsity: if True, sparsity masks will be created such that the
72+
average sparsity across all given tensors is the target sparsity with the
73+
lowest global values masked. If False, each tensor will be masked to the
74+
target sparsity ranking values within each individual tensor. Default is
75+
False
76+
:return: list of masks (0.0 for values that are masked, 1.0 for values that are
77+
unmasked) calculated from the tensors such that the desired number of zeros
7078
matches the sparsity.
7179
"""
7280
raise NotImplementedError()
@@ -76,54 +84,88 @@ class UnstructuredPruningMaskCreator(PruningMaskCreator):
7684
"""
7785
Class for creating unstructured sparsity masks.
7886
Masks will be created using unstructured sparsity by pruning weights ranked
79-
by their magnitude.
87+
by their magnitude. Each mask will correspond to the given tensor.
8088
"""
8189

82-
def create_sparsity_mask_from_abs_threshold(
83-
self, tensor: Tensor, threshold: Union[float, Tensor]
84-
) -> Tensor:
90+
def create_sparsity_masks_from_abs_threshold(
91+
self, tensors: List[Tensor], threshold: Union[float, Tensor]
92+
) -> List[Tensor]:
8593
"""
86-
:param tensor: the tensor to calculate a mask from based on the contained values
94+
:param tensors: list of tensors to calculate a masks based on their contained
95+
values
8796
:param threshold: a threshold at which to mask abs(values) if they are
8897
less than it or equal
89-
:return: a mask (0.0 for values that are masked, 1.0 for values that are
90-
unmasked) calculated from the tens abs(values) <= threshold are masked,
98+
:return: list of masks (0.0 for values that are masked, 1.0 for values that are
99+
unmasked) calculated from the tensors abs(values) <= threshold are masked,
91100
all others are unmasked
92101
"""
93-
return (torch.abs(tensor) > threshold).type(tensor.type())
102+
return [
103+
(torch.abs(tensor) > threshold).type(tensor.type()) for tensor in tensors
104+
]
94105

95-
def create_sparsity_mask(self, tensor: Tensor, sparsity: float) -> Tensor:
106+
def create_sparsity_masks(
107+
self,
108+
tensors: List[Tensor],
109+
sparsity: float,
110+
global_sparsity: bool = False,
111+
) -> List[Tensor]:
96112
"""
97-
:param tensor: the tensor to calculate a mask from based on the contained values
113+
:param tensors: list of tensors to calculate a mask from based on their
114+
contained values
98115
:param sparsity: the desired sparsity to reach within the mask
99116
(decimal fraction of zeros)
100-
:return: a mask (0.0 for values that are masked, 1.0 for values that are
101-
unmasked) calculated from the tens such that the desired number of zeros
102-
matches the sparsity. removes the abs lowest values if there are more zeros
103-
in the tens than desired sparsity, then will randomly choose the zeros
104-
"""
105-
threshold = self._abs_threshold_from_sparsity(tensor, sparsity)
106-
107-
if threshold.numel() < 1:
108-
return tensor.new_ones(tensor.shape)
109-
110-
if threshold.item() > 0.0:
111-
return (torch.abs(tensor) > threshold).type(tensor.type())
112-
113-
# too many zeros so will go over the already given sparsity
114-
# and choose which zeros to not keep in mask at random
115-
zero_indices = (tensor == 0.0).nonzero()
116-
rand_indices = list(range(zero_indices.shape[0]))
117-
random.shuffle(rand_indices)
118-
num_elem = tensor.numel()
119-
num_mask = int(num_elem * sparsity)
120-
rand_indices = rand_indices[:num_mask]
121-
rand_indices = tensor.new_tensor(rand_indices, dtype=torch.int64)
122-
zero_indices = zero_indices[rand_indices, :]
123-
mask = tensor.new_ones(tensor.shape).type(tensor.type())
124-
mask[zero_indices.split(1, dim=1)] = 0
125-
126-
return mask.type(tensor.type())
117+
:param global_sparsity: if True, sparsity masks will be created such that the
118+
average sparsity across all given tensors is the target sparsity with the
119+
lowest global values masked. If False, each tensor will be masked to the
120+
target sparsity ranking values within each individual tensor. Default is
121+
False
122+
:return: list of masks (0.0 for values that are masked, 1.0 for values that are
123+
unmasked) calculated from the tensors such that the desired number of zeros
124+
matches the sparsity. If there are more zeros than the desired sparsity,
125+
zeros will be randomly chosen to match the target sparsity
126+
"""
127+
if global_sparsity:
128+
# create tensor to make global mask with
129+
original_tensors = tensors
130+
tensors = [self._flatten_and_stack_tensors(tensors)]
131+
else:
132+
original_tensors = None
133+
134+
masks = []
135+
136+
for tensor in tensors:
137+
threshold = self._abs_threshold_from_sparsity(tensor, sparsity)
138+
139+
if threshold.numel() < 1:
140+
masks.append(tensor.new_ones(tensor.shape))
141+
continue
142+
143+
if threshold.item() > 0.0:
144+
masks.append((torch.abs(tensor) > threshold).type(tensor.type()))
145+
continue
146+
147+
# too many zeros so will go over the already given sparsity
148+
# and choose which zeros to not keep in mask at random
149+
zero_indices = (tensor == 0.0).nonzero()
150+
rand_indices = list(range(zero_indices.shape[0]))
151+
random.shuffle(rand_indices)
152+
num_elem = tensor.numel()
153+
num_mask = int(num_elem * sparsity)
154+
rand_indices = rand_indices[:num_mask]
155+
rand_indices = tensor.new_tensor(rand_indices, dtype=torch.int64)
156+
zero_indices = zero_indices[rand_indices, :]
157+
mask = tensor.new_ones(tensor.shape).type(tensor.type())
158+
mask[zero_indices.split(1, dim=1)] = 0
159+
160+
masks.append(mask.type(tensor.type()))
161+
162+
if global_sparsity:
163+
# unpack global mask into tensor-masks with the original shapes
164+
global_mask = masks[0]
165+
masks = self._unstack_flattened_tensors(global_mask, original_tensors)
166+
del global_mask
167+
168+
return masks
127169

128170
def _abs_threshold_from_sparsity(self, tensor: Tensor, sparsity: float) -> Tensor:
129171
"""
@@ -146,6 +188,42 @@ def _abs_threshold_from_sparsity(self, tensor: Tensor, sparsity: float) -> Tenso
146188

147189
return sorted_vals[lookup_index]
148190

191+
def _flatten_and_stack_tensors(self, tensors: List[Tensor]) -> Tensor:
192+
total_elements = sum(tensor.numel() for tensor in tensors)
193+
194+
global_tensor = (
195+
tensors[0].new_zeros(total_elements).detach().requires_grad_(False)
196+
)
197+
198+
curr_element = 0
199+
for idx, tensor in enumerate(tensors):
200+
global_tensor[
201+
curr_element : curr_element + tensor.numel()
202+
] = tensor.reshape(-1)
203+
curr_element += tensor.numel()
204+
205+
return global_tensor
206+
207+
def _unstack_flattened_tensors(
208+
self, stacked_tensor: Tensor, original_tensors: List[Tensor]
209+
) -> List[Tensor]:
210+
unstacked_tensors = []
211+
global_idx = 0
212+
for tensor in original_tensors:
213+
# unpack global tensor into masks matching original tensor shapes
214+
unstacked_tensor = (
215+
tensor.new_empty(tensor.numel()).detach().requires_grad_(False)
216+
)
217+
unstacked_tensor.copy_(
218+
stacked_tensor[global_idx : global_idx + tensor.numel()]
219+
).type(tensor.type())
220+
unstacked_tensor = unstacked_tensor.reshape(tensor.shape)
221+
222+
unstacked_tensors.append(unstacked_tensor)
223+
global_idx += tensor.numel()
224+
225+
return unstacked_tensors
226+
149227
def __str__(self):
150228
return "unstructured"
151229

@@ -213,44 +291,71 @@ def _map_mask_to_tensor(
213291
"""
214292
raise NotImplementedError()
215293

216-
def create_sparsity_mask_from_tensor(self, tensor: Tensor) -> Tensor:
294+
def create_sparsity_masks_from_tensor(self, tensors: List[Tensor]) -> List[Tensor]:
217295
"""
218-
:param tensor: the tensor to calculate a mask based on its values
219-
:return: a mask derived from the values of tensor grouped by the group_tensor
296+
:param tensors: list of tensors to calculate masks based on their values
297+
:return: list of masks derived from the values of the tensors grouped by
298+
the group_tensor function.
220299
"""
221-
grouped_tensor = self.group_tensor(tensor)
222-
grouped_mask = super().create_sparsity_mask_from_tensor(grouped_tensor)
223-
return self._map_mask_to_tensor(grouped_mask, tensor.shape)
300+
masks = []
301+
for tensor in tensors:
302+
grouped_tensor = self.group_tensor(tensor)
303+
grouped_mask = super().create_sparsity_masks_from_tensor([grouped_tensor])[
304+
0
305+
]
306+
masks.append(self._map_mask_to_tensor(grouped_mask, tensor.shape))
307+
return masks
224308

225-
def create_sparsity_mask_from_abs_threshold(
226-
self, tensor: Tensor, threshold: Union[float, Tensor]
227-
) -> Tensor:
309+
def create_sparsity_masks_from_abs_threshold(
310+
self, tensors: List[Tensor], threshold: Union[float, Tensor]
311+
) -> List[Tensor]:
228312
"""
229-
:param tensor: the tensor to calculate a mask from based on the contained
313+
:param tensors: list of tensors to calculate masks from based on their contained
230314
values
231315
:param threshold: a threshold of group_tensor values to determine cutoff
232316
for sparsification
233-
:return: a mask derived from the tensor and the grouped threshold
234-
"""
235-
grouped_tensor = self.group_tensor(tensor)
236-
grouped_mask = super().create_sparsity_mask_from_abs_threshold(
237-
grouped_tensor, threshold
238-
)
239-
return self._map_mask_to_tensor(grouped_mask, tensor.shape)
240-
241-
def create_sparsity_mask(self, tensor: Tensor, sparsity: float) -> Tensor:
317+
:return: list of masks derived from the tensors and the grouped threshold
318+
"""
319+
masks = []
320+
for tensor in tensors:
321+
grouped_tensor = self.group_tensor(tensor)
322+
grouped_mask = super().create_sparsity_masks_from_abs_threshold(
323+
[grouped_tensor], threshold
324+
)[0]
325+
masks.append(self._map_mask_to_tensor(grouped_mask, tensor.shape))
326+
return masks
327+
328+
def create_sparsity_masks(
329+
self,
330+
tensors: List[Tensor],
331+
sparsity: float,
332+
global_sparsity: bool = False,
333+
) -> List[Tensor]:
242334
"""
243-
:param tensor: the tensor to calculate a mask from based on the contained values
335+
:param tensors: list of tensors to calculate masks from based on their contained
336+
values
244337
:param sparsity: the desired sparsity to reach within the mask
245338
(decimal fraction of zeros)
246-
:return: a mask (0.0 for values that are masked, 1.0 for values that are
247-
unmasked) calculated from the tens such that the desired number of zeros
248-
matches the sparsity and all values mapped to the same group have the
249-
same value.
250-
"""
251-
grouped_tensor = self.group_tensor(tensor)
252-
grouped_mask = super().create_sparsity_mask(grouped_tensor, sparsity)
253-
return self._map_mask_to_tensor(grouped_mask, tensor.shape)
339+
:param global_sparsity: if True, sparsity masks will be created such that the
340+
average sparsity across all given tensors is the target sparsity with the
341+
lowest global values masked. If False, each tensor will be masked to the
342+
target sparsity ranking values within each individual tensor. Default is
343+
False
344+
:return: list of masks (0.0 for values that are masked, 1.0 for values that are
345+
unmasked) calculated from the tensors such that the desired number of zeros
346+
matches the sparsity and all values mapped to the same group have the same
347+
value
348+
"""
349+
grouped_tensors = [self.group_tensor(tensor) for tensor in tensors]
350+
grouped_masks = super().create_sparsity_masks(
351+
grouped_tensors, sparsity, global_sparsity
352+
)
353+
masks = [
354+
self._map_mask_to_tensor(grouped_mask, tensor.shape)
355+
for grouped_mask, tensor in zip(grouped_masks, tensors)
356+
]
357+
358+
return masks
254359

255360

256361
class DimensionSparsityMaskCreator(GroupedPruningMaskCreator):
@@ -286,6 +391,31 @@ def __init__(
286391
)
287392
self._dim = dim # List[int]
288393

394+
def create_sparsity_masks(
395+
self,
396+
tensors: List[Tensor],
397+
sparsity: float,
398+
global_sparsity: bool = False,
399+
) -> List[Tensor]:
400+
"""
401+
:param tensors: list of tensors to calculate masks from based on their contained
402+
values
403+
:param sparsity: the desired sparsity to reach within the mask
404+
(decimal fraction of zeros)
405+
:param global_sparsity: do not set True, unsupported for
406+
DimensionSparsityMaskCreator
407+
:return: list of masks (0.0 for values that are masked, 1.0 for values that are
408+
unmasked) calculated from the tensors such that the desired number of zeros
409+
matches the sparsity and all values mapped to the same group have the same
410+
value
411+
"""
412+
if global_sparsity:
413+
# global sparsity unsupported because channel dims may vary across layers
414+
raise ValueError(
415+
"global_sparsity not supported for DimensionSparsityMaskCreator"
416+
)
417+
return super().create_sparsity_masks(tensors, sparsity, global_sparsity=False)
418+
289419
def group_tensor(self, tensor: Tensor) -> Tensor:
290420
"""
291421
:param tensor: The tensor to transform

0 commit comments

Comments
 (0)