@@ -738,41 +738,43 @@ def check_allocs_in_fgraph(fgraph, n):
738738 def setup_method (self ):
739739 self .rng = np .random .default_rng (seed = utt .fetch_seed ())
740740
741- def test_alloc_constant_folding (self ):
741+ @pytest .mark .parametrize (
742+ "subtensor_fn, expected_grad_n_alloc" ,
743+ [
744+ # IncSubtensor1
745+ (lambda x : x [:60 ], 1 ),
746+ # AdvancedIncSubtensor1
747+ (lambda x : x [np .arange (60 )], 1 ),
748+ # AdvancedIncSubtensor
749+ (lambda x : x [np .arange (50 ), np .arange (50 )], 1 ),
750+ ],
751+ )
752+ def test_alloc_constant_folding (self , subtensor_fn , expected_grad_n_alloc ):
742753 test_params = np .asarray (self .rng .standard_normal (50 * 60 ), self .dtype )
743754
744755 some_vector = vector ("some_vector" , dtype = self .dtype )
745756 some_matrix = some_vector .reshape ((60 , 50 ))
746757 variables = self .shared (np .ones ((50 ,), dtype = self .dtype ))
747- idx = constant (np .arange (50 ))
748-
749- for alloc_ , (subtensor , n_alloc ) in zip (
750- self .allocs ,
751- [
752- # IncSubtensor1
753- (some_matrix [:60 ], 2 ),
754- # AdvancedIncSubtensor1
755- (some_matrix [arange (60 )], 2 ),
756- # AdvancedIncSubtensor
757- (some_matrix [idx , idx ], 1 ),
758- ],
759- strict = True ,
760- ):
761- derp = pt_sum (dense_dot (subtensor , variables ))
762758
763- fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
764- grad_derp = pytensor .grad (derp , some_vector )
765- fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
759+ subtensor = subtensor_fn (some_matrix )
766760
767- topo_obj = fobj .maker .fgraph .toposort ()
768- assert sum (isinstance (node .op , type (alloc_ )) for node in topo_obj ) == 0
761+ derp = pt_sum (dense_dot (subtensor , variables ))
762+ fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
763+ assert (
764+ sum (isinstance (node .op , Alloc ) for node in fobj .maker .fgraph .apply_nodes )
765+ == 0
766+ )
767+ # TODO: Assert something about the value if we bothered to call it?
768+ fobj (test_params )
769769
770- topo_grad = fgrad .maker .fgraph .toposort ()
771- assert (
772- sum (isinstance (node .op , type (alloc_ )) for node in topo_grad ) == n_alloc
773- ), (alloc_ , subtensor , n_alloc , topo_grad )
774- fobj (test_params )
775- fgrad (test_params )
770+ grad_derp = pytensor .grad (derp , some_vector )
771+ fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
772+ assert (
773+ sum (isinstance (node .op , Alloc ) for node in fgrad .maker .fgraph .apply_nodes )
774+ == expected_grad_n_alloc
775+ )
776+ # TODO: Assert something about the value if we bothered to call it?
777+ fgrad (test_params )
776778
777779 def test_alloc_output (self ):
778780 val = constant (self .rng .standard_normal ((1 , 1 )), dtype = self .dtype )
0 commit comments