From 821c78c1bc4b36f274ad7694b75b5a6ad243a84b Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Sun, 7 Dec 2025 11:10:19 +0000 Subject: [PATCH 1/2] Implement type consistent ancestor callback --- GeneralisedFilters/Project.toml | 2 - GeneralisedFilters/src/GFTest/resamplers.jl | 2 +- GeneralisedFilters/src/GeneralisedFilters.jl | 1 - .../src/algorithms/particles.jl | 16 ++-- GeneralisedFilters/src/callbacks.jl | 57 ++++++++----- GeneralisedFilters/src/containers.jl | 19 +++++ GeneralisedFilters/src/resamplers.jl | 12 +-- GeneralisedFilters/test/runtests.jl | 79 +++++++++++-------- 8 files changed, 120 insertions(+), 68 deletions(-) diff --git a/GeneralisedFilters/Project.toml b/GeneralisedFilters/Project.toml index 6bb33841..75906e60 100644 --- a/GeneralisedFilters/Project.toml +++ b/GeneralisedFilters/Project.toml @@ -17,7 +17,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" @@ -34,7 +33,6 @@ DataStructures = "0.18.20, 0.19" Distributions = "0.25" LogExpFunctions = "0.3" NNlib = "0.9" -OffsetArrays = "1.14.1" PDMats = "0.11.35" SSMProblems = "0.6" StaticArrays = "1.9.14" diff --git a/GeneralisedFilters/src/GFTest/resamplers.jl b/GeneralisedFilters/src/GFTest/resamplers.jl index 1037a928..5e8ddf63 100644 --- a/GeneralisedFilters/src/GFTest/resamplers.jl +++ b/GeneralisedFilters/src/GFTest/resamplers.jl @@ -29,7 +29,7 @@ function GeneralisedFilters.resample( alt_resampler::AlternatingResampler, state, weights; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, kwargs..., ) alt_resampler.resample_next = !alt_resampler.resample_next diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index a3b3464c..70bcb957 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -4,7 +4,6 @@ using AbstractMCMC: AbstractMCMC, AbstractSampler import Distributions: MvNormal, params import Random: AbstractRNG, default_rng, rand import SSMProblems: prior, dyn, obs -using OffsetArrays using SSMProblems using StatsBase diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 101f773c..e4cab575 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -37,12 +37,12 @@ function initialise( rng::AbstractRNG, prior::StatePrior, algo::AbstractParticleFilter; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, kwargs..., ) N = num_particles(algo) particles = map(1:N) do i - ref = !isnothing(ref_state) && i == 1 ? ref_state[0] : nothing + ref = !isnothing(ref_state) && i == 1 ? ref_state.x0 : nothing initialise_particle(rng, prior, algo, ref; kwargs...) end @@ -56,12 +56,12 @@ function predict( iter::Integer, state, observation; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, kwargs..., ) particles = map(1:num_particles(algo)) do i particle = state.particles[i] - ref = !isnothing(ref_state) && i == 1 ? ref_state[iter] : nothing + ref = !isnothing(ref_state) && i == 1 ? ref_state.xs[iter] : nothing predict_particle(rng, dyn, algo, iter, particle, observation, ref; kwargs...) end @@ -151,7 +151,7 @@ function step( iter::Integer, state, observation; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, callback::CallbackType=nothing, kwargs..., ) @@ -251,7 +251,7 @@ function filter( model::HierarchicalSSM, algo::ParticleFilter, observations::AbstractVector; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, kwargs..., ) ssm = StateSpaceModel( @@ -321,7 +321,7 @@ function initialise( rng::AbstractRNG, prior::StatePrior, algo::AuxiliaryParticleFilter; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, kwargs..., ) return initialise(rng, prior, algo.pf; ref_state, kwargs...) @@ -334,7 +334,7 @@ function step( iter::Integer, state, observation; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, callback::CallbackType=nothing, kwargs..., ) diff --git a/GeneralisedFilters/src/callbacks.jl b/GeneralisedFilters/src/callbacks.jl index c53b4bdd..7333b3d5 100644 --- a/GeneralisedFilters/src/callbacks.jl +++ b/GeneralisedFilters/src/callbacks.jl @@ -50,21 +50,26 @@ end ## DENSE PARTICLE STORAGE ################################################################## -struct DenseParticleContainer{T} - particles::OffsetVector{Vector{T},Vector{Vector{T}}} +mutable struct DenseParticleContainer{T0,T1} + x0s::Vector{T0} + xs::Union{Vector{Vector{T1}},Nothing} ancestors::Vector{Vector{Int}} end -function get_ancestry(container::DenseParticleContainer{T}, i::Integer) where {T} +# Partial constructor when only x0s are known +function DenseParticleContainer(x0s::Vector{T0}) where {T0} + return DenseParticleContainer{T0,Any}(x0s, nothing, Vector{Int}[]) +end + +function get_ancestry(container::DenseParticleContainer{T0,T1}, i::Integer) where {T0,T1} a = i - v = Vector{T}(undef, length(container.particles)) - ancestry = OffsetVector(v, -1) - for t in length(container.ancestors):-1:1 - ancestry[t] = container.particles[t][a] - a = container.ancestors[t][a] + xs = Vector{T1}(undef, length(container.xs)) + for k in length(container.ancestors):-1:1 + xs[k] = container.xs[k][a] + a = container.ancestors[k][a] end - ancestry[0] = container.particles[0][a] - return ancestry + + return ReferenceTrajectory(container.x0s[a], xs) end """ @@ -73,16 +78,21 @@ end A callback for dense ancestry storage, which fills a `DenseParticleContainer`. """ mutable struct DenseAncestorCallback <: AbstractCallback - container + container::Union{DenseParticleContainer,Nothing} +end + +# Default to initialising with no container +function DenseAncestorCallback() + return DenseAncestorCallback(nothing) end function (c::DenseAncestorCallback)( model, filter, state, data, ::PostInitCallback; kwargs... ) particles = state.particles - c.container = DenseParticleContainer( - OffsetVector([deepcopy(getfield.(particles, :state))], -1), Vector{Int}[] - ) + x0s = deepcopy(getfield.(particles, :state)) + # Partially construct container using just initial states + c.container = DenseParticleContainer(x0s) return nothing end @@ -90,8 +100,19 @@ function (c::DenseAncestorCallback)( model, filter, step, state, data, ::PostUpdateCallback; kwargs... ) particles = state.particles - push!(c.container.particles, deepcopy(getfield.(particles, :state))) - push!(c.container.ancestors, deepcopy(getfield.(particles, :ancestor))) + states = deepcopy(getfield.(particles, :state)) + + # Re-instantiate trajectory storage with concrete type on first call + if c.container.xs === nothing + T0 = eltype(c.container.x0s) + T1 = eltype(states) + c.container = DenseParticleContainer{T0,T1}( + c.container.x0s, Vector{Vector{T1}}(), c.container.ancestors + ) + end + push!(c.container.xs, states) + push!(c.container.ancestors, getfield.(particles, :ancestor)) + return nothing end @@ -314,7 +335,7 @@ function get_ancestry(tree::ParallelParticleTree{ST}, T::Integer) where {ST} paths[t] = tree.states[parents] gather!(parents, tree.parents, parents) end - return paths + return paths end # Get ancestory of a single particle @@ -328,7 +349,7 @@ function get_ancestry( path[t] = container.states[ancestor_index] ancestor_index = container.parents[ancestor_index] end - return path + return path end end diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 545112cc..e2fb3724 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -1,5 +1,7 @@ using LogExpFunctions +export ReferenceTrajectory + """Containers used for storing representations of the filtering distribution.""" ## TYPELESS INITIALIZERS ################################################################### @@ -207,3 +209,20 @@ information/precision matrix. function natural_params(state::InformationLikelihood) return state.λ, state.Ω end + +## REFERENCE TRAJECTORIES ################################################################## + +""" + ReferenceTrajectory + +A container representing a sampled trajectory from a state-space model, typically used for +particle smoothing or conditional SMC. + +Fields: +- `x0`: Initial state at time 0 +- `xs`: Vector of states at times 1:T +""" +struct ReferenceTrajectory{ST,VT<:AbstractVector} + x0::ST + xs::VT +end diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index 29237cad..19094050 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -25,7 +25,7 @@ end rng::AbstractRNG, resampler::AbstractResampler, state::ParticleDistribution; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, ) -> ParticleDistribution Perform resampling if the resampler's condition is met (for conditional resamplers), @@ -37,7 +37,7 @@ function maybe_resample( resampler::AbstractResampler, state, weights=get_weights(state); - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, auxiliary_weights=nothing, ) return resample(rng, resampler, state, weights; ref_state, auxiliary_weights) @@ -48,7 +48,7 @@ function resample( resampler::AbstractResampler, state, weights=get_weights(state); - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, auxiliary_weights::Union{Nothing,AbstractVector}=nothing, kwargs..., ) @@ -89,7 +89,7 @@ function resample( rng::AbstractRNG, auxiliary::AuxiliaryResampler, state; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, ) weights = softmax(log_weights(state) + auxiliary.log_weights) auxiliary_weights = auxiliary.log_weights @@ -149,7 +149,7 @@ function maybe_resample( cond_resampler::AbstractConditionalResampler, state, weights=get_weights(state); - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, auxiliary_weights::Union{Nothing,AbstractVector}=nothing, ) if will_resample(cond_resampler, state, weights) @@ -178,7 +178,7 @@ function resample( cond_resampler::ESSResampler, state, weights=get_weights(state); - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state=nothing, auxiliary_weights::Union{Nothing,AbstractVector}=nothing, ) return resample( diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 14733c0a..39adf520 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -464,7 +464,6 @@ end using Random: randexp, AbstractRNG using StatsBase: sample, Weights - using OffsetArrays struct DummyResampler <: GeneralisedFilters.AbstractResampler end @@ -482,19 +481,29 @@ end model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) _, _, ys = sample(rng, model, K) - ref_traj = OffsetVector([rand(rng, 1) for _ in 0:K], -1) + # Create reference trajectory + ref_states = [rand(rng, 1) for _ in 0:K] + ref_traj = ReferenceTrajectory(ref_states[1], ref_states[2:end]) bf = BF(N_particles; threshold=1.0, resampler=DummyResampler()) - cb = GeneralisedFilters.DenseAncestorCallback(nothing) + cb = GeneralisedFilters.DenseAncestorCallback() bf_state, _ = GeneralisedFilters.filter( rng, model, bf, ys; ref_state=ref_traj, callback=cb ) traj = GeneralisedFilters.get_ancestry(cb.container, N_particles) - true_traj = [cb.container.particles[t][N_particles - K + t] for t in 0:K] - @test traj.parent == true_traj - @test GeneralisedFilters.get_ancestry(cb.container, 1) == ref_traj + # Construct expected trajectory manually + true_x0 = cb.container.x0s[N_particles - K] + true_xs = [cb.container.xs[t][N_particles - K + t] for t in 1:K] + + @test traj.x0 == true_x0 + @test traj.xs == true_xs + + # Test that particle 1 retrieves the reference trajectory + ref_traj_reconstructed = GeneralisedFilters.get_ancestry(cb.container, 1) + @test ref_traj_reconstructed.x0 == ref_traj.x0 + @test ref_traj_reconstructed.xs == ref_traj.xs end @testitem "CSMC test" begin @@ -506,7 +515,6 @@ end using Random: randexp using StatsBase: sample, weights - using OffsetArrays SEED = 1234 Dx = 1 @@ -534,7 +542,7 @@ end lls = [] for i in 1:N_steps - cb = GeneralisedFilters.DenseAncestorCallback(nothing) + cb = GeneralisedFilters.DenseAncestorCallback() bf_state, ll = GeneralisedFilters.filter( rng, model, bf, ys; ref_state=ref_traj, callback=cb ) @@ -551,7 +559,7 @@ end # unbiased estimate of 1 / Z. See Elements of Sequential Monte Carlo (Section 5.2) log_recip_likelihood_estimate = logsumexp(-lls) - log(length(lls)) - csmc_mean = sum(getindex.(trajectory_samples, t_smooth)) / N_sample + csmc_mean = sum([traj.xs[t_smooth] for traj in trajectory_samples]) / N_sample @test csmc_mean ≈ state.μ rtol = 1e-3 @test log_recip_likelihood_estimate ≈ -ks_ll rtol = 1e-3 end @@ -566,7 +574,6 @@ end using StaticArrays using Statistics - using OffsetArrays SEED = 1234 D_outer = 1 @@ -597,7 +604,7 @@ end ref_traj = nothing trajectory_samples = [] - cb = GeneralisedFilters.DenseAncestorCallback(nothing) + cb = GeneralisedFilters.DenseAncestorCallback() for i in 1:N_steps bf_state, _ = GeneralisedFilters.filter( rng, hier_model, rbpf, ys; ref_state=ref_traj, callback=cb @@ -610,11 +617,14 @@ end push!(trajectory_samples, deepcopy(ref_traj)) end # Reference trajectory should only be nonlinear state for RBPF - ref_traj = getproperty.(ref_traj, :x) + # Extract outer states from the hierarchical states + ref_traj = ReferenceTrajectory( + getproperty(ref_traj.x0, :x), getproperty.(ref_traj.xs, :x) + ) end - # Extract inner and outer trajectories - x_trajectories = getproperty.(getindex.(trajectory_samples, t_smooth), :x) + # Extract inner and outer trajectories at time t_smooth + x_trajectories = [getproperty(traj.xs[t_smooth], :x) for traj in trajectory_samples] # Manually perform smoothing until we have a cleaner interface A = hier_model.inner_model.dyn.A @@ -623,13 +633,13 @@ end Q = hier_model.inner_model.dyn.Q z_smoothed_means = Vector{T}(undef, N_sample) for i in 1:N_sample - μ = trajectory_samples[i][K].z.μ - Σ = trajectory_samples[i][K].z.Σ + μ = trajectory_samples[i].xs[K].z.μ + Σ = trajectory_samples[i].xs[K].z.Σ for t in (K - 1):-1:t_smooth - μ_filt = trajectory_samples[i][t].z.μ - Σ_filt = trajectory_samples[i][t].z.Σ - μ_pred = A * μ_filt + b + C * trajectory_samples[i][t].x + μ_filt = trajectory_samples[i].xs[t].z.μ + Σ_filt = trajectory_samples[i].xs[t].z.Σ + μ_pred = A * μ_filt + b + C * trajectory_samples[i].xs[t].x Σ_pred = X_A_Xt(Σ_filt, A) + Q Σ_pred = PDMat(Symmetric(Σ_pred)) @@ -660,7 +670,6 @@ end import SSMProblems: prior, dyn, obs import GeneralisedFilters: resampler, resample, move, RBState, InformationLikelihood - using OffsetArrays SEED = 1234 D_outer = 1 @@ -692,7 +701,7 @@ end for i in 1:N_steps global predictive_likelihoods - cb = GeneralisedFilters.DenseAncestorCallback(nothing) + cb = GeneralisedFilters.DenseAncestorCallback() # Manual filtering with ancestor resampling bf_state = initialise(rng, prior(hier_model), rbpf; ref_state=ref_traj) @@ -710,7 +719,7 @@ end rbpf, t, particle.state, - RBState(ref_traj[t], predictive_likelihoods[t]), + RBState(ref_traj.xs[t], predictive_likelihoods[t]), ) end ancestor_idx = sample( @@ -739,7 +748,10 @@ end push!(trajectory_samples, deepcopy(ref_traj)) end # Reference trajectory should only be nonlinear state for RBPF - ref_traj = getproperty.(ref_traj, :x) + # Extract outer states from the hierarchical states + ref_traj_outer = ReferenceTrajectory( + getproperty(ref_traj.x0, :x), getproperty.(ref_traj.xs, :x) + ) pred_lik = backward_initialise( rng, hier_model.inner_model.obs, BackwardInformationPredictor(), K, ys[K] @@ -752,8 +764,8 @@ end BackwardInformationPredictor(), t, pred_lik; - prev_outer=ref_traj[t], - next_outer=ref_traj[t + 1], + prev_outer=ref_traj_outer.xs[t], + next_outer=ref_traj_outer.xs[t + 1], ) pred_lik = backward_update( hier_model.inner_model.obs, @@ -764,10 +776,13 @@ end ) predictive_likelihoods[t] = deepcopy(pred_lik) end + + # Update ref_traj for next iteration + global ref_traj = ref_traj_outer end - # Extract inner and outer trajectories - x_trajectories = getproperty.(getindex.(trajectory_samples, t_smooth), :x) + # Extract inner and outer trajectories at time t_smooth + x_trajectories = [getproperty(traj.xs[t_smooth], :x) for traj in trajectory_samples] # Manually perform smoothing until we have a cleaner interface A = hier_model.inner_model.dyn.A @@ -776,13 +791,13 @@ end Q = hier_model.inner_model.dyn.Q z_smoothed_means = Vector{T}(undef, N_sample) for i in 1:N_sample - μ = trajectory_samples[i][K].z.μ - Σ = trajectory_samples[i][K].z.Σ + μ = trajectory_samples[i].xs[K].z.μ + Σ = trajectory_samples[i].xs[K].z.Σ for t in (K - 1):-1:t_smooth - μ_filt = trajectory_samples[i][t].z.μ - Σ_filt = trajectory_samples[i][t].z.Σ - μ_pred = A * μ_filt + b + C * trajectory_samples[i][t].x + μ_filt = trajectory_samples[i].xs[t].z.μ + Σ_filt = trajectory_samples[i].xs[t].z.Σ + μ_pred = A * μ_filt + b + C * trajectory_samples[i].xs[t].x Σ_pred = A * Σ_filt * A' + Q G = Σ_filt * A' * inv(Σ_pred) From b8a27e09b6b2b6a2eb7a1a4e8e91898cdc92eaee Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Mon, 8 Dec 2025 07:40:28 +0000 Subject: [PATCH 2/2] Correct formatting --- GeneralisedFilters/src/callbacks.jl | 4 ++-- GeneralisedFilters/src/resamplers.jl | 7 +------ GeneralisedFilters/test/runtests.jl | 4 ---- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/GeneralisedFilters/src/callbacks.jl b/GeneralisedFilters/src/callbacks.jl index 7333b3d5..0f285539 100644 --- a/GeneralisedFilters/src/callbacks.jl +++ b/GeneralisedFilters/src/callbacks.jl @@ -335,7 +335,7 @@ function get_ancestry(tree::ParallelParticleTree{ST}, T::Integer) where {ST} paths[t] = tree.states[parents] gather!(parents, tree.parents, parents) end - return paths + return paths end # Get ancestory of a single particle @@ -349,7 +349,7 @@ function get_ancestry( path[t] = container.states[ancestor_index] ancestor_index = container.parents[ancestor_index] end - return path + return path end end diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index 19094050..6bd2f151 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -85,12 +85,7 @@ struct AuxiliaryResampler <: AbstractResampler log_weights::AbstractVector end -function resample( - rng::AbstractRNG, - auxiliary::AuxiliaryResampler, - state; - ref_state=nothing, -) +function resample(rng::AbstractRNG, auxiliary::AuxiliaryResampler, state; ref_state=nothing) weights = softmax(log_weights(state) + auxiliary.log_weights) auxiliary_weights = auxiliary.log_weights return resample(rng, auxiliary.resampler, state, weights; ref_state, auxiliary_weights) diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 39adf520..e8a713d4 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -464,7 +464,6 @@ end using Random: randexp, AbstractRNG using StatsBase: sample, Weights - struct DummyResampler <: GeneralisedFilters.AbstractResampler end function GeneralisedFilters.sample_ancestors( @@ -515,7 +514,6 @@ end using Random: randexp using StatsBase: sample, weights - SEED = 1234 Dx = 1 Dy = 1 @@ -574,7 +572,6 @@ end using StaticArrays using Statistics - SEED = 1234 D_outer = 1 D_inner = 1 @@ -670,7 +667,6 @@ end import SSMProblems: prior, dyn, obs import GeneralisedFilters: resampler, resample, move, RBState, InformationLikelihood - SEED = 1234 D_outer = 1 D_inner = 1