Skip to content

Commit 44065d2

Browse files
Fixed one of the tests of cyclycal learning rate. (#1606)
1 parent 6be19d2 commit 44065d2

File tree

1 file changed

+28
-29
lines changed

1 file changed

+28
-29
lines changed

tensorflow_addons/optimizers/tests/cyclical_learning_rate_test.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from absl.testing import parameterized
1818

19+
import pytest
1920
import tensorflow as tf
2021
from tensorflow_addons.utils import test_utils
2122
import numpy as np
@@ -25,42 +26,40 @@
2526

2627
def _maybe_serialized(lr_decay, serialize_and_deserialize):
2728
if serialize_and_deserialize:
28-
serialized = tf.keras.optimizers.learning_rate_schedule.serialize(lr_decay)
29-
return tf.keras.optimizers.learning_rate_schedule.deserialize(serialized)
29+
serialized = tf.keras.optimizers.schedules.serialize(lr_decay)
30+
return tf.keras.optimizers.schedules.deserialize(serialized)
3031
else:
3132
return lr_decay
3233

3334

34-
@test_utils.run_all_in_graph_and_eager_modes
35-
@parameterized.named_parameters(("NotSerialized", False), ("Serialized", True))
36-
class CyclicalLearningRateTest(tf.test.TestCase, parameterized.TestCase):
37-
def testTriangularCyclicalLearningRate(self, serialize):
38-
self.skipTest("Failing. See https://github.com/tensorflow/addons/issues/1203")
39-
initial_learning_rate = 0.1
40-
maximal_learning_rate = 1
41-
step_size = 4000
42-
step = tf.resource_variable_ops.ResourceVariable(0)
43-
triangular_cyclical_lr = cyclical_learning_rate.TriangularCyclicalLearningRate(
44-
initial_learning_rate=initial_learning_rate,
45-
maximal_learning_rate=maximal_learning_rate,
46-
step_size=step_size,
47-
)
48-
triangular_cyclical_lr = _maybe_serialized(triangular_cyclical_lr, serialize)
35+
@pytest.mark.parametrize("serialize", [True, False])
36+
def test_triangular_cyclical_learning_rate(serialize):
37+
initial_learning_rate = 0.1
38+
max_learning_rate = 1
39+
step_size = 40
40+
triangular_cyclical_lr = cyclical_learning_rate.TriangularCyclicalLearningRate(
41+
initial_learning_rate=initial_learning_rate,
42+
maximal_learning_rate=max_learning_rate,
43+
step_size=step_size,
44+
)
45+
triangular_cyclical_lr = _maybe_serialized(triangular_cyclical_lr, serialize)
4946

50-
self.evaluate(tf.compat.v1.global_variables_initializer())
51-
expected = np.concatenate(
52-
[
53-
np.linspace(initial_learning_rate, maximal_learning_rate, num=2001)[1:],
54-
np.linspace(maximal_learning_rate, initial_learning_rate, num=2001)[1:],
55-
]
56-
)
47+
expected = np.concatenate(
48+
[
49+
np.linspace(initial_learning_rate, max_learning_rate, num=step_size + 1),
50+
np.linspace(max_learning_rate, initial_learning_rate, num=step_size + 1)[
51+
1:
52+
],
53+
]
54+
)
55+
56+
for step, expected_value in enumerate(expected):
57+
np.testing.assert_allclose(triangular_cyclical_lr(step), expected_value, 1e-6)
5758

58-
for expected_value in expected:
59-
self.assertAllClose(
60-
self.evaluate(triangular_cyclical_lr(step)), expected_value, 1e-6
61-
)
62-
self.evaluate(step.assign_add(1))
6359

60+
@test_utils.run_all_in_graph_and_eager_modes
61+
@parameterized.named_parameters(("NotSerialized", False), ("Serialized", True))
62+
class CyclicalLearningRateTest(tf.test.TestCase, parameterized.TestCase):
6463
def testTriangular2CyclicalLearningRate(self, serialize):
6564
self.skipTest("Failing. See https://github.com/tensorflow/addons/issues/1203")
6665
initial_learning_rate = 0.1

0 commit comments

Comments
 (0)