@@ -130,9 +130,24 @@ def create_sparsity_masks(
130130 masks .append (tensor .new_ones (tensor .shape ))
131131 continue
132132
133+ num_elem = tensor .numel ()
134+ target_num_mask = round (num_elem * sparsity_target )
133135 min_val = tensor .min ().item ()
136+
134137 if threshold .item () > min_val :
135- masks .append ((tensor > threshold ).type (tensor .type ()))
138+ threshold_mask = (tensor > threshold ).type (tensor .type ())
139+
140+ num_masked = num_elem - torch .sum (threshold_mask ).item ()
141+ if num_masked != target_num_mask :
142+ # attempt to reconcile expected number of masked weights
143+ # may occur if multiple values have the threshold weight
144+ num_to_flip = abs (num_masked - target_num_mask )
145+ over_masked = num_masked > target_num_mask
146+ threshold_mask = self ._flip_threshold_mask_vals (
147+ threshold_mask , tensor , threshold , num_to_flip , over_masked
148+ )
149+
150+ masks .append (threshold_mask )
136151 continue
137152
138153 # too many zeros so will go over the already given sparsity
@@ -141,9 +156,7 @@ def create_sparsity_masks(
141156 rand_indices = list (range (zero_indices .shape [0 ]))
142157 local_rng = random .Random (42 )
143158 local_rng .shuffle (rand_indices )
144- num_elem = tensor .numel ()
145- num_mask = round (num_elem * sparsity_target )
146- rand_indices = rand_indices [:num_mask ]
159+ rand_indices = rand_indices [:target_num_mask ]
147160 rand_indices = tensor .new_tensor (rand_indices , dtype = torch .int64 )
148161 zero_indices = zero_indices [rand_indices , :]
149162 mask = tensor .new_ones (tensor .shape ).type (tensor .type ())
@@ -173,7 +186,7 @@ def _threshold_from_sparsity(self, tensor: Tensor, sparsity: float) -> Tensor:
173186 return tensor .new_tensor ([])
174187
175188 sorted_vals , _ = torch .sort (tensor .view (- 1 ))
176- lookup_index = round (sparsity * ( tensor .numel () - 1 ))
189+ lookup_index = round (sparsity * tensor .numel ()) - 1
177190
178191 if lookup_index < 0 :
179192 lookup_index = 0
@@ -218,6 +231,35 @@ def _unstack_flattened_tensors(
218231
219232 return unstacked_tensors
220233
234+ def _flip_threshold_mask_vals (
235+ self ,
236+ mask : Tensor ,
237+ tensor : Tensor ,
238+ threshold : Tensor ,
239+ max_flip : int ,
240+ over_masked : bool ,
241+ ) -> Tensor :
242+ # flip mask values where tensor == threshold until mask has desired
243+ # number of 0s/1s
244+ threshold_idxs = torch .nonzero (tensor == threshold , as_tuple = False )
245+ num_flipped = 0
246+ for threshold_elem_idx in threshold_idxs :
247+ # make tensor returned by nonzero() indexable
248+ threshold_elem_idx = threshold_elem_idx .split (1 )
249+ threshold_mask_elem = mask [threshold_elem_idx ]
250+
251+ # flip mask val at threshold index if necessary
252+ if over_masked and threshold_mask_elem == 0 :
253+ mask [threshold_elem_idx ] = 1
254+ num_flipped += 1
255+ elif not over_masked and threshold_mask_elem == 1 :
256+ mask [threshold_elem_idx ] = 0
257+ num_flipped += 1
258+
259+ if num_flipped >= max_flip :
260+ break
261+ return mask
262+
221263
222264class GroupedPruningMaskCreator (UnstructuredPruningMaskCreator ):
223265 """
0 commit comments