1717import inspect
1818from typing import List , Union
1919
20- import tensorflow as tf
20+ import tensorflow
2121
2222from sparseml .keras .optim .mask_pruning_creator import (
2323 PruningMaskCreator ,
2424 load_mask_creator ,
2525)
26+ from sparseml .keras .utils import keras
2627
2728
2829__all__ = [
@@ -80,7 +81,7 @@ def deserialize(cls, config):
8081 if "class_name" not in config :
8182 raise ValueError ("The 'class_name' not found in config: {}" .format (config ))
8283 class_name = config ["class_name" ]
83- return tf . keras .utils .deserialize_keras_object (
84+ return keras .utils .deserialize_keras_object (
8485 config ,
8586 module_objects = globals (),
8687 custom_objects = {class_name : PruningScheduler ._REGISTRY [class_name ]},
@@ -112,7 +113,7 @@ def __init__(
112113 pruning_vars : List [MaskedParamInfo ],
113114 pruning_scheduler : PruningScheduler ,
114115 mask_creator : PruningMaskCreator ,
115- global_step : tf .Tensor ,
116+ global_step : tensorflow .Tensor ,
116117 ):
117118 self ._pruning_vars = pruning_vars
118119 self ._pruning_scheduler = pruning_scheduler
@@ -121,36 +122,36 @@ def __init__(
121122 self ._update_ready = None
122123
123124 def _is_pruning_step (self ) -> bool :
124- global_step_val = tf . keras .backend .get_value (self ._global_step )
125+ global_step_val = keras .backend .get_value (self ._global_step )
125126 assert global_step_val >= 0
126127 update_ready = self ._pruning_scheduler .should_prune (global_step_val )
127128 return update_ready
128129
129130 def _conditional_training_update (self ):
130131 def _no_update_masks_and_weights ():
131- return tf .no_op ("no_update" )
132+ return tensorflow .no_op ("no_update" )
132133
133134 def _update_masks_and_weights ():
134135 assignments = []
135- global_step_val = tf . keras .backend .get_value (self ._global_step )
136+ global_step_val = keras .backend .get_value (self ._global_step )
136137 for masked_param_info in self ._pruning_vars :
137138 new_sparsity = self ._pruning_scheduler .target_sparsity (global_step_val )
138139 new_mask = self ._mask_creator .create_sparsity_mask (
139140 masked_param_info .param , new_sparsity
140141 )
141142 assignments .append (masked_param_info .mask .assign (new_mask ))
142143 assignments .append (masked_param_info .sparsity .assign (new_sparsity ))
143- masked_param = tf .math .multiply (
144+ masked_param = tensorflow .math .multiply (
144145 masked_param_info .param , masked_param_info .mask
145146 )
146147 assignments .append (masked_param_info .param .assign (masked_param ))
147- return tf .group (assignments )
148+ return tensorflow .group (assignments )
148149
149150 update_ready = self ._is_pruning_step ()
150151
151152 self ._update_ready = update_ready
152- return tf .cond (
153- tf .cast (update_ready , tf .bool ),
153+ return tensorflow .cond (
154+ tensorflow .cast (update_ready , tensorflow .bool ),
154155 _update_masks_and_weights ,
155156 _no_update_masks_and_weights ,
156157 )
@@ -161,11 +162,11 @@ def apply_masks(self):
161162 """
162163 assignments = []
163164 for masked_param_info in self ._pruning_vars :
164- masked_param = tf .math .multiply (
165+ masked_param = tensorflow .math .multiply (
165166 masked_param_info .param , masked_param_info .mask
166167 )
167168 assignments .append (masked_param_info .param .assign (masked_param ))
168- return tf .group (assignments )
169+ return tensorflow .group (assignments )
169170
170171 def conditional_update (self , training = None ):
171172 """
@@ -175,32 +176,34 @@ def conditional_update(self, training=None):
175176 """
176177
177178 def _update ():
178- with tf .control_dependencies ([self ._conditional_training_update ()]):
179- return tf .no_op ("update" )
179+ with tensorflow .control_dependencies ([self ._conditional_training_update ()]):
180+ return tensorflow .no_op ("update" )
180181
181182 def _no_update ():
182- return tf .no_op ("no_update" )
183+ return tensorflow .no_op ("no_update" )
183184
184- training = tf .keras .backend .learning_phase () if training is None else training
185- return tf .cond (tf .cast (training , tf .bool ), _update , _no_update )
185+ training = keras .backend .learning_phase () if training is None else training
186+ return tensorflow .cond (
187+ tensorflow .cast (training , tensorflow .bool ), _update , _no_update
188+ )
186189
187190
188191_LAYER_PRUNABLE_PARAMS_MAP = {
189- tf . keras .layers .Conv1D : ["kernel" ],
190- tf . keras .layers .Conv2D : ["kernel" ],
191- tf . keras .layers .Conv2DTranspose : ["kernel" ],
192- tf . keras .layers .Conv3D : ["kernel" ],
193- tf . keras .layers .Conv3DTranspose : ["kernel" ],
194- tf . keras .layers .Dense : ["kernel" ],
195- tf . keras .layers .Embedding : ["embeddings" ],
196- tf . keras .layers .LocallyConnected1D : ["kernel" ],
197- tf . keras .layers .LocallyConnected2D : ["kernel" ],
198- tf . keras .layers .SeparableConv1D : ["pointwise_kernel" ],
199- tf . keras .layers .SeparableConv2D : ["pointwise_kernel" ],
192+ keras .layers .Conv1D : ["kernel" ],
193+ keras .layers .Conv2D : ["kernel" ],
194+ keras .layers .Conv2DTranspose : ["kernel" ],
195+ keras .layers .Conv3D : ["kernel" ],
196+ keras .layers .Conv3DTranspose : ["kernel" ],
197+ keras .layers .Dense : ["kernel" ],
198+ keras .layers .Embedding : ["embeddings" ],
199+ keras .layers .LocallyConnected1D : ["kernel" ],
200+ keras .layers .LocallyConnected2D : ["kernel" ],
201+ keras .layers .SeparableConv1D : ["pointwise_kernel" ],
202+ keras .layers .SeparableConv2D : ["pointwise_kernel" ],
200203}
201204
202205
203- def _get_default_prunable_params (layer : tf . keras .layers .Layer ):
206+ def _get_default_prunable_params (layer : keras .layers .Layer ):
204207 if layer .__class__ in _LAYER_PRUNABLE_PARAMS_MAP :
205208 prunable_param_names = _LAYER_PRUNABLE_PARAMS_MAP [layer .__class__ ]
206209 return {
@@ -216,7 +219,7 @@ def _get_default_prunable_params(layer: tf.keras.layers.Layer):
216219 )
217220
218221
219- class MaskedLayer (tf . keras .layers .Wrapper ):
222+ class MaskedLayer (keras .layers .Wrapper ):
220223 """
221224 Masked layer is a layer wrapping around another layer with a mask; the mask however
222225 is shared if the enclosed layer is again of MaskedLayer type
@@ -229,13 +232,13 @@ class MaskedLayer(tf.keras.layers.Wrapper):
229232
230233 def __init__ (
231234 self ,
232- layer : tf . keras .layers .Layer ,
235+ layer : keras .layers .Layer ,
233236 pruning_scheduler : PruningScheduler ,
234237 mask_type : Union [str , List [int ]] = "unstructured" ,
235238 ** kwargs ,
236239 ):
237240 if not isinstance (layer , MaskedLayer ) and not isinstance (
238- layer , tf . keras .layers .Layer
241+ layer , keras .layers .Layer
239242 ):
240243 raise ValueError (
241244 "Invalid layer passed in, expected MaskedLayer or a keras Layer, "
@@ -257,8 +260,8 @@ def build(self, input_shape):
257260 self ._global_step = self .add_weight (
258261 "global_step" ,
259262 shape = [],
260- initializer = tf . keras .initializers .Constant (- 1 ),
261- dtype = tf .int64 ,
263+ initializer = keras .initializers .Constant (- 1 ),
264+ dtype = tensorflow .int64 ,
262265 trainable = False ,
263266 )
264267 self ._mask_updater = MaskAndWeightUpdater (
@@ -276,43 +279,43 @@ def _reuse_or_create_pruning_vars(
276279 # for the "core", inner-most, Keras built-in layer
277280 return self ._layer .pruning_vars
278281
279- assert isinstance (self ._layer , tf . keras .layers .Layer )
282+ assert isinstance (self ._layer , keras .layers .Layer )
280283 prunable_params = _get_default_prunable_params (self ._layer )
281284
282285 pruning_vars = []
283286 for name , param in prunable_params .items ():
284287 mask = self .add_weight (
285288 "mask" ,
286289 shape = param .shape ,
287- initializer = tf . keras .initializers .get ("ones" ),
290+ initializer = keras .initializers .get ("ones" ),
288291 dtype = param .dtype ,
289292 trainable = False ,
290293 )
291294 sparsity = self .add_weight (
292295 "sparsity" ,
293296 shape = [],
294- initializer = tf . keras .initializers .get ("zeros" ),
297+ initializer = keras .initializers .get ("zeros" ),
295298 dtype = param .dtype ,
296299 trainable = False ,
297300 )
298301 pruning_vars .append (MaskedParamInfo (name , param , mask , sparsity ))
299302 return pruning_vars
300303
301- def call (self , inputs : tf .Tensor , training = None ):
304+ def call (self , inputs : tensorflow .Tensor , training = None ):
302305 """
303306 Forward function for calling layer instance as function
304307 """
305- training = tf . keras .backend .learning_phase () if training is None else training
308+ training = keras .backend .learning_phase () if training is None else training
306309
307310 def _apply_masks_to_weights ():
308- with tf .control_dependencies ([self ._mask_updater .apply_masks ()]):
309- return tf .no_op ("update" )
311+ with tensorflow .control_dependencies ([self ._mask_updater .apply_masks ()]):
312+ return tensorflow .no_op ("update" )
310313
311314 def _no_apply_masks_to_weights ():
312- return tf .no_op ("no_update_masks" )
315+ return tensorflow .no_op ("no_update_masks" )
313316
314- tf .cond (
315- tf .cast (training , tf .bool ),
317+ tensorflow .cond (
318+ tensorflow .cast (training , tensorflow .bool ),
316319 _apply_masks_to_weights ,
317320 _no_apply_masks_to_weights ,
318321 )
@@ -327,7 +330,7 @@ def get_config(self):
327330 """
328331 Get layer config
329332 Serialization and deserialization should be done using
330- tf. keras.serialize/deserialize, which create and retrieve the "class_name"
333+ keras.serialize/deserialize, which create and retrieve the "class_name"
331334 field automatically.
332335 The resulting config below therefore does not contain the field.
333336 """
@@ -345,11 +348,11 @@ def get_config(self):
345348 @classmethod
346349 def from_config (cls , config ):
347350 config = config .copy ()
348- layer = tf . keras .layers .deserialize (
351+ layer = keras .layers .deserialize (
349352 config .pop ("layer" ), custom_objects = {"MaskedLayer" : MaskedLayer }
350353 )
351354 if not isinstance (layer , MaskedLayer ) and not isinstance (
352- layer , tf . keras .layers .Layer
355+ layer , keras .layers .Layer
353356 ):
354357 raise RuntimeError ("Unexpected layer created from config" )
355358 pruning_scheduler = PruningScheduler .deserialize (
@@ -384,7 +387,7 @@ def pruning_vars(self):
384387 def pruned_layer (self ):
385388 if isinstance (self ._layer , MaskedLayer ):
386389 return self ._layer .pruned_layer
387- elif isinstance (self ._layer , tf . keras .layers .Layer ):
390+ elif isinstance (self ._layer , keras .layers .Layer ):
388391 return self ._layer
389392 else :
390393 raise RuntimeError ("Unrecognized layer" )
@@ -394,7 +397,7 @@ def masked_layer(self):
394397 return self ._layer
395398
396399
397- def remove_pruning_masks (model : tf . keras .Model ):
400+ def remove_pruning_masks (model : keras .Model ):
398401 """
399402 Remove pruning masks from a model that was pruned using the MaskedLayer logic
400403 :param model: a model that was pruned using MaskedLayer
@@ -410,7 +413,7 @@ def _get_pruned_layer(layer):
410413 ) or layer .__class__ .__name__ .endswith ("MaskedLayer" )
411414 if is_masked_layer :
412415 return _get_pruned_layer (layer .layer )
413- elif isinstance (layer , tf . keras .layers .Layer ):
416+ elif isinstance (layer , keras .layers .Layer ):
414417 return layer
415418 else :
416419 raise ValueError ("Unknown layer type" )
@@ -425,6 +428,6 @@ def _remove_pruning_masks(layer):
425428
426429 # TODO: while the resulting model could be exported to ONNX, its built status
427430 # is removed
428- return tf . keras .models .clone_model (
431+ return keras .models .clone_model (
429432 model , input_tensors = None , clone_function = _remove_pruning_masks
430433 )
0 commit comments