Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d9158db
first pass
charlesknipp Nov 19, 2025
3dbb6ad
revert change to auxiliary filter
charlesknipp Nov 19, 2025
ddf6c1d
added fix for APF
charlesknipp Nov 19, 2025
30783aa
correct likelihoods for APF
charlesknipp Nov 20, 2025
794d81a
Formatter
charlesknipp Nov 20, 2025
e3ccaa4
added type stability with Float32s
charlesknipp Nov 20, 2025
e61acae
fixed typo
charlesknipp Nov 20, 2025
95f88ab
added type stable `AuxiliaryResampler`
charlesknipp Nov 20, 2025
63ba171
introduce `UnweightedParticle` type
charlesknipp Nov 20, 2025
adcd8a4
minor adjustments
charlesknipp Nov 21, 2025
88c484c
Merge remote-tracking branch 'origin/main' into ck/type-stability
charlesknipp Nov 27, 2025
02f76d4
merge conflict issue
charlesknipp Nov 27, 2025
17170df
Merge remote-tracking branch 'origin/main' into ck/type-stability
charlesknipp Nov 27, 2025
adb795e
remove abstract particle
charlesknipp Nov 27, 2025
d917450
fix example script
charlesknipp Nov 27, 2025
3591631
revert `ScalMat`
charlesknipp Nov 27, 2025
709a5f7
never resample uniform particles
charlesknipp Dec 2, 2025
c06b310
added typeless initial weights and normalizing constants
charlesknipp Dec 2, 2025
9a866f6
formatter
charlesknipp Dec 2, 2025
5aa67f9
remove unused code
charlesknipp Dec 2, 2025
1ea3c46
fixed single precision static arrays
charlesknipp Dec 2, 2025
d3fe836
added naive type stability tests
charlesknipp Dec 2, 2025
8ba55c2
Fix Aqua tests
THargreaves Dec 4, 2025
2e90f39
added `@test_opt` to unit tests
charlesknipp Dec 4, 2025
98b6c2a
cleaning before merge
charlesknipp Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions GeneralisedFilters/src/GFTest/resamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@ mutable struct AlternatingResampler <: GeneralisedFilters.AbstractConditionalRes
end
end

function GeneralisedFilters.will_resample(alt_resampler::AlternatingResampler, state)
function GeneralisedFilters.will_resample(
alt_resampler::AlternatingResampler, state, weights
)
return alt_resampler.resample_next
end

function GeneralisedFilters.resample(
rng::AbstractRNG,
alt_resampler::AlternatingResampler,
state;
state,
weights;
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
alt_resampler.resample_next = !alt_resampler.resample_next
return GeneralisedFilters.resample(rng, alt_resampler.resampler, state; ref_state)
return GeneralisedFilters.resample(
rng, alt_resampler.resampler, state, weights; ref_state, kwargs...
)
end
83 changes: 19 additions & 64 deletions GeneralisedFilters/src/algorithms/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,7 @@ function initialise(
initialise_particle(rng, prior, algo, ref; kwargs...)
end

# Set equal weights: log_w = -log(N) so weights sum to 1
log_weight = -log(N)
for p in particles
p.log_w = log_weight
end

# Initialize with ll_baseline = 0.0
return ParticleDistribution(particles, 0.0)
return ParticleDistribution(particles, false)
end

function predict(
Expand All @@ -64,20 +57,18 @@ function predict(
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
particles = map(1:length(state.particles)) do i
particles = map(1:num_particles(algo)) do i
particle = state.particles[i]
ref = !isnothing(ref_state) && i == 1 ? ref_state[iter] : nothing
predict_particle(rng, dyn, algo, iter, particle, observation, ref; kwargs...)
end
state.particles = particles

# Accumulate the baseline with LSE of weights after prediction (before update)
# For plain PF/guided: ll_baseline is 0.0 on entry, becomes LSE_before
# For APF with resample: ll_baseline already stores negative correction; add LSE_before
LSE_before = logsumexp(map(p -> p.log_w, state.particles))
state.ll_baseline += LSE_before

return state
return ParticleDistribution(
particles, logsumexp(log_weights(state)) + state.ll_baseline
)
end

function update(
Expand All @@ -91,10 +82,9 @@ function update(
particles = map(state.particles) do particle
update_particle(obs, algo, iter, particle, observation; kwargs...)
end
state.particles = particles
ll_increment = marginalise!(state)
new_state, ll_increment = marginalise!(state, particles)

return state, ll_increment
return new_state, ll_increment
end

struct ParticleFilter{RS,PT} <: AbstractParticleFilter
Expand All @@ -119,38 +109,37 @@ function initialise_particle(
rng::AbstractRNG, prior::StatePrior, algo::ParticleFilter, ref_state; kwargs...
)
x = sample_prior(rng, prior, algo, ref_state; kwargs...)
# TODO (RB): determine the correct type for the log_w field or use a NoWeight type
return Particle(x, 0.0, 0)
return Particle(x, 0)
end

function predict_particle(
rng::AbstractRNG,
dyn::LatentDynamics,
algo::ParticleFilter,
iter::Integer,
particle::Particle,
particle::AbstractParticle,
observation,
ref_state;
kwargs...,
)
new_x, logw_inc = propogate(
new_x, log_increment = propogate(
rng, dyn, algo, iter, particle.state, observation, ref_state; kwargs...
)
return Particle(new_x, particle.log_w + logw_inc, particle.ancestor)
return Particle(new_x, log_weight(particle) + log_increment, particle.ancestor)
end

function update_particle(
obs::ObservationProcess,
::ParticleFilter,
iter::Integer,
particle::Particle,
particle::AbstractParticle,
observation;
kwargs...,
)
log_increment = SSMProblems.logdensity(
obs, iter, particle.state, observation; kwargs...
)
return Particle(particle.state, particle.log_w + log_increment, particle.ancestor)
return Particle(particle.state, log_weight(particle) + log_increment, particle.ancestor)
end

function step(
Expand Down Expand Up @@ -250,10 +239,8 @@ function propogate(
ref_state
end

# TODO: make this type consistent
# Will have to do a lazy zero or change propogate to accept a particle (in which case
# we'll need to construct a particle in the RBPF predict method)
return new_x, 0.0
# TODO: replace this with nothing (unweighted particle)
return new_x, 0
end

# TODO: I feel like we shouldn't need to do this conversion. It should be handled by dispatch
Expand Down Expand Up @@ -312,40 +299,8 @@ function step(
predictive_loglik(obs(model), algo.pf, iter, p_star, observation; kwargs...)
end

# Log normalizer for current weights
LSE_w = logsumexp(map(p -> p.log_w, state.particles))

# Incorporate lookahead weights into current weights
for (i, particle) in enumerate(state.particles)
particle.log_w += log_ξs[i]
end

# Compute lookahead evidence
LSE_lookahead = logsumexp(map(p -> p.log_w, state.particles)) - LSE_w

resample_flag = will_resample(resampler(algo), state)
if resample_flag
state = resample(rng, resampler(algo), state; ref_state)
else
# Not resampling: preserve ll_baseline and set ancestors to self
n = length(state.particles)
new_particles = similar(state.particles)
for i in 1:n
new_particles[i] = set_ancestor(state.particles[i], i)
end
state = ParticleDistribution(new_particles, state.ll_baseline)
end

# Compensate for lookahead weights
for particle in state.particles
particle.log_w -= log_ξs[particle.ancestor]
end

# Compute compensation log normalizer
if resample_flag
LSE_comp = logsumexp(map(p -> p.log_w, state.particles))
state.ll_baseline = -(LSE_lookahead + (LSE_comp - log(num_particles(algo))))
end
rs = AuxiliaryResampler(resampler(algo), log_ξs)
state = maybe_resample(rng, rs, state; ref_state)

callback(model, algo, iter, state, observation, PostResample; kwargs...)
return move(
Expand Down Expand Up @@ -386,11 +341,11 @@ function predictive_state(
dyn::LatentDynamics,
apf::AuxiliaryParticleFilter{<:AbstractParticleFilter},
iter::Integer,
particle::Particle;
particle::AbstractParticle;
kwargs...,
)
x_star = predictive_statistic(rng, apf.pp, dyn, iter, particle.state; kwargs...)
return Particle(x_star, particle.log_w, particle.ancestor)
return Particle(x_star, log_weight(particle), particle.ancestor)
end

function predictive_loglik(
Expand Down
20 changes: 13 additions & 7 deletions GeneralisedFilters/src/algorithms/rbpf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@ resampler(algo::RBPF) = resampler(algo.pf)
function initialise_particle(
rng::AbstractRNG, prior::HierarchicalPrior, algo::RBPF, ref_state; kwargs...
)
N = num_particles(algo)
x = sample_prior(rng, prior.outer_prior, algo.pf, ref_state; kwargs...)
z = initialise(rng, prior.inner_prior, algo.af; new_outer=x, kwargs...)
# TODO (RB): determine the correct type for the log_w field or use a NoWeight type
return Particle(RBState(x, z), 0.0, 0)
# return Particle(RBState(x, z), -log(N), 0)
return Particle(RBState(x, z), 0)
end

function predict_particle(
rng::AbstractRNG,
dyn::HierarchicalDynamics,
algo::RBPF,
iter::Integer,
particle::RBParticle,
particle::AbstractParticle{<:RBState},
observation,
ref_state;
kwargs...,
Expand Down Expand Up @@ -55,14 +57,16 @@ function predict_particle(
kwargs...,
)

return Particle(RBState(new_x, new_z), particle.log_w + logw_inc, particle.ancestor)
return Particle(
RBState(new_x, new_z), log_weight(particle) + logw_inc, particle.ancestor
)
end

function update_particle(
obs::ObservationProcess,
algo::RBPF,
iter::Integer,
particle::RBParticle,
particle::AbstractParticle{<:RBState},
observation;
kwargs...,
)
Expand All @@ -76,7 +80,9 @@ function update_particle(
kwargs...,
)
return Particle(
RBState(particle.state.x, new_z), particle.log_w + log_increment, particle.ancestor
RBState(particle.state.x, new_z),
log_weight(particle) + log_increment,
particle.ancestor,
)
end

Expand All @@ -85,7 +91,7 @@ function predictive_state(
dyn::HierarchicalDynamics,
apf::AuxiliaryParticleFilter{<:RBPF},
iter::Integer,
particle::RBParticle;
particle::AbstractParticle{<:RBState};
kwargs...,
)
rbpf = apf.pf
Expand All @@ -110,7 +116,7 @@ function predictive_loglik(
obs::ObservationProcess,
algo::RBPF,
iter::Integer,
p_star::RBParticle,
p_star::AbstractParticle{<:RBState},
observation;
kwargs...,
)
Expand Down
38 changes: 25 additions & 13 deletions GeneralisedFilters/src/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,32 @@

## PARTICLES ###############################################################################

abstract type AbstractParticle{T} end

"""
Particle

A container representing a single particle in a particle filter distribution, composed of a
weighted sampled (stored as a log weight) and its ancestor index.
"""
mutable struct Particle{ST,WT,AT<:Integer}
mutable struct Particle{ST,WT,AT<:Integer} <: AbstractParticle{ST}
state::ST
log_w::WT
ancestor::AT
end

# NOTE: this is only ever used for initializing a particle filter
const UnweightedParticle{ST,AT} = Particle{ST,Nothing,AT}

Particle(state, ancestor) = Particle(state, nothing, ancestor)
Particle(particle::UnweightedParticle, ancestor) = Particle(particle.state, ancestor)
function Particle(particle::Particle{<:Any,WT}, ancestor) where {WT<:Real}
return Particle(particle.state, zero(WT), ancestor)
end

log_weight(p::Particle{<:Any,<:Real}) = p.log_w
log_weight(::UnweightedParticle) = false

"""
RBState

Expand All @@ -30,8 +44,6 @@ mutable struct RBState{XT,ZT}
z::ZT
end

const RBParticle{XT,ZT,WT} = Particle{RBState{XT,ZT},WT}

"""
ParticleDistribution

Expand All @@ -44,7 +56,7 @@ their ancestories) into a distibution-like object.
the unnormalized logsumexp of weights before update (for standard PF/guided filters)
or a modified value that includes APF first-stage correction (for auxiliary PF).
"""
mutable struct ParticleDistribution{WT,PT<:Particle{<:Any,WT},VT<:AbstractVector{PT}}
mutable struct ParticleDistribution{WT,PT<:AbstractParticle,VT<:AbstractVector{PT}}
particles::VT
ll_baseline::WT
end
Expand All @@ -59,10 +71,11 @@ Base.iterate(state::ParticleDistribution) = iterate(state.particles)
# Not sure if this is kosher, since it doesn't follow the convention of Base.getindex
Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i]

log_weights(state::ParticleDistribution) = map(p -> log_weight(p), state.particles)
get_weights(state::ParticleDistribution) = softmax(log_weights(state))

# Helpers for StatsBase compatibility
function StatsBase.weights(state::ParticleDistribution)
return Weights(softmax(map(p -> p.log_w, state.particles)))
end
StatsBase.weights(state::ParticleDistribution) = StatsBase.Weights(get_weights(state))

"""
marginalise!(state::ParticleDistribution)
Expand All @@ -78,22 +91,21 @@ cases through a single-scalar caching mechanism. For standard PF, ll_baseline eq
LSE before adding observation weights. For APF with resampling, it includes first-stage
correction terms computed during the APF resampling step.
"""
function marginalise!(state::ParticleDistribution)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for passing in both state and particles. Can we not use particles = state.particles?

function marginalise!(state::ParticleDistribution, particles)
# Compute logsumexp after adding observation likelihoods
LSE_after = logsumexp(map(p -> p.log_w, state.particles))
LSE_after = logsumexp(log_weight.(particles))

# Compute log-likelihood increment: works for both PF and APF cases
ll_increment = LSE_after - state.ll_baseline

# Normalize weights
for p in state.particles
for p in particles
p.log_w -= LSE_after
end
Comment on lines +162 to 164
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only issue with this entire PR is this single line of code. If we could get rid of this, we no longer have to rely on the mutation of Particles and ParticleDistributions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be on the complete wrong page, but what's stopping us from just constructing new particles using the old weights - LSE_after, rather than mutating?


# Reset baseline for next iteration
state.ll_baseline = 0.0

return ll_increment
new_state = ParticleDistribution(particles, zero(ll_increment))
return new_state, ll_increment
end

## GAUSSIAN STATES #########################################################################
Expand Down
Loading
Loading