2626
2727@test_utils .run_all_in_graph_and_eager_modes
2828class WeightNormalizationTest (tf .test .TestCase ):
29+ # TODO: Get data init to work with tf_function compile #428
2930 def test_weightnorm_dense_train (self ):
3031 model = tf .keras .models .Sequential ()
3132 model .add (
3233 wrappers .WeightNormalization (
3334 tf .keras .layers .Dense (2 ), input_shape = (3 , 4 )))
3435 model .compile (
3536 optimizer = tf .keras .optimizers .RMSprop (learning_rate = 0.001 ),
36- loss = 'mse' )
37+ loss = 'mse' ,
38+ experimental_run_tf_function = False )
3739 model .fit (
3840 np .random .random ((10 , 3 , 4 )),
3941 np .random .random ((10 , 3 , 2 )),
@@ -58,6 +60,7 @@ def test_weightnorm_dense_train_notinit(self):
5860 self .assertTrue (hasattr (model .layers [0 ], 'g' ))
5961
6062 def test_weightnorm_conv2d (self ):
63+ # TODO: Get data init to work with tf_function compile #428
6164 model = tf .keras .models .Sequential ()
6265 model .add (
6366 wrappers .WeightNormalization (
@@ -67,7 +70,8 @@ def test_weightnorm_conv2d(self):
6770 model .add (tf .keras .layers .Activation ('relu' ))
6871 model .compile (
6972 optimizer = tf .keras .optimizers .RMSprop (learning_rate = 0.001 ),
70- loss = 'mse' )
73+ loss = 'mse' ,
74+ experimental_run_tf_function = False )
7175 model .fit (
7276 np .random .random ((2 , 4 , 4 , 3 )),
7377 np .random .random ((2 , 4 , 4 , 5 )),
0 commit comments