@@ -717,6 +717,32 @@ def test_masked_array_not_implemented(
717717 ptb .as_tensor (x )
718718
719719
720+ def check_alloc_runtime_broadcast (mode ):
721+ """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
722+ floatX = config .floatX
723+ x_v = vector ("x" , shape = (None ,))
724+
725+ out = alloc (x_v , 5 , 3 )
726+ f = pytensor .function ([x_v ], out , mode = mode )
727+ TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
728+
729+ np .testing .assert_array_equal (
730+ f (x = np .zeros ((3 ,), dtype = floatX )),
731+ np .zeros ((5 , 3 ), dtype = floatX ),
732+ )
733+ with pytest .raises (ValueError , match = "Runtime broadcasting not allowed" ):
734+ f (x = np .zeros ((1 ,), dtype = floatX ))
735+
736+ out = alloc (specify_shape (x_v , (1 ,)), 5 , 3 )
737+ f = pytensor .function ([x_v ], out , mode = mode )
738+ TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
739+
740+ np .testing .assert_array_equal (
741+ f (x = np .zeros ((1 ,), dtype = floatX )),
742+ np .zeros ((5 , 3 ), dtype = floatX ),
743+ )
744+
745+
720746class TestAlloc :
721747 dtype = config .floatX
722748 mode = mode_opt
@@ -730,32 +756,6 @@ def check_allocs_in_fgraph(fgraph, n):
730756 == n
731757 )
732758
733- @staticmethod
734- def check_runtime_broadcast (mode ):
735- """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
736- floatX = config .floatX
737- x_v = vector ("x" , shape = (None ,))
738-
739- out = alloc (x_v , 5 , 3 )
740- f = pytensor .function ([x_v ], out , mode = mode )
741- TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
742-
743- np .testing .assert_array_equal (
744- f (x = np .zeros ((3 ,), dtype = floatX )),
745- np .zeros ((5 , 3 ), dtype = floatX ),
746- )
747- with pytest .raises (ValueError , match = "Runtime broadcasting not allowed" ):
748- f (x = np .zeros ((1 ,), dtype = floatX ))
749-
750- out = alloc (specify_shape (x_v , (1 ,)), 5 , 3 )
751- f = pytensor .function ([x_v ], out , mode = mode )
752- TestAlloc .check_allocs_in_fgraph (f .maker .fgraph , 1 )
753-
754- np .testing .assert_array_equal (
755- f (x = np .zeros ((1 ,), dtype = floatX )),
756- np .zeros ((5 , 3 ), dtype = floatX ),
757- )
758-
759759 def setup_method (self ):
760760 self .rng = np .random .default_rng (seed = utt .fetch_seed ())
761761
@@ -912,7 +912,7 @@ def test_alloc_of_view_linker(self):
912912
913913 @pytest .mark .parametrize ("mode" , (Mode ("py" ), Mode ("c" )))
914914 def test_runtime_broadcast (self , mode ):
915- self . check_runtime_broadcast (mode )
915+ check_alloc_runtime_broadcast (mode )
916916
917917
918918def test_infer_static_shape ():
0 commit comments