3939 ScheduledUpdateModifier ,
4040)
4141from sparseml .pytorch .utils import (
42+ GradSampler ,
4243 MFACOptions ,
4344 NamedLayerParam ,
4445 get_layer ,
@@ -255,7 +256,7 @@ def initialize(
255256 )
256257 )
257258
258- self ._check_mask_update (module , epoch , steps_per_epoch = 1 )
259+ self ._check_mask_update (module , epoch , steps_per_epoch = 1 , ** kwargs )
259260
260261 def finalize (
261262 self , module : Optional [Module ] = None , reset_loggers : bool = True , ** kwargs
@@ -332,7 +333,9 @@ def optimizer_post_step(
332333 self ._module_masks .apply ()
333334
334335 @abstractmethod
335- def _check_mask_update (self , module : Module , epoch : float , steps_per_epoch : int ):
336+ def _check_mask_update (
337+ self , module : Module , epoch : float , steps_per_epoch : int , ** kwargs
338+ ):
336339 raise NotImplementedError ()
337340
338341 def _should_log (
@@ -444,7 +447,9 @@ def __init__(
444447 log_types = log_types ,
445448 )
446449
447- def _check_mask_update (self , module : Module , epoch : float , steps_per_epoch : int ):
450+ def _check_mask_update (
451+ self , module : Module , epoch : float , steps_per_epoch : int , ** kwargs
452+ ):
448453 if self .start_pending (epoch , steps_per_epoch ):
449454 self ._module_masks .set_param_masks_from_weights ()
450455 self ._module_masks .enabled = True
@@ -807,7 +812,9 @@ def validate(self):
807812 ).format (self ._inter_func , INTERPOLATION_FUNCS , self .__class__ .__name__ )
808813 )
809814
810- def _check_mask_update (self , module : Module , epoch : float , steps_per_epoch : int ):
815+ def _check_mask_update (
816+ self , module : Module , epoch : float , steps_per_epoch : int , ** kwargs
817+ ):
811818 """
812819 Check for updating the pruning masks at the given epoch.
813820 Called from both initialize and update.
@@ -822,8 +829,9 @@ def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int)
822829 self ._module_masks .enabled = True
823830 started = True
824831
825- self ._module_masks .pre_optim_step_update ()
826- self ._pre_step_completed = True
832+ if not self ._pre_step_completed :
833+ self ._module_masks .pre_optim_step_update ()
834+ self ._pre_step_completed = True
827835
828836 if started :
829837 # set the mask tensors according to the new sparsity
@@ -1359,6 +1367,16 @@ def mfac_options(self) -> Dict[str, Any]:
13591367 """
13601368 return self ._mfac_options
13611369
1370+ def _check_mask_update (
1371+ self , module : Module , epoch : float , steps_per_epoch : int , ** kwargs
1372+ ):
1373+ # create grads for pne-shot pruning
1374+ if "grad_sampler" in kwargs :
1375+ self ._collect_grad_samples (module , kwargs ["grad_sampler" ])
1376+ self ._pre_step_completed = True
1377+
1378+ super ()._check_mask_update (module , epoch , steps_per_epoch , ** kwargs )
1379+
13621380 def _create_pruning_mask (
13631381 self , layers : List [Module ], layer_names : List [str ], param_names : List [str ]
13641382 ) -> ModuleParamPruningMask :
@@ -1371,6 +1389,23 @@ def _create_pruning_mask(
13711389 score_type = MFACOptions (** self ._mfac_options ),
13721390 )
13731391
1392+ def _collect_grad_samples (
1393+ self ,
1394+ module : Module ,
1395+ grad_sampler : GradSampler ,
1396+ ):
1397+ if not isinstance (grad_sampler , GradSampler ):
1398+ raise ValueError (
1399+ "One-shot MFAC pruning requires a GradSampler object given by the "
1400+ f"grad_sampler kwarg. Given an object of type { type (grad_sampler )} "
1401+ )
1402+ num_grads = MFACOptions (** self ._mfac_options ).get_num_grads_for_sparsity (
1403+ self ._applied_sparsity or 0.0
1404+ )
1405+
1406+ for _ in grad_sampler .iter_module_backwards (module , num_grads ):
1407+ self ._module_masks .pre_optim_step_update ()
1408+
13741409
13751410@PyTorchModifierYAML ()
13761411class MFACGlobalPruningModifier (MFACPruningModifier ):
0 commit comments