Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions GeneralisedFilters/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
Expand All @@ -34,7 +33,6 @@ DataStructures = "0.18.20, 0.19"
Distributions = "0.25"
LogExpFunctions = "0.3"
NNlib = "0.9"
OffsetArrays = "1.14.1"
PDMats = "0.11.35"
SSMProblems = "0.6"
StaticArrays = "1.9.14"
Expand Down
2 changes: 1 addition & 1 deletion GeneralisedFilters/src/GFTest/resamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function GeneralisedFilters.resample(
alt_resampler::AlternatingResampler,
state,
weights;
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
kwargs...,
)
alt_resampler.resample_next = !alt_resampler.resample_next
Expand Down
1 change: 0 additions & 1 deletion GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using AbstractMCMC: AbstractMCMC, AbstractSampler
import Distributions: MvNormal, params
import Random: AbstractRNG, default_rng, rand
import SSMProblems: prior, dyn, obs
using OffsetArrays
using SSMProblems
using StatsBase

Expand Down
16 changes: 8 additions & 8 deletions GeneralisedFilters/src/algorithms/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ function initialise(
rng::AbstractRNG,
prior::StatePrior,
algo::AbstractParticleFilter;
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
kwargs...,
)
N = num_particles(algo)
particles = map(1:N) do i
ref = !isnothing(ref_state) && i == 1 ? ref_state[0] : nothing
ref = !isnothing(ref_state) && i == 1 ? ref_state.x0 : nothing
initialise_particle(rng, prior, algo, ref; kwargs...)
end

Expand All @@ -56,12 +56,12 @@ function predict(
iter::Integer,
state,
observation;
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
kwargs...,
)
particles = map(1:num_particles(algo)) do i
particle = state.particles[i]
ref = !isnothing(ref_state) && i == 1 ? ref_state[iter] : nothing
ref = !isnothing(ref_state) && i == 1 ? ref_state.xs[iter] : nothing
predict_particle(rng, dyn, algo, iter, particle, observation, ref; kwargs...)
end

Expand Down Expand Up @@ -151,7 +151,7 @@ function step(
iter::Integer,
state,
observation;
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
callback::CallbackType=nothing,
kwargs...,
)
Expand Down Expand Up @@ -251,7 +251,7 @@ function filter(
model::HierarchicalSSM,
algo::ParticleFilter,
observations::AbstractVector;
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
kwargs...,
)
ssm = StateSpaceModel(
Expand Down Expand Up @@ -321,7 +321,7 @@ function initialise(
rng::AbstractRNG,
prior::StatePrior,
algo::AuxiliaryParticleFilter;
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
kwargs...,
)
return initialise(rng, prior, algo.pf; ref_state, kwargs...)
Expand All @@ -334,7 +334,7 @@ function step(
iter::Integer,
state,
observation;
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
callback::CallbackType=nothing,
kwargs...,
)
Expand Down
53 changes: 37 additions & 16 deletions GeneralisedFilters/src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,26 @@ end

## DENSE PARTICLE STORAGE ##################################################################

struct DenseParticleContainer{T}
particles::OffsetVector{Vector{T},Vector{Vector{T}}}
mutable struct DenseParticleContainer{T0,T1}
x0s::Vector{T0}
xs::Union{Vector{Vector{T1}},Nothing}
ancestors::Vector{Vector{Int}}
end

function get_ancestry(container::DenseParticleContainer{T}, i::Integer) where {T}
# Partial constructor when only x0s are known
function DenseParticleContainer(x0s::Vector{T0}) where {T0}
return DenseParticleContainer{T0,Any}(x0s, nothing, Vector{Int}[])
end

function get_ancestry(container::DenseParticleContainer{T0,T1}, i::Integer) where {T0,T1}
a = i
v = Vector{T}(undef, length(container.particles))
ancestry = OffsetVector(v, -1)
for t in length(container.ancestors):-1:1
ancestry[t] = container.particles[t][a]
a = container.ancestors[t][a]
xs = Vector{T1}(undef, length(container.xs))
for k in length(container.ancestors):-1:1
xs[k] = container.xs[k][a]
a = container.ancestors[k][a]
end
ancestry[0] = container.particles[0][a]
return ancestry

return ReferenceTrajectory(container.x0s[a], xs)
end

"""
Expand All @@ -73,25 +78,41 @@ end
A callback for dense ancestry storage, which fills a `DenseParticleContainer`.
"""
mutable struct DenseAncestorCallback <: AbstractCallback
container
container::Union{DenseParticleContainer,Nothing}
end

# Default to initialising with no container
function DenseAncestorCallback()
return DenseAncestorCallback(nothing)
end

function (c::DenseAncestorCallback)(
model, filter, state, data, ::PostInitCallback; kwargs...
)
particles = state.particles
c.container = DenseParticleContainer(
OffsetVector([deepcopy(getfield.(particles, :state))], -1), Vector{Int}[]
)
x0s = deepcopy(getfield.(particles, :state))
# Partially construct container using just initial states
c.container = DenseParticleContainer(x0s)
return nothing
end

function (c::DenseAncestorCallback)(
model, filter, step, state, data, ::PostUpdateCallback; kwargs...
)
particles = state.particles
push!(c.container.particles, deepcopy(getfield.(particles, :state)))
push!(c.container.ancestors, deepcopy(getfield.(particles, :ancestor)))
states = deepcopy(getfield.(particles, :state))

# Re-instantiate trajectory storage with concrete type on first call
if c.container.xs === nothing
T0 = eltype(c.container.x0s)
T1 = eltype(states)
c.container = DenseParticleContainer{T0,T1}(
c.container.x0s, Vector{Vector{T1}}(), c.container.ancestors
)
end
push!(c.container.xs, states)
push!(c.container.ancestors, getfield.(particles, :ancestor))

return nothing
end

Expand Down
19 changes: 19 additions & 0 deletions GeneralisedFilters/src/containers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using LogExpFunctions

export ReferenceTrajectory

"""Containers used for storing representations of the filtering distribution."""

## TYPELESS INITIALIZERS ###################################################################
Expand Down Expand Up @@ -207,3 +209,20 @@ information/precision matrix.
function natural_params(state::InformationLikelihood)
return state.λ, state.Ω
end

## REFERENCE TRAJECTORIES ##################################################################

"""
ReferenceTrajectory

A container representing a sampled trajectory from a state-space model, typically used for
particle smoothing or conditional SMC.

Fields:
- `x0`: Initial state at time 0
- `xs`: Vector of states at times 1:T
"""
struct ReferenceTrajectory{ST,VT<:AbstractVector}
x0::ST
xs::VT
end
17 changes: 6 additions & 11 deletions GeneralisedFilters/src/resamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
rng::AbstractRNG,
resampler::AbstractResampler,
state::ParticleDistribution;
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
) -> ParticleDistribution

Perform resampling if the resampler's condition is met (for conditional resamplers),
Expand All @@ -37,7 +37,7 @@ function maybe_resample(
resampler::AbstractResampler,
state,
weights=get_weights(state);
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
auxiliary_weights=nothing,
)
return resample(rng, resampler, state, weights; ref_state, auxiliary_weights)
Expand All @@ -48,7 +48,7 @@ function resample(
resampler::AbstractResampler,
state,
weights=get_weights(state);
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
auxiliary_weights::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
Expand Down Expand Up @@ -85,12 +85,7 @@ struct AuxiliaryResampler <: AbstractResampler
log_weights::AbstractVector
end

function resample(
rng::AbstractRNG,
auxiliary::AuxiliaryResampler,
state;
ref_state::Union{Nothing,AbstractVector}=nothing,
)
function resample(rng::AbstractRNG, auxiliary::AuxiliaryResampler, state; ref_state=nothing)
weights = softmax(log_weights(state) + auxiliary.log_weights)
auxiliary_weights = auxiliary.log_weights
return resample(rng, auxiliary.resampler, state, weights; ref_state, auxiliary_weights)
Expand Down Expand Up @@ -149,7 +144,7 @@ function maybe_resample(
cond_resampler::AbstractConditionalResampler,
state,
weights=get_weights(state);
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
auxiliary_weights::Union{Nothing,AbstractVector}=nothing,
)
if will_resample(cond_resampler, state, weights)
Expand Down Expand Up @@ -178,7 +173,7 @@ function resample(
cond_resampler::ESSResampler,
state,
weights=get_weights(state);
ref_state::Union{Nothing,AbstractVector}=nothing,
ref_state=nothing,
auxiliary_weights::Union{Nothing,AbstractVector}=nothing,
)
return resample(
Expand Down
Loading
Loading