@@ -21,22 +21,18 @@ function matmul!(output, a)
2121 return KernelAbstractions. synchronize (backend)
2222end
2323
24- # https://github.com/EnzymeAD/Reactant.jl/issues/614
25- const skip_non_cuda_tests = true
26-
27- @static if ! Sys. isapple ()
24+ # TODO : raising fails on TPU CI.
25+ # https://github.com/EnzymeAD/Reactant.jl/pull/1923#discussion_r2580461294
26+ if ! Reactant. Accelerators. TPU. has_tpu ()
2827 @testset " KernelAbstractions Matmul" begin
2928 A = Reactant. to_rarray (ones (100 , 100 ))
3029 out = Reactant. to_rarray (ones (100 , 100 ))
31- if CUDA. functional ()
32- @test all (Array (@jit (matmul! (out, A))) .≈ 100 ) broken = true
33- else
34- @static if skip_non_cuda_tests
35- @test false broken = true
36- else
37- @code_hlo optimize = :before_kernel matmul! (out, A)
38- end
39- end
30+ platform_name = Reactant. XLA. platform_name (Reactant. XLA. default_backend ())
31+ raise = platform_name ∉ (" cpu" , " cuda" )
32+ @jit raise = raise matmul! (out, A)
33+ out_c = Array (out)
34+ A_c = Array (A)
35+ @test out_c ≈ A_c * A_c'
4036 end
4137end
4238
@@ -54,34 +50,26 @@ function square(x)
5450 return y
5551end
5652
57- @static if ! Sys. isapple ()
58- @testset " KernelAbstractions Square" begin
59- x = Reactant. to_rarray (collect (1 : 1 : 64 ) ./ 64 )
53+ @testset " KernelAbstractions Square" begin
54+ x = Reactant. to_rarray (collect (1 : 1 : 64 ) ./ 64 )
6055
61- b = get_backend (x)
62- @test b isa
63- Base. get_extension (Reactant, :ReactantKernelAbstractionsExt ). ReactantBackend
64- let y = allocate (b, Float32, (100 , 10 ))
65- @test y isa ConcreteRArray{Float32,2 }
66- @test size (y) == (100 , 10 )
67- end
68- let y = KernelAbstractions. zeros (b, Float32, (100 , 10 ))
69- @test y isa ConcreteRArray{Float32,2 }
70- @test Array (y) == zeros (Float32, 100 , 10 )
71- end
72- let y = KernelAbstractions. ones (b, Float32, (100 , 10 ))
73- @test y isa ConcreteRArray{Float32,2 }
74- @test Array (y) == ones (Float32, 100 , 10 )
75- end
56+ platform_name = Reactant. XLA. platform_name (Reactant. XLA. default_backend ())
57+ raise = platform_name ∉ (" cpu" , " cuda" )
7658
77- if CUDA . functional ( )
78- @test all ( Array ( @jit ( square (x))) .≈ Array (x) .* Array (x))
79- else
80- @static if skip_non_cuda_tests
81- @test false broken = true
82- else
83- @code_hlo optimize = :before_kernel square (x )
84- end
85- end
59+ b = get_backend (x )
60+ @test b isa Base . get_extension (Reactant, :ReactantKernelAbstractionsExt ) . ReactantBackend
61+ let y = allocate (b, Float32, ( 100 , 10 ))
62+ @test y isa ConcreteRArray{Float32, 2 }
63+ @test size (y) == ( 100 , 10 )
64+ end
65+ let y = KernelAbstractions . zeros (b, Float32, ( 100 , 10 ) )
66+ @test y isa ConcreteRArray{Float32, 2 }
67+ @test Array (y) == zeros (Float32, 100 , 10 )
8668 end
69+ let y = KernelAbstractions. ones (b, Float32, (100 , 10 ))
70+ @test y isa ConcreteRArray{Float32,2 }
71+ @test Array (y) == ones (Float32, 100 , 10 )
72+ end
73+
74+ @test all (Array (@jit (raise = raise, square (x))) .≈ Array (x) .* Array (x))
8775end
0 commit comments