@@ -759,40 +759,43 @@ def check_allocs_in_fgraph(fgraph, n):
759759 def setup_method (self ):
760760 self .rng = np .random .default_rng (seed = utt .fetch_seed ())
761761
762- def test_alloc_constant_folding (self ):
762+ @pytest .mark .parametrize (
763+ "subtensor_fn, expected_grad_n_alloc" ,
764+ [
765+ # IncSubtensor1
766+ (lambda x : x [:60 ], 1 ),
767+ # AdvancedIncSubtensor1
768+ (lambda x : x [np .arange (60 )], 1 ),
769+ # AdvancedIncSubtensor
770+ (lambda x : x [np .arange (50 ), np .arange (50 )], 1 ),
771+ ],
772+ )
773+ def test_alloc_constant_folding (self , subtensor_fn , expected_grad_n_alloc ):
763774 test_params = np .asarray (self .rng .standard_normal (50 * 60 ), self .dtype )
764775
765776 some_vector = vector ("some_vector" , dtype = self .dtype )
766777 some_matrix = some_vector .reshape ((60 , 50 ))
767778 variables = self .shared (np .ones ((50 ,), dtype = self .dtype ))
768- idx = constant (np .arange (50 ))
769779
770- for alloc_ , (subtensor , n_alloc ) in zip (
771- self .allocs ,
772- [
773- # IncSubtensor1
774- (some_matrix [:60 ], 2 ),
775- # AdvancedIncSubtensor1
776- (some_matrix [arange (60 )], 2 ),
777- # AdvancedIncSubtensor
778- (some_matrix [idx , idx ], 1 ),
779- ],
780- ):
781- derp = pt_sum (dense_dot (subtensor , variables ))
780+ subtensor = subtensor_fn (some_matrix )
782781
783- fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
784- grad_derp = pytensor .grad (derp , some_vector )
785- fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
786-
787- topo_obj = fobj .maker .fgraph .toposort ()
788- assert sum (isinstance (node .op , type (alloc_ )) for node in topo_obj ) == 0
782+ derp = pt_sum (dense_dot (subtensor , variables ))
783+ fobj = pytensor .function ([some_vector ], derp , mode = self .mode )
784+ assert (
785+ sum (isinstance (node .op , Alloc ) for node in fobj .maker .fgraph .apply_nodes )
786+ == 0
787+ )
788+ # TODO: Assert something about the value if we bothered to call it?
789+ fobj (test_params )
789790
790- topo_grad = fgrad .maker .fgraph .toposort ()
791- assert (
792- sum (isinstance (node .op , type (alloc_ )) for node in topo_grad ) == n_alloc
793- ), (alloc_ , subtensor , n_alloc , topo_grad )
794- fobj (test_params )
795- fgrad (test_params )
791+ grad_derp = pytensor .grad (derp , some_vector )
792+ fgrad = pytensor .function ([some_vector ], grad_derp , mode = self .mode )
793+ assert (
794+ sum (isinstance (node .op , Alloc ) for node in fgrad .maker .fgraph .apply_nodes )
795+ == expected_grad_n_alloc
796+ )
797+ # TODO: Assert something about the value if we bothered to call it?
798+ fgrad (test_params )
796799
797800 def test_alloc_output (self ):
798801 val = constant (self .rng .standard_normal ((1 , 1 )), dtype = self .dtype )
0 commit comments