@@ -346,6 +346,52 @@ def test_groupnorm_convnet_no_center_no_scale():
346346 )
347347
348348
349+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
350+ @pytest .mark .parametrize ("center" , [True , False ])
351+ @pytest .mark .parametrize ("scale" , [True , False ])
352+ def test_group_norm_compute_output_shape (center , scale ):
353+
354+ target_variables_len = [center , scale ].count (True )
355+ target_trainable_variables_len = [center , scale ].count (True )
356+ layer1 = GroupNormalization (groups = 2 , center = center , scale = scale )
357+ layer1 .build (input_shape = [8 , 28 , 28 , 16 ]) # build()
358+ assert len (layer1 .variables ) == target_variables_len
359+ assert len (layer1 .trainable_variables ) == target_trainable_variables_len
360+
361+ layer2 = GroupNormalization (groups = 2 , center = center , scale = scale )
362+ layer2 .compute_output_shape (input_shape = [8 , 28 , 28 , 16 ]) # compute_output_shape()
363+ assert len (layer2 .variables ) == target_variables_len
364+ assert len (layer2 .trainable_variables ) == target_trainable_variables_len
365+
366+ layer3 = GroupNormalization (groups = 2 , center = center , scale = scale )
367+ layer3 (tf .random .normal (shape = [8 , 28 , 28 , 16 ])) # call()
368+ assert len (layer3 .variables ) == target_variables_len
369+ assert len (layer3 .trainable_variables ) == target_trainable_variables_len
370+
371+
372+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
373+ @pytest .mark .parametrize ("center" , [True , False ])
374+ @pytest .mark .parametrize ("scale" , [True , False ])
375+ def test_instance_norm_compute_output_shape (center , scale ):
376+
377+ target_variables_len = [center , scale ].count (True )
378+ target_trainable_variables_len = [center , scale ].count (True )
379+ layer1 = InstanceNormalization (groups = 2 , center = center , scale = scale )
380+ layer1 .build (input_shape = [8 , 28 , 28 , 16 ]) # build()
381+ assert len (layer1 .variables ) == target_variables_len
382+ assert len (layer1 .trainable_variables ) == target_trainable_variables_len
383+
384+ layer2 = InstanceNormalization (groups = 2 , center = center , scale = scale )
385+ layer2 .compute_output_shape (input_shape = [8 , 28 , 28 , 16 ]) # compute_output_shape()
386+ assert len (layer2 .variables ) == target_variables_len
387+ assert len (layer2 .trainable_variables ) == target_trainable_variables_len
388+
389+ layer3 = InstanceNormalization (groups = 2 , center = center , scale = scale )
390+ layer3 (tf .random .normal (shape = [8 , 28 , 28 , 16 ])) # call()
391+ assert len (layer3 .variables ) == target_variables_len
392+ assert len (layer3 .trainable_variables ) == target_trainable_variables_len
393+
394+
349395def calculate_frn (
350396 x , beta = 0.2 , gamma = 1 , eps = 1e-6 , learned_epsilon = False , dtype = np .float32
351397):
@@ -471,3 +517,23 @@ def test_filter_response_normalization_save(tmpdir):
471517 model .save (filepath , save_format = "h5" )
472518 filepath = str (tmpdir / "test" )
473519 model .save (filepath , save_format = "tf" )
520+
521+
522+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
523+ def test_filter_response_norm_compute_output_shape ():
524+ target_variables_len = 2
525+ target_trainable_variables_len = 2
526+ layer1 = FilterResponseNormalization ()
527+ layer1 .build (input_shape = [8 , 28 , 28 , 16 ]) # build()
528+ assert len (layer1 .variables ) == target_variables_len
529+ assert len (layer1 .trainable_variables ) == target_trainable_variables_len
530+
531+ layer2 = FilterResponseNormalization ()
532+ layer2 .compute_output_shape (input_shape = [8 , 28 , 28 , 16 ]) # compute_output_shape()
533+ assert len (layer2 .variables ) == target_variables_len
534+ assert len (layer2 .trainable_variables ) == target_trainable_variables_len
535+
536+ layer3 = FilterResponseNormalization ()
537+ layer3 (tf .random .normal (shape = [8 , 28 , 28 , 16 ])) # call()
538+ assert len (layer3 .variables ) == target_variables_len
539+ assert len (layer3 .trainable_variables ) == target_trainable_variables_len
0 commit comments