-
Notifications
You must be signed in to change notification settings - Fork 3
Unified interface for batched filters #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2917471
646f4de
70d2056
dec6b20
dcd2f4b
c95d3a6
ae9b3c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -122,26 +122,34 @@ function predict( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
ref_state::Union{Nothing,AbstractVector}=nothing, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
proposed_particles = map(enumerate(state.particles)) do (i, particle) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if !isnothing(ref_state) && i == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
ref_state[iter] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
simulate(rng, model, filter.proposal, iter, particle, observation; kwargs...) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
proposed_particles = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SSMProblems.simulate.( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(rng), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(model), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(filter.proposal), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(iter), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.particles, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(observation), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+125
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if !isnothing(ref_state) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
proposed_particles[1] = ref_state[iter] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.log_weights += | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
map(zip(proposed_particles, state.particles)) do (new_state, prev_state) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_f = SSMProblems.logdensity( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model.dyn, iter, prev_state, new_state; kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_q = SSMProblems.logdensity( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model, filter.proposal, iter, prev_state, new_state, observation; kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
(log_f - log_q) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.log_weights .+= | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SSMProblems.logdensity.( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.log_weights .-= | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SSMProblems.logdensity.( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(model), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(filter.proposal), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(iter), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.particles, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
proposed_particles, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(observation); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+139
to
+152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
Comment on lines
+139
to
+152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there slowness induced by the need for an additional loop? I liked the map block for a couple reasons (1) no need for an additional loop and (2) the code contains far fewer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, and it's actually fairly substantial in the batched case. I think we can get around this by just having a function that does both log density calculations in one and broadcasting over that instead. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.particles = proposed_particles | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -156,10 +164,10 @@ function update( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
observation; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) where {T} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_increments = map( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x -> SSMProblems.logdensity(model.obs, iter, x, observation; kwargs...), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.particles, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
log_increments = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SSMProblems.logdensity.( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+167
to
+170
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.log_weights += log_increments | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -207,12 +215,12 @@ function predict( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
ref_state::Union{Nothing,AbstractVector}=nothing, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.particles = map(enumerate(state.particles)) do (i, particle) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if !isnothing(ref_state) && i == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
ref_state[iter] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SSMProblems.simulate(rng, model.dyn, iter, particle; kwargs...) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.particles = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
SSMProblems.simulate.( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+218
to
+221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if !isnothing(ref_state) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
state.particles[1] = ref_state[iter] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return state | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -233,3 +241,45 @@ function filter( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Broadcast wrapper for batched types | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# TODO: this can likely be replaced with a broadcast style | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to see this implemented before I go ahead and merge anything |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
function Base.Broadcast.broadcasted( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
::typeof(SSMProblems.simulate), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
rng_ref::Base.RefValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_dyn_ref::Base.RefValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_ref::Base.RefValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
particles::BatchedVector; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Extract values from Ref and call non-broadcasted version | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return SSMProblems.simulate( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
rng_ref[], model_dyn_ref[], iter_ref[], particles; kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
function Base.Broadcast.broadcasted( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
::typeof(SSMProblems.logdensity), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_obs_ref::Base.RefValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_ref::Base.RefValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
particles::BatchedVector, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
observation::Base.RefValue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Extract values from Ref and call non-broadcasted version | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return SSMProblems.logdensity( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_obs_ref[], iter_ref[], particles, observation[]; kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
function Base.Broadcast.broadcasted( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
::typeof(SSMProblems.logdensity), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_dyn_ref::Base.RefValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
iter_ref::Base.RefValue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
prev_particles::BatchedVector, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
new_particles::BatchedVector; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Extract values from Ref and call non-broadcasted version | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return SSMProblems.logdensity( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_dyn_ref[], iter_ref[], prev_particles, new_particles; kwargs... | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for replacing
x'
withtranspose(x)
? For real numbers it should be the same thingUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'd need to define
'
oradjoint
on anBatchedVector
I thinkThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just to do with how I defined my CuBLAS wrappers. I will revert these back to adjoints once I have the full set of wrappers.