Skip to content

Commit 327cb79

Browse files
committed
Check for valid sizes in numba implementation of Split
1 parent e449e1e commit 327cb79

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,16 @@ def join(axis, *tensors):
139139
@register_funcify_default_op_cache_key(Split)
140140
def numba_funcify_Split(op, **kwargs):
141141
@numba_basic.numba_njit
142-
def split(tensor, axis, indices):
143-
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item())
144-
145-
return split
142+
def split(tensor, axis, sizes):
143+
if (sizes < 0).any():
144+
raise ValueError("Split sizes must be non-negative")
145+
axis = axis.item()
146+
split_indices = np.cumsum(sizes)
147+
if split_indices[-1] != tensor.shape[axis]:
148+
raise ValueError("Split sizes do not add up to the length of the tensor.")
149+
return np.split(tensor, split_indices[:-1], axis=axis)
150+
151+
return split, 1
146152

147153

148154
@register_funcify_default_op_cache_key(ExtractDiag)

0 commit comments

Comments
 (0)