From 610205d2a1804c94d1f65500e015f5e1187a2e5f Mon Sep 17 00:00:00 2001 From: nhz2 Date: Fri, 15 May 2026 18:15:47 -0400 Subject: [PATCH 1/3] Avoid extra philox calls with ElementRNG --- src/host/random.jl | 33 +++++++++++++++------------------ test/testsuite/random.jl | 16 ++++++++++++++++ 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/host/random.jl b/src/host/random.jl index 5beb7d31..8cbc50ab 100644 --- a/src/host/random.jl +++ b/src/host/random.jl @@ -416,19 +416,15 @@ end @inline Random.rng_native_52(::ElementRNG) = UInt64 -@inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{UInt64}) - sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads - unsafe_store!(rng.ctr0_ptr, sc) - a1, a2, _, _ = philox4x32_10(sc, rng.counter, rng.seed) - UInt64(a1) | UInt64(a2) << 32 -end - -@inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{UInt128}) - UInt128(rand(rng, Random.SamplerType{UInt64}())) | - UInt128(rand(rng, Random.SamplerType{UInt64}())) << 64 +for T in (UInt64, UInt128, Int128) + @eval @inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{$T}) + sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads + unsafe_store!(rng.ctr0_ptr, sc) + first(philox_to_vals($T, philox4x32_10(sc, rng.counter, rng.seed)...)) + end end -@inline Random.rand(rng::ElementRNG, ::Random.SamplerType{T}) where T <: Union{Bool,Base.BitInteger} = +@inline Random.rand(rng::ElementRNG, ::Random.SamplerType{T}) where T <: Union{UInt8, UInt16, UInt32, Int8, Int16, Int32, Int64, Bool} = rand(rng, Random.SamplerType{UInt64}()) % T @@ -447,7 +443,7 @@ end end end -function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number +function Random.rand!(rng::RNG, A::AnyGPUArray) isempty(A) && return A rand_generic_kernel!(get_backend(A))(rng.seed, rng.counter, A; ndrange=length(A)) advance_counter!(rng) @@ -546,12 +542,13 @@ end # tables aren't device-accessible, and on Metal the Float64 tables can't even # be loaded. Reached via Base's Complex recursion when the element type is # e.g. Complex{Float16}. -@inline Random.randn(rng::ElementRNG, ::Type{Float16}) = - first(boxmuller(Float16, rand(rng, UInt32), rand(rng, UInt32))) -@inline Random.randn(rng::ElementRNG, ::Type{Float32}) = - first(boxmuller(Float32, rand(rng, UInt32), rand(rng, UInt32))) -@inline Random.randn(rng::ElementRNG, ::Type{Float64}) = - first(boxmuller(Float64, rand(rng, UInt64), rand(rng, UInt64))) +for T in (Float16, Float32, Float64, ComplexF32, ComplexF64) + @eval @inline function Random.randn(rng::ElementRNG, ::Type{$T}) + sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads + unsafe_store!(rng.ctr0_ptr, sc) + first(philox_to_normals($T, philox4x32_10(sc, rng.counter, rng.seed)...)) + end +end @kernel function randn_generic_kernel!(seed::UInt64, counter::UInt64, A::AbstractArray{T}) where T gid = @index(Global, Linear) diff --git a/test/testsuite/random.jl b/test/testsuite/random.jl index 2cd65a3c..aeae2b86 100644 --- a/test/testsuite/random.jl +++ b/test/testsuite/random.jl @@ -27,6 +27,22 @@ fill!(A, true) rand!(rng, A) @test false in Array(A) + + # Complex{Int128} + A = AT{Complex{Int128}}(undef, 1024) + rand!(rng, A) + @test count(x -> real(x) < 0, Array(A)) > 0 + out = Array(A) + @test real(out[1]) != imag(out[1]) + + # Tuples + A = AT{NTuple{2, Int128}}(undef, 1024) + rand!(rng, A) + @test count(x -> first(x) < 0, Array(A)) > 0 + + A = AT{NTuple{2, Int16}}(undef, 1024) + rand!(rng, A) + @test count(x -> first(x) < 0, Array(A)) > 0 end @testset "randn" begin # normally-distributed From 0782d379f0fab885b7fc1b8ddd9db1a443f9a73a Mon Sep 17 00:00:00 2001 From: nhz2 Date: Fri, 15 May 2026 22:40:06 -0400 Subject: [PATCH 2/3] Only test Int128 with JLArray --- test/testsuite/random.jl | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/test/testsuite/random.jl b/test/testsuite/random.jl index aeae2b86..0a97ec28 100644 --- a/test/testsuite/random.jl +++ b/test/testsuite/random.jl @@ -28,18 +28,20 @@ rand!(rng, A) @test false in Array(A) - # Complex{Int128} - A = AT{Complex{Int128}}(undef, 1024) - rand!(rng, A) - @test count(x -> real(x) < 0, Array(A)) > 0 - out = Array(A) - @test real(out[1]) != imag(out[1]) - - # Tuples - A = AT{NTuple{2, Int128}}(undef, 1024) - rand!(rng, A) - @test count(x -> first(x) < 0, Array(A)) > 0 + # Int128 is not supported on many backends yet + if nameof(AT) == :JLArray + # Complex{Int128} + A = AT{Complex{Int128}}(undef, 1024) + rand!(rng, A) + @test count(x -> real(x) < 0, Array(A)) > 0 + out = Array(A) + @test real(out[1]) != imag(out[1]) + # Tuples + A = AT{NTuple{2, Int128}}(undef, 1024) + rand!(rng, A) + @test count(x -> first(x) < 0, Array(A)) > 0 + end A = AT{NTuple{2, Int16}}(undef, 1024) rand!(rng, A) @test count(x -> first(x) < 0, Array(A)) > 0 From bdebfe20e4fa30e86e82e719a75870634e95c2be Mon Sep 17 00:00:00 2001 From: nhz2 Date: Sat, 16 May 2026 13:22:16 -0400 Subject: [PATCH 3/3] Tuple rand needs at least Julia 1.11 --- test/testsuite/random.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/testsuite/random.jl b/test/testsuite/random.jl index 0a97ec28..057809a4 100644 --- a/test/testsuite/random.jl +++ b/test/testsuite/random.jl @@ -36,15 +36,13 @@ @test count(x -> real(x) < 0, Array(A)) > 0 out = Array(A) @test real(out[1]) != imag(out[1]) - - # Tuples - A = AT{NTuple{2, Int128}}(undef, 1024) + end + # rand support for Tuple requires at least Julia 1.11 + if VERSION ≥ v"1.11" + A = AT{NTuple{5, Int64}}(undef, 1024) rand!(rng, A) - @test count(x -> first(x) < 0, Array(A)) > 0 + @test allunique(collect(Iterators.flatten(Array(A)))) end - A = AT{NTuple{2, Int16}}(undef, 1024) - rand!(rng, A) - @test count(x -> first(x) < 0, Array(A)) > 0 end @testset "randn" begin # normally-distributed