diff --git a/src/host/random.jl b/src/host/random.jl index 5beb7d31..8cbc50ab 100644 --- a/src/host/random.jl +++ b/src/host/random.jl @@ -416,19 +416,15 @@ end @inline Random.rng_native_52(::ElementRNG) = UInt64 -@inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{UInt64}) - sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads - unsafe_store!(rng.ctr0_ptr, sc) - a1, a2, _, _ = philox4x32_10(sc, rng.counter, rng.seed) - UInt64(a1) | UInt64(a2) << 32 -end - -@inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{UInt128}) - UInt128(rand(rng, Random.SamplerType{UInt64}())) | - UInt128(rand(rng, Random.SamplerType{UInt64}())) << 64 +for T in (UInt64, UInt128, Int128) + @eval @inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{$T}) + sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads + unsafe_store!(rng.ctr0_ptr, sc) + first(philox_to_vals($T, philox4x32_10(sc, rng.counter, rng.seed)...)) + end end -@inline Random.rand(rng::ElementRNG, ::Random.SamplerType{T}) where T <: Union{Bool,Base.BitInteger} = +@inline Random.rand(rng::ElementRNG, ::Random.SamplerType{T}) where T <: Union{UInt8, UInt16, UInt32, Int8, Int16, Int32, Int64, Bool} = rand(rng, Random.SamplerType{UInt64}()) % T @@ -447,7 +443,7 @@ end end end -function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number +function Random.rand!(rng::RNG, A::AnyGPUArray) isempty(A) && return A rand_generic_kernel!(get_backend(A))(rng.seed, rng.counter, A; ndrange=length(A)) advance_counter!(rng) @@ -546,12 +542,13 @@ end # tables aren't device-accessible, and on Metal the Float64 tables can't even # be loaded. Reached via Base's Complex recursion when the element type is # e.g. Complex{Float16}. -@inline Random.randn(rng::ElementRNG, ::Type{Float16}) = - first(boxmuller(Float16, rand(rng, UInt32), rand(rng, UInt32))) -@inline Random.randn(rng::ElementRNG, ::Type{Float32}) = - first(boxmuller(Float32, rand(rng, UInt32), rand(rng, UInt32))) -@inline Random.randn(rng::ElementRNG, ::Type{Float64}) = - first(boxmuller(Float64, rand(rng, UInt64), rand(rng, UInt64))) +for T in (Float16, Float32, Float64, ComplexF32, ComplexF64) + @eval @inline function Random.randn(rng::ElementRNG, ::Type{$T}) + sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads + unsafe_store!(rng.ctr0_ptr, sc) + first(philox_to_normals($T, philox4x32_10(sc, rng.counter, rng.seed)...)) + end +end @kernel function randn_generic_kernel!(seed::UInt64, counter::UInt64, A::AbstractArray{T}) where T gid = @index(Global, Linear) diff --git a/test/testsuite/random.jl b/test/testsuite/random.jl index 2cd65a3c..057809a4 100644 --- a/test/testsuite/random.jl +++ b/test/testsuite/random.jl @@ -27,6 +27,22 @@ fill!(A, true) rand!(rng, A) @test false in Array(A) + + # Int128 is not supported on many backends yet + if nameof(AT) == :JLArray + # Complex{Int128} + A = AT{Complex{Int128}}(undef, 1024) + rand!(rng, A) + @test count(x -> real(x) < 0, Array(A)) > 0 + out = Array(A) + @test real(out[1]) != imag(out[1]) + end + # rand support for Tuple requires at least Julia 1.11 + if VERSION ≥ v"1.11" + A = AT{NTuple{5, Int64}}(undef, 1024) + rand!(rng, A) + @test allunique(collect(Iterators.flatten(Array(A)))) + end end @testset "randn" begin # normally-distributed