Skip to content

Commit 29152da

Browse files
new interface, fix pre-sampling
1 parent d0ab700 commit 29152da

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/algorithms/apf.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
export AuxiliaryParticleFilter, APF
22

3-
struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter
3+
mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter
44
N::Integer
55
resampler::RS
66
aux::Vector # Auxiliary weights
77
end
88

99
function AuxiliaryParticleFilter(
10-
N::Integer, threshold::Real=1.0, resampler::AbstractResampler=Systematic()
10+
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
1111
)
1212
conditional_resampler = ESSResampler(threshold, resampler)
1313
return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N))
@@ -25,7 +25,7 @@ function initialise(
2525
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N))
2626
initial_weights = fill(-log(T(filter.N)), filter.N)
2727

28-
return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state)
28+
return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter)
2929
end
3030

3131
function update_weights!(
@@ -57,16 +57,16 @@ function predict(
5757
auxiliary_weights = map(
5858
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), predicted
5959
)
60-
state.filtered.log_weights .+= auxiliary_weights
60+
states.filtered.log_weights .+= auxiliary_weights
6161
filter.aux = auxiliary_weights
6262

63-
states.proposed = resample(rng, filter.resampler, states.filtered, filter)
63+
states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter)
6464
states.proposed.particles = map(
6565
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...),
6666
states.proposed.particles,
6767
)
6868

69-
return update_ref!(states, ref_state, step)
69+
return update_ref!(states, ref_state, filter, step)
7070
end
7171

7272
function update(

0 commit comments

Comments
 (0)