Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit f366825

Browse files
authored
Native Keras support (#74)
* Enable native keras import * Rebase, address comments * Import keras from sparseml in tests
1 parent 5396662 commit f366825

File tree

15 files changed

+271
-241
lines changed

15 files changed

+271
-241
lines changed

src/sparseml/keras/__init__.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,24 @@
2222
try:
2323
import tensorflow
2424

25-
version = [int(v) for v in tensorflow.__version__.split(".")]
26-
if version[0] != 2 or version[1] < 2:
27-
raise Exception
25+
if tensorflow.__version__ < "2.1.0":
26+
raise RuntimeError("TensorFlow >= 2.1.0 is required, found {}".format(version))
2827
except:
2928
raise RuntimeError(
30-
"Unable to import tensorflow. tensorflow>=2.2 is required"
29+
"Unable to import tensorflow. TensorFlow>=2.1.0 is required"
3130
" to use sparseml.keras."
3231
)
32+
33+
34+
try:
35+
import keras
36+
37+
v = keras.__version__
38+
if v < "2.4.3":
39+
raise RuntimeError(
40+
"Native keras is found and will be used, but required >= 2.4.3, found {}".format(
41+
v
42+
)
43+
)
44+
except:
45+
pass

src/sparseml/keras/optim/manager.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020

2121
from typing import List, Union
2222

23-
import tensorflow as tf
23+
from tensorflow import Tensor
2424

2525
from sparseml.keras.optim.modifier import Modifier, ScheduledModifier
26+
from sparseml.keras.utils.compat import keras
2627
from sparseml.keras.utils.logger import KerasLogger
2728
from sparseml.optim import BaseManager
2829
from sparseml.utils import load_recipe_yaml_str
@@ -71,11 +72,11 @@ def __init__(self, modifiers: List[ScheduledModifier]):
7172

7273
def modify(
7374
self,
74-
model: Union[tf.keras.Model, tf.keras.Sequential],
75-
optimizer: tf.keras.optimizers.Optimizer,
75+
model: Union[keras.Model, keras.Sequential],
76+
optimizer: keras.optimizers.Optimizer,
7677
steps_per_epoch: int,
7778
loggers: Union[KerasLogger, List[KerasLogger]] = None,
78-
input_tensors: tf.Tensor = None,
79+
input_tensors: Tensor = None,
7980
):
8081
"""
8182
Modify the model and optimizer based on the requirements of modifiers
@@ -106,14 +107,14 @@ def modify(
106107
continue
107108
if isinstance(callback, list):
108109
callbacks = callbacks + callback
109-
elif isinstance(callback, tf.keras.callbacks.Callback):
110+
elif isinstance(callback, keras.callbacks.Callback):
110111
callbacks.append(callback)
111112
else:
112113
raise RuntimeError("Invalid callback type")
113114
self._optimizer = optimizer
114115
return model, optimizer, callbacks
115116

116-
def finalize(self, model: tf.keras.Model):
117+
def finalize(self, model: keras.Model):
117118
"""
118119
Remove extra information related to the modifier from the model that is
119120
not necessary for exporting

src/sparseml/keras/optim/mask_pruning.py

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
import inspect
1818
from typing import List, Union
1919

20-
import tensorflow as tf
20+
import tensorflow
2121

2222
from sparseml.keras.optim.mask_pruning_creator import (
2323
PruningMaskCreator,
2424
load_mask_creator,
2525
)
26+
from sparseml.keras.utils import keras
2627

2728

2829
__all__ = [
@@ -80,7 +81,7 @@ def deserialize(cls, config):
8081
if "class_name" not in config:
8182
raise ValueError("The 'class_name' not found in config: {}".format(config))
8283
class_name = config["class_name"]
83-
return tf.keras.utils.deserialize_keras_object(
84+
return keras.utils.deserialize_keras_object(
8485
config,
8586
module_objects=globals(),
8687
custom_objects={class_name: PruningScheduler._REGISTRY[class_name]},
@@ -112,7 +113,7 @@ def __init__(
112113
pruning_vars: List[MaskedParamInfo],
113114
pruning_scheduler: PruningScheduler,
114115
mask_creator: PruningMaskCreator,
115-
global_step: tf.Tensor,
116+
global_step: tensorflow.Tensor,
116117
):
117118
self._pruning_vars = pruning_vars
118119
self._pruning_scheduler = pruning_scheduler
@@ -121,36 +122,36 @@ def __init__(
121122
self._update_ready = None
122123

123124
def _is_pruning_step(self) -> bool:
124-
global_step_val = tf.keras.backend.get_value(self._global_step)
125+
global_step_val = keras.backend.get_value(self._global_step)
125126
assert global_step_val >= 0
126127
update_ready = self._pruning_scheduler.should_prune(global_step_val)
127128
return update_ready
128129

129130
def _conditional_training_update(self):
130131
def _no_update_masks_and_weights():
131-
return tf.no_op("no_update")
132+
return tensorflow.no_op("no_update")
132133

133134
def _update_masks_and_weights():
134135
assignments = []
135-
global_step_val = tf.keras.backend.get_value(self._global_step)
136+
global_step_val = keras.backend.get_value(self._global_step)
136137
for masked_param_info in self._pruning_vars:
137138
new_sparsity = self._pruning_scheduler.target_sparsity(global_step_val)
138139
new_mask = self._mask_creator.create_sparsity_mask(
139140
masked_param_info.param, new_sparsity
140141
)
141142
assignments.append(masked_param_info.mask.assign(new_mask))
142143
assignments.append(masked_param_info.sparsity.assign(new_sparsity))
143-
masked_param = tf.math.multiply(
144+
masked_param = tensorflow.math.multiply(
144145
masked_param_info.param, masked_param_info.mask
145146
)
146147
assignments.append(masked_param_info.param.assign(masked_param))
147-
return tf.group(assignments)
148+
return tensorflow.group(assignments)
148149

149150
update_ready = self._is_pruning_step()
150151

151152
self._update_ready = update_ready
152-
return tf.cond(
153-
tf.cast(update_ready, tf.bool),
153+
return tensorflow.cond(
154+
tensorflow.cast(update_ready, tensorflow.bool),
154155
_update_masks_and_weights,
155156
_no_update_masks_and_weights,
156157
)
@@ -161,11 +162,11 @@ def apply_masks(self):
161162
"""
162163
assignments = []
163164
for masked_param_info in self._pruning_vars:
164-
masked_param = tf.math.multiply(
165+
masked_param = tensorflow.math.multiply(
165166
masked_param_info.param, masked_param_info.mask
166167
)
167168
assignments.append(masked_param_info.param.assign(masked_param))
168-
return tf.group(assignments)
169+
return tensorflow.group(assignments)
169170

170171
def conditional_update(self, training=None):
171172
"""
@@ -175,32 +176,34 @@ def conditional_update(self, training=None):
175176
"""
176177

177178
def _update():
178-
with tf.control_dependencies([self._conditional_training_update()]):
179-
return tf.no_op("update")
179+
with tensorflow.control_dependencies([self._conditional_training_update()]):
180+
return tensorflow.no_op("update")
180181

181182
def _no_update():
182-
return tf.no_op("no_update")
183+
return tensorflow.no_op("no_update")
183184

184-
training = tf.keras.backend.learning_phase() if training is None else training
185-
return tf.cond(tf.cast(training, tf.bool), _update, _no_update)
185+
training = keras.backend.learning_phase() if training is None else training
186+
return tensorflow.cond(
187+
tensorflow.cast(training, tensorflow.bool), _update, _no_update
188+
)
186189

187190

188191
_LAYER_PRUNABLE_PARAMS_MAP = {
189-
tf.keras.layers.Conv1D: ["kernel"],
190-
tf.keras.layers.Conv2D: ["kernel"],
191-
tf.keras.layers.Conv2DTranspose: ["kernel"],
192-
tf.keras.layers.Conv3D: ["kernel"],
193-
tf.keras.layers.Conv3DTranspose: ["kernel"],
194-
tf.keras.layers.Dense: ["kernel"],
195-
tf.keras.layers.Embedding: ["embeddings"],
196-
tf.keras.layers.LocallyConnected1D: ["kernel"],
197-
tf.keras.layers.LocallyConnected2D: ["kernel"],
198-
tf.keras.layers.SeparableConv1D: ["pointwise_kernel"],
199-
tf.keras.layers.SeparableConv2D: ["pointwise_kernel"],
192+
keras.layers.Conv1D: ["kernel"],
193+
keras.layers.Conv2D: ["kernel"],
194+
keras.layers.Conv2DTranspose: ["kernel"],
195+
keras.layers.Conv3D: ["kernel"],
196+
keras.layers.Conv3DTranspose: ["kernel"],
197+
keras.layers.Dense: ["kernel"],
198+
keras.layers.Embedding: ["embeddings"],
199+
keras.layers.LocallyConnected1D: ["kernel"],
200+
keras.layers.LocallyConnected2D: ["kernel"],
201+
keras.layers.SeparableConv1D: ["pointwise_kernel"],
202+
keras.layers.SeparableConv2D: ["pointwise_kernel"],
200203
}
201204

202205

203-
def _get_default_prunable_params(layer: tf.keras.layers.Layer):
206+
def _get_default_prunable_params(layer: keras.layers.Layer):
204207
if layer.__class__ in _LAYER_PRUNABLE_PARAMS_MAP:
205208
prunable_param_names = _LAYER_PRUNABLE_PARAMS_MAP[layer.__class__]
206209
return {
@@ -216,7 +219,7 @@ def _get_default_prunable_params(layer: tf.keras.layers.Layer):
216219
)
217220

218221

219-
class MaskedLayer(tf.keras.layers.Wrapper):
222+
class MaskedLayer(keras.layers.Wrapper):
220223
"""
221224
Masked layer is a layer wrapping around another layer with a mask; the mask however
222225
is shared if the enclosed layer is again of MaskedLayer type
@@ -229,13 +232,13 @@ class MaskedLayer(tf.keras.layers.Wrapper):
229232

230233
def __init__(
231234
self,
232-
layer: tf.keras.layers.Layer,
235+
layer: keras.layers.Layer,
233236
pruning_scheduler: PruningScheduler,
234237
mask_type: Union[str, List[int]] = "unstructured",
235238
**kwargs,
236239
):
237240
if not isinstance(layer, MaskedLayer) and not isinstance(
238-
layer, tf.keras.layers.Layer
241+
layer, keras.layers.Layer
239242
):
240243
raise ValueError(
241244
"Invalid layer passed in, expected MaskedLayer or a keras Layer, "
@@ -257,8 +260,8 @@ def build(self, input_shape):
257260
self._global_step = self.add_weight(
258261
"global_step",
259262
shape=[],
260-
initializer=tf.keras.initializers.Constant(-1),
261-
dtype=tf.int64,
263+
initializer=keras.initializers.Constant(-1),
264+
dtype=tensorflow.int64,
262265
trainable=False,
263266
)
264267
self._mask_updater = MaskAndWeightUpdater(
@@ -276,43 +279,43 @@ def _reuse_or_create_pruning_vars(
276279
# for the "core", inner-most, Keras built-in layer
277280
return self._layer.pruning_vars
278281

279-
assert isinstance(self._layer, tf.keras.layers.Layer)
282+
assert isinstance(self._layer, keras.layers.Layer)
280283
prunable_params = _get_default_prunable_params(self._layer)
281284

282285
pruning_vars = []
283286
for name, param in prunable_params.items():
284287
mask = self.add_weight(
285288
"mask",
286289
shape=param.shape,
287-
initializer=tf.keras.initializers.get("ones"),
290+
initializer=keras.initializers.get("ones"),
288291
dtype=param.dtype,
289292
trainable=False,
290293
)
291294
sparsity = self.add_weight(
292295
"sparsity",
293296
shape=[],
294-
initializer=tf.keras.initializers.get("zeros"),
297+
initializer=keras.initializers.get("zeros"),
295298
dtype=param.dtype,
296299
trainable=False,
297300
)
298301
pruning_vars.append(MaskedParamInfo(name, param, mask, sparsity))
299302
return pruning_vars
300303

301-
def call(self, inputs: tf.Tensor, training=None):
304+
def call(self, inputs: tensorflow.Tensor, training=None):
302305
"""
303306
Forward function for calling layer instance as function
304307
"""
305-
training = tf.keras.backend.learning_phase() if training is None else training
308+
training = keras.backend.learning_phase() if training is None else training
306309

307310
def _apply_masks_to_weights():
308-
with tf.control_dependencies([self._mask_updater.apply_masks()]):
309-
return tf.no_op("update")
311+
with tensorflow.control_dependencies([self._mask_updater.apply_masks()]):
312+
return tensorflow.no_op("update")
310313

311314
def _no_apply_masks_to_weights():
312-
return tf.no_op("no_update_masks")
315+
return tensorflow.no_op("no_update_masks")
313316

314-
tf.cond(
315-
tf.cast(training, tf.bool),
317+
tensorflow.cond(
318+
tensorflow.cast(training, tensorflow.bool),
316319
_apply_masks_to_weights,
317320
_no_apply_masks_to_weights,
318321
)
@@ -327,7 +330,7 @@ def get_config(self):
327330
"""
328331
Get layer config
329332
Serialization and deserialization should be done using
330-
tf.keras.serialize/deserialize, which create and retrieve the "class_name"
333+
keras.serialize/deserialize, which create and retrieve the "class_name"
331334
field automatically.
332335
The resulting config below therefore does not contain the field.
333336
"""
@@ -345,11 +348,11 @@ def get_config(self):
345348
@classmethod
346349
def from_config(cls, config):
347350
config = config.copy()
348-
layer = tf.keras.layers.deserialize(
351+
layer = keras.layers.deserialize(
349352
config.pop("layer"), custom_objects={"MaskedLayer": MaskedLayer}
350353
)
351354
if not isinstance(layer, MaskedLayer) and not isinstance(
352-
layer, tf.keras.layers.Layer
355+
layer, keras.layers.Layer
353356
):
354357
raise RuntimeError("Unexpected layer created from config")
355358
pruning_scheduler = PruningScheduler.deserialize(
@@ -384,7 +387,7 @@ def pruning_vars(self):
384387
def pruned_layer(self):
385388
if isinstance(self._layer, MaskedLayer):
386389
return self._layer.pruned_layer
387-
elif isinstance(self._layer, tf.keras.layers.Layer):
390+
elif isinstance(self._layer, keras.layers.Layer):
388391
return self._layer
389392
else:
390393
raise RuntimeError("Unrecognized layer")
@@ -394,7 +397,7 @@ def masked_layer(self):
394397
return self._layer
395398

396399

397-
def remove_pruning_masks(model: tf.keras.Model):
400+
def remove_pruning_masks(model: keras.Model):
398401
"""
399402
Remove pruning masks from a model that was pruned using the MaskedLayer logic
400403
:param model: a model that was pruned using MaskedLayer
@@ -410,7 +413,7 @@ def _get_pruned_layer(layer):
410413
) or layer.__class__.__name__.endswith("MaskedLayer")
411414
if is_masked_layer:
412415
return _get_pruned_layer(layer.layer)
413-
elif isinstance(layer, tf.keras.layers.Layer):
416+
elif isinstance(layer, keras.layers.Layer):
414417
return layer
415418
else:
416419
raise ValueError("Unknown layer type")
@@ -425,6 +428,6 @@ def _remove_pruning_masks(layer):
425428

426429
# TODO: while the resulting model could be exported to ONNX, its built status
427430
# is removed
428-
return tf.keras.models.clone_model(
431+
return keras.models.clone_model(
429432
model, input_tensors=None, clone_function=_remove_pruning_masks
430433
)

0 commit comments

Comments
 (0)