Skip to content

Commit cca938c

Browse files
authored
Fix KA when CUDA is non-functional & re-enable tests (#1923)
* Fix KA when CUDA is non-functional & re-enable tests * disable for tpu
1 parent 07719f9 commit cca938c

File tree

2 files changed

+29
-41
lines changed

2 files changed

+29
-41
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ function Reactant.ka_with_reactant(ndrange, workgroupsize, obj, args...)
519519

520520
# figure out the optimal workgroupsize automatically
521521
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
522-
if !Reactant.Compiler.PartitionKA[] || raising()
522+
if !Reactant.Compiler.PartitionKA[] || raising() || !CUDA.functional()
523523
threads = prod(ndrange)
524524
else
525525
config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))

test/integration/kernelabstractions.jl

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,18 @@ function matmul!(output, a)
2121
return KernelAbstractions.synchronize(backend)
2222
end
2323

24-
# https://github.com/EnzymeAD/Reactant.jl/issues/614
25-
const skip_non_cuda_tests = true
26-
27-
@static if !Sys.isapple()
24+
# TODO: raising fails on TPU CI.
25+
# https://github.com/EnzymeAD/Reactant.jl/pull/1923#discussion_r2580461294
26+
if !Reactant.Accelerators.TPU.has_tpu()
2827
@testset "KernelAbstractions Matmul" begin
2928
A = Reactant.to_rarray(ones(100, 100))
3029
out = Reactant.to_rarray(ones(100, 100))
31-
if CUDA.functional()
32-
@test all(Array(@jit(matmul!(out, A))) .≈ 100) broken = true
33-
else
34-
@static if skip_non_cuda_tests
35-
@test false broken = true
36-
else
37-
@code_hlo optimize = :before_kernel matmul!(out, A)
38-
end
39-
end
30+
platform_name = Reactant.XLA.platform_name(Reactant.XLA.default_backend())
31+
raise = platform_name ("cpu", "cuda")
32+
@jit raise = raise matmul!(out, A)
33+
out_c = Array(out)
34+
A_c = Array(A)
35+
@test out_c A_c * A_c'
4036
end
4137
end
4238

@@ -54,34 +50,26 @@ function square(x)
5450
return y
5551
end
5652

57-
@static if !Sys.isapple()
58-
@testset "KernelAbstractions Square" begin
59-
x = Reactant.to_rarray(collect(1:1:64) ./ 64)
53+
@testset "KernelAbstractions Square" begin
54+
x = Reactant.to_rarray(collect(1:1:64) ./ 64)
6055

61-
b = get_backend(x)
62-
@test b isa
63-
Base.get_extension(Reactant, :ReactantKernelAbstractionsExt).ReactantBackend
64-
let y = allocate(b, Float32, (100, 10))
65-
@test y isa ConcreteRArray{Float32,2}
66-
@test size(y) == (100, 10)
67-
end
68-
let y = KernelAbstractions.zeros(b, Float32, (100, 10))
69-
@test y isa ConcreteRArray{Float32,2}
70-
@test Array(y) == zeros(Float32, 100, 10)
71-
end
72-
let y = KernelAbstractions.ones(b, Float32, (100, 10))
73-
@test y isa ConcreteRArray{Float32,2}
74-
@test Array(y) == ones(Float32, 100, 10)
75-
end
56+
platform_name = Reactant.XLA.platform_name(Reactant.XLA.default_backend())
57+
raise = platform_name ("cpu", "cuda")
7658

77-
if CUDA.functional()
78-
@test all(Array(@jit(square(x))) .≈ Array(x) .* Array(x))
79-
else
80-
@static if skip_non_cuda_tests
81-
@test false broken = true
82-
else
83-
@code_hlo optimize = :before_kernel square(x)
84-
end
85-
end
59+
b = get_backend(x)
60+
@test b isa Base.get_extension(Reactant, :ReactantKernelAbstractionsExt).ReactantBackend
61+
let y = allocate(b, Float32, (100, 10))
62+
@test y isa ConcreteRArray{Float32,2}
63+
@test size(y) == (100, 10)
64+
end
65+
let y = KernelAbstractions.zeros(b, Float32, (100, 10))
66+
@test y isa ConcreteRArray{Float32,2}
67+
@test Array(y) == zeros(Float32, 100, 10)
8668
end
69+
let y = KernelAbstractions.ones(b, Float32, (100, 10))
70+
@test y isa ConcreteRArray{Float32,2}
71+
@test Array(y) == ones(Float32, 100, 10)
72+
end
73+
74+
@test all(Array(@jit(raise = raise, square(x))) .≈ Array(x) .* Array(x))
8775
end

0 commit comments

Comments
 (0)