1+ import functools
12from typing import List , Tuple
23
34import numpy as np
45
5- from pytensor import Variable , as_symbolic
6+ from pytensor import Variable , as_symbolic , clone_replace
67from pytensor .graph import FunctionGraph
8+ from pytensor .graph .basic import Constant , truncated_graph_inputs
79from pytensor .loop .op import Scan
810from pytensor .scan .utils import until
9- from pytensor .tensor import as_tensor , empty_like
11+ from pytensor .tensor import as_tensor , constant , empty_like , minimum
1012
1113
1214def scan (
@@ -20,6 +22,8 @@ def scan(
2022 if sequences is None and n_steps is None :
2123 raise ValueError ("Must provide n_steps when scanning without sequences" )
2224
25+ # TODO: init_states should be made opaque to the inner function,
26+ # since any relationship to the outer graph no longer holds
2327 if init_states is None :
2428 init_states = []
2529 else :
@@ -34,20 +38,31 @@ def scan(
3438 sequences = [sequences ]
3539 sequences = [as_tensor (s ) for s in sequences ]
3640
41+ if sequences :
42+ leading_dims = [seq .shape [0 ] for seq in sequences ]
43+ shortest_dim = functools .reduce (minimum , leading_dims )
44+ if n_steps is None :
45+ n_steps = shortest_dim
46+ else :
47+ n_steps = minimum (n_steps , shortest_dim )
48+
3749 if non_sequences is None :
3850 non_sequences = []
3951 else :
4052 if not isinstance (non_sequences , (tuple , list )):
4153 non_sequences = [non_sequences ]
4254 non_sequences = [as_symbolic (n ) for n in non_sequences ]
4355
56+ # Create subsequence inputs for the inner function
57+ idx = constant (0 , dtype = "int64" , name = "idx" )
58+ symbolic_idx = idx .type (name = "idx" )
59+ subsequences = [s [symbolic_idx ] for s in sequences ]
4460 # Note: Old scan order is sequences + init + non_sequences
45- inner_sequences = [s [0 ] for s in sequences ]
46- inner_inputs = [i .type () for i in init_states + inner_sequences + non_sequences ]
47- inner_outputs = fn (* inner_inputs )
48- if not isinstance (inner_outputs , (tuple , list )):
49- inner_outputs = [inner_outputs ]
50- next_states = [out for out in inner_outputs if not isinstance (out , until )]
61+ fn_inputs = init_states + subsequences + non_sequences
62+ fn_outputs = fn (* fn_inputs )
63+ if not isinstance (fn_outputs , (tuple , list )):
64+ fn_outputs = [fn_outputs ]
65+ next_states = [out for out in fn_outputs if not isinstance (out , until )]
5166
5267 if len (next_states ) > len (init_states ):
5368 if not init_states :
@@ -61,27 +76,45 @@ def scan(
6176 prev_states = []
6277 for i , (init_state , next_state ) in enumerate (zip (init_states , next_states )):
6378 if init_state is None :
79+ # next_state may reference idx, let's replace that by the initial value
80+ [next_state ] = clone_replace (
81+ output = [next_state ], replace = {symbolic_idx : idx }
82+ )
6483 init_state = empty_like (next_state )
65- init_state .name = "empty_init_state"
66- inner_inputs .insert (i , init_state .type ())
84+ init_state .name = (
85+ "empty_init_state" # add 1 offset, since idx is the first state
86+ )
6787 prev_states .append (init_state )
6888
69- until_condition = [out .condition for out in inner_outputs if isinstance (out , until )]
89+ until_condition = [out .condition for out in fn_outputs if isinstance (out , until )]
7090 if not until_condition :
7191 until_condition = [as_tensor (np .array (True ))]
7292 if len (until_condition ) > 1 :
7393 raise ValueError ("Only one until condition can be returned" )
7494
75- update_fg = FunctionGraph (
76- inputs = inner_inputs , outputs = until_condition + next_states
95+ fgraph_inputs = [symbolic_idx ] + prev_states + sequences + non_sequences
96+ fgraph_outputs = until_condition + [symbolic_idx + 1 ] + next_states
97+
98+ all_fgraph_inputs = truncated_graph_inputs (
99+ fgraph_outputs , ancestors_to_include = fgraph_inputs
100+ )
101+ extra_fgraph_inputs = [
102+ inp
103+ for inp in all_fgraph_inputs
104+ if (not isinstance (inp , Constant ) and inp not in fgraph_inputs )
105+ ]
106+ fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
107+ update_fg = FunctionGraph (inputs = fgraph_inputs , outputs = fgraph_outputs )
108+
109+ scan_op = Scan (update_fg = update_fg )
110+ scan_outs = scan_op (
111+ n_steps , idx , * prev_states , * sequences , * non_sequences , * extra_fgraph_inputs
77112 )
78- scan_op = Scan (update_fg = update_fg , n_sequences = len (sequences ))
79- scan_outs = scan_op (n_steps , * prev_states , * sequences , * non_sequences )
80113 assert isinstance (scan_outs , list )
81114 last_states = scan_outs [: scan_op .n_states ]
82115 traces = scan_outs [scan_op .n_states :]
83-
84- return last_states , traces
116+ # Don't return the inner index state
117+ return last_states [ 1 :] , traces [ 1 :]
85118
86119
87120def map (
0 commit comments