Skip to content

Commit 7bad878

Browse files
committed
Add custom mask to gather/scatter
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent 2ac37c2 commit 7bad878

4 files changed

Lines changed: 285 additions & 18 deletions

File tree

changelog.d/gather-scatter-mask.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added optional `mask` parameter to `ct.gather()` and `ct.scatter()` for custom boolean masking.

src/cuda/tile/_ir/ops.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2340,9 +2340,9 @@ def pointer_offset(pointer: Var, offset: Var) -> Var:
23402340

23412341

23422342
@impl(ct.gather)
2343-
def gather_impl(array: Var, indices: Var, padding_value: Var, check_bounds: Var,
2344-
latency: Var) -> Var:
2345-
pointer, mask = _gather_scatter_pointer_and_mask(array, indices, check_bounds)
2343+
def gather_impl(array: Var, indices: Var, mask: Var, padding_value: Var,
2344+
check_bounds: Var, latency: Var) -> Var:
2345+
pointer, final_mask = _gather_scatter_pointer_and_mask(array, indices, check_bounds, mask)
23462346
pointer_ty = pointer.get_type()
23472347
pointer_shape = pointer_ty.shape_value if isinstance(pointer_ty, TileTy) else ()
23482348

@@ -2360,12 +2360,13 @@ def gather_impl(array: Var, indices: Var, padding_value: Var, check_bounds: Var,
23602360
# Handle the latency hint
23612361
latency = require_optional_constant_int(latency)
23622362
check_load_store_hints(latency)
2363-
return load_pointer(pointer, mask, padding_value, latency)
2363+
return load_pointer(pointer, final_mask, padding_value, latency)
23642364

23652365

23662366
@impl(ct.scatter)
2367-
def scatter_impl(array: Var, indices: Var, value: Var, check_bounds: Var, latency: Var):
2368-
pointer, mask = _gather_scatter_pointer_and_mask(array, indices, check_bounds)
2367+
def scatter_impl(array: Var, indices: Var, value: Var, mask: Var,
2368+
check_bounds: Var, latency: Var):
2369+
pointer, final_mask = _gather_scatter_pointer_and_mask(array, indices, check_bounds, mask)
23692370
pointer_ty = pointer.get_type()
23702371
pointer_shape = pointer_ty.shape_value if isinstance(pointer_ty, TileTy) else ()
23712372

@@ -2377,7 +2378,7 @@ def scatter_impl(array: Var, indices: Var, value: Var, check_bounds: Var, latenc
23772378
latency = require_optional_constant_int(latency)
23782379
check_load_store_hints(latency)
23792380

2380-
store_pointer(pointer, value, mask, latency)
2381+
store_pointer(pointer, value, final_mask, latency)
23812382

23822383

23832384
def _get_scatter_value(value: Var, pointer_shape: Tuple[int, ...], array_dtype: DType,
@@ -2395,9 +2396,52 @@ def _get_scatter_value(value: Var, pointer_shape: Tuple[int, ...], array_dtype:
23952396
return broadcast_to(value, pointer_shape)
23962397

23972398

2398-
def _gather_scatter_pointer_and_mask(array: Var,
2399-
indices: Var,
2400-
check_bounds: Var) -> Tuple[Var, Optional[Var]]:
2399+
def _process_custom_mask(mask: Optional[Var], bounds_mask: Optional[Var],
2400+
pointer_shape: Tuple[int, ...]) -> Optional[Var]:
2401+
"""
2402+
Process and validate the custom mask parameter for gather/scatter operations.
2403+
2404+
Args:
2405+
mask: The user-provided mask (can be Python None or Var containing None)
2406+
bounds_mask: The generated bounds-checking mask based on indices (or None)
2407+
pointer_shape: The target shape that the mask should be broadcast to
2408+
2409+
Returns:
2410+
The final mask to use (custom AND bounds, or just one of them, or None)
2411+
"""
2412+
# Check if mask is None (either Python None or Var containing None)
2413+
if mask is None or (mask.is_constant() and mask.get_constant() is None):
2414+
# No custom mask provided, return the bounds mask
2415+
return bounds_mask
2416+
2417+
# Validate the mask type
2418+
mask_ty = require_tile_or_scalar_type(mask)
2419+
mask_dtype = get_dtype(mask_ty)
2420+
2421+
if not is_boolean(mask_dtype):
2422+
raise TileTypeError(f"Custom mask must have boolean dtype, but got {mask_dtype}")
2423+
2424+
# Check that mask shape is broadcastable
2425+
mask_shape = mask_ty.shape_value if isinstance(mask_ty, TileTy) else ()
2426+
if not is_shape_broadcastable_to(mask_shape, pointer_shape):
2427+
raise TileTypeError(f"Custom mask shape {mask_shape} is not broadcastable"
2428+
f" to the index shape {pointer_shape}")
2429+
2430+
# Broadcast the mask to the pointer shape
2431+
mask = broadcast_to(mask, pointer_shape)
2432+
2433+
# Combine with bounds mask if both exist
2434+
if bounds_mask is None:
2435+
return mask
2436+
else:
2437+
return binary_bitwise("and_", bounds_mask, mask)
2438+
2439+
2440+
def _gather_scatter_pointer_and_mask(
2441+
array: Var,
2442+
indices: Var,
2443+
check_bounds: Var,
2444+
custom_mask: Optional[Var] = None) -> Tuple[Var, Optional[Var]]:
24012445
check_bounds = require_constant_bool(check_bounds)
24022446
array_ty = require_array_type(array)
24032447
indices_ty = require_index_or_index_tuple_type(indices,
@@ -2475,10 +2519,15 @@ def _gather_scatter_pointer_and_mask(array: Var,
24752519
# Offset the base pointer
24762520
if offset is None:
24772521
# 0-D array case
2478-
return array_val.base_ptr, None
2522+
pointer = array_val.base_ptr
2523+
pointer_shape = ()
24792524
else:
24802525
pointer = pointer_offset(array_val.base_ptr, offset)
2481-
return pointer, mask
2526+
pointer_shape = common_shape
2527+
2528+
# Process custom mask and combine with bounds mask
2529+
final_mask = _process_custom_mask(custom_mask, mask, pointer_shape)
2530+
return pointer, final_mask
24822531

24832532

24842533
@memory_effect(MemoryEffect.STORE)

src/cuda/tile/_stub.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,8 @@ def store(array: Array, /,
625625

626626

627627
@function
628-
def gather(array, indices, /, *, padding_value=0, check_bounds=True, latency=None) -> Tile:
628+
def gather(array, indices, /, *, mask=None, padding_value=0, check_bounds=True,
629+
latency=None) -> Tile:
629630
"""
630631
Loads a tile from the `array` elements specified by `indices`.
631632
@@ -651,10 +652,19 @@ def gather(array, indices, /, *, padding_value=0, check_bounds=True, latency=Non
651652
652653
>>> ct.gather(array, ind0) # equivalent to ct.gather(array, (ind0,))
653654
655+
A custom boolean `mask` can be provided to control which elements are loaded.
656+
The mask must be a scalar or a tile whose shape is broadcastable to the common shape
657+
of indices. Where the mask is ``False``, `padding_value` is returned instead of loading
658+
from the array.
659+
654660
`gather()` checks that indices are within the bounds of the array. For indices
655661
that are out of bounds, `padding_value` will be returned (zero by default).
656662
It must be a scalar or a tile whose shape is broadcastable to the common shape of indices.
657663
664+
If both `mask` and `check_bounds=True` are provided, the effective mask is the logical
665+
AND of both the custom mask and the bounds-checking mask. This means an element is only
666+
loaded if both the custom mask is ``True`` AND the index is within bounds.
667+
658668
To disable bounds checking, set `check_bounds` to ``False``.
659669
In this mode, the caller is responsible for ensuring that all indices are within the bounds
660670
of the array, and any out-of-bounds access will result in undefined behavior.
@@ -665,7 +675,7 @@ def gather(array, indices, /, *, padding_value=0, check_bounds=True, latency=Non
665675

666676

667677
@function
668-
def scatter(array, indices, value, /, *, check_bounds=True, latency=None):
678+
def scatter(array, indices, value, /, *, mask=None, check_bounds=True, latency=None):
669679
"""
670680
Stores a tile `value` into the `array` elements specified by `indices`.
671681
@@ -692,11 +702,20 @@ def scatter(array, indices, value, /, *, check_bounds=True, latency=None):
692702
693703
>>> ct.scatter(array, ind0, value) # equivalent to ct.scatter(array, (ind0,), value)
694704
705+
A custom boolean `mask` can be provided to control which elements are stored.
706+
The mask must be a scalar or a tile whose shape is broadcastable to the common shape
707+
of indices. Where the mask is ``False``, no store occurs.
708+
695709
`scatter()` checks that indices are within the bounds of the array. For indices
696-
that are out of bounds, nothing is stored. To disable bounds checking,
697-
set `check_bounds` to ``False``. In this mode, the caller is responsible for ensuring that
698-
all indices are within the bounds of the array, and any out-of-bounds access
699-
will result in undefined behavior.
710+
that are out of bounds, nothing is stored.
711+
712+
If both `mask` and `check_bounds=True` are provided, the effective mask is the logical
713+
AND of both the custom mask and the bounds-checking mask. This means an element is only
714+
stored if both the custom mask is ``True`` AND the index is within bounds.
715+
716+
To disable bounds checking, set `check_bounds` to ``False``. In this mode, the caller
717+
is responsible for ensuring that all indices are within the bounds of the array, and
718+
any out-of-bounds access will result in undefined behavior.
700719
"""
701720

702721

test/test_gather_scatter.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,198 @@ def test_ir_checked_vs_unchecked(kernel, expected_mask):
211211
store_ops = [op for op in root_block.traverse() if isinstance(op, StorePointerTokenOrdered)]
212212
assert len(store_ops) == 1
213213
assert (store_ops[0].mask is not None) == expected_mask
214+
215+
216+
# ============================================================================
217+
# Tests for custom mask parameter
218+
# ============================================================================
219+
220+
@ct.kernel
221+
def gather_with_custom_mask_1d(x, y, mask_array):
222+
"""Test gather with custom boolean mask."""
223+
indices = ct.arange(8, dtype=ct.int32)
224+
# Load mask from array
225+
mask_tile = ct.gather(mask_array, indices)
226+
# Gather with custom mask, no bounds checking needed
227+
tx = ct.gather(x, indices, mask=mask_tile, padding_value=-999.0, check_bounds=False)
228+
ct.scatter(y, indices, tx)
229+
230+
231+
def test_gather_with_custom_mask_1d():
232+
"""Test gather with a custom mask that selectively loads elements."""
233+
x = torch.arange(8, dtype=torch.float32, device="cuda")
234+
y = torch.zeros(8, dtype=torch.float32, device="cuda")
235+
# Create a mask: load only even indices
236+
mask = torch.tensor([True, False, True, False, True, False, True, False],
237+
dtype=torch.bool, device="cuda")
238+
239+
ct.launch(torch.cuda.current_stream(), (1,), gather_with_custom_mask_1d, (x, y, mask))
240+
241+
# Expected: even indices get their values, odd indices get padding value -999.0
242+
expected = torch.tensor([0.0, -999.0, 2.0, -999.0, 4.0, -999.0, 6.0, -999.0],
243+
device="cuda")
244+
assert_equal(expected, y)
245+
246+
247+
@ct.kernel
248+
def gather_with_mask_and_bounds_check(x, y, indices_array, mask_array):
249+
"""Test gather with both custom mask and bounds checking."""
250+
idx = ct.arange(8, dtype=ct.int32)
251+
ind = ct.gather(indices_array, idx)
252+
mask_tile = ct.gather(mask_array, idx)
253+
# Both custom mask AND bounds checking
254+
tx = ct.gather(x, ind, mask=mask_tile, padding_value=-1.0, check_bounds=True)
255+
ct.scatter(y, idx, tx)
256+
257+
258+
def test_gather_with_mask_and_bounds_check():
259+
"""Test that custom mask AND bounds checking are combined correctly."""
260+
x = torch.arange(10, dtype=torch.float32, device="cuda") # array size 10
261+
y = torch.zeros(8, dtype=torch.float32, device="cuda")
262+
# Mix of valid indices, out-of-bounds indices, and masked indices
263+
# 15, 20 are OOB
264+
indices = torch.tensor([0, 1, 15, 3, 4, 20, 6, 7], dtype=torch.int32,
265+
device="cuda")
266+
mask = torch.tensor([True, True, True, False, True, True, False, True],
267+
dtype=torch.bool, device="cuda")
268+
269+
ct.launch(torch.cuda.current_stream(), (1,),
270+
gather_with_mask_and_bounds_check, (x, y, indices, mask))
271+
272+
# Expected behavior:
273+
# idx 0: mask=True, in-bounds (0<10) → load x[0]=0.0
274+
# idx 1: mask=True, in-bounds (1<10) → load x[1]=1.0
275+
# idx 2: mask=True, OOB (15>=10) → padding -1.0
276+
# idx 3: mask=False, in-bounds → padding -1.0
277+
# idx 4: mask=True, in-bounds (4<10) → load x[4]=4.0
278+
# idx 5: mask=True, OOB (20>=10) → padding -1.0
279+
# idx 6: mask=False, in-bounds → padding -1.0
280+
# idx 7: mask=True, in-bounds (7<10) → load x[7]=7.0
281+
expected = torch.tensor([0.0, 1.0, -1.0, -1.0, 4.0, -1.0, -1.0, 7.0], device="cuda")
282+
assert_equal(expected, y)
283+
284+
285+
@ct.kernel
286+
def scatter_with_custom_mask(x, y, mask_array):
287+
"""Test scatter with custom mask."""
288+
indices = ct.arange(8, dtype=ct.int32)
289+
mask_tile = ct.gather(mask_array, indices)
290+
values = ct.gather(x, indices)
291+
# Scatter with custom mask
292+
ct.scatter(y, indices, values, mask=mask_tile, check_bounds=False)
293+
294+
295+
def test_scatter_with_custom_mask():
296+
"""Test scatter with a custom mask that selectively stores elements."""
297+
# [100, 101, ..., 107]
298+
x = torch.arange(100, 108, dtype=torch.float32, device="cuda")
299+
y = torch.zeros(8, dtype=torch.float32, device="cuda")
300+
# Create a mask: store only at indices 0, 2, 4, 6
301+
mask = torch.tensor([True, False, True, False, True, False, True, False],
302+
dtype=torch.bool, device="cuda")
303+
304+
ct.launch(torch.cuda.current_stream(), (1,), scatter_with_custom_mask, (x, y, mask))
305+
306+
# Expected: only masked positions are written
307+
expected = torch.tensor([100.0, 0.0, 102.0, 0.0, 104.0, 0.0, 106.0, 0.0], device="cuda")
308+
assert_equal(expected, y)
309+
310+
311+
@ct.kernel
312+
def gather_2d_with_broadcast_mask(x, y, mask_array):
313+
"""Test gather with 2D indices and broadcasted mask."""
314+
# Create 2D indices that broadcast
315+
ind0 = ct.arange(4, dtype=ct.int32)[:, None] # shape (4, 1)
316+
ind1 = ct.arange(4, dtype=ct.int32) # shape (4,)
317+
# Load mask - it's already (4, 1) shaped
318+
mask_tile = ct.gather(mask_array, (ct.arange(4, dtype=ct.int32)[:, None], 0))
319+
# Gather with broadcasted mask: mask (4,1) broadcasts to (4,4)
320+
t = ct.gather(x, (ind0, ind1), mask=mask_tile, padding_value=0.0, check_bounds=False)
321+
# Flatten and store result
322+
ct.scatter(y, ct.arange(16, dtype=ct.int32), ct.reshape(t, (16,)))
323+
324+
325+
def test_gather_2d_with_broadcast_mask():
326+
"""Test that mask broadcasting works correctly with 2D indices."""
327+
x = torch.arange(16, dtype=torch.float32, device="cuda").reshape(4, 4)
328+
y = torch.zeros(16, dtype=torch.float32, device="cuda")
329+
# Mask shape (4, 1) - prepared outside kernel
330+
mask = torch.tensor([[True], [False], [True], [False]], dtype=torch.bool,
331+
device="cuda")
332+
333+
ct.launch(torch.cuda.current_stream(), (1,), gather_2d_with_broadcast_mask, (x, y, mask))
334+
335+
# ind0 (4,1): [[0], [1], [2], [3]]
336+
# ind1 (4,): [0, 1, 2, 3]
337+
# Broadcast to (4,4):
338+
# ind0: [[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]]
339+
# ind1: [[0,1,2,3], [0,1,2,3], [0,1,2,3], [0,1,2,3]]
340+
# Mask (4,1) broadcasts to (4,4):
341+
# [[T,T,T,T], [F,F,F,F], [T,T,T,T], [F,F,F,F]]
342+
# Expected gathered values (flattened):
343+
# Row 0 (mask=True): x[0,0], x[0,1], x[0,2], x[0,3] = [0, 1, 2, 3]
344+
# Row 1 (mask=False): [0, 0, 0, 0]
345+
# Row 2 (mask=True): x[2,0], x[2,1], x[2,2], x[2,3] = [8, 9, 10, 11]
346+
# Row 3 (mask=False): [0, 0, 0, 0]
347+
expected = torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 8, 9, 10, 11, 0, 0, 0, 0],
348+
dtype=torch.float32, device="cuda")
349+
assert_equal(expected, y)
350+
351+
352+
@ct.kernel
353+
def gather_with_scalar_mask(x, y, mask_val: ct.Constant[bool]):
354+
"""Test gather with scalar mask."""
355+
indices = ct.arange(8, dtype=ct.int32)
356+
tx = ct.gather(x, indices, mask=mask_val, padding_value=-1.0, check_bounds=False)
357+
ct.scatter(y, indices, tx)
358+
359+
360+
@pytest.mark.parametrize("mask_val", [True, False])
361+
def test_gather_with_scalar_mask(mask_val):
362+
"""Test that scalar masks work correctly."""
363+
x = torch.arange(8, dtype=torch.float32, device="cuda")
364+
y = torch.zeros(8, dtype=torch.float32, device="cuda")
365+
366+
ct.launch(torch.cuda.current_stream(), (1,), gather_with_scalar_mask, (x, y, mask_val))
367+
368+
if mask_val:
369+
# mask=True: all elements should be loaded
370+
expected = x
371+
else:
372+
# mask=False: all elements should be padding value
373+
expected = torch.full_like(x, -1.0)
374+
375+
assert_equal(expected, y)
376+
377+
378+
def test_mask_type_error():
379+
"""Test that providing non-boolean mask raises TileTypeError."""
380+
@ct.kernel
381+
def gather_with_int_mask(x, y):
382+
indices = ct.arange(8, dtype=ct.int32)
383+
mask = ct.arange(8, dtype=ct.int32) # Wrong: integer mask instead of boolean
384+
tx = ct.gather(x, indices, mask=mask, check_bounds=False)
385+
ct.scatter(y, indices, tx)
386+
387+
x = torch.arange(8, dtype=torch.float32, device="cuda")
388+
y = torch.zeros(8, dtype=torch.float32, device="cuda")
389+
390+
with pytest.raises(TileTypeError, match="boolean"):
391+
ct.launch(torch.cuda.current_stream(), (1,), gather_with_int_mask, (x, y))
392+
393+
394+
def test_mask_shape_error():
395+
"""Test that incompatible mask shape raises TileTypeError."""
396+
@ct.kernel
397+
def gather_with_wrong_shape_mask(x, y):
398+
indices = ct.arange(8, dtype=ct.int32)
399+
# Create mask with wrong shape: (4,) not broadcastable to (8,)
400+
mask_tile = ct.arange(4, dtype=ct.int32) > 0 # shape (4,), bool
401+
tx = ct.gather(x, indices, mask=mask_tile, check_bounds=False)
402+
ct.scatter(y, indices, tx)
403+
404+
x = torch.arange(8, dtype=torch.float32, device="cuda")
405+
y = torch.zeros(8, dtype=torch.float32, device="cuda")
406+
407+
with pytest.raises(TileTypeError, match="not broadcastable"):
408+
ct.launch(torch.cuda.current_stream(), (1,), gather_with_wrong_shape_mask, (x, y))

0 commit comments

Comments
 (0)