Skip to content

Commit ed8b8e9

Browse files
authored
Support negative dimensions on "aten.split_with_sizes_copy.default" (#16006)
### Summary Support a negative dimension for MPS SplitWithSizes implementation. ### Test plan I re-registered this node visitor on a model that previously failed to export with the error message ``` RuntimeError: split_copy: dim -1 out of range for input tensor with 2 dimensions ``` and it succeeded.
1 parent 488d761 commit ed8b8e9

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

backends/apple/mps/operators/shape_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,14 @@ def define_node(
242242
output_ids = self.define_tensor_list(node, mps_graph)
243243
split_sizes = eval_shape(cast(torch.SymInt, node.args[1]))
244244
dim = cast(int, node.args[2])
245+
orig_dim = dim
245246
input_shape = get_shape(get_input_node(node, 0))
247+
if dim < 0:
248+
dim += len(input_shape)
246249

247250
if dim < 0 or dim >= len(input_shape):
248251
raise RuntimeError(
249-
f"split_copy: dim {dim} out of range for input tensor with {len(input_shape)} dimensions"
252+
f"split_copy: dim {orig_dim} out of range for input tensor with {len(input_shape)} dimensions"
250253
)
251254

252255
mps_node = MPSNode(

0 commit comments

Comments
 (0)