Skip to content

Commit c40fdd9

Browse files
committed
.WIP: Tweak Boolean AdvancedSubtensor tests
1 parent 79239cb commit c40fdd9

File tree

1 file changed

+54
-35
lines changed

1 file changed

+54
-35
lines changed

tests/tensor/test_subtensor.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
import logging
22
import re
33
import sys
4+
from contextlib import nullcontext
45
from io import StringIO
56

67
import numpy as np
78
import pytest
89
from numpy.testing import assert_array_equal
9-
from packaging import version
1010

1111
import pytensor
1212
import pytensor.scalar as scal
1313
import pytensor.tensor.basic as ptb
1414
from pytensor import function
1515
from pytensor.compile import DeepCopyOp, shared
1616
from pytensor.compile.io import In
17-
from pytensor.compile.mode import Mode
17+
from pytensor.compile.mode import Mode, get_default_mode
1818
from pytensor.configdefaults import config
1919
from pytensor.gradient import grad
2020
from pytensor.graph import Constant
2121
from pytensor.graph.basic import equal_computations
2222
from pytensor.graph.op import get_test_value
2323
from pytensor.graph.rewriting.utils import is_same_graph
24+
from pytensor.link.numba import NumbaLinker
2425
from pytensor.printing import pprint
2526
from pytensor.scalar.basic import as_scalar, int16
2627
from pytensor.tensor import as_tensor, constant, get_vector_length, vectorize
@@ -368,7 +369,7 @@ def setup_method(self):
368369
"local_replace_AdvancedSubtensor",
369370
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1",
370371
"local_useless_subtensor",
371-
)
372+
).excluding("bool_idx_to_nonzero")
372373
self.fast_compile = config.mode == "FAST_COMPILE"
373374

374375
def function(
@@ -755,36 +756,46 @@ def numpy_inc_subtensor(x, idx, a):
755756
test_array[mask].eval()
756757
with pytest.raises(IndexError):
757758
inc_subtensor(test_array[mask], 1).eval()
758-
# - too large, padded with False (this works in NumPy < 0.13.0)
759+
# - too large, padded with False
760+
# When padded with False converting boolean to nonzero() will not fail
761+
# We exclude that rewrite by excluding `shape_unsafe` more generally
762+
# However numba doesn't enforce masked array sizes: https://github.com/numba/numba/issues/10374
763+
# So the tests that use numba native impl will not fail.
764+
shape_safe_mode = get_default_mode().excluding("shape_unsafe")
765+
linker_dependent_expectation = (
766+
nullcontext()
767+
if isinstance(get_default_mode().linker, NumbaLinker)
768+
else pytest.raises(IndexError)
769+
)
759770
mask = np.array([True, False, False])
760-
# with pytest.raises(IndexError):
761-
test_array[mask].eval()
762-
# with pytest.raises(IndexError):
763-
test_array[mask, ...].eval()
764-
# with pytest.raises(IndexError):
765-
inc_subtensor(test_array[mask], 1).eval()
766-
# with pytest.raises(IndexError):
767-
inc_subtensor(test_array[mask, ...], 1).eval()
771+
with linker_dependent_expectation:
772+
test_array[mask].eval(mode=shape_safe_mode)
773+
with linker_dependent_expectation:
774+
test_array[mask, ...].eval(mode=shape_safe_mode)
775+
with linker_dependent_expectation:
776+
inc_subtensor(test_array[mask], 1).eval(mode=shape_safe_mode)
777+
with linker_dependent_expectation:
778+
inc_subtensor(test_array[mask, ...], 1).eval(mode=shape_safe_mode)
768779
mask = np.array([[True, False, False, False], [False, True, False, False]])
769-
# with pytest.raises(IndexError):
770-
test_array[mask].eval()
771-
# with pytest.raises(IndexError):
772-
inc_subtensor(test_array[mask], 1).eval()
780+
with pytest.raises(IndexError):
781+
test_array[mask].eval(mode=shape_safe_mode)
782+
with pytest.raises(IndexError):
783+
inc_subtensor(test_array[mask], 1).eval(mode=shape_safe_mode)
773784
# - mask too small
774785
mask = np.array([True])
775-
# with pytest.raises(IndexError):
776-
test_array[mask].eval()
777-
# with pytest.raises(IndexError):
778-
test_array[mask, ...].eval()
779-
# with pytest.raises(IndexError):
780-
inc_subtensor(test_array[mask], 1).eval()
781-
# with pytest.raises(IndexError):
782-
inc_subtensor(test_array[mask, ...], 1).eval()
786+
with linker_dependent_expectation:
787+
test_array[mask].eval(mode=shape_safe_mode)
788+
with linker_dependent_expectation:
789+
test_array[mask, ...].eval(mode=shape_safe_mode)
790+
with linker_dependent_expectation:
791+
inc_subtensor(test_array[mask], 1).eval(mode=shape_safe_mode)
792+
with linker_dependent_expectation:
793+
inc_subtensor(test_array[mask, ...], 1).eval(mode=shape_safe_mode)
783794
mask = np.array([[True], [True]])
784-
# with pytest.raises(IndexError):
785-
test_array[mask].eval()
786-
# with pytest.raises(IndexError):
787-
inc_subtensor(test_array[mask], 1).eval()
795+
with pytest.raises(IndexError):
796+
test_array[mask].eval(mode=shape_safe_mode)
797+
with pytest.raises(IndexError):
798+
inc_subtensor(test_array[mask], 1).eval(mode=shape_safe_mode)
788799
# - too many dimensions
789800
mask = np.array([[[True, False, False], [False, True, False]]])
790801
with pytest.raises(IndexError):
@@ -1348,10 +1359,6 @@ def test_advanced1_inc_and_set(self):
13481359
# you enable the debug code above.
13491360
assert np.allclose(f_out, output_num), (params, f_out, output_num)
13501361

1351-
@pytest.mark.skipif(
1352-
version.parse(np.__version__) < version.parse("2.0"),
1353-
reason="Legacy C-implementation did not check for runtime broadcast",
1354-
)
13551362
@pytest.mark.parametrize("func", (advanced_inc_subtensor1, advanced_set_subtensor1))
13561363
def test_advanced1_inc_runtime_broadcast(self, func):
13571364
y = matrix("y", dtype="float64", shape=(None, None))
@@ -1366,10 +1373,20 @@ def test_advanced1_inc_runtime_broadcast(self, func):
13661373
"(Runtime broadcasting not allowed\\. AdvancedIncSubtensor1 was asked"
13671374
"|The number of indices and values must match)"
13681375
)
1369-
with pytest.raises(ValueError, match=err_message):
1376+
numba_linker = isinstance(f.maker.linker, NumbaLinker)
1377+
# Numba implementation does not raise for runtime broadcasting
1378+
with (
1379+
nullcontext()
1380+
if numba_linker
1381+
else pytest.raises(ValueError, match=err_message)
1382+
):
13701383
f(np.ones((1, 5)))
1371-
# with pytest.raises(ValueError, match=err_message):
1372-
f(np.ones((20, 1)))
1384+
with (
1385+
nullcontext()
1386+
if numba_linker
1387+
else pytest.raises(ValueError, match=err_message)
1388+
):
1389+
f(np.ones((20, 1)))
13731390

13741391
def test_adv_constant_arg(self):
13751392
# Test case provided (and bug detected, gh-607) by John Salvatier
@@ -2391,6 +2408,8 @@ def test_boolean_scalar_raises(self):
23912408

23922409

23932410
class TestInferShape(utt.InferShapeTester):
2411+
mode = get_default_mode().excluding("bool_idx_to_nonzero")
2412+
23942413
@staticmethod
23952414
def random_bool_mask(shape, rng=None):
23962415
if rng is None:

0 commit comments

Comments
 (0)