6363 "GMPruningModifier" ,
6464 "MagnitudePruningModifier" ,
6565 "MFACPruningModifier" ,
66+ "MFACGlobalPruningModifier" ,
6667 "MovementPruningModifier" ,
6768 "GlobalMagnitudePruningModifier" ,
6869 "LayerPruningModifier" ,
@@ -825,6 +826,7 @@ def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int)
825826 self ._pre_step_completed = True
826827
827828 if started :
829+ # set the mask tensors according to the new sparsity
828830 if isinstance (self ._final_sparsity , List ):
829831 self ._applied_sparsity = [
830832 interpolate (
@@ -1294,6 +1296,8 @@ class MFACPruningModifier(GMPruningModifier):
12941296 :param mask_type: String to define type of sparsity (options: ['unstructured',
12951297 'channel', 'filter']), List to define block shape of a parameters in and out
12961298 channels, or a SparsityMaskCreator object. default is 'unstructured'
1299+ :param global_sparsity: set True to enable global pruning. if False, pruning will
1300+ be layer-wise. Default is False
12971301 :param mfac_options: Dictionary of key words specifying arguments for the M-FAC
12981302 pruning run. num_grads controls the number of gradient samples that are kept,
12991303 fisher_block_size specifies the block size to break the M-FAC computation into
@@ -1316,6 +1320,7 @@ def __init__(
13161320 phased : bool = False ,
13171321 log_types : Union [str , List [str ]] = ALL_TOKEN ,
13181322 mask_type : Union [str , List [int ], PruningMaskCreator ] = "unstructured" ,
1323+ global_sparsity : bool = False ,
13191324 mfac_options : Dict [str , Any ] = None ,
13201325 ):
13211326 super ().__init__ (
@@ -1330,18 +1335,11 @@ def __init__(
13301335 phased = phased ,
13311336 log_types = log_types ,
13321337 mask_type = mask_type ,
1333- global_sparsity = True ,
1338+ global_sparsity = global_sparsity ,
13341339 score_type = "MFAC" ,
13351340 )
13361341 self ._mfac_options = mfac_options or {}
13371342
1338- @ModifierProp (serializable = False )
1339- def global_sparsity (self ) -> bool :
1340- """
1341- :return: True if global pruning is enabled, False otherwise
1342- """
1343- return self ._global_sparsity
1344-
13451343 @ModifierProp (serializable = False )
13461344 def score_type (self ) -> str :
13471345 """
@@ -1374,6 +1372,112 @@ def _create_pruning_mask(
13741372 )
13751373
13761374
1375+ @PyTorchModifierYAML ()
1376+ class MFACGlobalPruningModifier (MFACPruningModifier ):
1377+ """
1378+ Gradually applies kernel sparsity to a given parameter or parameters from
1379+ init_sparsity until final_sparsity is reached over a given amount of time
1380+ and applied with an interpolated function for each step taken.
1381+
1382+ Uses the Matrix-Free Approxmiate Curvature (M-FAC) algorithm for solving
1383+ for optimal pruning updates by estimating the inverse Hessian matrix to the
1384+ loss over time under the Optimal Brain Surgeon (OBS) framework.
1385+ A link to the paper will be included here in an upcoming update.
1386+
1387+ | Sample yaml:
1388+ | !MFACPruningModifier
1389+ | init_sparsity: 0.05
1390+ | final_sparsity: 0.8
1391+ | start_epoch: 0.0
1392+ | end_epoch: 10.0
1393+ | update_frequency: 1.0
1394+ | params: ["re:.*weight"]
1395+ | leave_enabled: True
1396+ | inter_func: cubic
1397+ | log_types: __ALL__
1398+ | mask_type: unstructured
1399+ | mfac_options:
1400+ | num_grads: {0.0: 64, 0.5: 128, 0.75: 256, 0.85: 512}
1401+ | fisher_block_size: 10000
1402+ | available_gpus: ["cuda:0"]
1403+
1404+ :param init_sparsity: the initial sparsity for the param to start with at
1405+ start_epoch
1406+ :param final_sparsity: the final sparsity for the param to end with at end_epoch.
1407+ Can also be a Dict of final sparsity values to a list of parameters to apply
1408+ them to. If given a Dict, then params must be set to [] and the params to
1409+ be pruned will be read from the final_sparsity Dict
1410+ :param start_epoch: The epoch to start the modifier at
1411+ :param end_epoch: The epoch to end the modifier at
1412+ :param update_frequency: The number of epochs or fraction of epochs to update at
1413+ between start and end
1414+ :param params: A list of full parameter names or regex patterns of names to apply
1415+ pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
1416+ will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
1417+ and Linear layers' weights. If a sparsity to param mapping is defined by
1418+ final_sparsity, then params should be set to []
1419+ :param leave_enabled: True to continue masking the weights after end_epoch,
1420+ False to stop masking. Should be set to False if exporting the result
1421+ immediately after or doing some other prune
1422+ :param inter_func: the type of interpolation function to use:
1423+ [linear, cubic, inverse_cubic]
1424+ :param phased: True to enable a phased approach where pruning will
1425+ turn on and off with the update_frequency. Starts with pruning on
1426+ at start_epoch, off at start_epoch + update_frequency, and so on.
1427+ :param log_types: The loggers to allow the learning rate to be logged to,
1428+ default is __ALL__
1429+ :param mask_type: String to define type of sparsity (options: ['unstructured',
1430+ 'channel', 'filter']), List to define block shape of a parameters in and out
1431+ channels, or a SparsityMaskCreator object. default is 'unstructured'
1432+ :param mfac_options: Dictionary of key words specifying arguments for the M-FAC
1433+ pruning run. num_grads controls the number of gradient samples that are kept,
1434+ fisher_block_size specifies the block size to break the M-FAC computation into
1435+ (default is 2000, use None for no blocks), available_gpus specifies a list
1436+ of device ids that can be used for computation. For a full list of options,
1437+ see the MFACOptions dataclass documentation. Default configuration uses
1438+ CPU for computation without blocked computation
1439+ """
1440+
1441+ def __init__ (
1442+ self ,
1443+ init_sparsity : float ,
1444+ final_sparsity : Union [float , Dict [float , List [str ]]],
1445+ start_epoch : float ,
1446+ end_epoch : float ,
1447+ update_frequency : float ,
1448+ params : Union [str , List [str ]],
1449+ leave_enabled : bool = True ,
1450+ inter_func : str = "cubic" ,
1451+ phased : bool = False ,
1452+ log_types : Union [str , List [str ]] = ALL_TOKEN ,
1453+ mask_type : Union [str , List [int ], PruningMaskCreator ] = "unstructured" ,
1454+ mfac_options : Dict [str , Any ] = None ,
1455+ ):
1456+ super ().__init__ (
1457+ init_sparsity = init_sparsity ,
1458+ final_sparsity = final_sparsity ,
1459+ start_epoch = start_epoch ,
1460+ end_epoch = end_epoch ,
1461+ update_frequency = update_frequency ,
1462+ params = params ,
1463+ leave_enabled = leave_enabled ,
1464+ inter_func = inter_func ,
1465+ phased = phased ,
1466+ log_types = log_types ,
1467+ mask_type = mask_type ,
1468+ global_sparsity = True ,
1469+ mfac_options = mfac_options ,
1470+ )
1471+ self ._mfac_options = mfac_options or {}
1472+
1473+ @ModifierProp (serializable = False )
1474+ def global_sparsity (self ) -> bool :
1475+ """
1476+ :return: True if global pruning is enabled, False otherwise
1477+ """
1478+ return self ._global_sparsity
1479+
1480+
13771481@PyTorchModifierYAML ()
13781482class LayerPruningModifier (ScheduledUpdateModifier ):
13791483 """
0 commit comments