Skip to content

Unified interface for batched filters #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

THargreaves
Copy link
Collaborator

@THargreaves THargreaves commented Jun 6, 2025

This PR aims to solve #47 by handling batched filtering through dispatch rather than defining replacement algorithms.

The core idea of the PR is already there and demonstrated in the batch_kalman_test.jl file. Note that BatchKalmanFilter is now gone, which is great for reducing code duplication!

Instead, the batched behaviour is taken care of through special batch types that use dispatch to create wrappers around CuBLAS, CuSolver functions. These types and wrappers are included in this package now for ease, but the goal is to separate these out into their own package in the near future, since they have far wider use cases.

I still have a few more things to do:

  • Update resampling to followed batched interface
  • Integrate these changes with the RBPF
  • Remove legacy batching code
  • Create a CPU batched type using static arrays
  • Create a MAGMA wrapper and use LocalPreferences to determine which backend to use
  • Bring in my fused CUDA kernels (again using LocalPreferences to enable)

I would like to stress that this PR is only implementing the bare basics of this interface. I.e. it is not necessarily meant to be fast, clean, or feature-complete, but rather just define the interface. These changes are not really the responsibility of GeneralisedFilters, but rather the batching interface.

The main improvements to be made (in a future PR) are:

  • Tighten up the transposition/factorization/adjoint interface, basing the approach on how LinearAlgebra handles these (the current approach is fairly hacked together based around our current needs—notably, I haven't defined Adjoint, just Transpose and so the KF had to be changed to use transpose for now.
  • Automatically generate methods for mixed batched + singular operations (currently just manually added the ones we use)
  • Maybe make the pointer creation lazy — these aren't needed for "strided" kernels
  • Error checking (e.g. dims, batch count) is currently handled by the low-level CUDA functions, which aren't always the clearest.

Finally, I initially planned to have BatchedCuMatrix be a subtype of AbstractMatrix since that's kind of how it behaves. I think this might not be possible however as AbstractMatrix defines specific behaviour for how matrices get displayed in the REPL and I doubt our batched matrices can conform to this interface (which leads to errors when printed). For now, I'm using Union{AbstractMatrix, BatchedMatrix} is type signatures.

Thoughts and feedback would be greatly appreciated.

ys = cu.(ys_cpu)

# Hack: manually setting of initialisation for this model
function GeneralisedFilters.initialise_log_evidence(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@charlesknipp

This is probably the section most relevant to your work on initialisation. This is a hack that I've put in to ensure that we start from the correct ll type.

As we've discussed, it would be best to refactor the code so that we don't need to specify the ll type at this point, and instead just use whatever is returned by the first update.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot here for me to dig into. From what it looks like, it really doesn't interfere with my PR, so hopefully I can tie things together pretty nicely.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I agree that they are largely independent. I more wanted to get this out there to demonstrate how the initialisation of the state impacts the filtering behaviour through dispatch.

In this case, because we start off with a batched Gaussian state and a CuVector for the lls, dispatch handles the rest and we get batched filtering.

But with your changes, we'd only need to worry about defining the initial state and the the ll type will just come from whatever the output of logdensity is (which will be handled by dispatch)

Copy link
Contributor

github-actions bot commented Jun 6, 2025

SSMProblems.jl/SSMProblems documentation for PR #105 is available at:
https://TuringLang.github.io/SSMProblems.jl/SSMProblems/previews/PR105/

Copy link
Contributor

github-actions bot commented Jun 6, 2025

SSMProblems.jl/GeneralisedFilters documentation for PR #105 is available at:
https://TuringLang.github.io/SSMProblems.jl/GeneralisedFilters/previews/PR105/

Comment on lines +125 to +134
proposed_particles =
SSMProblems.simulate.(
Ref(rng),
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
Ref(observation),
kwargs...,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
proposed_particles =
SSMProblems.simulate.(
Ref(rng),
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
Ref(observation),
kwargs...,
)
proposed_particles = SSMProblems.simulate.(
Ref(rng),
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
Ref(observation),
kwargs...,
)

Comment on lines +139 to +152
state.log_weights .+=
SSMProblems.logdensity.(
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs...
)
state.log_weights .-=
SSMProblems.logdensity.(
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
proposed_particles,
Ref(observation);
kwargs...,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
state.log_weights .+=
SSMProblems.logdensity.(
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs...
)
state.log_weights .-=
SSMProblems.logdensity.(
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
proposed_particles,
Ref(observation);
kwargs...,
)
state.log_weights .+= SSMProblems.logdensity.(
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs...
)
state.log_weights .-= SSMProblems.logdensity.(
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
proposed_particles,
Ref(observation);
kwargs...,
)

Comment on lines +167 to +170
log_increments =
SSMProblems.logdensity.(
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs...
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
log_increments =
SSMProblems.logdensity.(
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs...
)
log_increments = SSMProblems.logdensity.(
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs...
)

Comment on lines +218 to +221
state.particles =
SSMProblems.simulate.(
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs...
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
state.particles =
SSMProblems.simulate.(
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs...
)
state.particles = SSMProblems.simulate.(
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs...
)

Comment on lines +7 to +8


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@THargreaves
Copy link
Collaborator Author

THargreaves commented Jun 11, 2025

Quite an exciting change in the last commit: the BF now works with both the CPU and GPU interface.

Methodology

Rather than using a map inside of the BF for prorogation and log density calculations, we use broadcasting

    proposed_particles =
        SSMProblems.simulate.(
            Ref(rng),
            Ref(model),
            Ref(filter.proposal),
            Ref(iter),
            state.particles,
            Ref(observation),
            kwargs...,
        )
    if !isnothing(ref_state)
        proposed_particles[1] = ref_state[iter]
    end

This makes no difference to the CPU version, but can be intercepted when using a batched type.

function Base.Broadcast.broadcasted(
    ::typeof(SSMProblems.simulate),
    rng_ref::Base.RefValue,
    model_dyn_ref::Base.RefValue,
    iter_ref::Base.RefValue,
    particles::BatchedVector;
    kwargs...,
)
    # Extract values from Ref and call non-broadcasted version
    return SSMProblems.simulate(
        rng_ref[], model_dyn_ref[], iter_ref[], particles; kwargs...
    )
end

I'm not thrilled that this code has to be defined specifically for the simulate and logdensity function, but it works for now. I think we can get around this be using a broadcast style, which should hopefully define this behaviour for all functions at once.

The rest of the code is basically just defining wrappers for batched operations. The only other important change is that I've swapped out calls to MvNormal with Gaussian since the latter doesn't place any type restrictions of the mean and covariance (so we can therefore use a batched type).

Calls to rand and logpdf on the CPU are handled by GaussianDistributions itself and we can add methods for batched arrays. I've currently hard-coded these for the ones we need but it shouldn't be difficult to do this in a generic way.

function Random.rand(::AbstractRNG, P::Gaussian{BatchedCuVector{T},<:CuMatrix{T}}) where {T}
    D, N = size(P.μ.data)
    Σ_L = cholesky(P.Σ).L
    Z = BatchedCuVector(CUDA.randn(T, D, N))
    return P.μ + CuArray(Σ_L) * Z
end

Hacks and Caveats

To avoid this one PR being the size of a new package, I've had to hard-code a load of batched wrappers. For example, for batched-matrix-singleton-vector multiplication I've hardcoded

for (fname, elty, gemv_batched) in (
    (:cublasSgemvBatched_64, :Float32, CUDA.CUBLAS.cublasSgemvBatched_64),
    (:cublasDgemvBatched_64, :Float64, CUDA.CUBLAS.cublasDgemvBatched_64),
    (:cublasCgemvBatched_64, :ComplexF32, CUDA.CUBLAS.cublasCgemvBatched_64),
    (:cublasZgemvBatched_64, :ComplexF64, CUDA.CUBLAS.cublasZgemvBatched_64),
)
    @eval begin
        function *(A::BatchedCuMatrix{$elty}, x::CuVector{$elty})
            m, n, b = size(A.data)
            y_data = CuArray{$elty}(undef, m, b)
            y = BatchedCuVector(y_data)

            # Call gemv directly
            x_ptrs = batch_singleton(x, b)
            h = CUDA.CUBLAS.handle()
            $gemv_batched(
                h, 'N', m, n, $elty(1.0), A.ptrs, m, x_ptrs, 1, $elty(0.0), y.ptrs, 1, b
            )
            return y
        end

        function *(A::CuMatrix{$elty}, x::BatchedCuVector{$elty})
            m, n = size(A)
            b = size(x.data, 2)
            y_data = CuArray{$elty}(undef, m, b)
            y = BatchedCuVector(y_data)

            # Call gemv directly
            A_ptrs = batch_singleton(A, b)
            h = CUDA.CUBLAS.handle()
            $gemv_batched(
                h, 'N', m, n, $elty(1.0), A_ptrs, m, x.ptrs, 1, $elty(0.0), y.ptrs, 1, b
            )
            return y
        end
    end
end

I've only done this for the combinations I need to pass the unit test so if you step size of that, you'll get method call errors. These all follow the same pattern so we should be able to create some wrapper/method generator that handles all of these automatically.

The BF doesn't work with resampling yet, but will with a few small additions to the batching interface.

Immediate Next Steps

Tomorrow I'm going to generalise this to the guided PF and the RBPF. Shouldn't be too difficult now this main idea is in place 🤞

Copy link
Member

@charlesknipp charlesknipp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great stuff which makes the interface easier to work with and reduces quite a bit of redundancy.

I think it still needs some work with broadcasting and RBPF, but once it's more finalized I may go ahead and propose some minor optimizations for the particle filtering.

One minor thing: I wonder if this breaks Zygote compatibility. This doesn't really bother me since I stick to either Mooncake or ForwardDiff.

ys = cu.(ys_cpu)

# Hack: manually setting of initialisation for this model
function GeneralisedFilters.initialise_log_evidence(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot here for me to dig into. From what it looks like, it really doesn't interfere with my PR, so hopefully I can tie things together pretty nicely.

Comment on lines +80 to +85
T<:Real,
μT<:Union{AbstractVector{T},BatchedVector{T}},
ΣT<:Union{AbstractMatrix{T},BatchedMatrix{T}},
AT<:Union{AbstractMatrix{T},BatchedMatrix{T}},
bT<:Union{AbstractVector{T},BatchedVector{T}},
QT<:Union{AbstractMatrix{T},BatchedMatrix{T}},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Types are definitely subject to change when merged my PR, so we can keep this even more abstract

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what would be more abstract than this other than Any. Though I'd be happy with that. That's what Gaussian does which made it very easy to integrate with.

#### DISTRIBUTIONAL OPERATIONS ####
###################################

# Can likely replace by using Gaussian
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually wondering if we could remove the dependency on GaussianDistributions.jl and just write our own. The package is super lightweight (and also unmaintained) so it shouldn't be more than 150 lines of code.

@@ -233,3 +241,45 @@ function filter(
)
return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...)
end

# Broadcast wrapper for batched types
# TODO: this can likely be replaced with a broadcast style
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to see this implemented before I go ahead and merge anything

@@ -44,89 +44,21 @@ function update(
# Update state
m = H * μ + c
y = observation - m
S = hermitianpart(H * Σ * H' + R)
K = Σ * H' / S
S = H * Σ * transpose(H) + R
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for replacing x' with transpose(x)? For real numbers it should be the same thing

Copy link
Member

@FredericWantiez FredericWantiez Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd need to define ' or adjoint on an BatchedVector I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just to do with how I defined my CuBLAS wrappers. I will revert these back to adjoints once I have the full set of wrappers.

Comment on lines +139 to +152
state.log_weights .+=
SSMProblems.logdensity.(
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs...
)
state.log_weights .-=
SSMProblems.logdensity.(
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
proposed_particles,
Ref(observation);
kwargs...,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there slowness induced by the need for an additional loop?

I liked the map block for a couple reasons (1) no need for an additional loop and (2) the code contains far fewer getindex calls, only relying on setindex for the RBPF.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there slowness induced by the need for an additional loop?

Yeah, and it's actually fairly substantial in the batched case. I think we can get around this by just having a function that does both log density calculations in one and broadcasting over that instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants