@@ -46,14 +46,7 @@ function initialise(
4646 initialise_particle (rng, prior, algo, ref; kwargs... )
4747 end
4848
49- # Set equal weights: log_w = -log(N) so weights sum to 1
50- log_weight = - log (N)
51- for p in particles
52- p. log_w = log_weight
53- end
54-
55- # Initialize with ll_baseline = 0.0
56- return ParticleDistribution (particles, 0.0 )
49+ return ParticleDistribution (particles, TypelessZero ())
5750end
5851
5952function predict (
@@ -66,20 +59,18 @@ function predict(
6659 ref_state:: Union{Nothing,AbstractVector} = nothing ,
6760 kwargs... ,
6861)
69- particles = map (1 : length (state . particles )) do i
62+ particles = map (1 : num_particles (algo )) do i
7063 particle = state. particles[i]
7164 ref = ! isnothing (ref_state) && i == 1 ? ref_state[iter] : nothing
7265 predict_particle (rng, dyn, algo, iter, particle, observation, ref; kwargs... )
7366 end
74- state. particles = particles
7567
7668 # Accumulate the baseline with LSE of weights after prediction (before update)
7769 # For plain PF/guided: ll_baseline is 0.0 on entry, becomes LSE_before
7870 # For APF with resample: ll_baseline already stores negative correction; add LSE_before
79- LSE_before = logsumexp (map (p -> p. log_w, state. particles))
80- state. ll_baseline += LSE_before
81-
82- return state
71+ return ParticleDistribution (
72+ particles, logsumexp (log_weights (state)) + state. ll_baseline
73+ )
8374end
8475
8576function update (
@@ -93,10 +84,9 @@ function update(
9384 particles = map (state. particles) do particle
9485 update_particle (obs, algo, iter, particle, observation; kwargs... )
9586 end
96- state. particles = particles
97- ll_increment = marginalise! (state)
87+ new_state, ll_increment = marginalise! (state, particles)
9888
99- return state , ll_increment
89+ return new_state , ll_increment
10090end
10191
10292struct ParticleFilter{RS,PT} <: AbstractParticleFilter
@@ -121,8 +111,7 @@ function initialise_particle(
121111 rng:: AbstractRNG , prior:: StatePrior , algo:: ParticleFilter , ref_state; kwargs...
122112)
123113 x = sample_prior (rng, prior, algo, ref_state; kwargs... )
124- # TODO (RB): determine the correct type for the log_w field or use a NoWeight type
125- return Particle (x, 0.0 , 0 )
114+ return Particle (x, 0 )
126115end
127116
128117function predict_particle (
@@ -135,10 +124,10 @@ function predict_particle(
135124 ref_state;
136125 kwargs... ,
137126)
138- new_x, logw_inc = propogate (
127+ new_x, log_increment = propogate (
139128 rng, dyn, algo, iter, particle. state, observation, ref_state; kwargs...
140129 )
141- return Particle (new_x, particle. log_w + logw_inc , particle. ancestor)
130+ return Particle (new_x, log_weight ( particle) + log_increment , particle. ancestor)
142131end
143132
144133function update_particle (
@@ -152,7 +141,7 @@ function update_particle(
152141 log_increment = SSMProblems. logdensity (
153142 obs, iter, particle. state, observation; kwargs...
154143 )
155- return Particle (particle. state, particle. log_w + log_increment, particle. ancestor)
144+ return Particle (particle. state, log_weight ( particle) + log_increment, particle. ancestor)
156145end
157146
158147function step (
@@ -252,10 +241,7 @@ function propogate(
252241 ref_state
253242 end
254243
255- # TODO : make this type consistent
256- # Will have to do a lazy zero or change propogate to accept a particle (in which case
257- # we'll need to construct a particle in the RBPF predict method)
258- return new_x, 0.0
244+ return new_x, TypelessZero ()
259245end
260246
261247# TODO : I feel like we shouldn't need to do this conversion. It should be handled by dispatch
@@ -353,7 +339,7 @@ function step(
353339 kwargs... ,
354340)
355341 # Compute lookahead weights approximating log p(y_{t+1} | x_{t}^(i))
356- log_ξs = map (state. particles) do particle
342+ log_ηs = map (state. particles) do particle
357343 compute_logeta (
358344 rng,
359345 algo. weight_strategy,
@@ -366,40 +352,8 @@ function step(
366352 )
367353 end
368354
369- # Log normalizer for current weights
370- LSE_w = logsumexp (map (p -> p. log_w, state. particles))
371-
372- # Incorporate lookahead weights into current weights
373- for (i, particle) in enumerate (state. particles)
374- particle. log_w += log_ξs[i]
375- end
376-
377- # Compute lookahead evidence
378- LSE_lookahead = logsumexp (map (p -> p. log_w, state. particles)) - LSE_w
379-
380- resample_flag = will_resample (resampler (algo), state)
381- if resample_flag
382- state = resample (rng, resampler (algo), state; ref_state)
383- else
384- # Not resampling: preserve ll_baseline and set ancestors to self
385- n = length (state. particles)
386- new_particles = similar (state. particles)
387- for i in 1 : n
388- new_particles[i] = set_ancestor (state. particles[i], i)
389- end
390- state = ParticleDistribution (new_particles, state. ll_baseline)
391- end
392-
393- # Compensate for lookahead weights
394- for particle in state. particles
395- particle. log_w -= log_ξs[particle. ancestor]
396- end
397-
398- # Compute compensation log normalizer
399- if resample_flag
400- LSE_comp = logsumexp (map (p -> p. log_w, state. particles))
401- state. ll_baseline = - (LSE_lookahead + (LSE_comp - log (num_particles (algo))))
402- end
355+ rs = AuxiliaryResampler (resampler (algo), log_ηs)
356+ state = maybe_resample (rng, rs, state; ref_state)
403357
404358 callback (model, algo, iter, state, observation, PostResample; kwargs... )
405359 return move (
0 commit comments