Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 1cfc4b8

Browse files
authored
separate layer-wise and global WoodFisher pruning modifiers (#294)
* define final sparsity by parameter in torch pruning modifiers * separate layer-wise and global WoodFisher pruning modifiers * get_num_grads support for multiple applied sparsities * fixing tests after rebase
1 parent 48f7c5e commit 1cfc4b8

File tree

3 files changed

+208
-10
lines changed

3 files changed

+208
-10
lines changed

src/sparseml/pytorch/optim/modifier_pruning.py

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"GMPruningModifier",
6464
"MagnitudePruningModifier",
6565
"MFACPruningModifier",
66+
"MFACGlobalPruningModifier",
6667
"MovementPruningModifier",
6768
"GlobalMagnitudePruningModifier",
6869
"LayerPruningModifier",
@@ -825,6 +826,7 @@ def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int)
825826
self._pre_step_completed = True
826827

827828
if started:
829+
# set the mask tensors according to the new sparsity
828830
if isinstance(self._final_sparsity, List):
829831
self._applied_sparsity = [
830832
interpolate(
@@ -1294,6 +1296,8 @@ class MFACPruningModifier(GMPruningModifier):
12941296
:param mask_type: String to define type of sparsity (options: ['unstructured',
12951297
'channel', 'filter']), List to define block shape of a parameters in and out
12961298
channels, or a SparsityMaskCreator object. default is 'unstructured'
1299+
:param global_sparsity: set True to enable global pruning. if False, pruning will
1300+
be layer-wise. Default is False
12971301
:param mfac_options: Dictionary of key words specifying arguments for the M-FAC
12981302
pruning run. num_grads controls the number of gradient samples that are kept,
12991303
fisher_block_size specifies the block size to break the M-FAC computation into
@@ -1316,6 +1320,7 @@ def __init__(
13161320
phased: bool = False,
13171321
log_types: Union[str, List[str]] = ALL_TOKEN,
13181322
mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
1323+
global_sparsity: bool = False,
13191324
mfac_options: Dict[str, Any] = None,
13201325
):
13211326
super().__init__(
@@ -1330,18 +1335,11 @@ def __init__(
13301335
phased=phased,
13311336
log_types=log_types,
13321337
mask_type=mask_type,
1333-
global_sparsity=True,
1338+
global_sparsity=global_sparsity,
13341339
score_type="MFAC",
13351340
)
13361341
self._mfac_options = mfac_options or {}
13371342

1338-
@ModifierProp(serializable=False)
1339-
def global_sparsity(self) -> bool:
1340-
"""
1341-
:return: True if global pruning is enabled, False otherwise
1342-
"""
1343-
return self._global_sparsity
1344-
13451343
@ModifierProp(serializable=False)
13461344
def score_type(self) -> str:
13471345
"""
@@ -1374,6 +1372,112 @@ def _create_pruning_mask(
13741372
)
13751373

13761374

1375+
@PyTorchModifierYAML()
1376+
class MFACGlobalPruningModifier(MFACPruningModifier):
1377+
"""
1378+
Gradually applies kernel sparsity to a given parameter or parameters from
1379+
init_sparsity until final_sparsity is reached over a given amount of time
1380+
and applied with an interpolated function for each step taken.
1381+
1382+
Uses the Matrix-Free Approxmiate Curvature (M-FAC) algorithm for solving
1383+
for optimal pruning updates by estimating the inverse Hessian matrix to the
1384+
loss over time under the Optimal Brain Surgeon (OBS) framework.
1385+
A link to the paper will be included here in an upcoming update.
1386+
1387+
| Sample yaml:
1388+
| !MFACPruningModifier
1389+
| init_sparsity: 0.05
1390+
| final_sparsity: 0.8
1391+
| start_epoch: 0.0
1392+
| end_epoch: 10.0
1393+
| update_frequency: 1.0
1394+
| params: ["re:.*weight"]
1395+
| leave_enabled: True
1396+
| inter_func: cubic
1397+
| log_types: __ALL__
1398+
| mask_type: unstructured
1399+
| mfac_options:
1400+
| num_grads: {0.0: 64, 0.5: 128, 0.75: 256, 0.85: 512}
1401+
| fisher_block_size: 10000
1402+
| available_gpus: ["cuda:0"]
1403+
1404+
:param init_sparsity: the initial sparsity for the param to start with at
1405+
start_epoch
1406+
:param final_sparsity: the final sparsity for the param to end with at end_epoch.
1407+
Can also be a Dict of final sparsity values to a list of parameters to apply
1408+
them to. If given a Dict, then params must be set to [] and the params to
1409+
be pruned will be read from the final_sparsity Dict
1410+
:param start_epoch: The epoch to start the modifier at
1411+
:param end_epoch: The epoch to end the modifier at
1412+
:param update_frequency: The number of epochs or fraction of epochs to update at
1413+
between start and end
1414+
:param params: A list of full parameter names or regex patterns of names to apply
1415+
pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
1416+
will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
1417+
and Linear layers' weights. If a sparsity to param mapping is defined by
1418+
final_sparsity, then params should be set to []
1419+
:param leave_enabled: True to continue masking the weights after end_epoch,
1420+
False to stop masking. Should be set to False if exporting the result
1421+
immediately after or doing some other prune
1422+
:param inter_func: the type of interpolation function to use:
1423+
[linear, cubic, inverse_cubic]
1424+
:param phased: True to enable a phased approach where pruning will
1425+
turn on and off with the update_frequency. Starts with pruning on
1426+
at start_epoch, off at start_epoch + update_frequency, and so on.
1427+
:param log_types: The loggers to allow the learning rate to be logged to,
1428+
default is __ALL__
1429+
:param mask_type: String to define type of sparsity (options: ['unstructured',
1430+
'channel', 'filter']), List to define block shape of a parameters in and out
1431+
channels, or a SparsityMaskCreator object. default is 'unstructured'
1432+
:param mfac_options: Dictionary of key words specifying arguments for the M-FAC
1433+
pruning run. num_grads controls the number of gradient samples that are kept,
1434+
fisher_block_size specifies the block size to break the M-FAC computation into
1435+
(default is 2000, use None for no blocks), available_gpus specifies a list
1436+
of device ids that can be used for computation. For a full list of options,
1437+
see the MFACOptions dataclass documentation. Default configuration uses
1438+
CPU for computation without blocked computation
1439+
"""
1440+
1441+
def __init__(
1442+
self,
1443+
init_sparsity: float,
1444+
final_sparsity: Union[float, Dict[float, List[str]]],
1445+
start_epoch: float,
1446+
end_epoch: float,
1447+
update_frequency: float,
1448+
params: Union[str, List[str]],
1449+
leave_enabled: bool = True,
1450+
inter_func: str = "cubic",
1451+
phased: bool = False,
1452+
log_types: Union[str, List[str]] = ALL_TOKEN,
1453+
mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
1454+
mfac_options: Dict[str, Any] = None,
1455+
):
1456+
super().__init__(
1457+
init_sparsity=init_sparsity,
1458+
final_sparsity=final_sparsity,
1459+
start_epoch=start_epoch,
1460+
end_epoch=end_epoch,
1461+
update_frequency=update_frequency,
1462+
params=params,
1463+
leave_enabled=leave_enabled,
1464+
inter_func=inter_func,
1465+
phased=phased,
1466+
log_types=log_types,
1467+
mask_type=mask_type,
1468+
global_sparsity=True,
1469+
mfac_options=mfac_options,
1470+
)
1471+
self._mfac_options = mfac_options or {}
1472+
1473+
@ModifierProp(serializable=False)
1474+
def global_sparsity(self) -> bool:
1475+
"""
1476+
:return: True if global pruning is enabled, False otherwise
1477+
"""
1478+
return self._global_sparsity
1479+
1480+
13771481
@PyTorchModifierYAML()
13781482
class LayerPruningModifier(ScheduledUpdateModifier):
13791483
"""

src/sparseml/pytorch/utils/mfac_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,11 @@ class MFACOptions:
6868
num_pages: int = 1 # break computation into pages when block size is None
6969
available_gpus: List[str] = field(default_factory=list)
7070

71-
def get_num_grads_for_sparsity(self, sparsity: float) -> int:
71+
def get_num_grads_for_sparsity(self, sparsity: Union[float, List[float]]) -> int:
7272
if isinstance(self.num_grads, int):
7373
return self.num_grads
74+
if isinstance(sparsity, List):
75+
sparsity = sum(sparsity) / len(sparsity)
7476

7577
sparsity_thresholds = list(sorted(self.num_grads, key=lambda key: float(key)))
7678
if 0.0 not in sparsity_thresholds:

tests/sparseml/pytorch/optim/test_modifier_pruning.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
GMPruningModifier,
2727
LayerPruningModifier,
2828
MagnitudePruningModifier,
29+
MFACGlobalPruningModifier,
2930
MFACPruningModifier,
3031
MovementPruningModifier,
3132
load_mask_creator,
@@ -357,6 +358,7 @@ def test_lifecycle(
357358
applied_sparsities, last_sparsities
358359
)
359360
)
361+
360362
last_sparsities = applied_sparsities
361363

362364
_ = model(test_batch) # check forward pass
@@ -813,6 +815,7 @@ def test_mfac_pruning_yaml():
813815
params = "__ALL_PRUNABLE__"
814816
inter_func = "cubic"
815817
mask_type = "unstructured"
818+
global_sparsity = False
816819
mfac_options = {"num_grads": 64, "available_gpus": ["cuda:0", "cuda:1"]}
817820
yaml_str = f"""
818821
!MFACPruningModifier
@@ -824,9 +827,10 @@ def test_mfac_pruning_yaml():
824827
params: {params}
825828
inter_func: {inter_func}
826829
mask_type: {mask_type}
830+
global_sparsity: {global_sparsity}
827831
mfac_options: {mfac_options}
828832
"""
829-
yaml_modifier = MFACPruningModifier.load_obj(yaml_str) # type: MFACPruningModifier
833+
yaml_modifier = MFACPruningModifier.load_obj(yaml_str)
830834
serialized_modifier = MFACPruningModifier.load_obj(
831835
str(yaml_modifier)
832836
) # type: MFACPruningModifier
@@ -839,6 +843,7 @@ def test_mfac_pruning_yaml():
839843
params=params,
840844
inter_func=inter_func,
841845
mask_type=mask_type,
846+
global_sparsity=global_sparsity,
842847
mfac_options=mfac_options,
843848
)
844849

@@ -879,6 +884,93 @@ def test_mfac_pruning_yaml():
879884
== str(serialized_modifier.mask_type)
880885
== str(obj_modifier.mask_type)
881886
)
887+
assert (
888+
str(yaml_modifier.global_sparsity)
889+
== str(serialized_modifier.global_sparsity)
890+
== str(obj_modifier.global_sparsity)
891+
)
892+
assert (
893+
yaml_modifier.mfac_options
894+
== serialized_modifier.mfac_options
895+
== obj_modifier.mfac_options
896+
)
897+
898+
899+
def test_global_mfac_pruning_yaml():
900+
init_sparsity = 0.05
901+
final_sparsity = 0.8
902+
start_epoch = 5.0
903+
end_epoch = 15.0
904+
update_frequency = 1.0
905+
params = "__ALL_PRUNABLE__"
906+
inter_func = "cubic"
907+
mask_type = "unstructured"
908+
mfac_options = {"num_grads": 64, "available_gpus": ["cuda:0", "cuda:1"]}
909+
yaml_str = f"""
910+
!MFACGlobalPruningModifier
911+
init_sparsity: {init_sparsity}
912+
final_sparsity: {final_sparsity}
913+
start_epoch: {start_epoch}
914+
end_epoch: {end_epoch}
915+
update_frequency: {update_frequency}
916+
params: {params}
917+
inter_func: {inter_func}
918+
mask_type: {mask_type}
919+
mfac_options: {mfac_options}
920+
"""
921+
yaml_modifier = MFACGlobalPruningModifier.load_obj(yaml_str)
922+
serialized_modifier = MFACGlobalPruningModifier.load_obj(
923+
str(yaml_modifier)
924+
) # type: MFACGlobalPruningModifier
925+
obj_modifier = MFACGlobalPruningModifier(
926+
init_sparsity=init_sparsity,
927+
final_sparsity=final_sparsity,
928+
start_epoch=start_epoch,
929+
end_epoch=end_epoch,
930+
update_frequency=update_frequency,
931+
params=params,
932+
inter_func=inter_func,
933+
mask_type=mask_type,
934+
mfac_options=mfac_options,
935+
)
936+
937+
assert isinstance(yaml_modifier, MFACGlobalPruningModifier)
938+
assert (
939+
yaml_modifier.init_sparsity
940+
== serialized_modifier.init_sparsity
941+
== obj_modifier.init_sparsity
942+
)
943+
assert (
944+
yaml_modifier.final_sparsity
945+
== serialized_modifier.final_sparsity
946+
== obj_modifier.final_sparsity
947+
)
948+
assert (
949+
yaml_modifier.start_epoch
950+
== serialized_modifier.start_epoch
951+
== obj_modifier.start_epoch
952+
)
953+
assert (
954+
yaml_modifier.end_epoch
955+
== serialized_modifier.end_epoch
956+
== obj_modifier.end_epoch
957+
)
958+
assert (
959+
yaml_modifier.update_frequency
960+
== serialized_modifier.update_frequency
961+
== obj_modifier.update_frequency
962+
)
963+
assert yaml_modifier.params == serialized_modifier.params == obj_modifier.params
964+
assert (
965+
yaml_modifier.inter_func
966+
== serialized_modifier.inter_func
967+
== obj_modifier.inter_func
968+
)
969+
assert (
970+
str(yaml_modifier.mask_type)
971+
== str(serialized_modifier.mask_type)
972+
== str(obj_modifier.mask_type)
973+
)
882974
assert (
883975
yaml_modifier.mfac_options
884976
== serialized_modifier.mfac_options

0 commit comments

Comments
 (0)