Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 24db8ee

Browse files
authored
Base Modifier classes and refactor for framework impls (#399)
1 parent dbad7de commit 24db8ee

19 files changed

+759
-809
lines changed

src/sparseml/keras/optim/modifier_epoch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
"""
1818

1919
from sparseml.keras.optim.modifier import KerasModifierYAML, ScheduledModifier
20+
from sparseml.sparsification import EpochRangeModifier as BaseEpochRangeModifier
2021

2122

2223
__all__ = ["EpochRangeModifier"]
2324

2425

2526
@KerasModifierYAML()
26-
class EpochRangeModifier(ScheduledModifier):
27+
class EpochRangeModifier(BaseEpochRangeModifier, ScheduledModifier):
2728
"""
2829
Simple modifier to set the range of epochs to train over for
2930
the recalibration process.
@@ -45,6 +46,6 @@ def __init__(
4546
:param start_epoch: The epoch to start the modifier at
4647
:param end_epoch: The epoch to end the modifier at
4748
"""
48-
super().__init__(
49+
super(EpochRangeModifier, self).__init__(
4950
start_epoch=start_epoch, end_epoch=end_epoch, end_comparator=-1
5051
)

src/sparseml/keras/optim/modifier_lr.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
ScheduledUpdateModifier,
2727
)
2828
from sparseml.keras.utils import KerasLogger, LoggerSettingCallback, LoggingMode, keras
29-
from sparseml.optim import LearningRate, SetLearningRate
29+
from sparseml.sparsification import LearningRateModifier as BaseLearningRateModifier
30+
from sparseml.sparsification import (
31+
SetLearningRateModifier as BaseSetLearningRateModifier,
32+
)
3033
from sparseml.utils import ALL_TOKEN
3134

3235

@@ -186,7 +189,7 @@ def _is_logging_step(self):
186189

187190

188191
@KerasModifierYAML()
189-
class SetLearningRateModifier(ScheduledModifier, SetLearningRate):
192+
class SetLearningRateModifier(BaseSetLearningRateModifier, ScheduledModifier):
190193
"""
191194
Modifier to set the learning rate to a specific value at a certain point
192195
in the training process. Once that point is reached, will update the optimizer's
@@ -212,7 +215,7 @@ def __init__(
212215
end_epoch: float = -1,
213216
log_types: Union[str, List[str]] = ALL_TOKEN,
214217
):
215-
super().__init__(
218+
super(SetLearningRateModifier, self).__init__(
216219
learning_rate=learning_rate,
217220
log_types=log_types,
218221
start_epoch=start_epoch,
@@ -316,7 +319,7 @@ def get_config(self):
316319

317320

318321
@KerasModifierYAML()
319-
class LearningRateModifier(ScheduledUpdateModifier, LearningRate):
322+
class LearningRateModifier(BaseLearningRateModifier, ScheduledUpdateModifier):
320323
"""
321324
Modifier to set the learning rate to follow specific schedulers
322325
within a period of epochs.
@@ -358,7 +361,7 @@ def __init__(
358361
update_frequency: float = -1.0,
359362
log_types: Union[str, List[str]] = ALL_TOKEN,
360363
):
361-
super().__init__(
364+
super(LearningRateModifier, self).__init__(
362365
lr_class=lr_class,
363366
lr_kwargs=lr_kwargs,
364367
init_lr=init_lr,

src/sparseml/keras/optim/modifier_params.py

Lines changed: 11 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121

2222
from tensorflow import Tensor
2323

24-
from sparseml.keras.optim.modifier import (
25-
KerasModifierYAML,
26-
ModifierProp,
27-
ScheduledModifier,
28-
)
24+
from sparseml.keras.optim.modifier import KerasModifierYAML, ScheduledModifier
2925
from sparseml.keras.optim.utils import get_layer_name_from_param
3026
from sparseml.keras.utils import keras
31-
from sparseml.utils import ALL_TOKEN, convert_to_bool, flatten_iterable
27+
from sparseml.sparsification import (
28+
TrainableParamsModifier as BaseTrainableParamsModifier,
29+
)
30+
from sparseml.utils import ALL_TOKEN, flatten_iterable
3231

3332

3433
__all__ = ["TrainableParamsModifier"]
@@ -62,7 +61,7 @@ def on_train_batch_end(self, batch, logs=None):
6261

6362

6463
@KerasModifierYAML()
65-
class TrainableParamsModifier(ScheduledModifier):
64+
class TrainableParamsModifier(BaseTrainableParamsModifier, ScheduledModifier):
6665
"""
6766
Modifier to control the params for a given list of parameters.
6867
Applies the trainability over all epochs.
@@ -93,16 +92,15 @@ def __init__(
9392
end_epoch: float = -1.0,
9493
):
9594
super(TrainableParamsModifier, self).__init__(
96-
start_epoch=-1,
97-
end_epoch=-1,
95+
params=self._validate_params(params),
96+
trainable=trainable,
97+
params_strict=params_strict,
98+
start_epoch=start_epoch,
99+
end_epoch=end_epoch,
98100
end_comparator=-1,
99101
)
100-
self._params = self._validate_params(params)
101102
self._layer_names = [get_layer_name_from_param(p) for p in self._params]
102-
self._trainable = convert_to_bool(trainable)
103-
self._params_strict = convert_to_bool(params_strict)
104103
self._vars_to_trainable_orig = {}
105-
self.validate()
106104

107105
def _validate_params(self, params: Union[str, List[Union[int, str]]]):
108106
if isinstance(params, str):
@@ -118,66 +116,10 @@ def _validate_params(self, params: Union[str, List[Union[int, str]]]):
118116
)
119117
)
120118

121-
@ModifierProp()
122-
def params(self) -> Union[str, List[str]]:
123-
"""
124-
:return: A list of full parameter names or regex patterns of names to apply
125-
pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
126-
will match to all parameters. Can also use the token __ALL__ to specify all
127-
params
128-
"""
129-
return self._params
130-
131-
@params.setter
132-
def params(self, value: Union[str, List[str]]):
133-
"""
134-
:param value: A list of full parameter names or regex patterns of names to apply
135-
pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
136-
will match to all parameters. Can also use the token __ALL__ to specify all
137-
params
138-
"""
139-
self._params = self._validate_params(value)
140-
self.validate()
141-
142119
@property
143120
def layer_names(self) -> List[str]:
144121
return self._layer_names
145122

146-
@ModifierProp()
147-
def trainable(self) -> bool:
148-
"""
149-
:return: True if the param(s) should be made trainable,
150-
False to make them non-trainable
151-
"""
152-
return self._trainable
153-
154-
@trainable.setter
155-
def trainable(self, value: bool):
156-
"""
157-
:param value: True if the param(s) should be made trainable,
158-
False to make them non-trainable
159-
"""
160-
self._trainable = value
161-
self.validate()
162-
163-
@ModifierProp()
164-
def params_strict(self) -> bool:
165-
"""
166-
:return: True if the given param(s) must be found in each layer and
167-
will raise an err if not found,
168-
False if missing params are ok and will not raise an err
169-
"""
170-
return self._params_strict
171-
172-
@params_strict.setter
173-
def params_strict(self, value: bool):
174-
"""
175-
:param value: True if the given param(s) must be found in each layer and
176-
will raise an err if not found,
177-
False if missing params are ok and will not raise an err
178-
"""
179-
self._params_strict = value
180-
181123
def validate(self):
182124
"""
183125
Validate the values of the params for the current instance are valid

0 commit comments

Comments
 (0)