Skip to content

Commit cc89403

Browse files
facaiyseanpmorgan
authored andcommitted
WIP: Fix: WeightNormalization data init fails (#453)
* TST: test cases pass * BUG: fix related bugs * BUG: fix test_weightnorm_keras
1 parent d9ed9d0 commit cc89403

File tree

2 files changed

+64
-58
lines changed

2 files changed

+64
-58
lines changed

tensorflow_addons/layers/wrappers.py

Lines changed: 61 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class WeightNormalization(tf.keras.layers.Wrapper):
5858
def __init__(self, layer, data_init=True, **kwargs):
5959
super(WeightNormalization, self).__init__(layer, **kwargs)
6060
self.data_init = data_init
61-
self._initialized = False
6261
self._track_trackable(layer, name='layer')
6362

6463
def build(self, input_shape):
@@ -69,85 +68,99 @@ def build(self, input_shape):
6968
if not self.layer.built:
7069
self.layer.build(input_shape)
7170

72-
if not hasattr(self.layer, 'kernel'):
73-
raise ValueError('`WeightNormalization` must wrap a layer that'
74-
' contains a `kernel` for weights')
71+
if not hasattr(self.layer, 'kernel'):
72+
raise ValueError('`WeightNormalization` must wrap a layer that'
73+
' contains a `kernel` for weights')
74+
75+
# The kernel's filter or unit dimension is -1
76+
self.layer_depth = int(self.layer.kernel.shape[-1])
77+
self.kernel_norm_axes = list(range(self.layer.kernel.shape.rank - 1))
78+
79+
self.g = self.add_variable(
80+
name='g',
81+
shape=(self.layer_depth,),
82+
initializer='ones',
83+
dtype=self.layer.kernel.dtype,
84+
trainable=True)
85+
self.v = self.layer.kernel
86+
87+
self._initialized = self.add_variable(
88+
name='initialized',
89+
shape=None,
90+
initializer='zeros',
91+
dtype=tf.dtypes.bool,
92+
trainable=False)
7593

76-
# The kernel's filter or unit dimension is -1
77-
self.layer_depth = int(self.layer.kernel.shape[-1])
78-
self.kernel_norm_axes = list(
79-
range(self.layer.kernel.shape.rank - 1))
80-
81-
self.v = self.layer.kernel
82-
self.g = self.add_variable(
83-
name="g",
84-
shape=(self.layer_depth,),
85-
initializer=tf.keras.initializers.get('ones'),
86-
dtype=self.layer.kernel.dtype,
87-
trainable=True)
94+
if self.data_init:
95+
self._naked_layer = tf.keras.layers.deserialize(
96+
tf.keras.layers.serialize(self.layer))
97+
self._naked_layer.build(input_shape)
98+
self._naked_layer.set_weights(self.layer.get_weights())
99+
self._naked_layer.activation = None
88100

89-
super(WeightNormalization, self).build()
101+
self.built = True
90102

91103
def call(self, inputs):
92104
"""Call `Layer`"""
93-
if not self._initialized:
94-
self._initialize_weights(inputs)
95105

96-
self._compute_weights() # Recompute weights for each forward pass
97-
output = self.layer(inputs)
98-
return output
106+
def _do_nothing():
107+
return inputs
99108

100-
def compute_output_shape(self, input_shape):
101-
return tf.TensorShape(
102-
self.layer.compute_output_shape(input_shape).as_list())
109+
def _update_weights():
110+
self._initialize_weights(inputs)
111+
return inputs
103112

104-
def _compute_weights(self):
105-
"""Generate normalized weights.
113+
inputs = tf.cond(self._initialized, _do_nothing, _update_weights)
106114

107-
This method will update the value of self.layer.kernel with the
108-
normalized value, so that the layer is ready for call().
109-
"""
110115
with tf.name_scope('compute_weights'):
116+
# Replace kernel by normalized weight variable.
111117
self.layer.kernel = tf.nn.l2_normalize(
112118
self.v, axis=self.kernel_norm_axes) * self.g
113119

120+
return self.layer(inputs)
121+
122+
def compute_output_shape(self, input_shape):
123+
return tf.TensorShape(
124+
self.layer.compute_output_shape(input_shape).as_list())
125+
114126
def _initialize_weights(self, inputs):
115127
"""Initialize weight g.
116128
117129
The initial value of g could either from the initial value in v,
118130
or by the input value if self.data_init is True.
119131
"""
120-
if self.data_init:
121-
self._data_dep_init(inputs)
122-
else:
123-
self._init_norm()
124-
self._initialized = True
132+
with tf.control_dependencies([
133+
tf.debugging.assert_equal( # pylint: disable=bad-continuation
134+
self._initialized,
135+
False,
136+
message='The layer has been initialized.')
137+
]):
138+
if self.data_init:
139+
self._data_dep_init(inputs)
140+
else:
141+
self._init_norm()
142+
self._initialized.assign(True)
125143

126144
def _init_norm(self):
127145
"""Set the weight g with the norm of the weight vector."""
128146
with tf.name_scope('init_norm'):
129-
flat = tf.reshape(self.v, [-1, self.layer_depth])
130-
self.g.assign(
131-
tf.reshape(tf.linalg.norm(flat, axis=0), (self.layer_depth,)))
147+
v_flat = tf.reshape(self.v, [-1, self.layer_depth])
148+
v_norm = tf.linalg.norm(v_flat, axis=0)
149+
self.g.assign(tf.reshape(v_norm, (self.layer_depth,)))
132150

133-
# TODO: Get data init to work with tf_function compile #428
134151
def _data_dep_init(self, inputs):
135152
"""Data dependent initialization."""
136-
137153
with tf.name_scope('data_dep_init'):
138154
# Generate data dependent init values
139-
existing_activation = self.layer.activation
140-
self.layer.activation = None
141-
x_init = self.layer(inputs)
155+
x_init = self._naked_layer(inputs)
142156
data_norm_axes = list(range(x_init.shape.rank - 1))
143157
m_init, v_init = tf.nn.moments(x_init, data_norm_axes)
144158
scale_init = 1. / tf.math.sqrt(v_init + 1e-10)
145159

146-
# Assign data dependent init values
147-
self.g = self.g * scale_init
148-
if hasattr(self.layer, 'bias'):
149-
self.layer.bias = -m_init * scale_init
150-
self.layer.activation = existing_activation
160+
# Assign data dependent init values
161+
self.g.assign(self.g * scale_init)
162+
if hasattr(self.layer, 'bias'):
163+
self.layer.bias.assign(-m_init * scale_init)
151164

152165
def get_config(self):
153166
config = {'data_init': self.data_init}

tensorflow_addons/layers/wrappers_test.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,14 @@
2626

2727
@test_utils.run_all_in_graph_and_eager_modes
2828
class WeightNormalizationTest(tf.test.TestCase):
29-
# TODO: Get data init to work with tf_function compile #428
3029
def test_weightnorm_dense_train(self):
3130
model = tf.keras.models.Sequential()
3231
model.add(
3332
wrappers.WeightNormalization(
3433
tf.keras.layers.Dense(2), input_shape=(3, 4)))
3534
model.compile(
3635
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
37-
loss='mse',
38-
experimental_run_tf_function=False)
36+
loss='mse')
3937
model.fit(
4038
np.random.random((10, 3, 4)),
4139
np.random.random((10, 3, 2)),
@@ -60,7 +58,6 @@ def test_weightnorm_dense_train_notinit(self):
6058
self.assertTrue(hasattr(model.layers[0], 'g'))
6159

6260
def test_weightnorm_conv2d(self):
63-
# TODO: Get data init to work with tf_function compile #428
6461
model = tf.keras.models.Sequential()
6562
model.add(
6663
wrappers.WeightNormalization(
@@ -70,8 +67,7 @@ def test_weightnorm_conv2d(self):
7067
model.add(tf.keras.layers.Activation('relu'))
7168
model.compile(
7269
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
73-
loss='mse',
74-
experimental_run_tf_function=False)
70+
loss='mse')
7571
model.fit(
7672
np.random.random((2, 4, 4, 3)),
7773
np.random.random((2, 4, 4, 5)),
@@ -105,10 +101,7 @@ def test_weightnorm_keras(self):
105101
'layer': tf.keras.layers.Dense(2),
106102
'input_shape': (3, 4)
107103
},
108-
input_data=input_data,
109-
# TODO: Fix the bug thats causing layer test to run a
110-
# graph Tensor in eager mode.
111-
validate_training=False)
104+
input_data=input_data)
112105

113106

114107
if __name__ == "__main__":

0 commit comments

Comments
 (0)