1+ import functools
12from typing import List , Tuple
23
34import numpy as np
45
5- from pytensor import Variable , as_symbolic
6+ from pytensor import Variable , as_symbolic , clone_replace
67from pytensor .graph import FunctionGraph
8+ from pytensor .graph .basic import Constant , truncated_graph_inputs
79from pytensor .loop .op import Scan
810from pytensor .scan .utils import until
9- from pytensor .tensor import as_tensor , empty_like
11+ from pytensor .tensor import as_tensor , constant , empty_like , minimum
1012
1113
1214def scan (
@@ -20,6 +22,8 @@ def scan(
2022 if sequences is None and n_steps is None :
2123 raise ValueError ("Must provide n_steps when scanning without sequences" )
2224
25+ # TODO: init_states should be made opaque to the inner function,
26+ # since any relationship to the outer graph no longer holds
2327 if init_states is None :
2428 init_states = []
2529 else :
@@ -34,20 +38,31 @@ def scan(
3438 sequences = [sequences ]
3539 sequences = [as_tensor (s ) for s in sequences ]
3640
41+ if sequences :
42+ leading_dims = [seq .shape [0 ] for seq in sequences ]
43+ shortest_dim = functools .reduce (minimum , leading_dims )
44+ if n_steps is None :
45+ n_steps = shortest_dim
46+ else :
47+ n_steps = minimum (n_steps , shortest_dim )
48+
3749 if non_sequences is None :
3850 non_sequences = []
3951 else :
4052 if not isinstance (non_sequences , (tuple , list )):
4153 non_sequences = [non_sequences ]
4254 non_sequences = [as_symbolic (n ) for n in non_sequences ]
4355
56+ # Create subsequence inputs for the inner function
57+ idx = constant (0 , dtype = "int64" , name = "idx" )
58+ symbolic_idx = idx .type (name = "idx" )
59+ subsequences = [s [symbolic_idx ] for s in sequences ]
4460 # Note: Old scan order is sequences + init + non_sequences
45- inner_sequences = [s [0 ] for s in sequences ]
46- inner_inputs = [i .type () for i in init_states + inner_sequences + non_sequences ]
47- inner_outputs = fn (* inner_inputs )
48- if not isinstance (inner_outputs , (tuple , list )):
49- inner_outputs = [inner_outputs ]
50- next_states = [out for out in inner_outputs if not isinstance (out , until )]
61+ fn_inputs = init_states + subsequences + non_sequences
62+ fn_outputs = fn (* fn_inputs )
63+ if not isinstance (fn_outputs , (tuple , list )):
64+ fn_outputs = [fn_outputs ]
65+ next_states = [out for out in fn_outputs if not isinstance (out , until )]
5166
5267 if len (next_states ) > len (init_states ):
5368 if not init_states :
@@ -61,27 +76,43 @@ def scan(
6176 prev_states = []
6277 for i , (init_state , next_state ) in enumerate (zip (init_states , next_states )):
6378 if init_state is None :
79+ # next_state may reference idx, let's replace that by the initial value
80+ [next_state ] = clone_replace (
81+ output = [next_state ], replace = {symbolic_idx : idx }
82+ )
6483 init_state = empty_like (next_state )
6584 init_state .name = "empty_init_state"
66- inner_inputs .insert (i , init_state .type ())
6785 prev_states .append (init_state )
6886
69- until_condition = [out .condition for out in inner_outputs if isinstance (out , until )]
87+ until_condition = [out .condition for out in fn_outputs if isinstance (out , until )]
7088 if not until_condition :
7189 until_condition = [as_tensor (np .array (True ))]
7290 if len (until_condition ) > 1 :
7391 raise ValueError ("Only one until condition can be returned" )
7492
75- update_fg = FunctionGraph (
76- inputs = inner_inputs , outputs = until_condition + next_states
93+ fgraph_inputs = [symbolic_idx ] + prev_states + sequences + non_sequences
94+ fgraph_outputs = until_condition + [symbolic_idx + 1 ] + next_states
95+
96+ all_fgraph_inputs = truncated_graph_inputs (
97+ fgraph_outputs , ancestors_to_include = fgraph_inputs
98+ )
99+ extra_fgraph_inputs = [
100+ inp
101+ for inp in all_fgraph_inputs
102+ if (not isinstance (inp , Constant ) and inp not in fgraph_inputs )
103+ ]
104+ fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
105+ update_fg = FunctionGraph (inputs = fgraph_inputs , outputs = fgraph_outputs )
106+
107+ scan_op = Scan (update_fg = update_fg )
108+ scan_outs = scan_op (
109+ n_steps , idx , * prev_states , * sequences , * non_sequences , * extra_fgraph_inputs
77110 )
78- scan_op = Scan (update_fg = update_fg , n_sequences = len (sequences ))
79- scan_outs = scan_op (n_steps , * prev_states , * sequences , * non_sequences )
80111 assert isinstance (scan_outs , list )
81112 last_states = scan_outs [: scan_op .n_states ]
82113 traces = scan_outs [scan_op .n_states :]
83-
84- return last_states , traces
114+ # Don't return the inner index state
115+ return last_states [ 1 :] , traces [ 1 :]
85116
86117
87118def map (
0 commit comments