2121
2222from tensorflow import Tensor
2323
24- from sparseml .keras .optim .modifier import (
25- KerasModifierYAML ,
26- ModifierProp ,
27- ScheduledModifier ,
28- )
24+ from sparseml .keras .optim .modifier import KerasModifierYAML , ScheduledModifier
2925from sparseml .keras .optim .utils import get_layer_name_from_param
3026from sparseml .keras .utils import keras
31- from sparseml .utils import ALL_TOKEN , convert_to_bool , flatten_iterable
27+ from sparseml .sparsification import (
28+ TrainableParamsModifier as BaseTrainableParamsModifier ,
29+ )
30+ from sparseml .utils import ALL_TOKEN , flatten_iterable
3231
3332
3433__all__ = ["TrainableParamsModifier" ]
@@ -62,7 +61,7 @@ def on_train_batch_end(self, batch, logs=None):
6261
6362
6463@KerasModifierYAML ()
65- class TrainableParamsModifier (ScheduledModifier ):
64+ class TrainableParamsModifier (BaseTrainableParamsModifier , ScheduledModifier ):
6665 """
6766 Modifier to control the params for a given list of parameters.
6867 Applies the trainability over all epochs.
@@ -93,16 +92,15 @@ def __init__(
9392 end_epoch : float = - 1.0 ,
9493 ):
9594 super (TrainableParamsModifier , self ).__init__ (
96- start_epoch = - 1 ,
97- end_epoch = - 1 ,
95+ params = self ._validate_params (params ),
96+ trainable = trainable ,
97+ params_strict = params_strict ,
98+ start_epoch = start_epoch ,
99+ end_epoch = end_epoch ,
98100 end_comparator = - 1 ,
99101 )
100- self ._params = self ._validate_params (params )
101102 self ._layer_names = [get_layer_name_from_param (p ) for p in self ._params ]
102- self ._trainable = convert_to_bool (trainable )
103- self ._params_strict = convert_to_bool (params_strict )
104103 self ._vars_to_trainable_orig = {}
105- self .validate ()
106104
107105 def _validate_params (self , params : Union [str , List [Union [int , str ]]]):
108106 if isinstance (params , str ):
@@ -118,66 +116,10 @@ def _validate_params(self, params: Union[str, List[Union[int, str]]]):
118116 )
119117 )
120118
121- @ModifierProp ()
122- def params (self ) -> Union [str , List [str ]]:
123- """
124- :return: A list of full parameter names or regex patterns of names to apply
125- pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
126- will match to all parameters. Can also use the token __ALL__ to specify all
127- params
128- """
129- return self ._params
130-
131- @params .setter
132- def params (self , value : Union [str , List [str ]]):
133- """
134- :param value: A list of full parameter names or regex patterns of names to apply
135- pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
136- will match to all parameters. Can also use the token __ALL__ to specify all
137- params
138- """
139- self ._params = self ._validate_params (value )
140- self .validate ()
141-
142119 @property
143120 def layer_names (self ) -> List [str ]:
144121 return self ._layer_names
145122
146- @ModifierProp ()
147- def trainable (self ) -> bool :
148- """
149- :return: True if the param(s) should be made trainable,
150- False to make them non-trainable
151- """
152- return self ._trainable
153-
154- @trainable .setter
155- def trainable (self , value : bool ):
156- """
157- :param value: True if the param(s) should be made trainable,
158- False to make them non-trainable
159- """
160- self ._trainable = value
161- self .validate ()
162-
163- @ModifierProp ()
164- def params_strict (self ) -> bool :
165- """
166- :return: True if the given param(s) must be found in each layer and
167- will raise an err if not found,
168- False if missing params are ok and will not raise an err
169- """
170- return self ._params_strict
171-
172- @params_strict .setter
173- def params_strict (self , value : bool ):
174- """
175- :param value: True if the given param(s) must be found in each layer and
176- will raise an err if not found,
177- False if missing params are ok and will not raise an err
178- """
179- self ._params_strict = value
180-
181123 def validate (self ):
182124 """
183125 Validate the values of the params for the current instance are valid
0 commit comments