Skip to content

Commit 7e0f343

Browse files
facaiyWindQAQ
authored andcommitted
bugfix: force WeightNormalization to execute in order (#458)
* BUG: force code to execute in order * CLN: clean codes * TST: clean
1 parent f540233 commit 7e0f343

File tree

2 files changed

+72
-90
lines changed

2 files changed

+72
-90
lines changed

tensorflow_addons/layers/wrappers.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,32 +92,39 @@ def build(self, input_shape):
9292
trainable=False)
9393

9494
if self.data_init:
95-
self._naked_layer = tf.keras.layers.deserialize(
96-
tf.keras.layers.serialize(self.layer))
97-
self._naked_layer.build(input_shape)
98-
self._naked_layer.set_weights(self.layer.get_weights())
99-
self._naked_layer.activation = None
95+
# Used for data initialization in self._data_dep_init.
96+
layer_config = tf.keras.layers.serialize(self.layer)
97+
layer_config['config']['trainable'] = False
98+
self._naked_clone_layer = tf.keras.layers.deserialize(layer_config)
99+
self._naked_clone_layer.build(input_shape)
100+
self._naked_clone_layer.set_weights(self.layer.get_weights())
101+
self._naked_clone_layer.activation = None
100102

101103
self.built = True
102104

103105
def call(self, inputs):
104106
"""Call `Layer`"""
105107

106108
def _do_nothing():
107-
return inputs
109+
return tf.identity(self.g)
108110

109111
def _update_weights():
110-
self._initialize_weights(inputs)
111-
return inputs
112+
# Ensure we read `self.g` after _update_weights.
113+
with tf.control_dependencies(self._initialize_weights(inputs)):
114+
return tf.identity(self.g)
112115

113-
inputs = tf.cond(self._initialized, _do_nothing, _update_weights)
116+
g = tf.cond(self._initialized, _do_nothing, _update_weights)
114117

115118
with tf.name_scope('compute_weights'):
116119
# Replace kernel by normalized weight variable.
117120
self.layer.kernel = tf.nn.l2_normalize(
118-
self.v, axis=self.kernel_norm_axes) * self.g
121+
self.v, axis=self.kernel_norm_axes) * g
119122

120-
return self.layer(inputs)
123+
# Ensure we calculate result after updating kernel.
124+
update_kernel = tf.identity(self.layer.kernel)
125+
with tf.control_dependencies([update_kernel]):
126+
outputs = self.layer(inputs)
127+
return outputs
121128

122129
def compute_output_shape(self, input_shape):
123130
return tf.TensorShape(
@@ -136,31 +143,36 @@ def _initialize_weights(self, inputs):
136143
message='The layer has been initialized.')
137144
]):
138145
if self.data_init:
139-
self._data_dep_init(inputs)
146+
assign_tensors = self._data_dep_init(inputs)
140147
else:
141-
self._init_norm()
142-
self._initialized.assign(True)
148+
assign_tensors = self._init_norm()
149+
assign_tensors.append(self._initialized.assign(True))
150+
return assign_tensors
143151

144152
def _init_norm(self):
145153
"""Set the weight g with the norm of the weight vector."""
146154
with tf.name_scope('init_norm'):
147155
v_flat = tf.reshape(self.v, [-1, self.layer_depth])
148156
v_norm = tf.linalg.norm(v_flat, axis=0)
149-
self.g.assign(tf.reshape(v_norm, (self.layer_depth,)))
157+
g_tensor = self.g.assign(tf.reshape(v_norm, (self.layer_depth,)))
158+
return [g_tensor]
150159

151160
def _data_dep_init(self, inputs):
152161
"""Data dependent initialization."""
153162
with tf.name_scope('data_dep_init'):
154163
# Generate data dependent init values
155-
x_init = self._naked_layer(inputs)
164+
x_init = self._naked_clone_layer(inputs)
156165
data_norm_axes = list(range(x_init.shape.rank - 1))
157166
m_init, v_init = tf.nn.moments(x_init, data_norm_axes)
158167
scale_init = 1. / tf.math.sqrt(v_init + 1e-10)
159168

160169
# Assign data dependent init values
161-
self.g.assign(self.g * scale_init)
170+
g_tensor = self.g.assign(self.g * scale_init)
162171
if hasattr(self.layer, 'bias'):
163-
self.layer.bias.assign(-m_init * scale_init)
172+
bias_tensor = self.layer.bias.assign(-m_init * scale_init)
173+
return [g_tensor, bias_tensor]
174+
else:
175+
return [g_tensor]
164176

165177
def get_config(self):
166178
config = {'data_init': self.data_init}

tensorflow_addons/layers/wrappers_test.py

Lines changed: 42 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -26,82 +26,52 @@
2626

2727
@test_utils.run_all_in_graph_and_eager_modes
2828
class WeightNormalizationTest(tf.test.TestCase):
29-
def test_weightnorm_dense_train(self):
30-
model = tf.keras.models.Sequential()
31-
model.add(
32-
wrappers.WeightNormalization(
33-
tf.keras.layers.Dense(2), input_shape=(3, 4)))
34-
model.compile(
35-
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
36-
loss='mse')
37-
model.fit(
38-
np.random.random((10, 3, 4)),
39-
np.random.random((10, 3, 2)),
40-
epochs=3,
41-
batch_size=10)
42-
self.assertTrue(hasattr(model.layers[0], 'g'))
43-
44-
def test_weightnorm_dense_train_notinit(self):
45-
model = tf.keras.models.Sequential()
46-
model.add(
47-
wrappers.WeightNormalization(
48-
tf.keras.layers.Dense(2), input_shape=(3, 4), data_init=False))
49-
50-
model.compile(
51-
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
52-
loss='mse')
53-
model.fit(
54-
np.random.random((10, 3, 4)),
55-
np.random.random((10, 3, 2)),
56-
epochs=3,
57-
batch_size=10)
58-
self.assertTrue(hasattr(model.layers[0], 'g'))
59-
60-
def test_weightnorm_conv2d(self):
61-
model = tf.keras.models.Sequential()
62-
model.add(
63-
wrappers.WeightNormalization(
64-
tf.keras.layers.Conv2D(5, (2, 2), padding='same'),
65-
input_shape=(4, 4, 3)))
66-
67-
model.add(tf.keras.layers.Activation('relu'))
68-
model.compile(
69-
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
70-
loss='mse')
71-
model.fit(
72-
np.random.random((2, 4, 4, 3)),
73-
np.random.random((2, 4, 4, 5)),
74-
epochs=3,
75-
batch_size=10)
76-
77-
self.assertTrue(hasattr(model.layers[0], 'g'))
78-
79-
def test_weightnorm_applylayer(self):
80-
images = tf.random.uniform((2, 4, 4, 3))
81-
wn_wrapper = wrappers.WeightNormalization(
82-
tf.keras.layers.Conv2D(32, [2, 2]), input_shape=(4, 4, 3))
83-
wn_wrapper.apply(images)
84-
self.assertTrue(hasattr(wn_wrapper, 'g'))
85-
86-
def test_weightnorm_nonlayer(self):
87-
images = tf.random.uniform((2, 4, 43))
88-
with self.assertRaises(AssertionError):
89-
wrappers.WeightNormalization(images)
90-
91-
def test_weightnorm_nokernel(self):
92-
with self.assertRaises(ValueError):
93-
wrappers.WeightNormalization(tf.keras.layers.MaxPooling2D(
94-
2, 2)).build((2, 2))
95-
96-
def test_weightnorm_keras(self):
97-
input_data = np.random.random((10, 3, 4)).astype(np.float32)
29+
def test_weightnorm(self):
30+
test_utils.layer_test(
31+
wrappers.WeightNormalization,
32+
kwargs={
33+
'layer': tf.keras.layers.Conv2D(5, (2, 2)),
34+
},
35+
input_shape=(2, 4, 4, 3))
36+
37+
def _check_data_init(self, data_init, input_data, expected_output):
38+
layer = tf.keras.layers.Dense(
39+
input_data.shape[-1],
40+
activation=None,
41+
kernel_initializer='identity',
42+
bias_initializer='zeros')
9843
test_utils.layer_test(
9944
wrappers.WeightNormalization,
10045
kwargs={
101-
'layer': tf.keras.layers.Dense(2),
102-
'input_shape': (3, 4)
46+
'layer': layer,
47+
'data_init': data_init,
10348
},
104-
input_data=input_data)
49+
input_data=input_data,
50+
expected_output=expected_output)
51+
52+
def test_weightnorm_with_data_init_is_false(self):
53+
input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32)
54+
self._check_data_init(
55+
data_init=False, input_data=input_data, expected_output=input_data)
56+
57+
def test_weightnorm_with_data_init_is_true(self):
58+
input_data = np.array([[[-4, -4], [4, 4]]], dtype=np.float32)
59+
self._check_data_init(
60+
data_init=True,
61+
input_data=input_data,
62+
expected_output=input_data / 4)
63+
64+
def test_weightnorm_non_layer(self):
65+
images = tf.random.uniform((2, 4, 43))
66+
with self.assertRaises(AssertionError):
67+
wrappers.WeightNormalization(images)
68+
69+
def test_weightnorm_non_kernel_layer(self):
70+
images = tf.random.uniform((2, 2, 2))
71+
with self.assertRaisesRegexp(ValueError, 'contains a `kernel`'):
72+
non_kernel_layer = tf.keras.layers.MaxPooling2D(2, 2)
73+
wn_wrapper = wrappers.WeightNormalization(non_kernel_layer)
74+
wn_wrapper(images)
10575

10676

10777
if __name__ == "__main__":

0 commit comments

Comments
 (0)