@@ -40,33 +40,41 @@ class PruningMaskCreator(ABC):
4040 Subclasses should define all methods for creating masks
4141 """
4242
43- def create_sparsity_mask_from_tensor (self , tensor : Tensor ) -> Tensor :
43+ def create_sparsity_masks_from_tensor (self , tensors : List [ Tensor ] ) -> List [ Tensor ] :
4444 """
45- :param tensor: the tensor to calculate a mask based on its values
46- :return: a mask derived from the values of tensor
45+ :param tensors: list of tensors to calculate a masks based on their values
46+ :return: list of masks derived from each of the given tensors
4747 """
48- return torch .ne (tensor , 0.0 ).type (tensor .type ())
48+ return [ torch .ne (tensor , 0.0 ).type (tensor .type ()) for tensor in tensors ]
4949
5050 @abstractmethod
51- def create_sparsity_mask_from_abs_threshold (
52- self , tensor : Tensor , threshold : Union [float , Tensor ]
53- ) -> Tensor :
51+ def create_sparsity_masks_from_abs_threshold (
52+ self , tensors : List [ Tensor ] , threshold : Union [float , Tensor ]
53+ ) -> List [ Tensor ] :
5454 """
55- :param tensor: the tensor to calculate a mask from based on the contained
55+ :param tensors: list of tensors to calculate a masks based on their contained
5656 values
5757 :param threshold: a threshold to determine cutoff for sparsification
58- :return: a mask derived from the tensor and the threshold
58+ :return: list of masks derived from each of the given tensors and the threshold
5959 """
6060 raise NotImplementedError ()
6161
6262 @abstractmethod
63- def create_sparsity_mask (self , tensor : Tensor , sparsity : float ) -> Tensor :
63+ def create_sparsity_masks (
64+ self , tensors : List [Tensor ], sparsity : float , global_sparsity : bool = False
65+ ) -> List [Tensor ]:
6466 """
65- :param tensor: the tensor to calculate a mask from based on the contained values
67+ :param tensors: list of tensors to calculate a masks based on their contained
68+ values
6669 :param sparsity: the desired sparsity to reach within the mask
6770 (decimal fraction of zeros)
68- :return: a mask (0.0 for values that are masked, 1.0 for values that are
69- unmasked) calculated from the tens such that the desired number of zeros
71+ :param global_sparsity: if True, sparsity masks will be created such that the
72+ average sparsity across all given tensors is the target sparsity with the
73+ lowest global values masked. If False, each tensor will be masked to the
74+ target sparsity ranking values within each individual tensor. Default is
75+ False
76+ :return: list of masks (0.0 for values that are masked, 1.0 for values that are
77+ unmasked) calculated from the tensors such that the desired number of zeros
7078 matches the sparsity.
7179 """
7280 raise NotImplementedError ()
@@ -76,54 +84,88 @@ class UnstructuredPruningMaskCreator(PruningMaskCreator):
7684 """
7785 Class for creating unstructured sparsity masks.
7886 Masks will be created using unstructured sparsity by pruning weights ranked
79- by their magnitude.
87+ by their magnitude. Each mask will correspond to the given tensor.
8088 """
8189
82- def create_sparsity_mask_from_abs_threshold (
83- self , tensor : Tensor , threshold : Union [float , Tensor ]
84- ) -> Tensor :
90+ def create_sparsity_masks_from_abs_threshold (
91+ self , tensors : List [ Tensor ] , threshold : Union [float , Tensor ]
92+ ) -> List [ Tensor ] :
8593 """
86- :param tensor: the tensor to calculate a mask from based on the contained values
94+ :param tensors: list of tensors to calculate a masks based on their contained
95+ values
8796 :param threshold: a threshold at which to mask abs(values) if they are
8897 less than it or equal
89- :return: a mask (0.0 for values that are masked, 1.0 for values that are
90- unmasked) calculated from the tens abs(values) <= threshold are masked,
98+ :return: list of masks (0.0 for values that are masked, 1.0 for values that are
99+ unmasked) calculated from the tensors abs(values) <= threshold are masked,
91100 all others are unmasked
92101 """
93- return (torch .abs (tensor ) > threshold ).type (tensor .type ())
102+ return [
103+ (torch .abs (tensor ) > threshold ).type (tensor .type ()) for tensor in tensors
104+ ]
94105
95- def create_sparsity_mask (self , tensor : Tensor , sparsity : float ) -> Tensor :
106+ def create_sparsity_masks (
107+ self ,
108+ tensors : List [Tensor ],
109+ sparsity : float ,
110+ global_sparsity : bool = False ,
111+ ) -> List [Tensor ]:
96112 """
97- :param tensor: the tensor to calculate a mask from based on the contained values
113+ :param tensors: list of tensors to calculate a mask from based on their
114+ contained values
98115 :param sparsity: the desired sparsity to reach within the mask
99116 (decimal fraction of zeros)
100- :return: a mask (0.0 for values that are masked, 1.0 for values that are
101- unmasked) calculated from the tens such that the desired number of zeros
102- matches the sparsity. removes the abs lowest values if there are more zeros
103- in the tens than desired sparsity, then will randomly choose the zeros
104- """
105- threshold = self ._abs_threshold_from_sparsity (tensor , sparsity )
106-
107- if threshold .numel () < 1 :
108- return tensor .new_ones (tensor .shape )
109-
110- if threshold .item () > 0.0 :
111- return (torch .abs (tensor ) > threshold ).type (tensor .type ())
112-
113- # too many zeros so will go over the already given sparsity
114- # and choose which zeros to not keep in mask at random
115- zero_indices = (tensor == 0.0 ).nonzero ()
116- rand_indices = list (range (zero_indices .shape [0 ]))
117- random .shuffle (rand_indices )
118- num_elem = tensor .numel ()
119- num_mask = int (num_elem * sparsity )
120- rand_indices = rand_indices [:num_mask ]
121- rand_indices = tensor .new_tensor (rand_indices , dtype = torch .int64 )
122- zero_indices = zero_indices [rand_indices , :]
123- mask = tensor .new_ones (tensor .shape ).type (tensor .type ())
124- mask [zero_indices .split (1 , dim = 1 )] = 0
125-
126- return mask .type (tensor .type ())
117+ :param global_sparsity: if True, sparsity masks will be created such that the
118+ average sparsity across all given tensors is the target sparsity with the
119+ lowest global values masked. If False, each tensor will be masked to the
120+ target sparsity ranking values within each individual tensor. Default is
121+ False
122+ :return: list of masks (0.0 for values that are masked, 1.0 for values that are
123+ unmasked) calculated from the tensors such that the desired number of zeros
124+ matches the sparsity. If there are more zeros than the desired sparsity,
125+ zeros will be randomly chosen to match the target sparsity
126+ """
127+ if global_sparsity :
128+ # create tensor to make global mask with
129+ original_tensors = tensors
130+ tensors = [self ._flatten_and_stack_tensors (tensors )]
131+ else :
132+ original_tensors = None
133+
134+ masks = []
135+
136+ for tensor in tensors :
137+ threshold = self ._abs_threshold_from_sparsity (tensor , sparsity )
138+
139+ if threshold .numel () < 1 :
140+ masks .append (tensor .new_ones (tensor .shape ))
141+ continue
142+
143+ if threshold .item () > 0.0 :
144+ masks .append ((torch .abs (tensor ) > threshold ).type (tensor .type ()))
145+ continue
146+
147+ # too many zeros so will go over the already given sparsity
148+ # and choose which zeros to not keep in mask at random
149+ zero_indices = (tensor == 0.0 ).nonzero ()
150+ rand_indices = list (range (zero_indices .shape [0 ]))
151+ random .shuffle (rand_indices )
152+ num_elem = tensor .numel ()
153+ num_mask = int (num_elem * sparsity )
154+ rand_indices = rand_indices [:num_mask ]
155+ rand_indices = tensor .new_tensor (rand_indices , dtype = torch .int64 )
156+ zero_indices = zero_indices [rand_indices , :]
157+ mask = tensor .new_ones (tensor .shape ).type (tensor .type ())
158+ mask [zero_indices .split (1 , dim = 1 )] = 0
159+
160+ masks .append (mask .type (tensor .type ()))
161+
162+ if global_sparsity :
163+ # unpack global mask into tensor-masks with the original shapes
164+ global_mask = masks [0 ]
165+ masks = self ._unstack_flattened_tensors (global_mask , original_tensors )
166+ del global_mask
167+
168+ return masks
127169
128170 def _abs_threshold_from_sparsity (self , tensor : Tensor , sparsity : float ) -> Tensor :
129171 """
@@ -146,6 +188,42 @@ def _abs_threshold_from_sparsity(self, tensor: Tensor, sparsity: float) -> Tenso
146188
147189 return sorted_vals [lookup_index ]
148190
191+ def _flatten_and_stack_tensors (self , tensors : List [Tensor ]) -> Tensor :
192+ total_elements = sum (tensor .numel () for tensor in tensors )
193+
194+ global_tensor = (
195+ tensors [0 ].new_zeros (total_elements ).detach ().requires_grad_ (False )
196+ )
197+
198+ curr_element = 0
199+ for idx , tensor in enumerate (tensors ):
200+ global_tensor [
201+ curr_element : curr_element + tensor .numel ()
202+ ] = tensor .reshape (- 1 )
203+ curr_element += tensor .numel ()
204+
205+ return global_tensor
206+
207+ def _unstack_flattened_tensors (
208+ self , stacked_tensor : Tensor , original_tensors : List [Tensor ]
209+ ) -> List [Tensor ]:
210+ unstacked_tensors = []
211+ global_idx = 0
212+ for tensor in original_tensors :
213+ # unpack global tensor into masks matching original tensor shapes
214+ unstacked_tensor = (
215+ tensor .new_empty (tensor .numel ()).detach ().requires_grad_ (False )
216+ )
217+ unstacked_tensor .copy_ (
218+ stacked_tensor [global_idx : global_idx + tensor .numel ()]
219+ ).type (tensor .type ())
220+ unstacked_tensor = unstacked_tensor .reshape (tensor .shape )
221+
222+ unstacked_tensors .append (unstacked_tensor )
223+ global_idx += tensor .numel ()
224+
225+ return unstacked_tensors
226+
149227 def __str__ (self ):
150228 return "unstructured"
151229
@@ -213,44 +291,71 @@ def _map_mask_to_tensor(
213291 """
214292 raise NotImplementedError ()
215293
216- def create_sparsity_mask_from_tensor (self , tensor : Tensor ) -> Tensor :
294+ def create_sparsity_masks_from_tensor (self , tensors : List [ Tensor ] ) -> List [ Tensor ] :
217295 """
218- :param tensor: the tensor to calculate a mask based on its values
219- :return: a mask derived from the values of tensor grouped by the group_tensor
296+ :param tensors: list of tensors to calculate masks based on their values
297+ :return: list of masks derived from the values of the tensors grouped by
298+ the group_tensor function.
220299 """
221- grouped_tensor = self .group_tensor (tensor )
222- grouped_mask = super ().create_sparsity_mask_from_tensor (grouped_tensor )
223- return self ._map_mask_to_tensor (grouped_mask , tensor .shape )
300+ masks = []
301+ for tensor in tensors :
302+ grouped_tensor = self .group_tensor (tensor )
303+ grouped_mask = super ().create_sparsity_masks_from_tensor ([grouped_tensor ])[
304+ 0
305+ ]
306+ masks .append (self ._map_mask_to_tensor (grouped_mask , tensor .shape ))
307+ return masks
224308
225- def create_sparsity_mask_from_abs_threshold (
226- self , tensor : Tensor , threshold : Union [float , Tensor ]
227- ) -> Tensor :
309+ def create_sparsity_masks_from_abs_threshold (
310+ self , tensors : List [ Tensor ] , threshold : Union [float , Tensor ]
311+ ) -> List [ Tensor ] :
228312 """
229- :param tensor: the tensor to calculate a mask from based on the contained
313+ :param tensors: list of tensors to calculate masks from based on their contained
230314 values
231315 :param threshold: a threshold of group_tensor values to determine cutoff
232316 for sparsification
233- :return: a mask derived from the tensor and the grouped threshold
234- """
235- grouped_tensor = self .group_tensor (tensor )
236- grouped_mask = super ().create_sparsity_mask_from_abs_threshold (
237- grouped_tensor , threshold
238- )
239- return self ._map_mask_to_tensor (grouped_mask , tensor .shape )
240-
241- def create_sparsity_mask (self , tensor : Tensor , sparsity : float ) -> Tensor :
317+ :return: list of masks derived from the tensors and the grouped threshold
318+ """
319+ masks = []
320+ for tensor in tensors :
321+ grouped_tensor = self .group_tensor (tensor )
322+ grouped_mask = super ().create_sparsity_masks_from_abs_threshold (
323+ [grouped_tensor ], threshold
324+ )[0 ]
325+ masks .append (self ._map_mask_to_tensor (grouped_mask , tensor .shape ))
326+ return masks
327+
328+ def create_sparsity_masks (
329+ self ,
330+ tensors : List [Tensor ],
331+ sparsity : float ,
332+ global_sparsity : bool = False ,
333+ ) -> List [Tensor ]:
242334 """
243- :param tensor: the tensor to calculate a mask from based on the contained values
335+ :param tensors: list of tensors to calculate masks from based on their contained
336+ values
244337 :param sparsity: the desired sparsity to reach within the mask
245338 (decimal fraction of zeros)
246- :return: a mask (0.0 for values that are masked, 1.0 for values that are
247- unmasked) calculated from the tens such that the desired number of zeros
248- matches the sparsity and all values mapped to the same group have the
249- same value.
250- """
251- grouped_tensor = self .group_tensor (tensor )
252- grouped_mask = super ().create_sparsity_mask (grouped_tensor , sparsity )
253- return self ._map_mask_to_tensor (grouped_mask , tensor .shape )
339+ :param global_sparsity: if True, sparsity masks will be created such that the
340+ average sparsity across all given tensors is the target sparsity with the
341+ lowest global values masked. If False, each tensor will be masked to the
342+ target sparsity ranking values within each individual tensor. Default is
343+ False
344+ :return: list of masks (0.0 for values that are masked, 1.0 for values that are
345+ unmasked) calculated from the tensors such that the desired number of zeros
346+ matches the sparsity and all values mapped to the same group have the same
347+ value
348+ """
349+ grouped_tensors = [self .group_tensor (tensor ) for tensor in tensors ]
350+ grouped_masks = super ().create_sparsity_masks (
351+ grouped_tensors , sparsity , global_sparsity
352+ )
353+ masks = [
354+ self ._map_mask_to_tensor (grouped_mask , tensor .shape )
355+ for grouped_mask , tensor in zip (grouped_masks , tensors )
356+ ]
357+
358+ return masks
254359
255360
256361class DimensionSparsityMaskCreator (GroupedPruningMaskCreator ):
@@ -286,6 +391,31 @@ def __init__(
286391 )
287392 self ._dim = dim # List[int]
288393
394+ def create_sparsity_masks (
395+ self ,
396+ tensors : List [Tensor ],
397+ sparsity : float ,
398+ global_sparsity : bool = False ,
399+ ) -> List [Tensor ]:
400+ """
401+ :param tensors: list of tensors to calculate masks from based on their contained
402+ values
403+ :param sparsity: the desired sparsity to reach within the mask
404+ (decimal fraction of zeros)
405+ :param global_sparsity: do not set True, unsupported for
406+ DimensionSparsityMaskCreator
407+ :return: list of masks (0.0 for values that are masked, 1.0 for values that are
408+ unmasked) calculated from the tensors such that the desired number of zeros
409+ matches the sparsity and all values mapped to the same group have the same
410+ value
411+ """
412+ if global_sparsity :
413+ # global sparsity unsupported because channel dims may vary across layers
414+ raise ValueError (
415+ "global_sparsity not supported for DimensionSparsityMaskCreator"
416+ )
417+ return super ().create_sparsity_masks (tensors , sparsity , global_sparsity = False )
418+
289419 def group_tensor (self , tensor : Tensor ) -> Tensor :
290420 """
291421 :param tensor: The tensor to transform
0 commit comments