Skip to content

Commit 79239cb

Browse files
committed
bool_idx_to_nonzero rewrite...
1 parent d507743 commit 79239cb

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
IncSubtensor,
3131
Subtensor,
3232
)
33-
from pytensor.tensor.type_other import MakeSlice
33+
from pytensor.tensor.type_other import MakeSlice, NoneTypeT
3434

3535

3636
def slice_new(self, start, stop, step):
@@ -270,6 +270,8 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
270270
not must_ignore_duplicates
271271
and len(adv_idxs) >= 1
272272
and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs)
273+
# Implementation does not support newaxis
274+
and not any(isinstance(idx.type, NoneTypeT) for idx in idxs)
273275
):
274276
return vector_integer_advanced_indexing(op, node, **kwargs)
275277

@@ -310,6 +312,8 @@ def vector_integer_advanced_indexing(
310312
):
311313
"""Implement all forms of advanced indexing (and assignment) that combine basic and vector integer indices.
312314
315+
It does not support `newaxis` in basic indices
316+
313317
It handles += like `np.add.at` would, accumulating add for duplicate indices.
314318
315319
Examples

pytensor/tensor/rewriting/subtensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,7 @@ def bool_idx_to_nonzero(fgraph, node):
17821782
bool_idx_to_nonzero.__name__,
17831783
bool_idx_to_nonzero,
17841784
"numba",
1785+
"shape_unsafe", # It can mask invalid mask sizes
17851786
use_db_name_as_tag=False, # Not included if only "specialize" is requested
17861787
)
17871788

0 commit comments

Comments
 (0)