File tree Expand file tree Collapse file tree 1 file changed +10
-4
lines changed
pytensor/link/numba/dispatch Expand file tree Collapse file tree 1 file changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -139,10 +139,16 @@ def join(axis, *tensors):
139139@register_funcify_default_op_cache_key (Split )
140140def numba_funcify_Split (op , ** kwargs ):
141141 @numba_basic .numba_njit
142- def split (tensor , axis , indices ):
143- return np .split (tensor , np .cumsum (indices )[:- 1 ], axis = axis .item ())
144-
145- return split
142+ def split (tensor , axis , sizes ):
143+ if (sizes < 0 ).any ():
144+ raise ValueError ("Split sizes must be non-negative" )
145+ axis = axis .item ()
146+ split_indices = np .cumsum (sizes )
147+ if split_indices [- 1 ] != tensor .shape [axis ]:
148+ raise ValueError ("Split sizes do not add up to the length of the tensor." )
149+ return np .split (tensor , split_indices [:- 1 ], axis = axis )
150+
151+ return split , 1
146152
147153
148154@register_funcify_default_op_cache_key (ExtractDiag )
You can’t perform that action at this time.
0 commit comments