Skip to content

Commit 3458a69

Browse files
committed
Don't import Test for helper
This triggers the remaining tests in pytest
1 parent 9de97c6 commit 3458a69

File tree

3 files changed

+31
-31
lines changed

3 files changed

+31
-31
lines changed

tests/link/jax/test_tensor_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.graph.op import get_test_value
1515
from pytensor.tensor.type import iscalar, matrix, scalar, vector
1616
from tests.link.jax.test_basic import compare_jax_and_py
17-
from tests.tensor.test_basic import TestAlloc
17+
from tests.tensor.test_basic import check_alloc_runtime_broadcast
1818

1919

2020
def test_jax_Alloc():
@@ -54,7 +54,7 @@ def compare_shape_dtype(x, y):
5454

5555

5656
def test_alloc_runtime_broadcast():
57-
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
57+
check_alloc_runtime_broadcast(get_mode("JAX"))
5858

5959

6060
def test_jax_MakeVector():

tests/link/numba/test_tensor_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
compare_shape_dtype,
1717
set_test_value,
1818
)
19-
from tests.tensor.test_basic import TestAlloc
19+
from tests.tensor.test_basic import check_alloc_runtime_broadcast
2020

2121

2222
pytest.importorskip("numba")
@@ -52,7 +52,7 @@ def test_Alloc(v, shape):
5252

5353

5454
def test_alloc_runtime_broadcast():
55-
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
55+
check_alloc_runtime_broadcast(get_mode("NUMBA"))
5656

5757

5858
def test_AllocEmpty():

tests/tensor/test_basic.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,32 @@ def test_masked_array_not_implemented(
717717
ptb.as_tensor(x)
718718

719719

720+
def check_alloc_runtime_broadcast(mode):
721+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
722+
floatX = config.floatX
723+
x_v = vector("x", shape=(None,))
724+
725+
out = alloc(x_v, 5, 3)
726+
f = pytensor.function([x_v], out, mode=mode)
727+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
728+
729+
np.testing.assert_array_equal(
730+
f(x=np.zeros((3,), dtype=floatX)),
731+
np.zeros((5, 3), dtype=floatX),
732+
)
733+
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
734+
f(x=np.zeros((1,), dtype=floatX))
735+
736+
out = alloc(specify_shape(x_v, (1,)), 5, 3)
737+
f = pytensor.function([x_v], out, mode=mode)
738+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
739+
740+
np.testing.assert_array_equal(
741+
f(x=np.zeros((1,), dtype=floatX)),
742+
np.zeros((5, 3), dtype=floatX),
743+
)
744+
745+
720746
class TestAlloc:
721747
dtype = config.floatX
722748
mode = mode_opt
@@ -730,32 +756,6 @@ def check_allocs_in_fgraph(fgraph, n):
730756
== n
731757
)
732758

733-
@staticmethod
734-
def check_runtime_broadcast(mode):
735-
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
736-
floatX = config.floatX
737-
x_v = vector("x", shape=(None,))
738-
739-
out = alloc(x_v, 5, 3)
740-
f = pytensor.function([x_v], out, mode=mode)
741-
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
742-
743-
np.testing.assert_array_equal(
744-
f(x=np.zeros((3,), dtype=floatX)),
745-
np.zeros((5, 3), dtype=floatX),
746-
)
747-
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
748-
f(x=np.zeros((1,), dtype=floatX))
749-
750-
out = alloc(specify_shape(x_v, (1,)), 5, 3)
751-
f = pytensor.function([x_v], out, mode=mode)
752-
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
753-
754-
np.testing.assert_array_equal(
755-
f(x=np.zeros((1,), dtype=floatX)),
756-
np.zeros((5, 3), dtype=floatX),
757-
)
758-
759759
def setup_method(self):
760760
self.rng = np.random.default_rng(seed=utt.fetch_seed())
761761

@@ -912,7 +912,7 @@ def test_alloc_of_view_linker(self):
912912

913913
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
914914
def test_runtime_broadcast(self, mode):
915-
self.check_runtime_broadcast(mode)
915+
check_alloc_runtime_broadcast(mode)
916916

917917

918918
def test_infer_static_shape():

0 commit comments

Comments
 (0)