99from pytensor .compile .mode import get_default_mode
1010from pytensor .configdefaults import config
1111from pytensor .gradient import grad , jacobian
12- from pytensor .graph .basic import equal_computations
12+ from pytensor .graph .basic import Constant , equal_computations
1313from pytensor .graph .fg import FunctionGraph
1414from pytensor .graph .replace import clone_replace
1515from pytensor .scan .op import Scan
@@ -1208,7 +1208,7 @@ def test_inplace3(self):
12081208
12091209
12101210class TestSaveMem :
1211- mode = get_default_mode ().including ("scan_save_mem" , "scan_save_mem" )
1211+ mode = get_default_mode ().including ("scan_save_mem" )
12121212
12131213 def test_save_mem (self ):
12141214 rng = np .random .default_rng (utt .fetch_seed ())
@@ -1295,11 +1295,27 @@ def f_rnn(u_t):
12951295 [x1 [:2 ], x2 [4 ], x3 [idx ], x4 [:idx ], x5 [- 10 ], x6 [- jdx ], x7 [:- jdx ]],
12961296 updates = updates ,
12971297 allow_input_downcast = True ,
1298- mode = self .mode ,
1298+ mode = self .mode . excluding ( "scan_push_out_seq" ) ,
12991299 )
1300+ # Check we actually have a Scan in the compiled function
1301+ [scan_node ] = [
1302+ node for node in f2 .maker .fgraph .toposort () if isinstance (node .op , Scan )
1303+ ]
1304+
13001305 # get random initial values
13011306 rng = np .random .default_rng (utt .fetch_seed ())
1302- v_u = rng .uniform (- 5.0 , 5.0 , size = (20 ,))
1307+ v_u = rng .uniform (- 5.0 , 5.0 , size = (20 ,)).astype (u .type .dtype )
1308+
1309+ # Check the number of steps is actually reduced from 20
1310+ n_steps = scan_node .inputs [0 ]
1311+ n_steps_fn = pytensor .function (
1312+ [u , idx , jdx ], n_steps , accept_inplace = True , on_unused_input = "ignore"
1313+ )
1314+ assert n_steps_fn (u = v_u , idx = 3 , jdx = 15 ) == 11 # x5[const=-10] requires 11 steps
1315+ assert n_steps_fn (u = v_u , idx = 3 , jdx = 3 ) == 18 # x6[jdx=-3] requires 18 steps
1316+ assert n_steps_fn (u = v_u , idx = 16 , jdx = 15 ) == 17 # x3[idx=16] requires 17 steps
1317+ assert n_steps_fn (u = v_u , idx = - 5 , jdx = 15 ) == 16 # x3[idx=-5] requires 16 steps
1318+ assert n_steps_fn (u = v_u , idx = 19 , jdx = 15 ) == 20 # x3[idx=19] requires 20 steps
13031319
13041320 # compute the output in numpy
13051321 tx1 , tx2 , tx3 , tx4 , tx5 , tx6 , tx7 = f2 (v_u , 3 , 15 )
@@ -1312,6 +1328,49 @@ def f_rnn(u_t):
13121328 utt .assert_allclose (tx6 , v_u [- 15 ] + 6.0 )
13131329 utt .assert_allclose (tx7 , v_u [:- 15 ] + 7.0 )
13141330
1331+ def test_save_mem_reduced_number_of_steps_constant (self ):
1332+ x0 = pt .scalar ("x0" )
1333+ xs , _ = scan (
1334+ lambda xtm1 : xtm1 + 1 ,
1335+ outputs_info = [x0 ],
1336+ n_steps = 10 ,
1337+ )
1338+
1339+ fn = function ([x0 ], xs [:5 ], mode = self .mode )
1340+ [scan_node ] = [
1341+ node for node in fn .maker .fgraph .toposort () if isinstance (node .op , Scan )
1342+ ]
1343+ n_steps = scan_node .inputs [0 ]
1344+ assert isinstance (n_steps , Constant ) and n_steps .data == 5
1345+
1346+ np .testing .assert_allclose (fn (0 ), np .arange (1 , 11 )[:5 ])
1347+
1348+ def test_save_mem_cannot_reduce_constant_number_of_steps (self ):
1349+ x0 = pt .scalar ("x0" )
1350+ [xs , ys ], _ = scan (
1351+ lambda xtm1 , ytm1 : (xtm1 + 1 , ytm1 - 1 ),
1352+ outputs_info = [x0 , x0 ],
1353+ n_steps = 10 ,
1354+ )
1355+
1356+ # Because of ys[-1] we need all the steps!
1357+ fn = function ([x0 ], [xs [:5 ], ys [- 1 ]], mode = self .mode )
1358+ [scan_node ] = [
1359+ node for node in fn .maker .fgraph .toposort () if isinstance (node .op , Scan )
1360+ ]
1361+ n_steps = scan_node .inputs [0 ]
1362+ assert isinstance (n_steps , Constant ) and n_steps .data == 10
1363+
1364+ res_x , res_y = fn (0 )
1365+ np .testing .assert_allclose (
1366+ res_x ,
1367+ np .arange (1 , 11 )[:5 ],
1368+ )
1369+ np .testing .assert_allclose (
1370+ res_y ,
1371+ - np .arange (1 , 11 )[- 1 ],
1372+ )
1373+
13151374 def test_save_mem_store_steps (self ):
13161375 def f_rnn (u_t , x1_tm1 , x1_tm3 , x2_tm1 , x3tm2 , x3_tm1 , x4_tm1 ):
13171376 return (
0 commit comments