Skip to content

Commit cada1c9

Browse files
seanpmorganWindQAQ
authored andcommitted
Hotfix: Compile Keras WN models without tf_function (#429)
1 parent 651a227 commit cada1c9

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tensorflow_addons/layers/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def build(self, input_shape):
8888

8989
super(WeightNormalization, self).build()
9090

91-
@tf.function
9291
def call(self, inputs):
9392
"""Call `Layer`"""
9493
if not self._initialized:
@@ -131,6 +130,7 @@ def _init_norm(self):
131130
self.g.assign(
132131
tf.reshape(tf.linalg.norm(flat, axis=0), (self.layer_depth,)))
133132

133+
# TODO: Get data init to work with tf_function compile #428
134134
def _data_dep_init(self, inputs):
135135
"""Data dependent initialization."""
136136

tensorflow_addons/layers/wrappers_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,16 @@
2626

2727
@test_utils.run_all_in_graph_and_eager_modes
2828
class WeightNormalizationTest(tf.test.TestCase):
29+
# TODO: Get data init to work with tf_function compile #428
2930
def test_weightnorm_dense_train(self):
3031
model = tf.keras.models.Sequential()
3132
model.add(
3233
wrappers.WeightNormalization(
3334
tf.keras.layers.Dense(2), input_shape=(3, 4)))
3435
model.compile(
3536
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
36-
loss='mse')
37+
loss='mse',
38+
experimental_run_tf_function=False)
3739
model.fit(
3840
np.random.random((10, 3, 4)),
3941
np.random.random((10, 3, 2)),
@@ -58,6 +60,7 @@ def test_weightnorm_dense_train_notinit(self):
5860
self.assertTrue(hasattr(model.layers[0], 'g'))
5961

6062
def test_weightnorm_conv2d(self):
63+
# TODO: Get data init to work with tf_function compile #428
6164
model = tf.keras.models.Sequential()
6265
model.add(
6366
wrappers.WeightNormalization(
@@ -67,7 +70,8 @@ def test_weightnorm_conv2d(self):
6770
model.add(tf.keras.layers.Activation('relu'))
6871
model.compile(
6972
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
70-
loss='mse')
73+
loss='mse',
74+
experimental_run_tf_function=False)
7175
model.fit(
7276
np.random.random((2, 4, 4, 3)),
7377
np.random.random((2, 4, 4, 5)),

0 commit comments

Comments
 (0)