3030from sparseml .pytorch .nn import Identity
3131from sparseml .pytorch .optim .analyzer_pruning import ModulePruningAnalyzer
3232from sparseml .pytorch .optim .mask_creator_pruning import (
33- DimensionSparsityMaskCreator ,
3433 PruningMaskCreator ,
3534 load_mask_creator ,
3635)
6564 "LegacyGMPruningModifier" ,
6665 "MovementPruningModifier" ,
6766 "LayerPruningModifier" ,
68- "StructuredPruningModifier" ,
6967]
7068
7169
@@ -874,191 +872,15 @@ def score_type(self) -> str:
874872 return self ._score_type
875873
876874
877- @PyTorchModifierYAML ()
878- class StructuredPruningModifier (_GMPruningModifier ):
879- """
880- Gradually applies structured kernel sparsity to a given parameter or parameters
881- from init_sparsity until final_sparsity is reached over a given amount of time
882- and applied with an interpolated function for each step taken. Channel and filter
883- pruning supported.
884-
885- A param_group_dependency_map must be provided that maps
886- groups of prunable parameter names that should have their dimensions pruned
887- together to a list of module parameter names that should be updated accordingly
888- when those parameters are pruned.
889-
890- | Sample yaml:
891- | !StructuredPruningModifier
892- | param_groups: [
893- | ["param.1.name","param.2.name"], ["param.3.name", "param.4.name"]
894- | ]
895- | init_sparsity: 0.05
896- | final_sparsity: 0.8
897- | start_epoch: 0.0
898- | end_epoch: 10.0
899- | update_frequency: 1.0
900- | params: __ALL_PRUNABLE__
901- | leave_enabled: True
902- | inter_func: cubic
903- | log_types: __ALL__
904- | mask_type: filter
905- | score_type: magnitude
906-
907- :param init_sparsity: the initial sparsity for the param to start with at
908- start_epoch
909- :param final_sparsity: the final sparsity for the param to end with at end_epoch
910- :param start_epoch: The epoch to start the modifier at
911- :param end_epoch: The epoch to end the modifier at
912- :param update_frequency: The number of epochs or fraction of epochs to update at
913- between start and end
914- :param param_groups: list of list of parameter names that should be pruned together
915- during structured pruning so that their same indices may be removed. May be
916- useful for structures such as residual blocks or grouped convolutions. Can be
917- generated from an onnx export of the target module with
918- sparseml.onnx.optim.get_param_structured_pruning_group_dependencies by
919- splitting its comma separated keys into lists.
920- i.e. [["param.1.name","param.2.name"], ["param.3.name", "param.4.name"]]
921- :param params: A list of full parameter names or regex patterns of names to apply
922- pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
923- will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
924- and Linear layers' weights. Defualt is __ALL_PRUNABLE__
925- :param leave_enabled: True to continue masking the weights after end_epoch,
926- False to stop masking. Should be set to False if exporting the result
927- immediately after or doing some other prune
928- :param inter_func: the type of interpolation function to use:
929- [linear, cubic, inverse_cubic]
930- :param phased: True to enable a phased approach where pruning will
931- turn on and off with the update_frequency. Starts with pruning on
932- at start_epoch, off at start_epoch + update_frequency, and so on.
933- :param log_types: The loggers to allow the learning rate to be logged to,
934- default is __ALL__
935- :param mask_type: String to define type of structured sparsity (options: [
936- 'channel', 'filter']), or a DimensionSparsityMaskCreator object.
937- default is 'filter'
938- :param score_type: Method used to score parameters for masking, i.e.
939- 'magnitude', 'movement'. Default is 'magnitude'
940- """
941-
942- def __init__ (
943- self ,
944- init_sparsity : float ,
945- final_sparsity : float ,
946- start_epoch : float ,
947- end_epoch : float ,
948- update_frequency : float ,
949- param_groups : List [List [str ]] = None ,
950- params : Union [str , List [str ]] = ALL_PRUNABLE_TOKEN ,
951- leave_enabled : bool = True ,
952- inter_func : str = "cubic" ,
953- phased : bool = False ,
954- log_types : Union [str , List [str ]] = ALL_TOKEN ,
955- mask_type : Union [str , DimensionSparsityMaskCreator ] = "filter" ,
956- score_type : str = "magnitude" ,
957- ):
958- if not isinstance (mask_type , DimensionSparsityMaskCreator ) and (
959- mask_type not in ["channel" , "filter" ]
960- ):
961- raise ValueError (
962- "StructuredPruningModifier mask_type must be a "
963- "DimensionSparsityMaskCreator or designate 'channel' or 'filter' "
964- f"found { mask_type } "
965- )
966- super ().__init__ (
967- init_sparsity = init_sparsity ,
968- final_sparsity = final_sparsity ,
969- start_epoch = start_epoch ,
970- end_epoch = end_epoch ,
971- update_frequency = update_frequency ,
972- params = params ,
973- leave_enabled = leave_enabled ,
974- inter_func = inter_func ,
975- phased = phased ,
976- log_types = log_types ,
977- mask_type = mask_type ,
978- global_sparsity = False ,
979- score_type = score_type ,
980- )
981-
982- self ._param_groups = param_groups or []
983-
984- @BaseGMPruningModifier .sparsification_types .getter
985- def sparsification_types (self ) -> List [SparsificationTypes ]:
986- """
987- :return: the sparsification types this modifier instance will apply
988- """
989- return [SparsificationTypes .pruning , SparsificationTypes .structured ]
990-
991- @ModifierProp ()
992- def param_groups (self ) -> List [List [str ]]:
993- """
994- :return: list of list of parameter names that should be pruned together
995- during structured pruning so that their same indices may be removed. May be
996- useful for structures such as residual blocks or grouped convolutions
997- """
998- return self ._param_groups
999-
1000- @ModifierProp (serializable = False )
1001- def global_sparsity (self ) -> bool :
1002- """
1003- :return: True if global pruning is enabled, False otherwise
1004- """
1005- return self ._global_sparsity
1006-
1007- def _create_pruning_mask (
1008- self , layers : List [Module ], layer_names : List [str ], param_names : List [str ]
1009- ) -> ModuleParamPruningMask :
1010- # find and validate parameter groups for structured pruning
1011- full_param_names = [
1012- f"{ layer_name } .{ param_name } "
1013- for layer_name , param_name in zip (layer_names , param_names )
1014- ]
1015- param_name_to_idx = dict (zip (full_param_names , range (len (full_param_names ))))
1016- param_group_idxs = []
1017- added_idxs = set ()
1018-
1019- for param_group in self ._param_groups :
1020- group_idxs = []
1021- for param_name in param_group :
1022- if param_name not in param_name_to_idx :
1023- raise ValueError (
1024- f"param { param_name } from param_groups "
1025- f"not found in pruning modifier params { full_param_names } "
1026- )
1027- param_idx = param_name_to_idx [param_name ]
1028- if param_idx in added_idxs :
1029- raise ValueError (
1030- "found repeated param name in param_groups " f"{ param_name } "
1031- )
1032- group_idxs .append (param_idx )
1033- added_idxs .add (param_idx )
1034- param_group_idxs .append (group_idxs )
1035- for idx in range (len (full_param_names )):
1036- if idx not in added_idxs :
1037- param_group_idxs .append ([idx ])
1038-
1039- self ._mask_creator .set_tensor_group_idxs (param_group_idxs )
1040-
1041- return ModuleParamPruningMask (
1042- layers ,
1043- param_names ,
1044- layer_names = layer_names ,
1045- mask_creator = self ._mask_creator ,
1046- global_sparsity = False ,
1047- score_type = self ._score_type ,
1048- )
1049-
1050-
1051875@PyTorchModifierYAML ()
1052876class LayerPruningModifier (ScheduledUpdateModifier ):
1053877 """
1054878 Class for pruning away layers within a module
1055879 (replaces with sparseml.pytorch.nn.Identity).
1056-
1057880 | Sample yaml:
1058881 | !LayerPruningModifier
1059882 | layers: ['bert.encoder.layer.6', 'bert.encoder.layer.7']
1060883 |
1061-
1062884 :param layers: A list of full layer names to apply pruning to.
1063885 __ALL_ will match to all layers. __ALL_PRUNABLE__ will match to all ConvNd
1064886 and Linear layers
@@ -1124,7 +946,6 @@ def initialize(
1124946 ):
1125947 """
1126948 Grab the layers and apply if epoch in range to control pruning for.
1127-
1128949 :param module: the PyTorch model/module to modify
1129950 :param epoch: The epoch to initialize the modifier and module at.
1130951 Defaults to 0 (start of the training process)
@@ -1145,7 +966,6 @@ def finalize(
1145966 ):
1146967 """
1147968 Cleans up any remaining hooks
1148-
1149969 :param module: The model/module to finalize the modifier for.
1150970 Marked optional so state can still be cleaned up on delete,
1151971 but generally should always be passed in.
@@ -1165,7 +985,6 @@ def update(
1165985 ):
1166986 """
1167987 Update to enable and disable the layers when chosen.
1168-
1169988 :param module: module to modify
1170989 :param optimizer: optimizer to modify
1171990 :param epoch: current epoch and progress within the current epoch
@@ -1180,7 +999,6 @@ def log_update(
1180999 ):
11811000 """
11821001 Check whether to log an update for the state of the modifier.
1183-
11841002 :param module: module to modify
11851003 :param optimizer: optimizer to modify
11861004 :param epoch: current epoch and progress within the current epoch
0 commit comments