From 2917471022533ba59e568e84a7fc8ba3ee405bf2 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Fri, 6 Jun 2025 16:01:14 +0100 Subject: [PATCH 1/7] Create prototype batching interface and wrappers --- GeneralisedFilters/src/GeneralisedFilters.jl | 3 + .../src/batching/batched_CUDA.jl | 181 ++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 GeneralisedFilters/src/batching/batched_CUDA.jl diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 6dc97c6b..e10b3fe8 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -17,6 +17,9 @@ include("callbacks.jl") include("containers.jl") include("resamplers.jl") +# Batching utilities +include("batching/batched_CUDA.jl") + ## FILTERING BASE ########################################################################## abstract type AbstractFilter <: AbstractSampler end diff --git a/GeneralisedFilters/src/batching/batched_CUDA.jl b/GeneralisedFilters/src/batching/batched_CUDA.jl new file mode 100644 index 00000000..46f5ff20 --- /dev/null +++ b/GeneralisedFilters/src/batching/batched_CUDA.jl @@ -0,0 +1,181 @@ +import Base: *, +, -, transpose, getindex +import LinearAlgebra: Transpose, cholesky, \, I, UniformScaling + +export BatchedCuVector, BatchedCuMatrix, BatchedCuCholesky + +########################### +#### VECTOR OPERATIONS #### +########################### + +struct BatchedCuVector{T} + data::CuArray{T,2,CUDA.DeviceMemory} + ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} +end +function BatchedCuVector(data::CuArray{T,2}) where {T} + ptrs = CUDA.CUBLAS.unsafe_strided_batch(data) + return BatchedCuVector{T}(data, ptrs) +end +Base.eltype(::BatchedCuVector{T}) where {T} = T + +function +(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} + z_data = x.data .+ y.data + return BatchedCuVector(z_data) +end + +function -(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} + z_data = x.data .- y.data + return BatchedCuVector(z_data) +end + +########################### +#### MATRIX OPERATIONS #### +########################### + +struct BatchedCuMatrix{T} + data::CuArray{T,3,CUDA.DeviceMemory} + ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} +end +function BatchedCuMatrix(data::CuArray{T,3}) where {T} + ptrs = CUDA.CUBLAS.unsafe_strided_batch(data) + return BatchedCuMatrix{T}(data, ptrs) +end +Base.eltype(::BatchedCuMatrix{T}) where {T} = T + +transpose(A::BatchedCuMatrix{T}) where {T} = Transpose{T,BatchedCuMatrix{T}}(A) + +function *(A::BatchedCuMatrix{T}, B::BatchedCuMatrix{T}) where {T} + C_data = CUDA.CUBLAS.gemm_strided_batched('N', 'N', A.data, B.data) + return BatchedCuMatrix(C_data) +end +function *(A::Transpose{T,BatchedCuMatrix{T}}, B::BatchedCuMatrix{T}) where {T} + C_data = CUDA.CUBLAS.gemm_strided_batched('T', 'N', A.parent.data, B.data) + return BatchedCuMatrix(C_data) +end +function *(A::BatchedCuMatrix{T}, B::Transpose{T,BatchedCuMatrix{T}}) where {T} + C_data = CUDA.CUBLAS.gemm_strided_batched('N', 'T', A.data, B.parent.data) + return BatchedCuMatrix(C_data) +end +function *(A::Transpose{T,BatchedCuMatrix{T}}, B::Transpose{T,BatchedCuMatrix{T}}) where {T} + C_data = CUDA.CUBLAS.gemm_strided_batched('T', 'T', A.parent.data, B.parent.data) + return BatchedCuMatrix(C_data) +end + +function +(A::BatchedCuMatrix{T}, B::BatchedCuMatrix{T}) where {T} + C_data = A.data .+ B.data + return BatchedCuMatrix(C_data) +end + +function -(A::BatchedCuMatrix{T}, B::BatchedCuMatrix{T}) where {T} + C_data = A.data .- B.data + return BatchedCuMatrix(C_data) +end + +function +(A::BatchedCuMatrix{T}, J::UniformScaling{<:Union{T,Bool}}) where {T} + m, n = size(A.data, 1), size(A.data, 2) + m == n || throw(DimensionMismatch("Matrix must be square for UniformScaling addition")) + B_data = copy(A.data) + for i in 1:m + B_data[i, i, :] .+= J.λ + end + return BatchedCuMatrix(B_data) +end + +################################## +#### MATRIX-VECTOR OPERATIONS #### +################################## + +function *(A::BatchedCuMatrix{T}, x::BatchedCuVector{T}) where {T} + y_data = CuArray{T}(undef, size(A.data, 1), size(x.data, 2)) + CUDA.CUBLAS.gemv_strided_batched('N', T(1.0), A.data, x.data, T(0.0), y_data) + return BatchedCuVector(y_data) +end +function *(A::Transpose{T,BatchedCuMatrix{T}}, x::BatchedCuVector{T}) where {T} + y_data = CuArray{T}(undef, size(A.data, 2), size(x.data, 2)) + CUDA.CUBLAS.gemv_strided_batched('T', T(1.0), A.data, x.data, T(0.0), y_data) + return BatchedCuVector(y_data) +end + +function getindex(A::BatchedCuMatrix{T}, i::Int, ::Colon) where {T} + row_data = A.data[i, :, :] + return BatchedCuVector(row_data) +end +function getindex(A::BatchedCuMatrix{T}, ::Colon, j::Int) where {T} + col_data = A.data[:, j, :] + return BatchedCuVector(col_data) +end + +######################### +#### POTR OPERATIONS #### +######################### + +struct BatchedCuCholesky{T} + data::CuArray{T,3,CUDA.DeviceMemory} + ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} +end +function BatchedCuCholesky(data::CuArray{T,3}) where {T} + ptrs = CUDA.CUBLAS.unsafe_strided_batch(data) + return BatchedCuCholesky{T}(data, ptrs) +end +Base.eltype(::BatchedCuCholesky{T}) where {T} = T + +for (fname, elty) in ( + (:cusolverDnSpotrfBatched, :Float32), + (:cusolverDnDpotrfBatched, :Float64), + (:cusolverDnCpotrfBatched, :ComplexF32), + (:cusolverDnZpotrfBatched, :ComplexF64), +) + @eval begin + function cholesky(A::BatchedCuMatrix{$elty}) + # HACK: assuming A is positive definite + m, n, b = size(A.data) + m == n || + throw(DimensionMismatch("Matrix must be square for Cholesky decomposition")) + + L_data = copy(A.data) + L = BatchedCuCholesky(L_data) + + dh = CUDA.CUSOLVER.dense_handle() + info = CuVector{Int}(undef, b) + CUDA.CUSOLVER.$fname(dh, 'L', m, L.ptrs, m, info, b) + + return L + end + end +end + +# TODO: CUSOLVER does not support matrix RHS for potrs; replace with MAGMA +for (fname, elty) in ( + (:cusolverDnSpotrsBatched, :Float32), + (:cusolverDnDpotrsBatched, :Float64), + (:cusolverDnCpotrsBatched, :ComplexF32), + (:cusolverDnZpotrsBatched, :ComplexF64), +) + @eval begin + function \(L::BatchedCuCholesky{$elty}, A::BatchedCuMatrix{$elty}) + m, n, b, = size(A.data) + # CUSOLVER does not support matrix RHS for potrs so solve each column separately + bs_data = Vector{CuMatrix{$elty}}(undef, n) + for i in 1:n + bs_data[i] = A.data[:, i, :] + b_ptr = CUDA.CUBLAS.unsafe_strided_batch(bs_data[i]) + + dh = CUDA.CUSOLVER.dense_handle() + info = CuVector{Int}(undef, b) + CUDA.CUSOLVER.$fname(dh, 'L', m, 1, L.ptrs, m, b_ptr, m, info, b) + end + + B_data = stack(bs_data; dims=2) + return BatchedCuMatrix(B_data) + end + end +end + +########################## +#### MIXED OPERATIONS #### +########################## + +function -(x::CuVector{T}, y::BatchedCuVector{T}) where {T} + z_data = x .- y.data + return BatchedCuVector(z_data) +end + From 646f4de924c6f4296e19e9b6b0c9595011f324d6 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Fri, 6 Jun 2025 16:19:22 +0100 Subject: [PATCH 2/7] Generalise LGSSM types to included batched arrays --- .../src/batching/batched_CUDA.jl | 20 ++++++++++--------- .../src/models/linear_gaussian.jl | 18 ++++++++++++----- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/GeneralisedFilters/src/batching/batched_CUDA.jl b/GeneralisedFilters/src/batching/batched_CUDA.jl index 46f5ff20..89eac95b 100644 --- a/GeneralisedFilters/src/batching/batched_CUDA.jl +++ b/GeneralisedFilters/src/batching/batched_CUDA.jl @@ -3,11 +3,14 @@ import LinearAlgebra: Transpose, cholesky, \, I, UniformScaling export BatchedCuVector, BatchedCuMatrix, BatchedCuCholesky +abstract type BatchedVector{T} end +abstract type BatchedMatrix{T} end + ########################### #### VECTOR OPERATIONS #### ########################### -struct BatchedCuVector{T} +struct BatchedCuVector{T} <: BatchedVector{T} data::CuArray{T,2,CUDA.DeviceMemory} ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} end @@ -31,7 +34,7 @@ end #### MATRIX OPERATIONS #### ########################### -struct BatchedCuMatrix{T} +struct BatchedCuMatrix{T} <: BatchedMatrix{T} data::CuArray{T,3,CUDA.DeviceMemory} ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} end @@ -108,7 +111,7 @@ end #### POTR OPERATIONS #### ######################### -struct BatchedCuCholesky{T} +struct BatchedCuCholesky{T} <: BatchedMatrix{T} data::CuArray{T,3,CUDA.DeviceMemory} ptrs::CuVector{CuPtr{T},CUDA.DeviceMemory} end @@ -131,14 +134,14 @@ for (fname, elty) in ( m == n || throw(DimensionMismatch("Matrix must be square for Cholesky decomposition")) - L_data = copy(A.data) - L = BatchedCuCholesky(L_data) + P_data = copy(A.data) + P = BatchedCuCholesky(L_data) dh = CUDA.CUSOLVER.dense_handle() info = CuVector{Int}(undef, b) CUDA.CUSOLVER.$fname(dh, 'L', m, L.ptrs, m, info, b) - return L + return P end end end @@ -151,7 +154,7 @@ for (fname, elty) in ( (:cusolverDnZpotrsBatched, :ComplexF64), ) @eval begin - function \(L::BatchedCuCholesky{$elty}, A::BatchedCuMatrix{$elty}) + function \(P::BatchedCuCholesky{$elty}, A::BatchedCuMatrix{$elty}) m, n, b, = size(A.data) # CUSOLVER does not support matrix RHS for potrs so solve each column separately bs_data = Vector{CuMatrix{$elty}}(undef, n) @@ -161,7 +164,7 @@ for (fname, elty) in ( dh = CUDA.CUSOLVER.dense_handle() info = CuVector{Int}(undef, b) - CUDA.CUSOLVER.$fname(dh, 'L', m, 1, L.ptrs, m, b_ptr, m, info, b) + CUDA.CUSOLVER.$fname(dh, 'L', m, 1, P.ptrs, m, b_ptr, m, info, b) end B_data = stack(bs_data; dims=2) @@ -178,4 +181,3 @@ function -(x::CuVector{T}, y::BatchedCuVector{T}) where {T} z_data = x .- y.data return BatchedCuVector(z_data) end - diff --git a/GeneralisedFilters/src/models/linear_gaussian.jl b/GeneralisedFilters/src/models/linear_gaussian.jl index da4a2eff..50aa291b 100644 --- a/GeneralisedFilters/src/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/models/linear_gaussian.jl @@ -76,12 +76,17 @@ end ########################################### struct HomogeneousLinearGaussianLatentDynamics{ - T<:Real,ΣT<:AbstractMatrix{T},AT<:AbstractMatrix{T},QT<:AbstractMatrix{T} + T<:Real, + μT<:Union{AbstractVector{T},BatchedVector{T}}, + ΣT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, + AT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, + bT<:Union{AbstractVector{T},BatchedVector{T}}, + QT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, } <: LinearGaussianLatentDynamics{T} - μ0::Vector{T} + μ0::μT Σ0::ΣT A::AT - b::Vector{T} + b::bT Q::QT end calc_μ0(dyn::HomogeneousLinearGaussianLatentDynamics; kwargs...) = dyn.μ0 @@ -91,10 +96,13 @@ calc_b(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn calc_Q(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn.Q struct HomogeneousLinearGaussianObservationProcess{ - T<:Real,HT<:AbstractMatrix{T},RT<:AbstractMatrix{T} + T<:Real, + HT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, + cT<:Union{AbstractVector{T},BatchedVector{T}}, + RT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, } <: LinearGaussianObservationProcess{T} H::HT - c::Vector{T} + c::cT R::RT end calc_H(obs::HomogeneousLinearGaussianObservationProcess, ::Integer; kwargs...) = obs.H From 70d20562e39f79c62133d13ae9dc10b2ab4c464e Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Fri, 6 Jun 2025 17:17:03 +0100 Subject: [PATCH 3/7] Combine batch and regular KF --- GeneralisedFilters/src/algorithms/kalman.jl | 88 +++---------------- .../src/batching/batched_CUDA.jl | 81 +++++++++++++++-- 2 files changed, 85 insertions(+), 84 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 878ea846..5dc7acae 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,7 +1,7 @@ export KalmanFilter, filter, BatchKalmanFilter using GaussianDistributions using CUDA: i32 -import LinearAlgebra: hermitianpart +import LinearAlgebra: hermitianpart, transpose, Cholesky export KalmanFilter, KF, KalmanSmoother, KS @@ -27,7 +27,7 @@ function predict( ) μ, Σ = GaussianDistributions.pair(state) A, b, Q = calc_params(model.dyn, iter; kwargs...) - return Gaussian(A * μ + b, A * Σ * A' + Q) + return Gaussian(A * μ + b, A * Σ * transpose(A) + Q) end function update( @@ -44,89 +44,21 @@ function update( # Update state m = H * μ + c y = observation - m - S = hermitianpart(H * Σ * H' + R) - K = Σ * H' / S + S = H * Σ * transpose(H) + R + S = (S + transpose(S)) / 2 # force symmetry + S_chol = cholesky(S) + KT = S_chol \ H * Σ # TODO: only using `\` for better integration with CuSolver - state = Gaussian(μ + K * y, Σ - K * H * Σ) + state = Gaussian(μ + transpose(KT) * y, Σ - transpose(KT) * H * Σ) # Compute log-likelihood - ll = logpdf(MvNormal(m, S), observation) + ll = gaussian_likelihood(m, S, observation) return state, ll end -struct BatchKalmanFilter <: AbstractBatchFilter - batch_size::Int -end - -function initialise( - rng::AbstractRNG, - model::LinearGaussianStateSpaceModel{T}, - algo::BatchKalmanFilter; - kwargs..., -) where {T} - μ0s, Σ0s = batch_calc_initial(model.dyn, algo.batch_size; kwargs...) - return BatchGaussianDistribution(μ0s, Σ0s) -end - -function predict( - rng::AbstractRNG, - model::LinearGaussianStateSpaceModel{T}, - algo::BatchKalmanFilter, - iter::Integer, - state::BatchGaussianDistribution, - observation; - kwargs..., -) where {T} - μs, Σs = state.μs, state.Σs - As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...) - μ̂s = NNlib.batched_vec(As, μs) .+ bs - Σ̂s = NNlib.batched_mul(NNlib.batched_mul(As, Σs), NNlib.batched_transpose(As)) .+ Qs - return BatchGaussianDistribution(μ̂s, Σ̂s) -end - -function update( - model::LinearGaussianStateSpaceModel{T}, - algo::BatchKalmanFilter, - iter::Integer, - state::BatchGaussianDistribution, - observation; - kwargs..., -) where {T} - μs, Σs = state.μs, state.Σs - Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...) - D = size(observation, 1) - - m = NNlib.batched_vec(Hs, μs) .+ cs - y_res = cu(observation) .- m - S = NNlib.batched_mul(Hs, NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))) .+ Rs - - ΣH_T = NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs)) - - S_inv = CUDA.similar(S) - d_ipiv, _, d_S = CUDA.CUBLAS.getrf_strided_batched(S, true) - CUDA.CUBLAS.getri_strided_batched!(d_S, S_inv, d_ipiv) - - diags = CuArray{eltype(S)}(undef, size(S, 1), size(S, 3)) - for i in 1:size(S, 1) - diags[i, :] .= d_S[i, i, :] - end - - log_dets = sum(log ∘ abs, diags; dims=1) - - K = NNlib.batched_mul(ΣH_T, S_inv) - - μ_filt = μs .+ NNlib.batched_vec(K, y_res) - Σ_filt = Σs .- NNlib.batched_mul(K, NNlib.batched_mul(Hs, Σs)) - - inv_term = NNlib.batched_vec(S_inv, y_res) - log_likes = -T(0.5) * NNlib.batched_vec(reshape(y_res, 1, D, size(S, 3)), inv_term) - log_likes = log_likes .- T(0.5) * (log_dets .+ D * log(T(2π))) - - # HACK: only errors seems to be from numerical stability so will just overwrite - log_likes[isnan.(log_likes)] .= -Inf - - return BatchGaussianDistribution(μ_filt, Σ_filt), dropdims(log_likes; dims=1) +function gaussian_likelihood(m::AbstractVector, S::AbstractMatrix, y::AbstractVector) + return logpdf(MvNormal(m, S), y) end ## KALMAN SMOOTHER ######################################################################### diff --git a/GeneralisedFilters/src/batching/batched_CUDA.jl b/GeneralisedFilters/src/batching/batched_CUDA.jl index 89eac95b..d53b756f 100644 --- a/GeneralisedFilters/src/batching/batched_CUDA.jl +++ b/GeneralisedFilters/src/batching/batched_CUDA.jl @@ -1,5 +1,5 @@ import Base: *, +, -, transpose, getindex -import LinearAlgebra: Transpose, cholesky, \, I, UniformScaling +import LinearAlgebra: Transpose, cholesky, \, /, I, UniformScaling, dot export BatchedCuVector, BatchedCuMatrix, BatchedCuCholesky @@ -30,6 +30,14 @@ function -(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} return BatchedCuVector(z_data) end +function dot(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} + if size(x.data, 1) != size(y.data, 1) + throw(DimensionMismatch("Vectors must have the same length for dot product")) + end + xy = x.data .* y.data + return dropdims(sum(xy; dims=1); dims=1) +end + ########################### #### MATRIX OPERATIONS #### ########################### @@ -67,6 +75,10 @@ function +(A::BatchedCuMatrix{T}, B::BatchedCuMatrix{T}) where {T} C_data = A.data .+ B.data return BatchedCuMatrix(C_data) end +function +(A::BatchedCuMatrix{T}, B::Transpose{T,BatchedCuMatrix{T}}) where {T} + C_data = A.data .+ permutedims(B.parent.data, (2, 1, 3)) + return BatchedCuMatrix(C_data) +end function -(A::BatchedCuMatrix{T}, B::BatchedCuMatrix{T}) where {T} C_data = A.data .- B.data @@ -89,12 +101,12 @@ end function *(A::BatchedCuMatrix{T}, x::BatchedCuVector{T}) where {T} y_data = CuArray{T}(undef, size(A.data, 1), size(x.data, 2)) - CUDA.CUBLAS.gemv_strided_batched('N', T(1.0), A.data, x.data, T(0.0), y_data) + CUDA.CUBLAS.gemv_strided_batched!('N', T(1.0), A.data, x.data, T(0.0), y_data) return BatchedCuVector(y_data) end function *(A::Transpose{T,BatchedCuMatrix{T}}, x::BatchedCuVector{T}) where {T} - y_data = CuArray{T}(undef, size(A.data, 2), size(x.data, 2)) - CUDA.CUBLAS.gemv_strided_batched('T', T(1.0), A.data, x.data, T(0.0), y_data) + y_data = CuArray{T}(undef, size(A.parent.data, 2), size(x.data, 2)) + CUDA.CUBLAS.gemv_strided_batched!('T', T(1.0), A.parent.data, x.data, T(0.0), y_data) return BatchedCuVector(y_data) end @@ -107,6 +119,15 @@ function getindex(A::BatchedCuMatrix{T}, ::Colon, j::Int) where {T} return BatchedCuVector(col_data) end +########################### +#### SCALAR OPERATIONS #### +########################### + +function /(A::BatchedCuMatrix, s::Number) + C_data = A.data ./ s + return BatchedCuMatrix(C_data) +end + ######################### #### POTR OPERATIONS #### ######################### @@ -135,11 +156,11 @@ for (fname, elty) in ( throw(DimensionMismatch("Matrix must be square for Cholesky decomposition")) P_data = copy(A.data) - P = BatchedCuCholesky(L_data) + P = BatchedCuCholesky(P_data) dh = CUDA.CUSOLVER.dense_handle() info = CuVector{Int}(undef, b) - CUDA.CUSOLVER.$fname(dh, 'L', m, L.ptrs, m, info, b) + CUDA.CUSOLVER.$fname(dh, 'L', m, P.ptrs, m, info, b) return P end @@ -173,6 +194,25 @@ for (fname, elty) in ( end end +for (fname, elty) in ( + (:cusolverDnSpotrsBatched, :Float32), + (:cusolverDnDpotrsBatched, :Float64), + (:cusolverDnCpotrsBatched, :ComplexF32), + (:cusolverDnZpotrsBatched, :ComplexF64), +) + @eval begin + function \(P::BatchedCuCholesky{$elty}, x::BatchedCuVector{$elty}) + m, b = size(x.data) + y_data = copy(x.data) + y = BatchedCuVector(y_data) + dh = CUDA.CUSOLVER.dense_handle() + info = CuVector{Int}(undef, b) + CUDA.CUSOLVER.$fname(dh, 'L', m, 1, P.ptrs, m, y.ptrs, m, info, b) + return y + end + end +end + ########################## #### MIXED OPERATIONS #### ########################## @@ -181,3 +221,32 @@ function -(x::CuVector{T}, y::BatchedCuVector{T}) where {T} z_data = x .- y.data return BatchedCuVector(z_data) end + +################################### +#### DISTRIBUTIONAL OPERATIONS #### +################################### + +function gaussian_likelihood( + m::BatchedCuVector{T}, S::BatchedCuMatrix{T}, y::Union{BatchedCuVector{T},CuVector{T}} +) where {T} + D = size(S.data, 1) + y_res = y - m + + # TODO: avoid recomputing Cholesky decomposition + S_chol = cholesky(S) + + diags = CuArray{T}(undef, size(S.data, 1), size(S.data, 3)) + for i in 1:size(S.data, 1) + diags[i, :] = S_chol.data[i, i, :] + end + log_dets = T(2) * dropdims(sum(log.(diags); dims=1); dims=1) + + inv_term = S_chol \ y_res + log_likes = -T(0.5) * dot(y_res, inv_term) + log_likes .-= T(0.5) * (log_dets .+ D * log(T(2π))) + + # HACK: only errors seems to be from numerical stability so will just overwrite + log_likes[isnan.(log_likes)] .= -Inf + + return log_likes +end From dec6b20ad40174cfae9f107e47c98a93eb2d49f1 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Fri, 6 Jun 2025 17:17:15 +0100 Subject: [PATCH 4/7] Modify batch KF test to new interface --- GeneralisedFilters/test/batch_kalman_test.jl | 133 +++++-------------- 1 file changed, 34 insertions(+), 99 deletions(-) diff --git a/GeneralisedFilters/test/batch_kalman_test.jl b/GeneralisedFilters/test/batch_kalman_test.jl index 07ba7353..18f28804 100644 --- a/GeneralisedFilters/test/batch_kalman_test.jl +++ b/GeneralisedFilters/test/batch_kalman_test.jl @@ -32,114 +32,49 @@ ] T = 5 - Ys = [[rand(rng, Dy) for _ in 1:T] for _ in 1:K] + ys_cpu = [rand(rng, Dy) for _ in 1:T] + Ys = [ys_cpu for _ in 1:K] outputs = [ GeneralisedFilters.filter(rng, models[k], KalmanFilter(), Ys[k]) for k in 1:K ] - states = first.(outputs) log_likelihoods = last.(outputs) - struct BatchLinearGaussianDynamics{T,MT} <: LinearGaussianLatentDynamics{T} - μ0s::CuArray{T,2,MT} - Σ0s::CuArray{T,3,MT} - As::CuArray{T,3,MT} - bs::CuArray{T,2,MT} - Qs::CuArray{T,3,MT} - end - - function BatchLinearGaussianDynamics( - μ0s::Vector{Vector{T}}, - Σ0s::Vector{Matrix{T}}, - As::Vector{Matrix{T}}, - bs::Vector{Vector{T}}, - Qs::Vector{Matrix{T}}, - ) where {T} - μ0s = CuArray(stack(μ0s)) - Σ0s = CuArray(stack(Σ0s)) - As = CuArray(stack(As)) - bs = CuArray(stack(bs)) - Qs = CuArray(stack(Qs)) - return BatchLinearGaussianDynamics(μ0s, Σ0s, As, bs, Qs) - end - - function GeneralisedFilters.batch_calc_μ0s( - dyn::BatchLinearGaussianDynamics, ::Integer; kwargs... - ) - return dyn.μ0s - end - function GeneralisedFilters.batch_calc_Σ0s( - dyn::BatchLinearGaussianDynamics, ::Integer; kwargs... - ) - return dyn.Σ0s - end - function GeneralisedFilters.batch_calc_As( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.As - end - function GeneralisedFilters.batch_calc_bs( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.bs - end - function GeneralisedFilters.batch_calc_Qs( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.Qs - end - - struct BatchLinearGaussianObservations{T,MT} <: LinearGaussianObservationProcess{T} - Hs::CuArray{T,3,MT} - cs::CuArray{T,2,MT} - Rs::CuArray{T,3,MT} - end - - function BatchLinearGaussianObservations( - Hs::Vector{Matrix{T}}, cs::Vector{Vector{T}}, Rs::Vector{Matrix{T}} + # Define batched model + μ0 = BatchedCuVector(cu(stack(μ0s))) + Σ0 = BatchedCuMatrix(cu(stack(Σ0s))) + A = BatchedCuMatrix(cu(stack(As))) + b = BatchedCuVector(cu(stack(bs))) + Q = BatchedCuMatrix(cu(stack(Qs))) + + dyn = GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics(μ0, Σ0, A, b, Q) + + H = BatchedCuMatrix(cu(stack(Hs))) + c = BatchedCuVector(cu(stack(cs))) + R = BatchedCuMatrix(cu(stack(Rs))) + + obs = GeneralisedFilters.HomogeneousLinearGaussianObservationProcess(H, c, R) + + ssm = StateSpaceModel(dyn, obs) + ys = cu.(ys_cpu) + + # Hack: manually setting of initialisation for this model + function GeneralisedFilters.initialise_log_evidence( + ::KalmanFilter, + model::StateSpaceModel{ + T, + <:GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics{ + T,<:BatchedCuVector + }, + }, ) where {T} - Hs = CuArray(stack(Hs)) - cs = CuArray(stack(cs)) - Rs = CuArray(stack(Rs)) - return BatchLinearGaussianObservations(Hs, cs, Rs) - end - - function GeneralisedFilters.batch_calc_Hs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.Hs - end - function GeneralisedFilters.batch_calc_cs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.cs + D = size(model.dyn.μ0.data, 2) + return CUDA.zeros(T, D) end - function GeneralisedFilters.batch_calc_Rs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.Rs - end - - batch_model = GeneralisedFilters.StateSpaceModel( - BatchLinearGaussianDynamics(μ0s, Σ0s, As, bs, Qs), - BatchLinearGaussianObservations(Hs, cs, Rs), - ) - - Ys_batch = Vector{Matrix{Float64}}(undef, T) - for t in 1:T - Ys_batch[t] = stack(Ys[k][t] for k in 1:K) - end - batch_output = GeneralisedFilters.filter( - rng, batch_model, BatchKalmanFilter(K), Ys_batch - ) - - # println("Batch log-likelihood: ", batch_output[2]) - # println("Individual log-likelihoods: ", log_likelihoods) - # println("Batch states: ", batch_output[1].μs') - # println("Individual states: ", getproperty.(states, :μ)) + state, ll = GeneralisedFilters.filter(rng, ssm, KalmanFilter(), ys) - @test Array(batch_output[2])[end] .≈ log_likelihoods[end] rtol = 1e-5 - @test Array(batch_output[1].μs) ≈ stack(getproperty.(states, :μ)) rtol = 1e-5 + @test all(isapprox.(Array(ll), log_likelihoods; rtol=1e-5)) + @test Array(state.μ.data) ≈ stack(getproperty.(states, :μ)) rtol = 1e-5 end From dcd2f4b619cf34cc24f4c7bf750cb9448ccaf125 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Mon, 9 Jun 2025 10:59:00 +0100 Subject: [PATCH 5/7] Update particle filter predict/update to batching-capatible broadcast style --- .../src/algorithms/particles.jl | 68 +++++++++++-------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 879e1570..0fcc9f25 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -122,26 +122,38 @@ function predict( ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - proposed_particles = map(enumerate(state.particles)) do (i, particle) - if !isnothing(ref_state) && i == 1 - ref_state[iter] - else - simulate(rng, model, filter.proposal, iter, particle, observation; kwargs...) - end + proposed_particles = + SSMProblems.simulate.( + Ref(rng), + Ref(model), + Ref(filter.proposal), + Ref(iter), + state.particles, + Ref(observation), + kwargs..., + ) + if !isnothing(ref_state) + proposed_particles[1] = ref_state[iter] end - state.log_weights += - map(zip(proposed_particles, state.particles)) do (new_state, prev_state) - log_f = SSMProblems.logdensity( - model.dyn, iter, prev_state, new_state; kwargs... - ) - - log_q = SSMProblems.logdensity( - model, filter.proposal, iter, prev_state, new_state, observation; kwargs... - ) - - (log_f - log_q) - end + state.log_weights .+= + SSMProblems.logdensity.( + Ref(model.dyn), + Ref(iter), + state.particles, + proposed_particles, + kwargs..., + ) + state.log_weights .-= + SSMProblems.logdensity.( + Ref(model), + Ref(filter.proposal), + Ref(iter), + state.particles, + proposed_particles, + Ref(observation); + kwargs..., + ) state.particles = proposed_particles @@ -156,10 +168,10 @@ function update( observation; kwargs..., ) where {T} - log_increments = map( - x -> SSMProblems.logdensity(model.obs, iter, x, observation; kwargs...), - state.particles, - ) + log_increments = + SSMProblems.logdensity.( + Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs... + ) state.log_weights += log_increments @@ -207,12 +219,12 @@ function predict( ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - state.particles = map(enumerate(state.particles)) do (i, particle) - if !isnothing(ref_state) && i == 1 - ref_state[iter] - else - SSMProblems.simulate(rng, model.dyn, iter, particle; kwargs...) - end + state.particles = + SSMProblems.simulate.( + Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs... + ) + if !isnothing(ref_state) + state.particles[1] = ref_state[iter] end return state From c95d3a6adf11068c52c668889cb049f10e32f7cd Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 11 Jun 2025 15:52:29 +0100 Subject: [PATCH 6/7] Modify BF to unify CPU and GPU filtering and add unit test --- GeneralisedFilters/src/GeneralisedFilters.jl | 2 + .../src/algorithms/particles.jl | 48 ++++++- .../src/batching/batched_CUDA.jl | 133 +++++++++++++++++- GeneralisedFilters/src/batching/batched_SA.jl | 8 ++ GeneralisedFilters/src/batching/batching.jl | 2 + GeneralisedFilters/src/containers.jl | 8 +- .../src/models/linear_gaussian.jl | 11 +- GeneralisedFilters/test/runtests.jl | 55 ++++++++ 8 files changed, 250 insertions(+), 17 deletions(-) create mode 100644 GeneralisedFilters/src/batching/batched_SA.jl create mode 100644 GeneralisedFilters/src/batching/batching.jl diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index e10b3fe8..e1801c37 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -18,7 +18,9 @@ include("containers.jl") include("resamplers.jl") # Batching utilities +include("batching/batching.jl") include("batching/batched_CUDA.jl") +include("batching/batched_SA.jl") ## FILTERING BASE ########################################################################## diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 0fcc9f25..46468199 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -138,11 +138,7 @@ function predict( state.log_weights .+= SSMProblems.logdensity.( - Ref(model.dyn), - Ref(iter), - state.particles, - proposed_particles, - kwargs..., + Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs... ) state.log_weights .-= SSMProblems.logdensity.( @@ -245,3 +241,45 @@ function filter( ) return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) end + +# Broadcast wrapper for batched types +# TODO: this can likely be replaced with a broadcast style +function Base.Broadcast.broadcasted( + ::typeof(SSMProblems.simulate), + rng_ref::Base.RefValue, + model_dyn_ref::Base.RefValue, + iter_ref::Base.RefValue, + particles::BatchedVector; + kwargs..., +) + # Extract values from Ref and call non-broadcasted version + return SSMProblems.simulate( + rng_ref[], model_dyn_ref[], iter_ref[], particles; kwargs... + ) +end +function Base.Broadcast.broadcasted( + ::typeof(SSMProblems.logdensity), + model_obs_ref::Base.RefValue, + iter_ref::Base.RefValue, + particles::BatchedVector, + observation::Base.RefValue; + kwargs..., +) + # Extract values from Ref and call non-broadcasted version + return SSMProblems.logdensity( + model_obs_ref[], iter_ref[], particles, observation[]; kwargs... + ) +end +function Base.Broadcast.broadcasted( + ::typeof(SSMProblems.logdensity), + model_dyn_ref::Base.RefValue, + iter_ref::Base.RefValue, + prev_particles::BatchedVector, + new_particles::BatchedVector; + kwargs..., +) + # Extract values from Ref and call non-broadcasted version + return SSMProblems.logdensity( + model_dyn_ref[], iter_ref[], prev_particles, new_particles; kwargs... + ) +end diff --git a/GeneralisedFilters/src/batching/batched_CUDA.jl b/GeneralisedFilters/src/batching/batched_CUDA.jl index d53b756f..1d23f676 100644 --- a/GeneralisedFilters/src/batching/batched_CUDA.jl +++ b/GeneralisedFilters/src/batching/batched_CUDA.jl @@ -1,11 +1,10 @@ import Base: *, +, -, transpose, getindex import LinearAlgebra: Transpose, cholesky, \, /, I, UniformScaling, dot +import Distributions: logpdf +import Random: rand export BatchedCuVector, BatchedCuMatrix, BatchedCuCholesky -abstract type BatchedVector{T} end -abstract type BatchedMatrix{T} end - ########################### #### VECTOR OPERATIONS #### ########################### @@ -19,6 +18,7 @@ function BatchedCuVector(data::CuArray{T,2}) where {T} return BatchedCuVector{T}(data, ptrs) end Base.eltype(::BatchedCuVector{T}) where {T} = T +Base.length(x::BatchedCuVector) = size(x.data, 2) function +(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} z_data = x.data .+ y.data @@ -51,6 +51,7 @@ function BatchedCuMatrix(data::CuArray{T,3}) where {T} return BatchedCuMatrix{T}(data, ptrs) end Base.eltype(::BatchedCuMatrix{T}) where {T} = T +Base.length(A::BatchedCuMatrix) = size(A.data, 3) transpose(A::BatchedCuMatrix{T}) where {T} = Transpose{T,BatchedCuMatrix{T}}(A) @@ -141,6 +142,7 @@ function BatchedCuCholesky(data::CuArray{T,3}) where {T} return BatchedCuCholesky{T}(data, ptrs) end Base.eltype(::BatchedCuCholesky{T}) where {T} = T +Base.length(P::BatchedCuCholesky) = size(P.data, 3) for (fname, elty) in ( (:cusolverDnSpotrfBatched, :Float32), @@ -221,11 +223,78 @@ function -(x::CuVector{T}, y::BatchedCuVector{T}) where {T} z_data = x .- y.data return BatchedCuVector(z_data) end +function +(x::CuVector{T}, y::BatchedCuVector{T}) where {T} + z_data = x .+ y.data + return BatchedCuVector(z_data) +end +function +(x::BatchedCuVector{T}, y::CuVector{T}) where {T} + z_data = x.data .+ y + return BatchedCuVector(z_data) +end +# TODO: these need to be generated automatically and call a common function +# TODO: are we best using the strided or non-strided version. The former don't need pointer duplication +for (fname, elty, gemv_batched) in ( + (:cublasSgemvBatched_64, :Float32, CUDA.CUBLAS.cublasSgemvBatched_64), + (:cublasDgemvBatched_64, :Float64, CUDA.CUBLAS.cublasDgemvBatched_64), + (:cublasCgemvBatched_64, :ComplexF32, CUDA.CUBLAS.cublasCgemvBatched_64), + (:cublasZgemvBatched_64, :ComplexF64, CUDA.CUBLAS.cublasZgemvBatched_64), +) + @eval begin + function *(A::BatchedCuMatrix{$elty}, x::CuVector{$elty}) + m, n, b = size(A.data) + y_data = CuArray{$elty}(undef, m, b) + y = BatchedCuVector(y_data) + + # Call gemv directly + x_ptrs = batch_singleton(x, b) + h = CUDA.CUBLAS.handle() + $gemv_batched( + h, 'N', m, n, $elty(1.0), A.ptrs, m, x_ptrs, 1, $elty(0.0), y.ptrs, 1, b + ) + return y + end + + function *(A::CuMatrix{$elty}, x::BatchedCuVector{$elty}) + m, n = size(A) + b = size(x.data, 2) + y_data = CuArray{$elty}(undef, m, b) + y = BatchedCuVector(y_data) + + # Call gemv directly + A_ptrs = batch_singleton(A, b) + h = CUDA.CUBLAS.handle() + $gemv_batched( + h, 'N', m, n, $elty(1.0), A_ptrs, m, x.ptrs, 1, $elty(0.0), y.ptrs, 1, b + ) + return y + end + end +end + +@inline function batch_singleton(array::DenseCuArray{T}, N::Int) where {T} + ptrs = CuArray{CuPtr{T}}(undef, N) + function compute_pointers() + i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x + grid_stride = gridDim().x * blockDim().x + while i <= length(ptrs) + @inbounds ptrs[i] = reinterpret(CuPtr{T}, pointer(array)) + i += grid_stride + end + return nothing + end + kernel = @cuda launch = false compute_pointers() + config = launch_configuration(kernel.fun) + threads = min(config.threads, N) + blocks = min(config.blocks, cld(N, threads)) + @cuda threads blocks compute_pointers() + return ptrs +end ################################### #### DISTRIBUTIONAL OPERATIONS #### ################################### +# Can likely replace by using Gaussian function gaussian_likelihood( m::BatchedCuVector{T}, S::BatchedCuMatrix{T}, y::Union{BatchedCuVector{T},CuVector{T}} ) where {T} @@ -250,3 +319,61 @@ function gaussian_likelihood( return log_likes end + +# HACK: this is more hard-coded than it needs to be. Can generalise to other batched types +# by taking advantage of the internal calls used in GaussianDistributions.jl +function Distributions.logpdf( + P::Gaussian{BatchedCuVector{T},BatchedCuMatrix{T}}, + y::Union{BatchedCuVector{T},CuVector{T}}, +) where {T} + return gaussian_likelihood(P.μ, P.Σ, y) +end +# HACK: MAJOR — this is just to handle a special case for bootstrap filter unit test until +# we have a general approach to this +function Distributions.logpdf( + P::Gaussian{BatchedCuVector{T},<:CuMatrix{T}}, y::CuVector{T} +) where {T} + # Stack Σ to form a batched matrix + Σ_data = CuArray{T}(undef, size(P.Σ)..., size(P.μ.data, 2)) + Σ_data[:, :, :] .= P.Σ + return gaussian_likelihood(P.μ, BatchedCuMatrix(Σ_data), y) +end + +# TODO: need to generalise to only one argument being batched +function Random.rand( + ::AbstractRNG, P::Gaussian{BatchedCuVector{T},BatchedCuMatrix{T}} +) where {T} + D, N = size(P.μ.data) + Σ_chol = cholesky(P.Σ) + Z = BatchedCuVector(CUDA.randn(T, D, N)) + # HACK: CUBLAS doesn't have batched trmv so we'll use gemm with zeroing out for now. + # Should later replace with MAGMA + L = BatchedCuMatrix(Σ_chol.data) + zero_upper_triangle!(L.data) + return P.μ + L * Z +end +# TODO: the singleton Cholesky should probably be handled on the CPU +function Random.rand(::AbstractRNG, P::Gaussian{BatchedCuVector{T},<:CuMatrix{T}}) where {T} + D, N = size(P.μ.data) + Σ_L = cholesky(P.Σ).L + Z = BatchedCuVector(CUDA.randn(T, D, N)) + return P.μ + CuArray(Σ_L) * Z +end + +function zero_upper_triangle!(A::CuArray{T,3}) where {T} + D, _, N = size(A) + + function kernel_zero_upper_triangle!(A, D) + i = threadIdx().x + j = threadIdx().y + k = blockIdx().x + + if i < j && i <= D && j <= D + A[i, j, k] = zero(eltype(A)) + end + return nothing + end + + @cuda threads = (D, D) blocks = N kernel_zero_upper_triangle!(A, D) + return nothing +end diff --git a/GeneralisedFilters/src/batching/batched_SA.jl b/GeneralisedFilters/src/batching/batched_SA.jl new file mode 100644 index 00000000..9fb4e74f --- /dev/null +++ b/GeneralisedFilters/src/batching/batched_SA.jl @@ -0,0 +1,8 @@ +import StaticArrays: SVector, SMatrix + +struct BatchedSAVector{T} <: BatchedVector{T} + data::Array{T,2} +end +Base.eltype(::BatchedSAVector{T}) where {T} = T + + diff --git a/GeneralisedFilters/src/batching/batching.jl b/GeneralisedFilters/src/batching/batching.jl new file mode 100644 index 00000000..7cee4234 --- /dev/null +++ b/GeneralisedFilters/src/batching/batching.jl @@ -0,0 +1,2 @@ +abstract type BatchedVector{T} end +abstract type BatchedMatrix{T} end diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 9920b9be..51c6cf78 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -8,12 +8,12 @@ A container for particle filters which composes the weighted sample into a distibution-like object, with the states (or particles) distributed accoring to their log-weights. """ -mutable struct ParticleDistribution{PT,WT<:Real} - particles::Vector{PT} +mutable struct ParticleDistribution{PT,WT} + particles::PT ancestors::Vector{Int} - log_weights::Vector{WT} + log_weights::WT end -function ParticleDistribution(particles::Vector{PT}, log_weights::Vector{WT}) where {PT,WT} +function ParticleDistribution(particles, log_weights) N = length(particles) return ParticleDistribution(particles, Vector{Int}(1:N), log_weights) end diff --git a/GeneralisedFilters/src/models/linear_gaussian.jl b/GeneralisedFilters/src/models/linear_gaussian.jl index 50aa291b..59a9cedb 100644 --- a/GeneralisedFilters/src/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/models/linear_gaussian.jl @@ -52,23 +52,24 @@ end #### DISTRIBUTIONS #### ####################### +# We choose Gaussian over MvNormal since it allows for batched types function SSMProblems.distribution(dyn::LinearGaussianLatentDynamics; kwargs...) μ0, Σ0 = calc_initial(dyn; kwargs...) - return MvNormal(μ0, Σ0) + return Gaussian(μ0, Σ0) end function SSMProblems.distribution( - dyn::LinearGaussianLatentDynamics, step::Integer, state::AbstractVector; kwargs... + dyn::LinearGaussianLatentDynamics, step::Integer, state; kwargs... ) A, b, Q = calc_params(dyn, step; kwargs...) - return MvNormal(A * state + b, Q) + return Gaussian(A * state + b, Q) end function SSMProblems.distribution( - obs::LinearGaussianObservationProcess, step::Integer, state::AbstractVector; kwargs... + obs::LinearGaussianObservationProcess, step::Integer, state; kwargs... ) H, c, R = calc_params(obs, step; kwargs...) - return MvNormal(H * state + c, R) + return Gaussian(H * state + c, R) end ########################################### diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index ea8dfa26..660d612e 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -102,6 +102,61 @@ end @test llkf ≈ llbf atol = 1e-1 end +@testitem "GPU bootstrap filter test" tags = [:gpu] begin + using SSMProblems + using StableRNGs + using LogExpFunctions: softmax + using CUDA + using LinearAlgebra + using GeneralisedFilters + + rng = StableRNG(1234) + model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) + _, _, ys = sample(rng, model, 10) + + # HACK: convert model to CUDA + gpu_model = StateSpaceModel( + GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics( + cu(model.dyn.μ0), + cu(model.dyn.Σ0), + cu(model.dyn.A), + cu(model.dyn.b), + cu(model.dyn.Q), + ), + GeneralisedFilters.HomogeneousLinearGaussianObservationProcess( + cu(model.obs.H), cu(model.obs.c), cu(model.obs.R) + ), + ) + + # HACK: disabling resampling for now + bf = BF(2^12; threshold=0.0) + # HACK: run BF manually until initialisation interface is finalised + # Initialisation + Z = BatchedCuVector(CUDA.randn(Float32, 1, bf.N)) + Σ_L = cu(cholesky(model.dyn.Σ0).L) + initial_particles = cu(model.dyn.μ0) + Σ_L.data * Z + bf_state = GeneralisedFilters.ParticleDistribution( + initial_particles, CUDA.zeros(Float32, bf.N) + ) + llbf = 0.0 + for t in eachindex(ys) + global bf_state, llbf + bf_state, ll_increment = GeneralisedFilters.step( + rng, gpu_model, bf, t, bf_state, cu(ys[t]) + ) + llbf += ll_increment + end + + kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) + + xs = bf_state.particles.data[1, :] + ws = softmax(bf_state.log_weights) + + # Compare log-likelihood and states + @test first(kf_state.μ) ≈ sum(xs .* ws) rtol = 1e-2 + @test llkf ≈ llbf atol = 1e-1 +end + @testitem "Guided filter test" begin using SSMProblems using LogExpFunctions: softmax From ae9b3c14b22bfc7f8f26dda16b46d2433a6fff9f Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Wed, 11 Jun 2025 20:06:19 +0100 Subject: [PATCH 7/7] Modify particle container and resampler to allow resampling of batched particles --- GeneralisedFilters/src/batching/batched_CUDA.jl | 3 +++ GeneralisedFilters/src/containers.jl | 11 ++++++++--- GeneralisedFilters/src/resamplers.jl | 2 +- GeneralisedFilters/test/runtests.jl | 4 +--- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/GeneralisedFilters/src/batching/batched_CUDA.jl b/GeneralisedFilters/src/batching/batched_CUDA.jl index 1d23f676..715179fe 100644 --- a/GeneralisedFilters/src/batching/batched_CUDA.jl +++ b/GeneralisedFilters/src/batching/batched_CUDA.jl @@ -19,6 +19,7 @@ function BatchedCuVector(data::CuArray{T,2}) where {T} end Base.eltype(::BatchedCuVector{T}) where {T} = T Base.length(x::BatchedCuVector) = size(x.data, 2) +Base.getindex(x::BatchedCuVector, idxs) = BatchedCuVector(x.data[:, idxs]) function +(x::BatchedCuVector{T}, y::BatchedCuVector{T}) where {T} z_data = x.data .+ y.data @@ -52,6 +53,7 @@ function BatchedCuMatrix(data::CuArray{T,3}) where {T} end Base.eltype(::BatchedCuMatrix{T}) where {T} = T Base.length(A::BatchedCuMatrix) = size(A.data, 3) +Base.getindex(A::BatchedCuMatrix, idxs) = BatchedCuMatrix(A.data[:, :, idxs]) transpose(A::BatchedCuMatrix{T}) where {T} = Transpose{T,BatchedCuMatrix{T}}(A) @@ -143,6 +145,7 @@ function BatchedCuCholesky(data::CuArray{T,3}) where {T} end Base.eltype(::BatchedCuCholesky{T}) where {T} = T Base.length(P::BatchedCuCholesky) = size(P.data, 3) +Base.getindex(P::BatchedCuCholesky, idxs) = BatchedCuCholesky(P.data[:, :, idxs]) for (fname, elty) in ( (:cusolverDnSpotrfBatched, :Float32), diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 51c6cf78..e829e546 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -8,15 +8,20 @@ A container for particle filters which composes the weighted sample into a distibution-like object, with the states (or particles) distributed accoring to their log-weights. """ -mutable struct ParticleDistribution{PT,WT} +mutable struct ParticleDistribution{PT,AT,WT} particles::PT - ancestors::Vector{Int} + ancestors::AT log_weights::WT end -function ParticleDistribution(particles, log_weights) +# TODO: these helpers might be more confusing than they're worth — and don't cover all cases +function ParticleDistribution(particles, log_weights::Vector) N = length(particles) return ParticleDistribution(particles, Vector{Int}(1:N), log_weights) end +function ParticleDistribution(particles, log_weights::CuVector) + N = length(particles) + return ParticleDistribution(particles, CuVector{Int}(1:N), log_weights) +end StatsBase.weights(state::ParticleDistribution) = softmax(state.log_weights) diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index 143a2da4..a8f2e4c9 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -23,7 +23,7 @@ function resample( end function construct_new_state(states::ParticleDistribution{PT,WT}, idxs) where {PT,WT} - return ParticleDistribution(states.particles[idxs], idxs, zeros(WT, length(states))) + return ParticleDistribution(states.particles[idxs], idxs, zero(states.log_weights)) end function construct_new_state( diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 660d612e..bfa37edb 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -108,7 +108,6 @@ end using LogExpFunctions: softmax using CUDA using LinearAlgebra - using GeneralisedFilters rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) @@ -128,8 +127,7 @@ end ), ) - # HACK: disabling resampling for now - bf = BF(2^12; threshold=0.0) + bf = BF(2^12; threshold=1.0) # HACK: run BF manually until initialisation interface is finalised # Initialisation Z = BatchedCuVector(CUDA.randn(Float32, 1, bf.N))