Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit c2903b4

Browse files
authored
StructuredPruningModifier Refactor (#560)
* StructuredPruningModifier Refactor * migrate packages * style * nit fix * rebase test fixes * rebase fixes + mask type for mfac * Fix: test MFAC mask type * Update src/sparseml/pytorch/sparsification/pruning/modifier_pruning_structured.py * style
1 parent 8db0994 commit c2903b4

18 files changed

+831
-298
lines changed

src/sparseml/pytorch/optim/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from .modifier_pruning import *
3737
from .modifier_quantization import *
3838
from .modifier_regularizer import *
39-
from .modifier_thinning import *
4039
from .optimizer import *
4140
from .sensitivity_as import *
4241
from .sensitivity_lr import *

src/sparseml/pytorch/optim/modifier_pruning.py

Lines changed: 0 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from sparseml.pytorch.nn import Identity
3131
from sparseml.pytorch.optim.analyzer_pruning import ModulePruningAnalyzer
3232
from sparseml.pytorch.optim.mask_creator_pruning import (
33-
DimensionSparsityMaskCreator,
3433
PruningMaskCreator,
3534
load_mask_creator,
3635
)
@@ -65,7 +64,6 @@
6564
"LegacyGMPruningModifier",
6665
"MovementPruningModifier",
6766
"LayerPruningModifier",
68-
"StructuredPruningModifier",
6967
]
7068

7169

@@ -874,191 +872,15 @@ def score_type(self) -> str:
874872
return self._score_type
875873

876874

877-
@PyTorchModifierYAML()
878-
class StructuredPruningModifier(_GMPruningModifier):
879-
"""
880-
Gradually applies structured kernel sparsity to a given parameter or parameters
881-
from init_sparsity until final_sparsity is reached over a given amount of time
882-
and applied with an interpolated function for each step taken. Channel and filter
883-
pruning supported.
884-
885-
A param_group_dependency_map must be provided that maps
886-
groups of prunable parameter names that should have their dimensions pruned
887-
together to a list of module parameter names that should be updated accordingly
888-
when those parameters are pruned.
889-
890-
| Sample yaml:
891-
| !StructuredPruningModifier
892-
| param_groups: [
893-
| ["param.1.name","param.2.name"], ["param.3.name", "param.4.name"]
894-
| ]
895-
| init_sparsity: 0.05
896-
| final_sparsity: 0.8
897-
| start_epoch: 0.0
898-
| end_epoch: 10.0
899-
| update_frequency: 1.0
900-
| params: __ALL_PRUNABLE__
901-
| leave_enabled: True
902-
| inter_func: cubic
903-
| log_types: __ALL__
904-
| mask_type: filter
905-
| score_type: magnitude
906-
907-
:param init_sparsity: the initial sparsity for the param to start with at
908-
start_epoch
909-
:param final_sparsity: the final sparsity for the param to end with at end_epoch
910-
:param start_epoch: The epoch to start the modifier at
911-
:param end_epoch: The epoch to end the modifier at
912-
:param update_frequency: The number of epochs or fraction of epochs to update at
913-
between start and end
914-
:param param_groups: list of list of parameter names that should be pruned together
915-
during structured pruning so that their same indices may be removed. May be
916-
useful for structures such as residual blocks or grouped convolutions. Can be
917-
generated from an onnx export of the target module with
918-
sparseml.onnx.optim.get_param_structured_pruning_group_dependencies by
919-
splitting its comma separated keys into lists.
920-
i.e. [["param.1.name","param.2.name"], ["param.3.name", "param.4.name"]]
921-
:param params: A list of full parameter names or regex patterns of names to apply
922-
pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
923-
will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
924-
and Linear layers' weights. Defualt is __ALL_PRUNABLE__
925-
:param leave_enabled: True to continue masking the weights after end_epoch,
926-
False to stop masking. Should be set to False if exporting the result
927-
immediately after or doing some other prune
928-
:param inter_func: the type of interpolation function to use:
929-
[linear, cubic, inverse_cubic]
930-
:param phased: True to enable a phased approach where pruning will
931-
turn on and off with the update_frequency. Starts with pruning on
932-
at start_epoch, off at start_epoch + update_frequency, and so on.
933-
:param log_types: The loggers to allow the learning rate to be logged to,
934-
default is __ALL__
935-
:param mask_type: String to define type of structured sparsity (options: [
936-
'channel', 'filter']), or a DimensionSparsityMaskCreator object.
937-
default is 'filter'
938-
:param score_type: Method used to score parameters for masking, i.e.
939-
'magnitude', 'movement'. Default is 'magnitude'
940-
"""
941-
942-
def __init__(
943-
self,
944-
init_sparsity: float,
945-
final_sparsity: float,
946-
start_epoch: float,
947-
end_epoch: float,
948-
update_frequency: float,
949-
param_groups: List[List[str]] = None,
950-
params: Union[str, List[str]] = ALL_PRUNABLE_TOKEN,
951-
leave_enabled: bool = True,
952-
inter_func: str = "cubic",
953-
phased: bool = False,
954-
log_types: Union[str, List[str]] = ALL_TOKEN,
955-
mask_type: Union[str, DimensionSparsityMaskCreator] = "filter",
956-
score_type: str = "magnitude",
957-
):
958-
if not isinstance(mask_type, DimensionSparsityMaskCreator) and (
959-
mask_type not in ["channel", "filter"]
960-
):
961-
raise ValueError(
962-
"StructuredPruningModifier mask_type must be a "
963-
"DimensionSparsityMaskCreator or designate 'channel' or 'filter' "
964-
f"found {mask_type}"
965-
)
966-
super().__init__(
967-
init_sparsity=init_sparsity,
968-
final_sparsity=final_sparsity,
969-
start_epoch=start_epoch,
970-
end_epoch=end_epoch,
971-
update_frequency=update_frequency,
972-
params=params,
973-
leave_enabled=leave_enabled,
974-
inter_func=inter_func,
975-
phased=phased,
976-
log_types=log_types,
977-
mask_type=mask_type,
978-
global_sparsity=False,
979-
score_type=score_type,
980-
)
981-
982-
self._param_groups = param_groups or []
983-
984-
@BaseGMPruningModifier.sparsification_types.getter
985-
def sparsification_types(self) -> List[SparsificationTypes]:
986-
"""
987-
:return: the sparsification types this modifier instance will apply
988-
"""
989-
return [SparsificationTypes.pruning, SparsificationTypes.structured]
990-
991-
@ModifierProp()
992-
def param_groups(self) -> List[List[str]]:
993-
"""
994-
:return: list of list of parameter names that should be pruned together
995-
during structured pruning so that their same indices may be removed. May be
996-
useful for structures such as residual blocks or grouped convolutions
997-
"""
998-
return self._param_groups
999-
1000-
@ModifierProp(serializable=False)
1001-
def global_sparsity(self) -> bool:
1002-
"""
1003-
:return: True if global pruning is enabled, False otherwise
1004-
"""
1005-
return self._global_sparsity
1006-
1007-
def _create_pruning_mask(
1008-
self, layers: List[Module], layer_names: List[str], param_names: List[str]
1009-
) -> ModuleParamPruningMask:
1010-
# find and validate parameter groups for structured pruning
1011-
full_param_names = [
1012-
f"{layer_name}.{param_name}"
1013-
for layer_name, param_name in zip(layer_names, param_names)
1014-
]
1015-
param_name_to_idx = dict(zip(full_param_names, range(len(full_param_names))))
1016-
param_group_idxs = []
1017-
added_idxs = set()
1018-
1019-
for param_group in self._param_groups:
1020-
group_idxs = []
1021-
for param_name in param_group:
1022-
if param_name not in param_name_to_idx:
1023-
raise ValueError(
1024-
f"param {param_name} from param_groups "
1025-
f"not found in pruning modifier params {full_param_names}"
1026-
)
1027-
param_idx = param_name_to_idx[param_name]
1028-
if param_idx in added_idxs:
1029-
raise ValueError(
1030-
"found repeated param name in param_groups " f"{param_name}"
1031-
)
1032-
group_idxs.append(param_idx)
1033-
added_idxs.add(param_idx)
1034-
param_group_idxs.append(group_idxs)
1035-
for idx in range(len(full_param_names)):
1036-
if idx not in added_idxs:
1037-
param_group_idxs.append([idx])
1038-
1039-
self._mask_creator.set_tensor_group_idxs(param_group_idxs)
1040-
1041-
return ModuleParamPruningMask(
1042-
layers,
1043-
param_names,
1044-
layer_names=layer_names,
1045-
mask_creator=self._mask_creator,
1046-
global_sparsity=False,
1047-
score_type=self._score_type,
1048-
)
1049-
1050-
1051875
@PyTorchModifierYAML()
1052876
class LayerPruningModifier(ScheduledUpdateModifier):
1053877
"""
1054878
Class for pruning away layers within a module
1055879
(replaces with sparseml.pytorch.nn.Identity).
1056-
1057880
| Sample yaml:
1058881
| !LayerPruningModifier
1059882
| layers: ['bert.encoder.layer.6', 'bert.encoder.layer.7']
1060883
|
1061-
1062884
:param layers: A list of full layer names to apply pruning to.
1063885
__ALL_ will match to all layers. __ALL_PRUNABLE__ will match to all ConvNd
1064886
and Linear layers
@@ -1124,7 +946,6 @@ def initialize(
1124946
):
1125947
"""
1126948
Grab the layers and apply if epoch in range to control pruning for.
1127-
1128949
:param module: the PyTorch model/module to modify
1129950
:param epoch: The epoch to initialize the modifier and module at.
1130951
Defaults to 0 (start of the training process)
@@ -1145,7 +966,6 @@ def finalize(
1145966
):
1146967
"""
1147968
Cleans up any remaining hooks
1148-
1149969
:param module: The model/module to finalize the modifier for.
1150970
Marked optional so state can still be cleaned up on delete,
1151971
but generally should always be passed in.
@@ -1165,7 +985,6 @@ def update(
1165985
):
1166986
"""
1167987
Update to enable and disable the layers when chosen.
1168-
1169988
:param module: module to modify
1170989
:param optimizer: optimizer to modify
1171990
:param epoch: current epoch and progress within the current epoch
@@ -1180,7 +999,6 @@ def log_update(
1180999
):
11811000
"""
11821001
Check whether to log an update for the state of the modifier.
1183-
11841002
:param module: module to modify
11851003
:param optimizer: optimizer to modify
11861004
:param epoch: current epoch and progress within the current epoch

src/sparseml/pytorch/sparsification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
# flake8: noqa
2323

2424
from .info import *
25+
from .modifier_thinning import *
2526
from .pruning import *
File renamed without changes.

src/sparseml/pytorch/sparsification/pruning/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@
2626
from .modifier_pruning_constant import *
2727
from .modifier_pruning_magnitude import *
2828
from .modifier_pruning_mfac import *
29+
from .modifier_pruning_structured import *
2930
from .scorer import *

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_acdc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,12 @@ def _get_scorer(self, params: List[Parameter]) -> PruningParamsScorer:
217217
"""
218218
return MagnitudePruningParamsScorer(params)
219219

220-
def _get_mask_creator(self) -> PruningMaskCreator:
220+
def _get_mask_creator(
221+
self, param_names: List[str], params: List[Parameter]
222+
) -> PruningMaskCreator:
221223
"""
224+
:param names: full names of parameters to be pruned
225+
:param params: list of Parameters to be masked
222226
:return: mask creator object to be used by this pruning algorithm
223227
"""
224228
if self._mask_type == "unstructured":

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_base.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,12 @@ def sparsification_types(self) -> List[SparsificationTypes]:
174174
return [SparsificationTypes.pruning]
175175

176176
@abstractmethod
177-
def _get_mask_creator(self) -> PruningMaskCreator:
177+
def _get_mask_creator(
178+
self, param_names: List[str], params: List[Parameter]
179+
) -> PruningMaskCreator:
178180
"""
181+
:param names: full names of parameters to be pruned
182+
:param params: list of Parameters to be masked
179183
:return: mask creator object to be used by this pruning algorithm
180184
"""
181185
raise NotImplementedError()
@@ -241,7 +245,7 @@ def mask_creator(self) -> Optional[PruningMaskCreator]:
241245
"""
242246
:return: mask creator object used by this pruning algorithm
243247
"""
244-
raise self._mask_creator
248+
return self._mask_creator
245249

246250
@property
247251
def scorer(self) -> Optional[PruningParamsScorer]:
@@ -295,13 +299,15 @@ def initialize(
295299
layer_names = [nlp.layer_name for nlp in named_layers_and_params]
296300

297301
# initialize mask_creator and scorer
298-
self._mask_creator = self._get_mask_creator()
299-
self._scorer = self._get_scorer(
300-
params=[
301-
getattr(layer, param_name)
302-
for layer, param_name in zip(layers, param_names)
303-
]
304-
)
302+
params = [
303+
getattr(layer, param_name) for layer, param_name in zip(layers, param_names)
304+
]
305+
full_param_names = [
306+
f"{layer_name}.{param_name}"
307+
for layer_name, param_name in zip(layer_names, param_names)
308+
]
309+
self._mask_creator = self._get_mask_creator(full_param_names, params)
310+
self._scorer = self._get_scorer(params)
305311

306312
self._module_masks = self._create_pruning_mask(layers, layer_names, param_names)
307313
self._analyzers = self._create_analyzers(layers, layer_names, param_names)

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_constant.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import torch
2323
from torch import Tensor
24-
from torch.nn import Module
24+
from torch.nn import Module, Parameter
2525

2626
from sparseml.pytorch.optim.modifier import (
2727
ModifierProp,
@@ -135,8 +135,12 @@ def __init__(
135135
parent_class_kwarg_names=["params"],
136136
)
137137

138-
def _get_mask_creator(self) -> PruningMaskCreator:
138+
def _get_mask_creator(
139+
self, param_names: List[str], params: List[Parameter]
140+
) -> PruningMaskCreator:
139141
"""
142+
:param names: full names of parameters to be pruned
143+
:param params: list of Parameters to be masked
140144
:return: mask creator object to be used by this pruning algorithm
141145
"""
142146
return ConstantMaskCreator()

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_magnitude.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,21 @@ def __init__(
135135
end_comparator=-1,
136136
global_sparsity=self._use_global_sparsity,
137137
allow_reintroduction=False,
138-
parent_class_kwarg_names=["init_sparsity", "final_sparsity", "params"],
138+
parent_class_kwarg_names=[
139+
"init_sparsity",
140+
"final_sparsity",
141+
"params",
142+
"leave_enabled",
143+
"mask_type",
144+
],
139145
)
140146

141-
def _get_mask_creator(self) -> PruningMaskCreator:
147+
def _get_mask_creator(
148+
self, param_names: List[str], params: List[Parameter]
149+
) -> PruningMaskCreator:
142150
"""
151+
:param names: full names of parameters to be pruned
152+
:param params: list of parameters to be masked
143153
:return: mask creator object to be used by this pruning algorithm
144154
"""
145155
if self.mask_type == "unstructured":

0 commit comments

Comments
 (0)