Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit a612e7b

Browse files
authored
fix mask from threshold logic to mask exactly target number (#678)
* fix mask from threshold logic to mask exactly target number * off by 1 fix from review
1 parent 589f5f9 commit a612e7b

File tree

2 files changed

+68
-9
lines changed

2 files changed

+68
-9
lines changed

src/sparseml/pytorch/sparsification/pruning/mask_creator.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,24 @@ def create_sparsity_masks(
130130
masks.append(tensor.new_ones(tensor.shape))
131131
continue
132132

133+
num_elem = tensor.numel()
134+
target_num_mask = round(num_elem * sparsity_target)
133135
min_val = tensor.min().item()
136+
134137
if threshold.item() > min_val:
135-
masks.append((tensor > threshold).type(tensor.type()))
138+
threshold_mask = (tensor > threshold).type(tensor.type())
139+
140+
num_masked = num_elem - torch.sum(threshold_mask).item()
141+
if num_masked != target_num_mask:
142+
# attempt to reconcile expected number of masked weights
143+
# may occur if multiple values have the threshold weight
144+
num_to_flip = abs(num_masked - target_num_mask)
145+
over_masked = num_masked > target_num_mask
146+
threshold_mask = self._flip_threshold_mask_vals(
147+
threshold_mask, tensor, threshold, num_to_flip, over_masked
148+
)
149+
150+
masks.append(threshold_mask)
136151
continue
137152

138153
# too many zeros so will go over the already given sparsity
@@ -141,9 +156,7 @@ def create_sparsity_masks(
141156
rand_indices = list(range(zero_indices.shape[0]))
142157
local_rng = random.Random(42)
143158
local_rng.shuffle(rand_indices)
144-
num_elem = tensor.numel()
145-
num_mask = round(num_elem * sparsity_target)
146-
rand_indices = rand_indices[:num_mask]
159+
rand_indices = rand_indices[:target_num_mask]
147160
rand_indices = tensor.new_tensor(rand_indices, dtype=torch.int64)
148161
zero_indices = zero_indices[rand_indices, :]
149162
mask = tensor.new_ones(tensor.shape).type(tensor.type())
@@ -173,7 +186,7 @@ def _threshold_from_sparsity(self, tensor: Tensor, sparsity: float) -> Tensor:
173186
return tensor.new_tensor([])
174187

175188
sorted_vals, _ = torch.sort(tensor.view(-1))
176-
lookup_index = round(sparsity * (tensor.numel() - 1))
189+
lookup_index = round(sparsity * tensor.numel()) - 1
177190

178191
if lookup_index < 0:
179192
lookup_index = 0
@@ -218,6 +231,35 @@ def _unstack_flattened_tensors(
218231

219232
return unstacked_tensors
220233

234+
def _flip_threshold_mask_vals(
235+
self,
236+
mask: Tensor,
237+
tensor: Tensor,
238+
threshold: Tensor,
239+
max_flip: int,
240+
over_masked: bool,
241+
) -> Tensor:
242+
# flip mask values where tensor == threshold until mask has desired
243+
# number of 0s/1s
244+
threshold_idxs = torch.nonzero(tensor == threshold, as_tuple=False)
245+
num_flipped = 0
246+
for threshold_elem_idx in threshold_idxs:
247+
# make tensor returned by nonzero() indexable
248+
threshold_elem_idx = threshold_elem_idx.split(1)
249+
threshold_mask_elem = mask[threshold_elem_idx]
250+
251+
# flip mask val at threshold index if necessary
252+
if over_masked and threshold_mask_elem == 0:
253+
mask[threshold_elem_idx] = 1
254+
num_flipped += 1
255+
elif not over_masked and threshold_mask_elem == 1:
256+
mask[threshold_elem_idx] = 0
257+
num_flipped += 1
258+
259+
if num_flipped >= max_flip:
260+
break
261+
return mask
262+
221263

222264
class GroupedPruningMaskCreator(UnstructuredPruningMaskCreator):
223265
"""

tests/sparseml/pytorch/sparsification/pruning/helpers.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,35 @@ def sparsity_mask_creator_test(tensor_shapes, mask_creator, sparsity_val, device
126126
for update_mask, target_sparsity in zip(update_masks, sparsity_val):
127127
assert abs(tensor_sparsity(update_mask) - target_sparsity) < 1e-2
128128

129+
if not isinstance(mask_creator, GroupedPruningMaskCreator):
130+
_test_num_masked(update_mask, target_sparsity)
131+
129132
if isinstance(mask_creator, GroupedPruningMaskCreator):
130-
grouped_masks_test(update_masks, mask_creator)
133+
grouped_masks_test(update_masks, mask_creator, sparsity_val)
131134

132135
return update_masks
133136

134137

135-
def grouped_masks_test(masks, mask_creator):
138+
def grouped_masks_test(masks, mask_creator, sparsity_val=None):
136139
# Check that every value in the mask_creator grouping
137140
# is the same within the mask. Assumes grouping applies
138-
# an absolte mean to each grouping
139-
for mask in masks:
141+
# an absolute mean to each grouping
142+
# also checks that the grouped mask matches the target sparsity exactly
143+
144+
if sparsity_val is None:
145+
sparsity_val = [sparsity_val] * len(masks)
146+
147+
for mask, target_sparsity in zip(masks, sparsity_val):
140148
grouped_mask = mask_creator.group_tensor(mask)
141149
grouped_mask /= max(torch.max(grouped_mask).item(), 1.0)
142150
mask_vals_are_grouped = torch.all((grouped_mask == 0.0) | (grouped_mask == 1.0))
143151
assert mask_vals_are_grouped
152+
153+
if target_sparsity is not None:
154+
_test_num_masked(grouped_mask, target_sparsity)
155+
156+
157+
def _test_num_masked(mask, target_sparsity):
158+
# tests that the number of masked values is exactly the number expected
159+
expected_num_masked = round(target_sparsity * mask.numel())
160+
assert torch.sum(1 - mask).item() == expected_num_masked

0 commit comments

Comments
 (0)