-
Notifications
You must be signed in to change notification settings - Fork 4
Restore Type Stability #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
d9158db
3dbb6ad
ddf6c1d
30783aa
794d81a
e3ccaa4
e61acae
95f88ab
63ba171
adcd8a4
88c484c
02f76d4
17170df
adb795e
d917450
3591631
709a5f7
c06b310
9a866f6
5aa67f9
1ea3c46
d3fe836
8ba55c2
2e90f39
98b6c2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,18 +2,32 @@ | |
|
|
||
| ## PARTICLES ############################################################################### | ||
|
|
||
| abstract type AbstractParticle{T} end | ||
|
|
||
| """ | ||
| Particle | ||
|
|
||
| A container representing a single particle in a particle filter distribution, composed of a | ||
| weighted sampled (stored as a log weight) and its ancestor index. | ||
| """ | ||
| mutable struct Particle{ST,WT,AT<:Integer} | ||
| mutable struct Particle{ST,WT,AT<:Integer} <: AbstractParticle{ST} | ||
| state::ST | ||
| log_w::WT | ||
| ancestor::AT | ||
| end | ||
|
|
||
| # NOTE: this is only ever used for initializing a particle filter | ||
| const UnweightedParticle{ST,AT} = Particle{ST,Nothing,AT} | ||
|
|
||
| Particle(state, ancestor) = Particle(state, nothing, ancestor) | ||
| Particle(particle::UnweightedParticle, ancestor) = Particle(particle.state, ancestor) | ||
| function Particle(particle::Particle{<:Any,WT}, ancestor) where {WT<:Real} | ||
| return Particle(particle.state, zero(WT), ancestor) | ||
| end | ||
|
|
||
| log_weight(p::Particle{<:Any,<:Real}) = p.log_w | ||
| log_weight(::UnweightedParticle) = false | ||
charlesknipp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| """ | ||
| RBState | ||
|
|
||
|
|
@@ -30,8 +44,6 @@ mutable struct RBState{XT,ZT} | |
| z::ZT | ||
| end | ||
|
|
||
| const RBParticle{XT,ZT,WT} = Particle{RBState{XT,ZT},WT} | ||
|
|
||
| """ | ||
| ParticleDistribution | ||
|
|
||
|
|
@@ -44,7 +56,7 @@ their ancestories) into a distibution-like object. | |
| the unnormalized logsumexp of weights before update (for standard PF/guided filters) | ||
| or a modified value that includes APF first-stage correction (for auxiliary PF). | ||
| """ | ||
| mutable struct ParticleDistribution{WT,PT<:Particle{<:Any,WT},VT<:AbstractVector{PT}} | ||
| mutable struct ParticleDistribution{WT,PT<:AbstractParticle,VT<:AbstractVector{PT}} | ||
charlesknipp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| particles::VT | ||
| ll_baseline::WT | ||
| end | ||
|
|
@@ -59,10 +71,11 @@ Base.iterate(state::ParticleDistribution) = iterate(state.particles) | |
| # Not sure if this is kosher, since it doesn't follow the convention of Base.getindex | ||
| Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i] | ||
|
|
||
| log_weights(state::ParticleDistribution) = map(p -> log_weight(p), state.particles) | ||
| get_weights(state::ParticleDistribution) = softmax(log_weights(state)) | ||
|
|
||
| # Helpers for StatsBase compatibility | ||
| function StatsBase.weights(state::ParticleDistribution) | ||
| return Weights(softmax(map(p -> p.log_w, state.particles))) | ||
| end | ||
| StatsBase.weights(state::ParticleDistribution) = StatsBase.Weights(get_weights(state)) | ||
|
|
||
| """ | ||
| marginalise!(state::ParticleDistribution) | ||
|
|
@@ -78,22 +91,21 @@ cases through a single-scalar caching mechanism. For standard PF, ll_baseline eq | |
| LSE before adding observation weights. For APF with resampling, it includes first-stage | ||
| correction terms computed during the APF resampling step. | ||
| """ | ||
| function marginalise!(state::ParticleDistribution) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason for passing in both |
||
| function marginalise!(state::ParticleDistribution, particles) | ||
| # Compute logsumexp after adding observation likelihoods | ||
| LSE_after = logsumexp(map(p -> p.log_w, state.particles)) | ||
| LSE_after = logsumexp(log_weight.(particles)) | ||
|
|
||
| # Compute log-likelihood increment: works for both PF and APF cases | ||
| ll_increment = LSE_after - state.ll_baseline | ||
|
|
||
| # Normalize weights | ||
| for p in state.particles | ||
| for p in particles | ||
| p.log_w -= LSE_after | ||
| end | ||
|
Comment on lines
+162
to
164
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My only issue with this entire PR is this single line of code. If we could get rid of this, we no longer have to rely on the mutation of
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be on the complete wrong page, but what's stopping us from just constructing new particles using the old weights - LSE_after, rather than mutating? |
||
|
|
||
| # Reset baseline for next iteration | ||
| state.ll_baseline = 0.0 | ||
|
|
||
| return ll_increment | ||
| new_state = ParticleDistribution(particles, zero(ll_increment)) | ||
THargreaves marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return new_state, ll_increment | ||
| end | ||
|
|
||
| ## GAUSSIAN STATES ######################################################################### | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.