diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 6dc97c6b..e1801c37 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -17,6 +17,11 @@ include("callbacks.jl") include("containers.jl") include("resamplers.jl") +# Batching utilities +include("batching/batching.jl") +include("batching/batched_CUDA.jl") +include("batching/batched_SA.jl") + ## FILTERING BASE ########################################################################## abstract type AbstractFilter <: AbstractSampler end diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 878ea846..5dc7acae 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,7 +1,7 @@ export KalmanFilter, filter, BatchKalmanFilter using GaussianDistributions using CUDA: i32 -import LinearAlgebra: hermitianpart +import LinearAlgebra: hermitianpart, transpose, Cholesky export KalmanFilter, KF, KalmanSmoother, KS @@ -27,7 +27,7 @@ function predict( ) μ, Σ = GaussianDistributions.pair(state) A, b, Q = calc_params(model.dyn, iter; kwargs...) - return Gaussian(A * μ + b, A * Σ * A' + Q) + return Gaussian(A * μ + b, A * Σ * transpose(A) + Q) end function update( @@ -44,89 +44,21 @@ function update( # Update state m = H * μ + c y = observation - m - S = hermitianpart(H * Σ * H' + R) - K = Σ * H' / S + S = H * Σ * transpose(H) + R + S = (S + transpose(S)) / 2 # force symmetry + S_chol = cholesky(S) + KT = S_chol \ H * Σ # TODO: only using `\` for better integration with CuSolver - state = Gaussian(μ + K * y, Σ - K * H * Σ) + state = Gaussian(μ + transpose(KT) * y, Σ - transpose(KT) * H * Σ) # Compute log-likelihood - ll = logpdf(MvNormal(m, S), observation) + ll = gaussian_likelihood(m, S, observation) return state, ll end -struct BatchKalmanFilter <: AbstractBatchFilter - batch_size::Int -end - -function initialise( - rng::AbstractRNG, - model::LinearGaussianStateSpaceModel{T}, - algo::BatchKalmanFilter; - kwargs..., -) where {T} - μ0s, Σ0s = batch_calc_initial(model.dyn, algo.batch_size; kwargs...) - return BatchGaussianDistribution(μ0s, Σ0s) -end - -function predict( - rng::AbstractRNG, - model::LinearGaussianStateSpaceModel{T}, - algo::BatchKalmanFilter, - iter::Integer, - state::BatchGaussianDistribution, - observation; - kwargs..., -) where {T} - μs, Σs = state.μs, state.Σs - As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...) - μ̂s = NNlib.batched_vec(As, μs) .+ bs - Σ̂s = NNlib.batched_mul(NNlib.batched_mul(As, Σs), NNlib.batched_transpose(As)) .+ Qs - return BatchGaussianDistribution(μ̂s, Σ̂s) -end - -function update( - model::LinearGaussianStateSpaceModel{T}, - algo::BatchKalmanFilter, - iter::Integer, - state::BatchGaussianDistribution, - observation; - kwargs..., -) where {T} - μs, Σs = state.μs, state.Σs - Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...) - D = size(observation, 1) - - m = NNlib.batched_vec(Hs, μs) .+ cs - y_res = cu(observation) .- m - S = NNlib.batched_mul(Hs, NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))) .+ Rs - - ΣH_T = NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs)) - - S_inv = CUDA.similar(S) - d_ipiv, _, d_S = CUDA.CUBLAS.getrf_strided_batched(S, true) - CUDA.CUBLAS.getri_strided_batched!(d_S, S_inv, d_ipiv) - - diags = CuArray{eltype(S)}(undef, size(S, 1), size(S, 3)) - for i in 1:size(S, 1) - diags[i, :] .= d_S[i, i, :] - end - - log_dets = sum(log ∘ abs, diags; dims=1) - - K = NNlib.batched_mul(ΣH_T, S_inv) - - μ_filt = μs .+ NNlib.batched_vec(K, y_res) - Σ_filt = Σs .- NNlib.batched_mul(K, NNlib.batched_mul(Hs, Σs)) - - inv_term = NNlib.batched_vec(S_inv, y_res) - log_likes = -T(0.5) * NNlib.batched_vec(reshape(y_res, 1, D, size(S, 3)), inv_term) - log_likes = log_likes .- T(0.5) * (log_dets .+ D * log(T(2π))) - - # HACK: only errors seems to be from numerical stability so will just overwrite - log_likes[isnan.(log_likes)] .= -Inf - - return BatchGaussianDistribution(μ_filt, Σ_filt), dropdims(log_likes; dims=1) +function gaussian_likelihood(m::AbstractVector, S::AbstractMatrix, y::AbstractVector) + return logpdf(MvNormal(m, S), y) end ## KALMAN SMOOTHER ######################################################################### diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 879e1570..46468199 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -122,26 +122,34 @@ function predict( ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - proposed_particles = map(enumerate(state.particles)) do (i, particle) - if !isnothing(ref_state) && i == 1 - ref_state[iter] - else - simulate(rng, model, filter.proposal, iter, particle, observation; kwargs...) - end + proposed_particles = + SSMProblems.simulate.( + Ref(rng), + Ref(model), + Ref(filter.proposal), + Ref(iter), + state.particles, + Ref(observation), + kwargs..., + ) + if !isnothing(ref_state) + proposed_particles[1] = ref_state[iter] end - state.log_weights += - map(zip(proposed_particles, state.particles)) do (new_state, prev_state) - log_f = SSMProblems.logdensity( - model.dyn, iter, prev_state, new_state; kwargs... - ) - - log_q = SSMProblems.logdensity( - model, filter.proposal, iter, prev_state, new_state, observation; kwargs... - ) - - (log_f - log_q) - end + state.log_weights .+= + SSMProblems.logdensity.( + Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs... + ) + state.log_weights .-= + SSMProblems.logdensity.( + Ref(model), + Ref(filter.proposal), + Ref(iter), + state.particles, + proposed_particles, + Ref(observation); + kwargs..., + ) state.particles = proposed_particles @@ -156,10 +164,10 @@ function update( observation; kwargs..., ) where {T} - log_increments = map( - x -> SSMProblems.logdensity(model.obs, iter, x, observation; kwargs...), - state.particles, - ) + log_increments = + SSMProblems.logdensity.( + Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs... + ) state.log_weights += log_increments @@ -207,12 +215,12 @@ function predict( ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - state.particles = map(enumerate(state.particles)) do (i, particle) - if !isnothing(ref_state) && i == 1 - ref_state[iter] - else - SSMProblems.simulate(rng, model.dyn, iter, particle; kwargs...) - end + state.particles = + SSMProblems.simulate.( + Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs... + ) + if !isnothing(ref_state) + state.particles[1] = ref_state[iter] end return state @@ -233,3 +241,45 @@ function filter( ) return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) end + +# Broadcast wrapper for batched types +# TODO: this can likely be replaced with a broadcast style +function Base.Broadcast.broadcasted( + ::typeof(SSMProblems.simulate), + rng_ref::Base.RefValue, + model_dyn_ref::Base.RefValue, + iter_ref::Base.RefValue, + particles::BatchedVector; + kwargs..., +) + # Extract values from Ref and call non-broadcasted version + return SSMProblems.simulate( + rng_ref[], model_dyn_ref[], iter_ref[], particles; kwargs... + ) +end +function Base.Broadcast.broadcasted( + ::typeof(SSMProblems.logdensity), + model_obs_ref::Base.RefValue, + iter_ref::Base.RefValue, + particles::BatchedVector, + observation::Base.RefValue; + kwargs..., +) + # Extract values from Ref and call non-broadcasted version + return SSMProblems.logdensity( + model_obs_ref[], iter_ref[], particles, observation[]; kwargs... + ) +end +function Base.Broadcast.broadcasted( + ::typeof(SSMProblems.logdensity), + model_dyn_ref::Base.RefValue, + iter_ref::Base.RefValue, + prev_particles::BatchedVector, + new_particles::BatchedVector; + kwargs..., +) + # Extract values from Ref and call non-broadcasted version + return SSMProblems.logdensity( + model_dyn_ref[], iter_ref[], prev_particles, new_particles; kwargs... + ) +end diff --git a/GeneralisedFilters/src/batching/batched_CUDA.jl b/GeneralisedFilters/src/batching/batched_CUDA.jl new file mode 100644 index 00000000..715179fe --- /dev/null +++ b/GeneralisedFilters/src/batching/batched_CUDA.jl @@ -0,0 +1,382 @@ +import Base: *, +, -, transpose, getindex +import LinearAlgebra: Transpose, cholesky, \, /, I, UniformScaling, dot +import Distributions: logpdf +import Random: rand + +export BatchedCuVector, BatchedCuMatrix, BatchedCuCholesky + +########################### +#### VECTOR OPERATIONS #### +########################### + +struct BatchedCuVector{T} <: BatchedVector{T} + data::CuArray{T,2,CUDA.DeviceMemory} + ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} +end +function BatchedCuVector(data::CuArray{T,2}) where {T} + ptrs = CUDA.CUBLAS.unsafe_strided_batch(data) + return BatchedCuVector{T}(data, ptrs) +end +Base.eltype(::BatchedCuVector{T}) where {T} = T +Base.length(x::BatchedCuVector) = size(x.data, 2) +Base.getindex(x::BatchedCuVector, idxs) = BatchedCuVector(x.data[:, idxs]) + +function +(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} + z_data = x.data .+ y.data + return BatchedCuVector(z_data) +end + +function -(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} + z_data = x.data .- y.data + return BatchedCuVector(z_data) +end + +function dot(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} + if size(x.data, 1) != size(y.data, 1) + throw(DimensionMismatch("Vectors must have the same length for dot product")) + end + xy = x.data .* y.data + return dropdims(sum(xy; dims=1); dims=1) +end + +########################### +#### MATRIX OPERATIONS #### +########################### + +struct BatchedCuMatrix{T} <: BatchedMatrix{T} + data::CuArray{T,3,CUDA.DeviceMemory} + ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} +end +function BatchedCuMatrix(data::CuArray{T,3}) where {T} + ptrs = CUDA.CUBLAS.unsafe_strided_batch(data) + return BatchedCuMatrix{T}(data, ptrs) +end +Base.eltype(::BatchedCuMatrix{T}) where {T} = T +Base.length(A::BatchedCuMatrix) = size(A.data, 3) +Base.getindex(A::BatchedCuMatrix, idxs) = BatchedCuMatrix(A.data[:, :, idxs]) + +transpose(A::BatchedCuMatrix{T}) where {T} = Transpose{T,BatchedCuMatrix{T}}(A) + +function *(A::BatchedCuMatrix{T}, B::BatchedCuMatrix{T}) where {T} + C_data = CUDA.CUBLAS.gemm_strided_batched('N', 'N', A.data, B.data) + return BatchedCuMatrix(C_data) +end +function *(A::Transpose{T,BatchedCuMatrix{T}}, B::BatchedCuMatrix{T}) where {T} + C_data = CUDA.CUBLAS.gemm_strided_batched('T', 'N', A.parent.data, B.data) + return BatchedCuMatrix(C_data) +end +function *(A::BatchedCuMatrix{T}, B::Transpose{T,BatchedCuMatrix{T}}) where {T} + C_data = CUDA.CUBLAS.gemm_strided_batched('N', 'T', A.data, B.parent.data) + return BatchedCuMatrix(C_data) +end +function *(A::Transpose{T,BatchedCuMatrix{T}}, B::Transpose{T,BatchedCuMatrix{T}}) where {T} + C_data = CUDA.CUBLAS.gemm_strided_batched('T', 'T', A.parent.data, B.parent.data) + return BatchedCuMatrix(C_data) +end + +function +(A::BatchedCuMatrix{T}, B::BatchedCuMatrix{T}) where {T} + C_data = A.data .+ B.data + return BatchedCuMatrix(C_data) +end +function +(A::BatchedCuMatrix{T}, B::Transpose{T,BatchedCuMatrix{T}}) where {T} + C_data = A.data .+ permutedims(B.parent.data, (2, 1, 3)) + return BatchedCuMatrix(C_data) +end + +function -(A::BatchedCuMatrix{T}, B::BatchedCuMatrix{T}) where {T} + C_data = A.data .- B.data + return BatchedCuMatrix(C_data) +end + +function +(A::BatchedCuMatrix{T}, J::UniformScaling{<:Union{T,Bool}}) where {T} + m, n = size(A.data, 1), size(A.data, 2) + m == n || throw(DimensionMismatch("Matrix must be square for UniformScaling addition")) + B_data = copy(A.data) + for i in 1:m + B_data[i, i, :] .+= J.λ + end + return BatchedCuMatrix(B_data) +end + +################################## +#### MATRIX-VECTOR OPERATIONS #### +################################## + +function *(A::BatchedCuMatrix{T}, x::BatchedCuVector{T}) where {T} + y_data = CuArray{T}(undef, size(A.data, 1), size(x.data, 2)) + CUDA.CUBLAS.gemv_strided_batched!('N', T(1.0), A.data, x.data, T(0.0), y_data) + return BatchedCuVector(y_data) +end +function *(A::Transpose{T,BatchedCuMatrix{T}}, x::BatchedCuVector{T}) where {T} + y_data = CuArray{T}(undef, size(A.parent.data, 2), size(x.data, 2)) + CUDA.CUBLAS.gemv_strided_batched!('T', T(1.0), A.parent.data, x.data, T(0.0), y_data) + return BatchedCuVector(y_data) +end + +function getindex(A::BatchedCuMatrix{T}, i::Int, ::Colon) where {T} + row_data = A.data[i, :, :] + return BatchedCuVector(row_data) +end +function getindex(A::BatchedCuMatrix{T}, ::Colon, j::Int) where {T} + col_data = A.data[:, j, :] + return BatchedCuVector(col_data) +end + +########################### +#### SCALAR OPERATIONS #### +########################### + +function /(A::BatchedCuMatrix, s::Number) + C_data = A.data ./ s + return BatchedCuMatrix(C_data) +end + +######################### +#### POTR OPERATIONS #### +######################### + +struct BatchedCuCholesky{T} <: BatchedMatrix{T} + data::CuArray{T,3,CUDA.DeviceMemory} + ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} +end +function BatchedCuCholesky(data::CuArray{T,3}) where {T} + ptrs = CUDA.CUBLAS.unsafe_strided_batch(data) + return BatchedCuCholesky{T}(data, ptrs) +end +Base.eltype(::BatchedCuCholesky{T}) where {T} = T +Base.length(P::BatchedCuCholesky) = size(P.data, 3) +Base.getindex(P::BatchedCuCholesky, idxs) = BatchedCuCholesky(P.data[:, :, idxs]) + +for (fname, elty) in ( + (:cusolverDnSpotrfBatched, :Float32), + (:cusolverDnDpotrfBatched, :Float64), + (:cusolverDnCpotrfBatched, :ComplexF32), + (:cusolverDnZpotrfBatched, :ComplexF64), +) + @eval begin + function cholesky(A::BatchedCuMatrix{$elty}) + # HACK: assuming A is positive definite + m, n, b = size(A.data) + m == n || + throw(DimensionMismatch("Matrix must be square for Cholesky decomposition")) + + P_data = copy(A.data) + P = BatchedCuCholesky(P_data) + + dh = CUDA.CUSOLVER.dense_handle() + info = CuVector{Int}(undef, b) + CUDA.CUSOLVER.$fname(dh, 'L', m, P.ptrs, m, info, b) + + return P + end + end +end + +# TODO: CUSOLVER does not support matrix RHS for potrs; replace with MAGMA +for (fname, elty) in ( + (:cusolverDnSpotrsBatched, :Float32), + (:cusolverDnDpotrsBatched, :Float64), + (:cusolverDnCpotrsBatched, :ComplexF32), + (:cusolverDnZpotrsBatched, :ComplexF64), +) + @eval begin + function \(P::BatchedCuCholesky{$elty}, A::BatchedCuMatrix{$elty}) + m, n, b, = size(A.data) + # CUSOLVER does not support matrix RHS for potrs so solve each column separately + bs_data = Vector{CuMatrix{$elty}}(undef, n) + for i in 1:n + bs_data[i] = A.data[:, i, :] + b_ptr = CUDA.CUBLAS.unsafe_strided_batch(bs_data[i]) + + dh = CUDA.CUSOLVER.dense_handle() + info = CuVector{Int}(undef, b) + CUDA.CUSOLVER.$fname(dh, 'L', m, 1, P.ptrs, m, b_ptr, m, info, b) + end + + B_data = stack(bs_data; dims=2) + return BatchedCuMatrix(B_data) + end + end +end + +for (fname, elty) in ( + (:cusolverDnSpotrsBatched, :Float32), + (:cusolverDnDpotrsBatched, :Float64), + (:cusolverDnCpotrsBatched, :ComplexF32), + (:cusolverDnZpotrsBatched, :ComplexF64), +) + @eval begin + function \(P::BatchedCuCholesky{$elty}, x::BatchedCuVector{$elty}) + m, b = size(x.data) + y_data = copy(x.data) + y = BatchedCuVector(y_data) + dh = CUDA.CUSOLVER.dense_handle() + info = CuVector{Int}(undef, b) + CUDA.CUSOLVER.$fname(dh, 'L', m, 1, P.ptrs, m, y.ptrs, m, info, b) + return y + end + end +end + +########################## +#### MIXED OPERATIONS #### +########################## + +function -(x::CuVector{T}, y::BatchedCuVector{T}) where {T} + z_data = x .- y.data + return BatchedCuVector(z_data) +end +function +(x::CuVector{T}, y::BatchedCuVector{T}) where {T} + z_data = x .+ y.data + return BatchedCuVector(z_data) +end +function +(x::BatchedCuVector{T}, y::CuVector{T}) where {T} + z_data = x.data .+ y + return BatchedCuVector(z_data) +end +# TODO: these need to be generated automatically and call a common function +# TODO: are we best using the strided or non-strided version. The former don't need pointer duplication +for (fname, elty, gemv_batched) in ( + (:cublasSgemvBatched_64, :Float32, CUDA.CUBLAS.cublasSgemvBatched_64), + (:cublasDgemvBatched_64, :Float64, CUDA.CUBLAS.cublasDgemvBatched_64), + (:cublasCgemvBatched_64, :ComplexF32, CUDA.CUBLAS.cublasCgemvBatched_64), + (:cublasZgemvBatched_64, :ComplexF64, CUDA.CUBLAS.cublasZgemvBatched_64), +) + @eval begin + function *(A::BatchedCuMatrix{$elty}, x::CuVector{$elty}) + m, n, b = size(A.data) + y_data = CuArray{$elty}(undef, m, b) + y = BatchedCuVector(y_data) + + # Call gemv directly + x_ptrs = batch_singleton(x, b) + h = CUDA.CUBLAS.handle() + $gemv_batched( + h, 'N', m, n, $elty(1.0), A.ptrs, m, x_ptrs, 1, $elty(0.0), y.ptrs, 1, b + ) + return y + end + + function *(A::CuMatrix{$elty}, x::BatchedCuVector{$elty}) + m, n = size(A) + b = size(x.data, 2) + y_data = CuArray{$elty}(undef, m, b) + y = BatchedCuVector(y_data) + + # Call gemv directly + A_ptrs = batch_singleton(A, b) + h = CUDA.CUBLAS.handle() + $gemv_batched( + h, 'N', m, n, $elty(1.0), A_ptrs, m, x.ptrs, 1, $elty(0.0), y.ptrs, 1, b + ) + return y + end + end +end + +@inline function batch_singleton(array::DenseCuArray{T}, N::Int) where {T} + ptrs = CuArray{CuPtr{T}}(undef, N) + function compute_pointers() + i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x + grid_stride = gridDim().x * blockDim().x + while i <= length(ptrs) + @inbounds ptrs[i] = reinterpret(CuPtr{T}, pointer(array)) + i += grid_stride + end + return nothing + end + kernel = @cuda launch = false compute_pointers() + config = launch_configuration(kernel.fun) + threads = min(config.threads, N) + blocks = min(config.blocks, cld(N, threads)) + @cuda threads blocks compute_pointers() + return ptrs +end + +################################### +#### DISTRIBUTIONAL OPERATIONS #### +################################### + +# Can likely replace by using Gaussian +function gaussian_likelihood( + m::BatchedCuVector{T}, S::BatchedCuMatrix{T}, y::Union{BatchedCuVector{T},CuVector{T}} +) where {T} + D = size(S.data, 1) + y_res = y - m + + # TODO: avoid recomputing Cholesky decomposition + S_chol = cholesky(S) + + diags = CuArray{T}(undef, size(S.data, 1), size(S.data, 3)) + for i in 1:size(S.data, 1) + diags[i, :] = S_chol.data[i, i, :] + end + log_dets = T(2) * dropdims(sum(log.(diags); dims=1); dims=1) + + inv_term = S_chol \ y_res + log_likes = -T(0.5) * dot(y_res, inv_term) + log_likes .-= T(0.5) * (log_dets .+ D * log(T(2π))) + + # HACK: only errors seems to be from numerical stability so will just overwrite + log_likes[isnan.(log_likes)] .= -Inf + + return log_likes +end + +# HACK: this is more hard-coded than it needs to be. Can generalise to other batched types +# by taking advantage of the internal calls used in GaussianDistributions.jl +function Distributions.logpdf( + P::Gaussian{BatchedCuVector{T},BatchedCuMatrix{T}}, + y::Union{BatchedCuVector{T},CuVector{T}}, +) where {T} + return gaussian_likelihood(P.μ, P.Σ, y) +end +# HACK: MAJOR — this is just to handle a special case for bootstrap filter unit test until +# we have a general approach to this +function Distributions.logpdf( + P::Gaussian{BatchedCuVector{T},<:CuMatrix{T}}, y::CuVector{T} +) where {T} + # Stack Σ to form a batched matrix + Σ_data = CuArray{T}(undef, size(P.Σ)..., size(P.μ.data, 2)) + Σ_data[:, :, :] .= P.Σ + return gaussian_likelihood(P.μ, BatchedCuMatrix(Σ_data), y) +end + +# TODO: need to generalise to only one argument being batched +function Random.rand( + ::AbstractRNG, P::Gaussian{BatchedCuVector{T},BatchedCuMatrix{T}} +) where {T} + D, N = size(P.μ.data) + Σ_chol = cholesky(P.Σ) + Z = BatchedCuVector(CUDA.randn(T, D, N)) + # HACK: CUBLAS doesn't have batched trmv so we'll use gemm with zeroing out for now. + # Should later replace with MAGMA + L = BatchedCuMatrix(Σ_chol.data) + zero_upper_triangle!(L.data) + return P.μ + L * Z +end +# TODO: the singleton Cholesky should probably be handled on the CPU +function Random.rand(::AbstractRNG, P::Gaussian{BatchedCuVector{T},<:CuMatrix{T}}) where {T} + D, N = size(P.μ.data) + Σ_L = cholesky(P.Σ).L + Z = BatchedCuVector(CUDA.randn(T, D, N)) + return P.μ + CuArray(Σ_L) * Z +end + +function zero_upper_triangle!(A::CuArray{T,3}) where {T} + D, _, N = size(A) + + function kernel_zero_upper_triangle!(A, D) + i = threadIdx().x + j = threadIdx().y + k = blockIdx().x + + if i < j && i <= D && j <= D + A[i, j, k] = zero(eltype(A)) + end + return nothing + end + + @cuda threads = (D, D) blocks = N kernel_zero_upper_triangle!(A, D) + return nothing +end diff --git a/GeneralisedFilters/src/batching/batched_SA.jl b/GeneralisedFilters/src/batching/batched_SA.jl new file mode 100644 index 00000000..9fb4e74f --- /dev/null +++ b/GeneralisedFilters/src/batching/batched_SA.jl @@ -0,0 +1,8 @@ +import StaticArrays: SVector, SMatrix + +struct BatchedSAVector{T} <: BatchedVector{T} + data::Array{T,2} +end +Base.eltype(::BatchedSAVector{T}) where {T} = T + + diff --git a/GeneralisedFilters/src/batching/batching.jl b/GeneralisedFilters/src/batching/batching.jl new file mode 100644 index 00000000..7cee4234 --- /dev/null +++ b/GeneralisedFilters/src/batching/batching.jl @@ -0,0 +1,2 @@ +abstract type BatchedVector{T} end +abstract type BatchedMatrix{T} end diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 9920b9be..e829e546 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -8,15 +8,20 @@ A container for particle filters which composes the weighted sample into a distibution-like object, with the states (or particles) distributed accoring to their log-weights. """ -mutable struct ParticleDistribution{PT,WT<:Real} - particles::Vector{PT} - ancestors::Vector{Int} - log_weights::Vector{WT} +mutable struct ParticleDistribution{PT,AT,WT} + particles::PT + ancestors::AT + log_weights::WT end -function ParticleDistribution(particles::Vector{PT}, log_weights::Vector{WT}) where {PT,WT} +# TODO: these helpers might be more confusing than they're worth — and don't cover all cases +function ParticleDistribution(particles, log_weights::Vector) N = length(particles) return ParticleDistribution(particles, Vector{Int}(1:N), log_weights) end +function ParticleDistribution(particles, log_weights::CuVector) + N = length(particles) + return ParticleDistribution(particles, CuVector{Int}(1:N), log_weights) +end StatsBase.weights(state::ParticleDistribution) = softmax(state.log_weights) diff --git a/GeneralisedFilters/src/models/linear_gaussian.jl b/GeneralisedFilters/src/models/linear_gaussian.jl index da4a2eff..59a9cedb 100644 --- a/GeneralisedFilters/src/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/models/linear_gaussian.jl @@ -52,23 +52,24 @@ end #### DISTRIBUTIONS #### ####################### +# We choose Gaussian over MvNormal since it allows for batched types function SSMProblems.distribution(dyn::LinearGaussianLatentDynamics; kwargs...) μ0, Σ0 = calc_initial(dyn; kwargs...) - return MvNormal(μ0, Σ0) + return Gaussian(μ0, Σ0) end function SSMProblems.distribution( - dyn::LinearGaussianLatentDynamics, step::Integer, state::AbstractVector; kwargs... + dyn::LinearGaussianLatentDynamics, step::Integer, state; kwargs... ) A, b, Q = calc_params(dyn, step; kwargs...) - return MvNormal(A * state + b, Q) + return Gaussian(A * state + b, Q) end function SSMProblems.distribution( - obs::LinearGaussianObservationProcess, step::Integer, state::AbstractVector; kwargs... + obs::LinearGaussianObservationProcess, step::Integer, state; kwargs... ) H, c, R = calc_params(obs, step; kwargs...) - return MvNormal(H * state + c, R) + return Gaussian(H * state + c, R) end ########################################### @@ -76,12 +77,17 @@ end ########################################### struct HomogeneousLinearGaussianLatentDynamics{ - T<:Real,ΣT<:AbstractMatrix{T},AT<:AbstractMatrix{T},QT<:AbstractMatrix{T} + T<:Real, + μT<:Union{AbstractVector{T},BatchedVector{T}}, + ΣT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, + AT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, + bT<:Union{AbstractVector{T},BatchedVector{T}}, + QT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, } <: LinearGaussianLatentDynamics{T} - μ0::Vector{T} + μ0::μT Σ0::ΣT A::AT - b::Vector{T} + b::bT Q::QT end calc_μ0(dyn::HomogeneousLinearGaussianLatentDynamics; kwargs...) = dyn.μ0 @@ -91,10 +97,13 @@ calc_b(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn calc_Q(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn.Q struct HomogeneousLinearGaussianObservationProcess{ - T<:Real,HT<:AbstractMatrix{T},RT<:AbstractMatrix{T} + T<:Real, + HT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, + cT<:Union{AbstractVector{T},BatchedVector{T}}, + RT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, } <: LinearGaussianObservationProcess{T} H::HT - c::Vector{T} + c::cT R::RT end calc_H(obs::HomogeneousLinearGaussianObservationProcess, ::Integer; kwargs...) = obs.H diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index 143a2da4..a8f2e4c9 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -23,7 +23,7 @@ function resample( end function construct_new_state(states::ParticleDistribution{PT,WT}, idxs) where {PT,WT} - return ParticleDistribution(states.particles[idxs], idxs, zeros(WT, length(states))) + return ParticleDistribution(states.particles[idxs], idxs, zero(states.log_weights)) end function construct_new_state( diff --git a/GeneralisedFilters/test/batch_kalman_test.jl b/GeneralisedFilters/test/batch_kalman_test.jl index 07ba7353..18f28804 100644 --- a/GeneralisedFilters/test/batch_kalman_test.jl +++ b/GeneralisedFilters/test/batch_kalman_test.jl @@ -32,114 +32,49 @@ ] T = 5 - Ys = [[rand(rng, Dy) for _ in 1:T] for _ in 1:K] + ys_cpu = [rand(rng, Dy) for _ in 1:T] + Ys = [ys_cpu for _ in 1:K] outputs = [ GeneralisedFilters.filter(rng, models[k], KalmanFilter(), Ys[k]) for k in 1:K ] - states = first.(outputs) log_likelihoods = last.(outputs) - struct BatchLinearGaussianDynamics{T,MT} <: LinearGaussianLatentDynamics{T} - μ0s::CuArray{T,2,MT} - Σ0s::CuArray{T,3,MT} - As::CuArray{T,3,MT} - bs::CuArray{T,2,MT} - Qs::CuArray{T,3,MT} - end - - function BatchLinearGaussianDynamics( - μ0s::Vector{Vector{T}}, - Σ0s::Vector{Matrix{T}}, - As::Vector{Matrix{T}}, - bs::Vector{Vector{T}}, - Qs::Vector{Matrix{T}}, - ) where {T} - μ0s = CuArray(stack(μ0s)) - Σ0s = CuArray(stack(Σ0s)) - As = CuArray(stack(As)) - bs = CuArray(stack(bs)) - Qs = CuArray(stack(Qs)) - return BatchLinearGaussianDynamics(μ0s, Σ0s, As, bs, Qs) - end - - function GeneralisedFilters.batch_calc_μ0s( - dyn::BatchLinearGaussianDynamics, ::Integer; kwargs... - ) - return dyn.μ0s - end - function GeneralisedFilters.batch_calc_Σ0s( - dyn::BatchLinearGaussianDynamics, ::Integer; kwargs... - ) - return dyn.Σ0s - end - function GeneralisedFilters.batch_calc_As( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.As - end - function GeneralisedFilters.batch_calc_bs( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.bs - end - function GeneralisedFilters.batch_calc_Qs( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.Qs - end - - struct BatchLinearGaussianObservations{T,MT} <: LinearGaussianObservationProcess{T} - Hs::CuArray{T,3,MT} - cs::CuArray{T,2,MT} - Rs::CuArray{T,3,MT} - end - - function BatchLinearGaussianObservations( - Hs::Vector{Matrix{T}}, cs::Vector{Vector{T}}, Rs::Vector{Matrix{T}} + # Define batched model + μ0 = BatchedCuVector(cu(stack(μ0s))) + Σ0 = BatchedCuMatrix(cu(stack(Σ0s))) + A = BatchedCuMatrix(cu(stack(As))) + b = BatchedCuVector(cu(stack(bs))) + Q = BatchedCuMatrix(cu(stack(Qs))) + + dyn = GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics(μ0, Σ0, A, b, Q) + + H = BatchedCuMatrix(cu(stack(Hs))) + c = BatchedCuVector(cu(stack(cs))) + R = BatchedCuMatrix(cu(stack(Rs))) + + obs = GeneralisedFilters.HomogeneousLinearGaussianObservationProcess(H, c, R) + + ssm = StateSpaceModel(dyn, obs) + ys = cu.(ys_cpu) + + # Hack: manually setting of initialisation for this model + function GeneralisedFilters.initialise_log_evidence( + ::KalmanFilter, + model::StateSpaceModel{ + T, + <:GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics{ + T,<:BatchedCuVector + }, + }, ) where {T} - Hs = CuArray(stack(Hs)) - cs = CuArray(stack(cs)) - Rs = CuArray(stack(Rs)) - return BatchLinearGaussianObservations(Hs, cs, Rs) - end - - function GeneralisedFilters.batch_calc_Hs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.Hs - end - function GeneralisedFilters.batch_calc_cs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.cs + D = size(model.dyn.μ0.data, 2) + return CUDA.zeros(T, D) end - function GeneralisedFilters.batch_calc_Rs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.Rs - end - - batch_model = GeneralisedFilters.StateSpaceModel( - BatchLinearGaussianDynamics(μ0s, Σ0s, As, bs, Qs), - BatchLinearGaussianObservations(Hs, cs, Rs), - ) - - Ys_batch = Vector{Matrix{Float64}}(undef, T) - for t in 1:T - Ys_batch[t] = stack(Ys[k][t] for k in 1:K) - end - batch_output = GeneralisedFilters.filter( - rng, batch_model, BatchKalmanFilter(K), Ys_batch - ) - - # println("Batch log-likelihood: ", batch_output[2]) - # println("Individual log-likelihoods: ", log_likelihoods) - # println("Batch states: ", batch_output[1].μs') - # println("Individual states: ", getproperty.(states, :μ)) + state, ll = GeneralisedFilters.filter(rng, ssm, KalmanFilter(), ys) - @test Array(batch_output[2])[end] .≈ log_likelihoods[end] rtol = 1e-5 - @test Array(batch_output[1].μs) ≈ stack(getproperty.(states, :μ)) rtol = 1e-5 + @test all(isapprox.(Array(ll), log_likelihoods; rtol=1e-5)) + @test Array(state.μ.data) ≈ stack(getproperty.(states, :μ)) rtol = 1e-5 end diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index ea8dfa26..bfa37edb 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -102,6 +102,59 @@ end @test llkf ≈ llbf atol = 1e-1 end +@testitem "GPU bootstrap filter test" tags = [:gpu] begin + using SSMProblems + using StableRNGs + using LogExpFunctions: softmax + using CUDA + using LinearAlgebra + + rng = StableRNG(1234) + model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) + _, _, ys = sample(rng, model, 10) + + # HACK: convert model to CUDA + gpu_model = StateSpaceModel( + GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics( + cu(model.dyn.μ0), + cu(model.dyn.Σ0), + cu(model.dyn.A), + cu(model.dyn.b), + cu(model.dyn.Q), + ), + GeneralisedFilters.HomogeneousLinearGaussianObservationProcess( + cu(model.obs.H), cu(model.obs.c), cu(model.obs.R) + ), + ) + + bf = BF(2^12; threshold=1.0) + # HACK: run BF manually until initialisation interface is finalised + # Initialisation + Z = BatchedCuVector(CUDA.randn(Float32, 1, bf.N)) + Σ_L = cu(cholesky(model.dyn.Σ0).L) + initial_particles = cu(model.dyn.μ0) + Σ_L.data * Z + bf_state = GeneralisedFilters.ParticleDistribution( + initial_particles, CUDA.zeros(Float32, bf.N) + ) + llbf = 0.0 + for t in eachindex(ys) + global bf_state, llbf + bf_state, ll_increment = GeneralisedFilters.step( + rng, gpu_model, bf, t, bf_state, cu(ys[t]) + ) + llbf += ll_increment + end + + kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) + + xs = bf_state.particles.data[1, :] + ws = softmax(bf_state.log_weights) + + # Compare log-likelihood and states + @test first(kf_state.μ) ≈ sum(xs .* ws) rtol = 1e-2 + @test llkf ≈ llbf atol = 1e-1 +end + @testitem "Guided filter test" begin using SSMProblems using LogExpFunctions: softmax