Skip to content

Commit ebc43b5

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[numpy] Fix test failures under NumPy 2.1.
PiperOrigin-RevId: 681078015
1 parent c6c86e7 commit ebc43b5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorflow_probability/python/internal/test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def assertAllInRange(self,
613613
'The value of %s does not have an ordered numeric type, instead it '
614614
'has type: %s' % (target, target.dtype))
615615

616-
nan_subscripts = np.where(np.isnan(target))
616+
nan_subscripts = np.where(np.atleast_1d(np.isnan(target)))
617617
if np.size(nan_subscripts):
618618
raise AssertionError(
619619
'%d of the %d element(s) are NaN. '
@@ -631,7 +631,7 @@ def assertAllInRange(self,
631631
violations,
632632
np.greater_equal(target, upper_bound)
633633
if open_upper_bound else np.greater(target, upper_bound))
634-
violation_subscripts = np.where(violations)
634+
violation_subscripts = np.where(np.atleast_1d(violations))
635635
if np.size(violation_subscripts):
636636
raise AssertionError(
637637
'%d of the %d element(s) are outside the range %s. ' %

0 commit comments

Comments
 (0)