Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions src/host/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,19 +416,15 @@ end

@inline Random.rng_native_52(::ElementRNG) = UInt64

@inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{UInt64})
sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads
unsafe_store!(rng.ctr0_ptr, sc)
a1, a2, _, _ = philox4x32_10(sc, rng.counter, rng.seed)
UInt64(a1) | UInt64(a2) << 32
end

@inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{UInt128})
UInt128(rand(rng, Random.SamplerType{UInt64}())) |
UInt128(rand(rng, Random.SamplerType{UInt64}())) << 64
for T in (UInt64, UInt128, Int128)
@eval @inline function Random.rand(rng::ElementRNG, ::Random.SamplerType{$T})
sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads
unsafe_store!(rng.ctr0_ptr, sc)
first(philox_to_vals($T, philox4x32_10(sc, rng.counter, rng.seed)...))
end
end

@inline Random.rand(rng::ElementRNG, ::Random.SamplerType{T}) where T <: Union{Bool,Base.BitInteger} =
@inline Random.rand(rng::ElementRNG, ::Random.SamplerType{T}) where T <: Union{UInt8, UInt16, UInt32, Int8, Int16, Int32, Int64, Bool} =
rand(rng, Random.SamplerType{UInt64}()) % T


Expand All @@ -447,7 +443,7 @@ end
end
end

function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
function Random.rand!(rng::RNG, A::AnyGPUArray)
isempty(A) && return A
rand_generic_kernel!(get_backend(A))(rng.seed, rng.counter, A; ndrange=length(A))
advance_counter!(rng)
Expand Down Expand Up @@ -546,12 +542,13 @@ end
# tables aren't device-accessible, and on Metal the Float64 tables can't even
# be loaded. Reached via Base's Complex recursion when the element type is
# e.g. Complex{Float16}.
@inline Random.randn(rng::ElementRNG, ::Type{Float16}) =
first(boxmuller(Float16, rand(rng, UInt32), rand(rng, UInt32)))
@inline Random.randn(rng::ElementRNG, ::Type{Float32}) =
first(boxmuller(Float32, rand(rng, UInt32), rand(rng, UInt32)))
@inline Random.randn(rng::ElementRNG, ::Type{Float64}) =
first(boxmuller(Float64, rand(rng, UInt64), rand(rng, UInt64)))
for T in (Float16, Float32, Float64, ComplexF32, ComplexF64)
@eval @inline function Random.randn(rng::ElementRNG, ::Type{$T})
sc = unsafe_load(rng.ctr0_ptr) + rng.nthreads
unsafe_store!(rng.ctr0_ptr, sc)
first(philox_to_normals($T, philox4x32_10(sc, rng.counter, rng.seed)...))
end
end

@kernel function randn_generic_kernel!(seed::UInt64, counter::UInt64, A::AbstractArray{T}) where T
gid = @index(Global, Linear)
Expand Down
16 changes: 16 additions & 0 deletions test/testsuite/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@
fill!(A, true)
rand!(rng, A)
@test false in Array(A)

# Int128 is not supported on many backends yet
if nameof(AT) == :JLArray
# Complex{Int128}
A = AT{Complex{Int128}}(undef, 1024)
rand!(rng, A)
@test count(x -> real(x) < 0, Array(A)) > 0
out = Array(A)
@test real(out[1]) != imag(out[1])
end
# rand support for Tuple requires at least Julia 1.11
if VERSION ≥ v"1.11"
A = AT{NTuple{5, Int64}}(undef, 1024)
rand!(rng, A)
@test allunique(collect(Iterators.flatten(Array(A))))
end
end

@testset "randn" begin # normally-distributed
Expand Down
Loading