Skip to content

Commit 18c8367

Browse files
authored
fix compute_output_shape behavior (#2678)
1 parent da14c3b commit 18c8367

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

tensorflow_addons/layers/normalizations.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,6 @@ def get_config(self):
155155
base_config = super().get_config()
156156
return {**base_config, **config}
157157

158-
def compute_output_shape(self, input_shape):
159-
return input_shape
160-
161158
def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):
162159

163160
group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
@@ -447,9 +444,6 @@ def call(self, inputs):
447444
normalized_inputs = inputs * tf.math.rsqrt(nu2 + epsilon)
448445
return self.gamma * normalized_inputs + self.beta
449446

450-
def compute_output_shape(self, input_shape):
451-
return input_shape
452-
453447
def get_config(self):
454448
config = {
455449
"axis": self.axis,

tensorflow_addons/layers/tests/normalizations_test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,52 @@ def test_groupnorm_convnet_no_center_no_scale():
346346
)
347347

348348

349+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
350+
@pytest.mark.parametrize("center", [True, False])
351+
@pytest.mark.parametrize("scale", [True, False])
352+
def test_group_norm_compute_output_shape(center, scale):
353+
354+
target_variables_len = [center, scale].count(True)
355+
target_trainable_variables_len = [center, scale].count(True)
356+
layer1 = GroupNormalization(groups=2, center=center, scale=scale)
357+
layer1.build(input_shape=[8, 28, 28, 16]) # build()
358+
assert len(layer1.variables) == target_variables_len
359+
assert len(layer1.trainable_variables) == target_trainable_variables_len
360+
361+
layer2 = GroupNormalization(groups=2, center=center, scale=scale)
362+
layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape()
363+
assert len(layer2.variables) == target_variables_len
364+
assert len(layer2.trainable_variables) == target_trainable_variables_len
365+
366+
layer3 = GroupNormalization(groups=2, center=center, scale=scale)
367+
layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call()
368+
assert len(layer3.variables) == target_variables_len
369+
assert len(layer3.trainable_variables) == target_trainable_variables_len
370+
371+
372+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
373+
@pytest.mark.parametrize("center", [True, False])
374+
@pytest.mark.parametrize("scale", [True, False])
375+
def test_instance_norm_compute_output_shape(center, scale):
376+
377+
target_variables_len = [center, scale].count(True)
378+
target_trainable_variables_len = [center, scale].count(True)
379+
layer1 = InstanceNormalization(groups=2, center=center, scale=scale)
380+
layer1.build(input_shape=[8, 28, 28, 16]) # build()
381+
assert len(layer1.variables) == target_variables_len
382+
assert len(layer1.trainable_variables) == target_trainable_variables_len
383+
384+
layer2 = InstanceNormalization(groups=2, center=center, scale=scale)
385+
layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape()
386+
assert len(layer2.variables) == target_variables_len
387+
assert len(layer2.trainable_variables) == target_trainable_variables_len
388+
389+
layer3 = InstanceNormalization(groups=2, center=center, scale=scale)
390+
layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call()
391+
assert len(layer3.variables) == target_variables_len
392+
assert len(layer3.trainable_variables) == target_trainable_variables_len
393+
394+
349395
def calculate_frn(
350396
x, beta=0.2, gamma=1, eps=1e-6, learned_epsilon=False, dtype=np.float32
351397
):
@@ -471,3 +517,23 @@ def test_filter_response_normalization_save(tmpdir):
471517
model.save(filepath, save_format="h5")
472518
filepath = str(tmpdir / "test")
473519
model.save(filepath, save_format="tf")
520+
521+
522+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
523+
def test_filter_response_norm_compute_output_shape():
524+
target_variables_len = 2
525+
target_trainable_variables_len = 2
526+
layer1 = FilterResponseNormalization()
527+
layer1.build(input_shape=[8, 28, 28, 16]) # build()
528+
assert len(layer1.variables) == target_variables_len
529+
assert len(layer1.trainable_variables) == target_trainable_variables_len
530+
531+
layer2 = FilterResponseNormalization()
532+
layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape()
533+
assert len(layer2.variables) == target_variables_len
534+
assert len(layer2.trainable_variables) == target_trainable_variables_len
535+
536+
layer3 = FilterResponseNormalization()
537+
layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call()
538+
assert len(layer3.variables) == target_variables_len
539+
assert len(layer3.trainable_variables) == target_trainable_variables_len

0 commit comments

Comments
 (0)