Skip to content

Commit 4278e3b

Browse files
authored
Correlation metrics keras2.10 fix (#2751)
* Fix for keras/tf v2.10
1 parent b756950 commit 4278e3b

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tensorflow_addons/metrics/tests/streaming_correlations_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,13 @@ def test_keras_binary_classification_model(self, correlation_type):
113113
model(x)[:, 0], y[:, 0]
114114
)[0]
115115

116-
history = model.fit(x, y, epochs=1, verbose=0, batch_size=32)
116+
history = model.fit(
117+
x, y, epochs=1, verbose=0, batch_size=32, validation_data=(x, y)
118+
)
117119

118120
# the training should increase the correlation metric
119-
assert np.all(history.history[metric.name] > initial_correlation)
121+
metric_history = history.history["val_" + metric.name]
122+
assert np.all(metric_history > initial_correlation)
120123

121124
preds = model(x)
122125
metric.reset_state()
@@ -125,6 +128,7 @@ def test_keras_binary_classification_model(self, correlation_type):
125128
tf.function(metric.update_state)(y, preds)
126129
metric_value = tf.function(metric.result)()
127130
scipy_value = self.scipy_corr[correlation_type](preds[:, 0], y[:, 0])[0]
131+
np.testing.assert_almost_equal(metric_value, metric_history[-1])
128132
np.testing.assert_almost_equal(metric_value, scipy_value, decimal=2)
129133

130134
@pytest.mark.parametrize("correlation_type", testing_types)

0 commit comments

Comments
 (0)