Skip to content

Commit 377edb7

Browse files
authored
FIX: weightnorm variables (#219)
* FIX: weightnorm variables
1 parent f45f606 commit 377edb7

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

tensorflow_addons/layers/wrappers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def build(self, input_shape):
8888

8989
super(WeightNormalization, self).build()
9090

91+
@tf.function
9192
def call(self, inputs):
9293
"""Call `Layer`"""
9394
if not self._initialized:
@@ -143,9 +144,9 @@ def _data_dep_init(self, inputs):
143144
scale_init = 1. / tf.math.sqrt(v_init + 1e-10)
144145

145146
# Assign data dependent init values
146-
self.g.assign(self.g * scale_init)
147+
self.g = self.g * scale_init
147148
if hasattr(self.layer, 'bias'):
148-
self.layer.bias.assign(-m_init * scale_init)
149+
self.layer.bias = -m_init * scale_init
149150
self.layer.activation = existing_activation
150151

151152
def get_config(self):

tensorflow_addons/layers/wrappers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_weightnorm_conv2d(self):
7575

7676
self.assertTrue(hasattr(model.layers[0], 'g'))
7777

78-
def test_weightnorm_tflayers(self):
78+
def test_weightnorm_applylayer(self):
7979
images = tf.random.uniform((2, 4, 4, 3))
8080
wn_wrapper = wrappers.WeightNormalization(
8181
tf.keras.layers.Conv2D(32, [2, 2]), input_shape=(4, 4, 3))

0 commit comments

Comments
 (0)