@@ -58,7 +58,6 @@ class WeightNormalization(tf.keras.layers.Wrapper):
5858 def __init__ (self , layer , data_init = True , ** kwargs ):
5959 super (WeightNormalization , self ).__init__ (layer , ** kwargs )
6060 self .data_init = data_init
61- self ._initialized = False
6261 self ._track_trackable (layer , name = 'layer' )
6362
6463 def build (self , input_shape ):
@@ -69,85 +68,99 @@ def build(self, input_shape):
6968 if not self .layer .built :
7069 self .layer .build (input_shape )
7170
72- if not hasattr (self .layer , 'kernel' ):
73- raise ValueError ('`WeightNormalization` must wrap a layer that'
74- ' contains a `kernel` for weights' )
71+ if not hasattr (self .layer , 'kernel' ):
72+ raise ValueError ('`WeightNormalization` must wrap a layer that'
73+ ' contains a `kernel` for weights' )
74+
75+ # The kernel's filter or unit dimension is -1
76+ self .layer_depth = int (self .layer .kernel .shape [- 1 ])
77+ self .kernel_norm_axes = list (range (self .layer .kernel .shape .rank - 1 ))
78+
79+ self .g = self .add_variable (
80+ name = 'g' ,
81+ shape = (self .layer_depth ,),
82+ initializer = 'ones' ,
83+ dtype = self .layer .kernel .dtype ,
84+ trainable = True )
85+ self .v = self .layer .kernel
86+
87+ self ._initialized = self .add_variable (
88+ name = 'initialized' ,
89+ shape = None ,
90+ initializer = 'zeros' ,
91+ dtype = tf .dtypes .bool ,
92+ trainable = False )
7593
76- # The kernel's filter or unit dimension is -1
77- self .layer_depth = int (self .layer .kernel .shape [- 1 ])
78- self .kernel_norm_axes = list (
79- range (self .layer .kernel .shape .rank - 1 ))
80-
81- self .v = self .layer .kernel
82- self .g = self .add_variable (
83- name = "g" ,
84- shape = (self .layer_depth ,),
85- initializer = tf .keras .initializers .get ('ones' ),
86- dtype = self .layer .kernel .dtype ,
87- trainable = True )
94+ if self .data_init :
95+ self ._naked_layer = tf .keras .layers .deserialize (
96+ tf .keras .layers .serialize (self .layer ))
97+ self ._naked_layer .build (input_shape )
98+ self ._naked_layer .set_weights (self .layer .get_weights ())
99+ self ._naked_layer .activation = None
88100
89- super ( WeightNormalization , self ). build ()
101+ self . built = True
90102
91103 def call (self , inputs ):
92104 """Call `Layer`"""
93- if not self ._initialized :
94- self ._initialize_weights (inputs )
95105
96- self ._compute_weights () # Recompute weights for each forward pass
97- output = self .layer (inputs )
98- return output
106+ def _do_nothing ():
107+ return inputs
99108
100- def compute_output_shape ( self , input_shape ):
101- return tf . TensorShape (
102- self . layer . compute_output_shape ( input_shape ). as_list ())
109+ def _update_weights ( ):
110+ self . _initialize_weights ( inputs )
111+ return inputs
103112
104- def _compute_weights (self ):
105- """Generate normalized weights.
113+ inputs = tf .cond (self ._initialized , _do_nothing , _update_weights )
106114
107- This method will update the value of self.layer.kernel with the
108- normalized value, so that the layer is ready for call().
109- """
110115 with tf .name_scope ('compute_weights' ):
116+ # Replace kernel by normalized weight variable.
111117 self .layer .kernel = tf .nn .l2_normalize (
112118 self .v , axis = self .kernel_norm_axes ) * self .g
113119
120+ return self .layer (inputs )
121+
122+ def compute_output_shape (self , input_shape ):
123+ return tf .TensorShape (
124+ self .layer .compute_output_shape (input_shape ).as_list ())
125+
114126 def _initialize_weights (self , inputs ):
115127 """Initialize weight g.
116128
117129 The initial value of g could either from the initial value in v,
118130 or by the input value if self.data_init is True.
119131 """
120- if self .data_init :
121- self ._data_dep_init (inputs )
122- else :
123- self ._init_norm ()
124- self ._initialized = True
132+ with tf .control_dependencies ([
133+ tf .debugging .assert_equal ( # pylint: disable=bad-continuation
134+ self ._initialized ,
135+ False ,
136+ message = 'The layer has been initialized.' )
137+ ]):
138+ if self .data_init :
139+ self ._data_dep_init (inputs )
140+ else :
141+ self ._init_norm ()
142+ self ._initialized .assign (True )
125143
126144 def _init_norm (self ):
127145 """Set the weight g with the norm of the weight vector."""
128146 with tf .name_scope ('init_norm' ):
129- flat = tf .reshape (self .v , [- 1 , self .layer_depth ])
130- self . g . assign (
131- tf . reshape (tf .linalg . norm ( flat , axis = 0 ) , (self .layer_depth ,)))
147+ v_flat = tf .reshape (self .v , [- 1 , self .layer_depth ])
148+ v_norm = tf . linalg . norm ( v_flat , axis = 0 )
149+ self . g . assign (tf .reshape ( v_norm , (self .layer_depth ,)))
132150
133- # TODO: Get data init to work with tf_function compile #428
134151 def _data_dep_init (self , inputs ):
135152 """Data dependent initialization."""
136-
137153 with tf .name_scope ('data_dep_init' ):
138154 # Generate data dependent init values
139- existing_activation = self .layer .activation
140- self .layer .activation = None
141- x_init = self .layer (inputs )
155+ x_init = self ._naked_layer (inputs )
142156 data_norm_axes = list (range (x_init .shape .rank - 1 ))
143157 m_init , v_init = tf .nn .moments (x_init , data_norm_axes )
144158 scale_init = 1. / tf .math .sqrt (v_init + 1e-10 )
145159
146- # Assign data dependent init values
147- self .g = self .g * scale_init
148- if hasattr (self .layer , 'bias' ):
149- self .layer .bias = - m_init * scale_init
150- self .layer .activation = existing_activation
160+ # Assign data dependent init values
161+ self .g .assign (self .g * scale_init )
162+ if hasattr (self .layer , 'bias' ):
163+ self .layer .bias .assign (- m_init * scale_init )
151164
152165 def get_config (self ):
153166 config = {'data_init' : self .data_init }
0 commit comments