Skip to content

Commit 2282c15

Browse files
charlesknippgithub-actions[bot]THargreaves
authored
Restore Type Stability (#122)
* first pass * revert change to auxiliary filter * added fix for APF * correct likelihoods for APF * Formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added type stability with Float32s * fixed typo * added type stable `AuxiliaryResampler` * introduce `UnweightedParticle` type * minor adjustments Just some fixes I noticed when going through my review * merge conflict issue * remove abstract particle * fix example script * revert `ScalMat` * never resample uniform particles * added typeless initial weights and normalizing constants * formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * remove unused code * fixed single precision static arrays * added naive type stability tests * Fix Aqua tests * added `@test_opt` to unit tests * cleaning before merge --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tim Hargreaves <[email protected]>
1 parent 018f554 commit 2282c15

File tree

12 files changed

+273
-143
lines changed

12 files changed

+273
-143
lines changed

GeneralisedFilters/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
5050
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5151
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
5252
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
53+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
5354

5455
[targets]
55-
test = ["Aqua", "PDMats", "StableRNGs", "Test", "TestItemRunner", "TestItems"]
56+
test = ["Aqua", "PDMats", "StableRNGs", "Test", "TestItemRunner", "TestItems", "JET"]

GeneralisedFilters/examples/trend-inflation/script.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ using PDMats
1414

1515
const GF = GeneralisedFilters
1616

17-
# INFL_PATH = joinpath(@__DIR__, "..", "..", "..", "examples", "trend-inflation"); #hide
18-
INFL_PATH = joinpath(@__DIR__)
17+
INFL_PATH = joinpath(@__DIR__, "..", "..", "..", "examples", "trend-inflation"); #hide
18+
# INFL_PATH = joinpath(@__DIR__)
1919
include(joinpath(INFL_PATH, "utilities.jl")); #hide
2020

2121
# ## Model Definition

GeneralisedFilters/src/GFTest/resamplers.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,22 @@ mutable struct AlternatingResampler <: GeneralisedFilters.AbstractConditionalRes
1818
end
1919
end
2020

21-
function GeneralisedFilters.will_resample(alt_resampler::AlternatingResampler, state)
21+
function GeneralisedFilters.will_resample(
22+
alt_resampler::AlternatingResampler, state, weights
23+
)
2224
return alt_resampler.resample_next
2325
end
2426

2527
function GeneralisedFilters.resample(
2628
rng::AbstractRNG,
2729
alt_resampler::AlternatingResampler,
28-
state;
30+
state,
31+
weights;
2932
ref_state::Union{Nothing,AbstractVector}=nothing,
33+
kwargs...,
3034
)
3135
alt_resampler.resample_next = !alt_resampler.resample_next
32-
return GeneralisedFilters.resample(rng, alt_resampler.resampler, state; ref_state)
36+
return GeneralisedFilters.resample(
37+
rng, alt_resampler.resampler, state, weights; ref_state, kwargs...
38+
)
3339
end

GeneralisedFilters/src/algorithms/particles.jl

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,7 @@ function initialise(
4646
initialise_particle(rng, prior, algo, ref; kwargs...)
4747
end
4848

49-
# Set equal weights: log_w = -log(N) so weights sum to 1
50-
log_weight = -log(N)
51-
for p in particles
52-
p.log_w = log_weight
53-
end
54-
55-
# Initialize with ll_baseline = 0.0
56-
return ParticleDistribution(particles, 0.0)
49+
return ParticleDistribution(particles, TypelessZero())
5750
end
5851

5952
function predict(
@@ -66,20 +59,18 @@ function predict(
6659
ref_state::Union{Nothing,AbstractVector}=nothing,
6760
kwargs...,
6861
)
69-
particles = map(1:length(state.particles)) do i
62+
particles = map(1:num_particles(algo)) do i
7063
particle = state.particles[i]
7164
ref = !isnothing(ref_state) && i == 1 ? ref_state[iter] : nothing
7265
predict_particle(rng, dyn, algo, iter, particle, observation, ref; kwargs...)
7366
end
74-
state.particles = particles
7567

7668
# Accumulate the baseline with LSE of weights after prediction (before update)
7769
# For plain PF/guided: ll_baseline is 0.0 on entry, becomes LSE_before
7870
# For APF with resample: ll_baseline already stores negative correction; add LSE_before
79-
LSE_before = logsumexp(map(p -> p.log_w, state.particles))
80-
state.ll_baseline += LSE_before
81-
82-
return state
71+
return ParticleDistribution(
72+
particles, logsumexp(log_weights(state)) + state.ll_baseline
73+
)
8374
end
8475

8576
function update(
@@ -93,10 +84,9 @@ function update(
9384
particles = map(state.particles) do particle
9485
update_particle(obs, algo, iter, particle, observation; kwargs...)
9586
end
96-
state.particles = particles
97-
ll_increment = marginalise!(state)
87+
new_state, ll_increment = marginalise!(state, particles)
9888

99-
return state, ll_increment
89+
return new_state, ll_increment
10090
end
10191

10292
struct ParticleFilter{RS,PT} <: AbstractParticleFilter
@@ -121,8 +111,7 @@ function initialise_particle(
121111
rng::AbstractRNG, prior::StatePrior, algo::ParticleFilter, ref_state; kwargs...
122112
)
123113
x = sample_prior(rng, prior, algo, ref_state; kwargs...)
124-
# TODO (RB): determine the correct type for the log_w field or use a NoWeight type
125-
return Particle(x, 0.0, 0)
114+
return Particle(x, 0)
126115
end
127116

128117
function predict_particle(
@@ -135,10 +124,10 @@ function predict_particle(
135124
ref_state;
136125
kwargs...,
137126
)
138-
new_x, logw_inc = propogate(
127+
new_x, log_increment = propogate(
139128
rng, dyn, algo, iter, particle.state, observation, ref_state; kwargs...
140129
)
141-
return Particle(new_x, particle.log_w + logw_inc, particle.ancestor)
130+
return Particle(new_x, log_weight(particle) + log_increment, particle.ancestor)
142131
end
143132

144133
function update_particle(
@@ -152,7 +141,7 @@ function update_particle(
152141
log_increment = SSMProblems.logdensity(
153142
obs, iter, particle.state, observation; kwargs...
154143
)
155-
return Particle(particle.state, particle.log_w + log_increment, particle.ancestor)
144+
return Particle(particle.state, log_weight(particle) + log_increment, particle.ancestor)
156145
end
157146

158147
function step(
@@ -252,10 +241,7 @@ function propogate(
252241
ref_state
253242
end
254243

255-
# TODO: make this type consistent
256-
# Will have to do a lazy zero or change propogate to accept a particle (in which case
257-
# we'll need to construct a particle in the RBPF predict method)
258-
return new_x, 0.0
244+
return new_x, TypelessZero()
259245
end
260246

261247
# TODO: I feel like we shouldn't need to do this conversion. It should be handled by dispatch
@@ -353,7 +339,7 @@ function step(
353339
kwargs...,
354340
)
355341
# Compute lookahead weights approximating log p(y_{t+1} | x_{t}^(i))
356-
log_ξs = map(state.particles) do particle
342+
log_ηs = map(state.particles) do particle
357343
compute_logeta(
358344
rng,
359345
algo.weight_strategy,
@@ -366,40 +352,8 @@ function step(
366352
)
367353
end
368354

369-
# Log normalizer for current weights
370-
LSE_w = logsumexp(map(p -> p.log_w, state.particles))
371-
372-
# Incorporate lookahead weights into current weights
373-
for (i, particle) in enumerate(state.particles)
374-
particle.log_w += log_ξs[i]
375-
end
376-
377-
# Compute lookahead evidence
378-
LSE_lookahead = logsumexp(map(p -> p.log_w, state.particles)) - LSE_w
379-
380-
resample_flag = will_resample(resampler(algo), state)
381-
if resample_flag
382-
state = resample(rng, resampler(algo), state; ref_state)
383-
else
384-
# Not resampling: preserve ll_baseline and set ancestors to self
385-
n = length(state.particles)
386-
new_particles = similar(state.particles)
387-
for i in 1:n
388-
new_particles[i] = set_ancestor(state.particles[i], i)
389-
end
390-
state = ParticleDistribution(new_particles, state.ll_baseline)
391-
end
392-
393-
# Compensate for lookahead weights
394-
for particle in state.particles
395-
particle.log_w -= log_ξs[particle.ancestor]
396-
end
397-
398-
# Compute compensation log normalizer
399-
if resample_flag
400-
LSE_comp = logsumexp(map(p -> p.log_w, state.particles))
401-
state.ll_baseline = -(LSE_lookahead + (LSE_comp - log(num_particles(algo))))
402-
end
355+
rs = AuxiliaryResampler(resampler(algo), log_ηs)
356+
state = maybe_resample(rng, rs, state; ref_state)
403357

404358
callback(model, algo, iter, state, observation, PostResample; kwargs...)
405359
return move(

GeneralisedFilters/src/algorithms/rbpf.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@ function initialise_particle(
1818
)
1919
x = sample_prior(rng, prior.outer_prior, algo.pf, ref_state; kwargs...)
2020
z = initialise(rng, prior.inner_prior, algo.af; new_outer=x, kwargs...)
21-
# TODO (RB): determine the correct type for the log_w field or use a NoWeight type
22-
return Particle(RBState(x, z), 0.0, 0)
21+
return Particle(RBState(x, z), 0)
2322
end
2423

2524
function predict_particle(
2625
rng::AbstractRNG,
2726
dyn::HierarchicalDynamics,
2827
algo::RBPF,
2928
iter::Integer,
30-
particle::RBParticle,
29+
particle::Particle{<:RBState},
3130
observation,
3231
ref_state;
3332
kwargs...,
@@ -55,14 +54,16 @@ function predict_particle(
5554
kwargs...,
5655
)
5756

58-
return Particle(RBState(new_x, new_z), particle.log_w + logw_inc, particle.ancestor)
57+
return Particle(
58+
RBState(new_x, new_z), log_weight(particle) + logw_inc, particle.ancestor
59+
)
5960
end
6061

6162
function update_particle(
6263
obs::ObservationProcess,
6364
algo::RBPF,
6465
iter::Integer,
65-
particle::RBParticle,
66+
particle::Particle{<:RBState},
6667
observation;
6768
kwargs...,
6869
)
@@ -76,7 +77,9 @@ function update_particle(
7677
kwargs...,
7778
)
7879
return Particle(
79-
RBState(particle.state.x, new_z), particle.log_w + log_increment, particle.ancestor
80+
RBState(particle.state.x, new_z),
81+
log_weight(particle) + log_increment,
82+
particle.ancestor,
8083
)
8184
end
8285

GeneralisedFilters/src/containers.jl

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,68 @@
1+
using LogExpFunctions
2+
13
"""Containers used for storing representations of the filtering distribution."""
24

5+
## TYPELESS INITIALIZERS ###################################################################
6+
7+
"""
8+
TypelessZero
9+
10+
A lazy promotion for uninitialized particle weights whos type is not yet known at the first
11+
simulation of a particle filter.
12+
"""
13+
struct TypelessZero <: Number end
14+
15+
Base.convert(::Type{T}, ::TypelessZero) where {T<:Number} = zero(T)
16+
Base.convert(::Type{TypelessZero}, ::TypelessZero) = TypelessZero()
17+
18+
Base.:+(::TypelessZero, ::TypelessZero) = TypelessZero()
19+
20+
Base.promote_rule(::Type{TypelessZero}, ::Type{T}) where {T<:Number} = T
21+
Base.promote_rule(::Type{TypelessZero}, ::Type{TypelessZero}) = TypelessZero
22+
23+
Base.zero(::TypelessZero) = TypelessZero()
24+
Base.zero(::Type{TypelessZero}) = TypelessZero()
25+
26+
Base.iszero(::TypelessZero) = true
27+
Base.isone(::TypelessZero) = false
28+
29+
Base.show(io::IO, ::TypelessZero) = print(io, "TypelessZero()")
30+
31+
"""
32+
TypelessBaseline
33+
34+
A lazy promotion for the computation of log-likelihood baslines given a collection of
35+
unweighted particles.
36+
"""
37+
struct TypelessBaseline <: Number
38+
N::Int64
39+
end
40+
41+
# Constructors for compatibility with Base.Number
42+
TypelessBaseline(x::TypelessBaseline) = x
43+
TypelessBaseline(x::Base.TwicePrecision) = TypelessBaseline(Int64(x))
44+
TypelessBaseline(x::AbstractChar) = TypelessBaseline(Int64(x))
45+
46+
Base.convert(::Type{T}, b::TypelessBaseline) where {T<:Number} = T(log(b.N))
47+
Base.promote_rule(::Type{TypelessBaseline}, ::Type{T}) where {T<:Number} = T
48+
49+
Base.iszero(::TypelessBaseline) = false
50+
Base.isone(::TypelessBaseline) = false
51+
52+
function LogExpFunctions.logsumexp(weights::AbstractVector{TypelessZero})
53+
return TypelessBaseline(length(weights))
54+
end
55+
56+
function LogExpFunctions.softmax(x::AbstractVector{TypelessZero})
57+
# TODO: horrible, but theoretically never used... except in the unit tests
58+
return fill(1 / length(x), length(x))
59+
end
60+
61+
Base.:+(::TypelessZero, b::TypelessBaseline) = b
62+
Base.:+(b::TypelessBaseline, ::TypelessZero) = b
63+
64+
Base.show(io::IO, b::TypelessBaseline) = print(io, "Typeless(log($(b.N)))")
65+
366
## PARTICLES ###############################################################################
467

568
"""
@@ -14,6 +77,17 @@ mutable struct Particle{ST,WT,AT<:Integer}
1477
ancestor::AT
1578
end
1679

80+
# NOTE: this is only ever used for initializing a particle filter
81+
const UnweightedParticle{ST,AT} = Particle{ST,TypelessZero,AT}
82+
83+
Particle(state, ancestor) = Particle(state, TypelessZero(), ancestor)
84+
Particle(particle::UnweightedParticle, ancestor) = Particle(particle.state, ancestor)
85+
function Particle(particle::Particle{<:Any,WT}, ancestor) where {WT<:Real}
86+
return Particle(particle.state, zero(WT), ancestor)
87+
end
88+
89+
log_weight(p::Particle) = p.log_w
90+
1791
"""
1892
RBState
1993
@@ -30,8 +104,6 @@ mutable struct RBState{XT,ZT}
30104
z::ZT
31105
end
32106

33-
const RBParticle{XT,ZT,WT} = Particle{RBState{XT,ZT},WT}
34-
35107
"""
36108
ParticleDistribution
37109
@@ -44,7 +116,7 @@ their ancestories) into a distibution-like object.
44116
the unnormalized logsumexp of weights before update (for standard PF/guided filters)
45117
or a modified value that includes APF first-stage correction (for auxiliary PF).
46118
"""
47-
mutable struct ParticleDistribution{WT,PT<:Particle{<:Any,WT},VT<:AbstractVector{PT}}
119+
mutable struct ParticleDistribution{WT,PT<:Particle,VT<:AbstractVector{PT}}
48120
particles::VT
49121
ll_baseline::WT
50122
end
@@ -59,10 +131,11 @@ Base.iterate(state::ParticleDistribution) = iterate(state.particles)
59131
# Not sure if this is kosher, since it doesn't follow the convention of Base.getindex
60132
Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i]
61133

134+
log_weights(state::ParticleDistribution) = map(p -> log_weight(p), state.particles)
135+
get_weights(state::ParticleDistribution) = softmax(log_weights(state))
136+
62137
# Helpers for StatsBase compatibility
63-
function StatsBase.weights(state::ParticleDistribution)
64-
return Weights(softmax(map(p -> p.log_w, state.particles)))
65-
end
138+
StatsBase.weights(state::ParticleDistribution) = StatsBase.Weights(get_weights(state))
66139

67140
"""
68141
marginalise!(state::ParticleDistribution)
@@ -78,22 +151,21 @@ cases through a single-scalar caching mechanism. For standard PF, ll_baseline eq
78151
LSE before adding observation weights. For APF with resampling, it includes first-stage
79152
correction terms computed during the APF resampling step.
80153
"""
81-
function marginalise!(state::ParticleDistribution)
154+
function marginalise!(state::ParticleDistribution, particles)
82155
# Compute logsumexp after adding observation likelihoods
83-
LSE_after = logsumexp(map(p -> p.log_w, state.particles))
156+
LSE_after = logsumexp(log_weight.(particles))
84157

85158
# Compute log-likelihood increment: works for both PF and APF cases
86159
ll_increment = LSE_after - state.ll_baseline
87160

88161
# Normalize weights
89-
for p in state.particles
162+
for p in particles
90163
p.log_w -= LSE_after
91164
end
92165

93166
# Reset baseline for next iteration
94-
state.ll_baseline = 0.0
95-
96-
return ll_increment
167+
new_state = ParticleDistribution(particles, zero(ll_increment))
168+
return new_state, ll_increment
97169
end
98170

99171
## GAUSSIAN STATES #########################################################################

0 commit comments

Comments
 (0)