@@ -211,3 +211,198 @@ def test_ir_checked_vs_unchecked(kernel, expected_mask):
211211 store_ops = [op for op in root_block .traverse () if isinstance (op , StorePointerTokenOrdered )]
212212 assert len (store_ops ) == 1
213213 assert (store_ops [0 ].mask is not None ) == expected_mask
214+
215+
216+ # ============================================================================
217+ # Tests for custom mask parameter
218+ # ============================================================================
219+
220+ @ct .kernel
221+ def gather_with_custom_mask_1d (x , y , mask_array ):
222+ """Test gather with custom boolean mask."""
223+ indices = ct .arange (8 , dtype = ct .int32 )
224+ # Load mask from array
225+ mask_tile = ct .gather (mask_array , indices )
226+ # Gather with custom mask, no bounds checking needed
227+ tx = ct .gather (x , indices , mask = mask_tile , padding_value = - 999.0 , check_bounds = False )
228+ ct .scatter (y , indices , tx )
229+
230+
231+ def test_gather_with_custom_mask_1d ():
232+ """Test gather with a custom mask that selectively loads elements."""
233+ x = torch .arange (8 , dtype = torch .float32 , device = "cuda" )
234+ y = torch .zeros (8 , dtype = torch .float32 , device = "cuda" )
235+ # Create a mask: load only even indices
236+ mask = torch .tensor ([True , False , True , False , True , False , True , False ],
237+ dtype = torch .bool , device = "cuda" )
238+
239+ ct .launch (torch .cuda .current_stream (), (1 ,), gather_with_custom_mask_1d , (x , y , mask ))
240+
241+ # Expected: even indices get their values, odd indices get padding value -999.0
242+ expected = torch .tensor ([0.0 , - 999.0 , 2.0 , - 999.0 , 4.0 , - 999.0 , 6.0 , - 999.0 ],
243+ device = "cuda" )
244+ assert_equal (expected , y )
245+
246+
247+ @ct .kernel
248+ def gather_with_mask_and_bounds_check (x , y , indices_array , mask_array ):
249+ """Test gather with both custom mask and bounds checking."""
250+ idx = ct .arange (8 , dtype = ct .int32 )
251+ ind = ct .gather (indices_array , idx )
252+ mask_tile = ct .gather (mask_array , idx )
253+ # Both custom mask AND bounds checking
254+ tx = ct .gather (x , ind , mask = mask_tile , padding_value = - 1.0 , check_bounds = True )
255+ ct .scatter (y , idx , tx )
256+
257+
258+ def test_gather_with_mask_and_bounds_check ():
259+ """Test that custom mask AND bounds checking are combined correctly."""
260+ x = torch .arange (10 , dtype = torch .float32 , device = "cuda" ) # array size 10
261+ y = torch .zeros (8 , dtype = torch .float32 , device = "cuda" )
262+ # Mix of valid indices, out-of-bounds indices, and masked indices
263+ # 15, 20 are OOB
264+ indices = torch .tensor ([0 , 1 , 15 , 3 , 4 , 20 , 6 , 7 ], dtype = torch .int32 ,
265+ device = "cuda" )
266+ mask = torch .tensor ([True , True , True , False , True , True , False , True ],
267+ dtype = torch .bool , device = "cuda" )
268+
269+ ct .launch (torch .cuda .current_stream (), (1 ,),
270+ gather_with_mask_and_bounds_check , (x , y , indices , mask ))
271+
272+ # Expected behavior:
273+ # idx 0: mask=True, in-bounds (0<10) → load x[0]=0.0
274+ # idx 1: mask=True, in-bounds (1<10) → load x[1]=1.0
275+ # idx 2: mask=True, OOB (15>=10) → padding -1.0
276+ # idx 3: mask=False, in-bounds → padding -1.0
277+ # idx 4: mask=True, in-bounds (4<10) → load x[4]=4.0
278+ # idx 5: mask=True, OOB (20>=10) → padding -1.0
279+ # idx 6: mask=False, in-bounds → padding -1.0
280+ # idx 7: mask=True, in-bounds (7<10) → load x[7]=7.0
281+ expected = torch .tensor ([0.0 , 1.0 , - 1.0 , - 1.0 , 4.0 , - 1.0 , - 1.0 , 7.0 ], device = "cuda" )
282+ assert_equal (expected , y )
283+
284+
285+ @ct .kernel
286+ def scatter_with_custom_mask (x , y , mask_array ):
287+ """Test scatter with custom mask."""
288+ indices = ct .arange (8 , dtype = ct .int32 )
289+ mask_tile = ct .gather (mask_array , indices )
290+ values = ct .gather (x , indices )
291+ # Scatter with custom mask
292+ ct .scatter (y , indices , values , mask = mask_tile , check_bounds = False )
293+
294+
295+ def test_scatter_with_custom_mask ():
296+ """Test scatter with a custom mask that selectively stores elements."""
297+ # [100, 101, ..., 107]
298+ x = torch .arange (100 , 108 , dtype = torch .float32 , device = "cuda" )
299+ y = torch .zeros (8 , dtype = torch .float32 , device = "cuda" )
300+ # Create a mask: store only at indices 0, 2, 4, 6
301+ mask = torch .tensor ([True , False , True , False , True , False , True , False ],
302+ dtype = torch .bool , device = "cuda" )
303+
304+ ct .launch (torch .cuda .current_stream (), (1 ,), scatter_with_custom_mask , (x , y , mask ))
305+
306+ # Expected: only masked positions are written
307+ expected = torch .tensor ([100.0 , 0.0 , 102.0 , 0.0 , 104.0 , 0.0 , 106.0 , 0.0 ], device = "cuda" )
308+ assert_equal (expected , y )
309+
310+
311+ @ct .kernel
312+ def gather_2d_with_broadcast_mask (x , y , mask_array ):
313+ """Test gather with 2D indices and broadcasted mask."""
314+ # Create 2D indices that broadcast
315+ ind0 = ct .arange (4 , dtype = ct .int32 )[:, None ] # shape (4, 1)
316+ ind1 = ct .arange (4 , dtype = ct .int32 ) # shape (4,)
317+ # Load mask - it's already (4, 1) shaped
318+ mask_tile = ct .gather (mask_array , (ct .arange (4 , dtype = ct .int32 )[:, None ], 0 ))
319+ # Gather with broadcasted mask: mask (4,1) broadcasts to (4,4)
320+ t = ct .gather (x , (ind0 , ind1 ), mask = mask_tile , padding_value = 0.0 , check_bounds = False )
321+ # Flatten and store result
322+ ct .scatter (y , ct .arange (16 , dtype = ct .int32 ), ct .reshape (t , (16 ,)))
323+
324+
325+ def test_gather_2d_with_broadcast_mask ():
326+ """Test that mask broadcasting works correctly with 2D indices."""
327+ x = torch .arange (16 , dtype = torch .float32 , device = "cuda" ).reshape (4 , 4 )
328+ y = torch .zeros (16 , dtype = torch .float32 , device = "cuda" )
329+ # Mask shape (4, 1) - prepared outside kernel
330+ mask = torch .tensor ([[True ], [False ], [True ], [False ]], dtype = torch .bool ,
331+ device = "cuda" )
332+
333+ ct .launch (torch .cuda .current_stream (), (1 ,), gather_2d_with_broadcast_mask , (x , y , mask ))
334+
335+ # ind0 (4,1): [[0], [1], [2], [3]]
336+ # ind1 (4,): [0, 1, 2, 3]
337+ # Broadcast to (4,4):
338+ # ind0: [[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]]
339+ # ind1: [[0,1,2,3], [0,1,2,3], [0,1,2,3], [0,1,2,3]]
340+ # Mask (4,1) broadcasts to (4,4):
341+ # [[T,T,T,T], [F,F,F,F], [T,T,T,T], [F,F,F,F]]
342+ # Expected gathered values (flattened):
343+ # Row 0 (mask=True): x[0,0], x[0,1], x[0,2], x[0,3] = [0, 1, 2, 3]
344+ # Row 1 (mask=False): [0, 0, 0, 0]
345+ # Row 2 (mask=True): x[2,0], x[2,1], x[2,2], x[2,3] = [8, 9, 10, 11]
346+ # Row 3 (mask=False): [0, 0, 0, 0]
347+ expected = torch .tensor ([0 , 1 , 2 , 3 , 0 , 0 , 0 , 0 , 8 , 9 , 10 , 11 , 0 , 0 , 0 , 0 ],
348+ dtype = torch .float32 , device = "cuda" )
349+ assert_equal (expected , y )
350+
351+
352+ @ct .kernel
353+ def gather_with_scalar_mask (x , y , mask_val : ct .Constant [bool ]):
354+ """Test gather with scalar mask."""
355+ indices = ct .arange (8 , dtype = ct .int32 )
356+ tx = ct .gather (x , indices , mask = mask_val , padding_value = - 1.0 , check_bounds = False )
357+ ct .scatter (y , indices , tx )
358+
359+
360+ @pytest .mark .parametrize ("mask_val" , [True , False ])
361+ def test_gather_with_scalar_mask (mask_val ):
362+ """Test that scalar masks work correctly."""
363+ x = torch .arange (8 , dtype = torch .float32 , device = "cuda" )
364+ y = torch .zeros (8 , dtype = torch .float32 , device = "cuda" )
365+
366+ ct .launch (torch .cuda .current_stream (), (1 ,), gather_with_scalar_mask , (x , y , mask_val ))
367+
368+ if mask_val :
369+ # mask=True: all elements should be loaded
370+ expected = x
371+ else :
372+ # mask=False: all elements should be padding value
373+ expected = torch .full_like (x , - 1.0 )
374+
375+ assert_equal (expected , y )
376+
377+
378+ def test_mask_type_error ():
379+ """Test that providing non-boolean mask raises TileTypeError."""
380+ @ct .kernel
381+ def gather_with_int_mask (x , y ):
382+ indices = ct .arange (8 , dtype = ct .int32 )
383+ mask = ct .arange (8 , dtype = ct .int32 ) # Wrong: integer mask instead of boolean
384+ tx = ct .gather (x , indices , mask = mask , check_bounds = False )
385+ ct .scatter (y , indices , tx )
386+
387+ x = torch .arange (8 , dtype = torch .float32 , device = "cuda" )
388+ y = torch .zeros (8 , dtype = torch .float32 , device = "cuda" )
389+
390+ with pytest .raises (TileTypeError , match = "boolean" ):
391+ ct .launch (torch .cuda .current_stream (), (1 ,), gather_with_int_mask , (x , y ))
392+
393+
394+ def test_mask_shape_error ():
395+ """Test that incompatible mask shape raises TileTypeError."""
396+ @ct .kernel
397+ def gather_with_wrong_shape_mask (x , y ):
398+ indices = ct .arange (8 , dtype = ct .int32 )
399+ # Create mask with wrong shape: (4,) not broadcastable to (8,)
400+ mask_tile = ct .arange (4 , dtype = ct .int32 ) > 0 # shape (4,), bool
401+ tx = ct .gather (x , indices , mask = mask_tile , check_bounds = False )
402+ ct .scatter (y , indices , tx )
403+
404+ x = torch .arange (8 , dtype = torch .float32 , device = "cuda" )
405+ y = torch .zeros (8 , dtype = torch .float32 , device = "cuda" )
406+
407+ with pytest .raises (TileTypeError , match = "not broadcastable" ):
408+ ct .launch (torch .cuda .current_stream (), (1 ,), gather_with_wrong_shape_mask , (x , y ))
0 commit comments