-
Notifications
You must be signed in to change notification settings - Fork 3
Unified interface for batched filters #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ys = cu.(ys_cpu) | ||
|
||
# Hack: manually setting of initialisation for this model | ||
function GeneralisedFilters.initialise_log_evidence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably the section most relevant to your work on initialisation. This is a hack that I've put in to ensure that we start from the correct ll type.
As we've discussed, it would be best to refactor the code so that we don't need to specify the ll type at this point, and instead just use whatever is returned by the first update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot here for me to dig into. From what it looks like, it really doesn't interfere with my PR, so hopefully I can tie things together pretty nicely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, I agree that they are largely independent. I more wanted to get this out there to demonstrate how the initialisation of the state impacts the filtering behaviour through dispatch.
In this case, because we start off with a batched Gaussian state and a CuVector for the lls, dispatch handles the rest and we get batched filtering.
But with your changes, we'd only need to worry about defining the initial state and the the ll type will just come from whatever the output of logdensity
is (which will be handled by dispatch)
SSMProblems.jl/SSMProblems documentation for PR #105 is available at: |
SSMProblems.jl/GeneralisedFilters documentation for PR #105 is available at: |
proposed_particles = | ||
SSMProblems.simulate.( | ||
Ref(rng), | ||
Ref(model), | ||
Ref(filter.proposal), | ||
Ref(iter), | ||
state.particles, | ||
Ref(observation), | ||
kwargs..., | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
proposed_particles = | |
SSMProblems.simulate.( | |
Ref(rng), | |
Ref(model), | |
Ref(filter.proposal), | |
Ref(iter), | |
state.particles, | |
Ref(observation), | |
kwargs..., | |
) | |
proposed_particles = SSMProblems.simulate.( | |
Ref(rng), | |
Ref(model), | |
Ref(filter.proposal), | |
Ref(iter), | |
state.particles, | |
Ref(observation), | |
kwargs..., | |
) |
state.log_weights .+= | ||
SSMProblems.logdensity.( | ||
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs... | ||
) | ||
state.log_weights .-= | ||
SSMProblems.logdensity.( | ||
Ref(model), | ||
Ref(filter.proposal), | ||
Ref(iter), | ||
state.particles, | ||
proposed_particles, | ||
Ref(observation); | ||
kwargs..., | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
state.log_weights .+= | |
SSMProblems.logdensity.( | |
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs... | |
) | |
state.log_weights .-= | |
SSMProblems.logdensity.( | |
Ref(model), | |
Ref(filter.proposal), | |
Ref(iter), | |
state.particles, | |
proposed_particles, | |
Ref(observation); | |
kwargs..., | |
) | |
state.log_weights .+= SSMProblems.logdensity.( | |
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs... | |
) | |
state.log_weights .-= SSMProblems.logdensity.( | |
Ref(model), | |
Ref(filter.proposal), | |
Ref(iter), | |
state.particles, | |
proposed_particles, | |
Ref(observation); | |
kwargs..., | |
) |
log_increments = | ||
SSMProblems.logdensity.( | ||
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs... | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
log_increments = | |
SSMProblems.logdensity.( | |
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs... | |
) | |
log_increments = SSMProblems.logdensity.( | |
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs... | |
) |
state.particles = | ||
SSMProblems.simulate.( | ||
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs... | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
state.particles = | |
SSMProblems.simulate.( | |
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs... | |
) | |
state.particles = SSMProblems.simulate.( | |
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs... | |
) |
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Quite an exciting change in the last commit: the BF now works with both the CPU and GPU interface. Methodology Rather than using a proposed_particles =
SSMProblems.simulate.(
Ref(rng),
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
Ref(observation),
kwargs...,
)
if !isnothing(ref_state)
proposed_particles[1] = ref_state[iter]
end This makes no difference to the CPU version, but can be intercepted when using a batched type. function Base.Broadcast.broadcasted(
::typeof(SSMProblems.simulate),
rng_ref::Base.RefValue,
model_dyn_ref::Base.RefValue,
iter_ref::Base.RefValue,
particles::BatchedVector;
kwargs...,
)
# Extract values from Ref and call non-broadcasted version
return SSMProblems.simulate(
rng_ref[], model_dyn_ref[], iter_ref[], particles; kwargs...
)
end I'm not thrilled that this code has to be defined specifically for the The rest of the code is basically just defining wrappers for batched operations. The only other important change is that I've swapped out calls to Calls to function Random.rand(::AbstractRNG, P::Gaussian{BatchedCuVector{T},<:CuMatrix{T}}) where {T}
D, N = size(P.μ.data)
Σ_L = cholesky(P.Σ).L
Z = BatchedCuVector(CUDA.randn(T, D, N))
return P.μ + CuArray(Σ_L) * Z
end Hacks and Caveats To avoid this one PR being the size of a new package, I've had to hard-code a load of batched wrappers. For example, for batched-matrix-singleton-vector multiplication I've hardcoded for (fname, elty, gemv_batched) in (
(:cublasSgemvBatched_64, :Float32, CUDA.CUBLAS.cublasSgemvBatched_64),
(:cublasDgemvBatched_64, :Float64, CUDA.CUBLAS.cublasDgemvBatched_64),
(:cublasCgemvBatched_64, :ComplexF32, CUDA.CUBLAS.cublasCgemvBatched_64),
(:cublasZgemvBatched_64, :ComplexF64, CUDA.CUBLAS.cublasZgemvBatched_64),
)
@eval begin
function *(A::BatchedCuMatrix{$elty}, x::CuVector{$elty})
m, n, b = size(A.data)
y_data = CuArray{$elty}(undef, m, b)
y = BatchedCuVector(y_data)
# Call gemv directly
x_ptrs = batch_singleton(x, b)
h = CUDA.CUBLAS.handle()
$gemv_batched(
h, 'N', m, n, $elty(1.0), A.ptrs, m, x_ptrs, 1, $elty(0.0), y.ptrs, 1, b
)
return y
end
function *(A::CuMatrix{$elty}, x::BatchedCuVector{$elty})
m, n = size(A)
b = size(x.data, 2)
y_data = CuArray{$elty}(undef, m, b)
y = BatchedCuVector(y_data)
# Call gemv directly
A_ptrs = batch_singleton(A, b)
h = CUDA.CUBLAS.handle()
$gemv_batched(
h, 'N', m, n, $elty(1.0), A_ptrs, m, x.ptrs, 1, $elty(0.0), y.ptrs, 1, b
)
return y
end
end
end I've only done this for the combinations I need to pass the unit test so if you step size of that, you'll get method call errors. These all follow the same pattern so we should be able to create some wrapper/method generator that handles all of these automatically. The BF doesn't work with resampling yet, but will with a few small additions to the batching interface. Immediate Next Steps Tomorrow I'm going to generalise this to the guided PF and the RBPF. Shouldn't be too difficult now this main idea is in place 🤞 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great stuff which makes the interface easier to work with and reduces quite a bit of redundancy.
I think it still needs some work with broadcasting and RBPF, but once it's more finalized I may go ahead and propose some minor optimizations for the particle filtering.
One minor thing: I wonder if this breaks Zygote compatibility. This doesn't really bother me since I stick to either Mooncake or ForwardDiff.
ys = cu.(ys_cpu) | ||
|
||
# Hack: manually setting of initialisation for this model | ||
function GeneralisedFilters.initialise_log_evidence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot here for me to dig into. From what it looks like, it really doesn't interfere with my PR, so hopefully I can tie things together pretty nicely.
T<:Real, | ||
μT<:Union{AbstractVector{T},BatchedVector{T}}, | ||
ΣT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, | ||
AT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, | ||
bT<:Union{AbstractVector{T},BatchedVector{T}}, | ||
QT<:Union{AbstractMatrix{T},BatchedMatrix{T}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Types are definitely subject to change when merged my PR, so we can keep this even more abstract
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what would be more abstract than this other than Any
. Though I'd be happy with that. That's what Gaussian
does which made it very easy to integrate with.
#### DISTRIBUTIONAL OPERATIONS #### | ||
################################### | ||
|
||
# Can likely replace by using Gaussian |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was actually wondering if we could remove the dependency on GaussianDistributions.jl
and just write our own. The package is super lightweight (and also unmaintained) so it shouldn't be more than 150 lines of code.
@@ -233,3 +241,45 @@ function filter( | |||
) | |||
return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) | |||
end | |||
|
|||
# Broadcast wrapper for batched types | |||
# TODO: this can likely be replaced with a broadcast style |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to see this implemented before I go ahead and merge anything
@@ -44,89 +44,21 @@ function update( | |||
# Update state | |||
m = H * μ + c | |||
y = observation - m | |||
S = hermitianpart(H * Σ * H' + R) | |||
K = Σ * H' / S | |||
S = H * Σ * transpose(H) + R |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for replacing x'
with transpose(x)
? For real numbers it should be the same thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'd need to define '
or adjoint
on an BatchedVector
I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just to do with how I defined my CuBLAS wrappers. I will revert these back to adjoints once I have the full set of wrappers.
state.log_weights .+= | ||
SSMProblems.logdensity.( | ||
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs... | ||
) | ||
state.log_weights .-= | ||
SSMProblems.logdensity.( | ||
Ref(model), | ||
Ref(filter.proposal), | ||
Ref(iter), | ||
state.particles, | ||
proposed_particles, | ||
Ref(observation); | ||
kwargs..., | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there slowness induced by the need for an additional loop?
I liked the map block for a couple reasons (1) no need for an additional loop and (2) the code contains far fewer getindex
calls, only relying on setindex
for the RBPF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there slowness induced by the need for an additional loop?
Yeah, and it's actually fairly substantial in the batched case. I think we can get around this by just having a function that does both log density calculations in one and broadcasting over that instead.
This PR aims to solve #47 by handling batched filtering through dispatch rather than defining replacement algorithms.
The core idea of the PR is already there and demonstrated in the
batch_kalman_test.jl
file. Note thatBatchKalmanFilter
is now gone, which is great for reducing code duplication!Instead, the batched behaviour is taken care of through special batch types that use dispatch to create wrappers around CuBLAS, CuSolver functions. These types and wrappers are included in this package now for ease, but the goal is to separate these out into their own package in the near future, since they have far wider use cases.
I still have a few more things to do:
I would like to stress that this PR is only implementing the bare basics of this interface. I.e. it is not necessarily meant to be fast, clean, or feature-complete, but rather just define the interface. These changes are not really the responsibility of GeneralisedFilters, but rather the batching interface.
The main improvements to be made (in a future PR) are:
LinearAlgebra
handles these (the current approach is fairly hacked together based around our current needs—notably, I haven't definedAdjoint
, justTranspose
and so the KF had to be changed to usetranspose
for now.Finally, I initially planned to have
BatchedCuMatrix
be a subtype ofAbstractMatrix
since that's kind of how it behaves. I think this might not be possible however asAbstractMatrix
defines specific behaviour for how matrices get displayed in the REPL and I doubt our batched matrices can conform to this interface (which leads to errors when printed). For now, I'm usingUnion{AbstractMatrix, BatchedMatrix}
is type signatures.Thoughts and feedback would be greatly appreciated.