Skip to content

Commit d507743

Browse files
committed
Fix AdvancedSubtensor static shape with newaxis
1 parent 6d83c24 commit d507743

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2630,7 +2630,7 @@ def make_node(self, x, *indices):
26302630
adv_group_axis = None
26312631
last_adv_group_axis = None
26322632
expanded_x_shape = tuple(
2633-
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes)
2633+
np.insert(np.array(x.type.shape, dtype=object), new_axes, values=1)
26342634
)
26352635
for i, (idx, dim_length) in enumerate(
26362636
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)

tests/tensor/test_subtensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,6 +1854,7 @@ def test_static_shape(self):
18541854

18551855
assert x[idx1].type.shape == (10, None)
18561856
assert x[:, idx1].type.shape == (None, 10)
1857+
assert x[None, :, idx1].type.shape == (1, None, 10)
18571858
assert x[idx2, :5].type.shape == (3, None, None)
18581859
assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5)
18591860
assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)

0 commit comments

Comments
 (0)