88from pytensor .compile import PYTORCH
99from pytensor .compile .builders import OpFromGraph
1010from pytensor .compile .ops import DeepCopyOp
11+ from pytensor .graph .basic import Constant
1112from pytensor .graph .fg import FunctionGraph
1213from pytensor .ifelse import IfElse
1314from pytensor .link .utils import fgraph_to_python
1920 Eye ,
2021 Join ,
2122 MakeVector ,
23+ Split ,
2224 TensorFromScalar ,
2325)
2426
@@ -120,14 +122,23 @@ def arange(start, stop, step):
120122
121123
122124@pytorch_funcify .register (Join )
123- def pytorch_funcify_Join (op , ** kwargs ):
124- def join (axis , * tensors ):
125- # tensors could also be tuples, and in this case they don't have a ndim
126- tensors = [torch .tensor (tensor ) for tensor in tensors ]
125+ def pytorch_funcify_Join (op , node , ** kwargs ):
126+ axis = node .inputs [0 ]
127127
128- return torch .cat (tensors , dim = axis )
128+ if isinstance (axis , Constant ):
129+ axis = int (axis .data )
129130
130- return join
131+ def join_constant_axis (_ , * tensors ):
132+ return torch .cat (tensors , dim = axis )
133+
134+ return join_constant_axis
135+
136+ else :
137+
138+ def join (axis , * tensors ):
139+ return torch .cat (tensors , dim = axis )
140+
141+ return join
131142
132143
133144@pytorch_funcify .register (Eye )
@@ -172,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
172183@pytorch_funcify .register (OpFromGraph )
173184def pytorch_funcify_OpFromGraph (op , node , ** kwargs ):
174185 kwargs .pop ("storage_map" , None )
175-
176186 # Apply inner rewrites
177187 PYTORCH .optimizer (op .fgraph )
178188 fgraph_fn = pytorch_funcify (op .fgraph , ** kwargs , squeeze_output = True )
@@ -185,3 +195,23 @@ def tensorfromscalar(x):
185195 return torch .as_tensor (x )
186196
187197 return tensorfromscalar
198+
199+
200+ @pytorch_funcify .register (Split )
201+ def pytorch_funcify_Split (op , node , ** kwargs ):
202+ x , dim , split_sizes = node .inputs
203+ if isinstance (dim , Constant ) and isinstance (split_sizes , Constant ):
204+ dim = int (dim .data )
205+ split_sizes = tuple (int (size ) for size in split_sizes .data )
206+
207+ def split_constant_axis_and_sizes (x , * _ ):
208+ return x .split (split_sizes , dim = dim )
209+
210+ return split_constant_axis_and_sizes
211+
212+ else :
213+
214+ def inner_fn (x , dim , split_amounts ):
215+ return x .split (split_amounts .tolist (), dim = dim .item ())
216+
217+ return inner_fn
0 commit comments