Skip to content

Unified interface for batched filters #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ include("callbacks.jl")
include("containers.jl")
include("resamplers.jl")

# Batching utilities
include("batching/batching.jl")
include("batching/batched_CUDA.jl")
include("batching/batched_SA.jl")

## FILTERING BASE ##########################################################################

abstract type AbstractFilter <: AbstractSampler end
Expand Down
88 changes: 10 additions & 78 deletions GeneralisedFilters/src/algorithms/kalman.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export KalmanFilter, filter, BatchKalmanFilter
using GaussianDistributions
using CUDA: i32
import LinearAlgebra: hermitianpart
import LinearAlgebra: hermitianpart, transpose, Cholesky

export KalmanFilter, KF, KalmanSmoother, KS

Expand All @@ -27,7 +27,7 @@ function predict(
)
μ, Σ = GaussianDistributions.pair(state)
A, b, Q = calc_params(model.dyn, iter; kwargs...)
return Gaussian(A * μ + b, A * Σ * A' + Q)
return Gaussian(A * μ + b, A * Σ * transpose(A) + Q)
end

function update(
Expand All @@ -44,89 +44,21 @@ function update(
# Update state
m = H * μ + c
y = observation - m
S = hermitianpart(H * Σ * H' + R)
K = Σ * H' / S
S = H * Σ * transpose(H) + R
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for replacing x' with transpose(x)? For real numbers it should be the same thing

Copy link
Member

@FredericWantiez FredericWantiez Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd need to define ' or adjoint on an BatchedVector I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just to do with how I defined my CuBLAS wrappers. I will revert these back to adjoints once I have the full set of wrappers.

S = (S + transpose(S)) / 2 # force symmetry
S_chol = cholesky(S)
KT = S_chol \ H * Σ # TODO: only using `\` for better integration with CuSolver

state = Gaussian(μ + K * y, Σ - K * H * Σ)
state = Gaussian(μ + transpose(KT) * y, Σ - transpose(KT) * H * Σ)

# Compute log-likelihood
ll = logpdf(MvNormal(m, S), observation)
ll = gaussian_likelihood(m, S, observation)

return state, ll
end

struct BatchKalmanFilter <: AbstractBatchFilter
batch_size::Int
end

function initialise(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel{T},
algo::BatchKalmanFilter;
kwargs...,
) where {T}
μ0s, Σ0s = batch_calc_initial(model.dyn, algo.batch_size; kwargs...)
return BatchGaussianDistribution(μ0s, Σ0s)
end

function predict(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel{T},
algo::BatchKalmanFilter,
iter::Integer,
state::BatchGaussianDistribution,
observation;
kwargs...,
) where {T}
μs, Σs = state.μs, state.Σs
As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...)
μ̂s = NNlib.batched_vec(As, μs) .+ bs
Σ̂s = NNlib.batched_mul(NNlib.batched_mul(As, Σs), NNlib.batched_transpose(As)) .+ Qs
return BatchGaussianDistribution(μ̂s, Σ̂s)
end

function update(
model::LinearGaussianStateSpaceModel{T},
algo::BatchKalmanFilter,
iter::Integer,
state::BatchGaussianDistribution,
observation;
kwargs...,
) where {T}
μs, Σs = state.μs, state.Σs
Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...)
D = size(observation, 1)

m = NNlib.batched_vec(Hs, μs) .+ cs
y_res = cu(observation) .- m
S = NNlib.batched_mul(Hs, NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))) .+ Rs

ΣH_T = NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))

S_inv = CUDA.similar(S)
d_ipiv, _, d_S = CUDA.CUBLAS.getrf_strided_batched(S, true)
CUDA.CUBLAS.getri_strided_batched!(d_S, S_inv, d_ipiv)

diags = CuArray{eltype(S)}(undef, size(S, 1), size(S, 3))
for i in 1:size(S, 1)
diags[i, :] .= d_S[i, i, :]
end

log_dets = sum(log ∘ abs, diags; dims=1)

K = NNlib.batched_mul(ΣH_T, S_inv)

μ_filt = μs .+ NNlib.batched_vec(K, y_res)
Σ_filt = Σs .- NNlib.batched_mul(K, NNlib.batched_mul(Hs, Σs))

inv_term = NNlib.batched_vec(S_inv, y_res)
log_likes = -T(0.5) * NNlib.batched_vec(reshape(y_res, 1, D, size(S, 3)), inv_term)
log_likes = log_likes .- T(0.5) * (log_dets .+ D * log(T(2π)))

# HACK: only errors seems to be from numerical stability so will just overwrite
log_likes[isnan.(log_likes)] .= -Inf

return BatchGaussianDistribution(μ_filt, Σ_filt), dropdims(log_likes; dims=1)
function gaussian_likelihood(m::AbstractVector, S::AbstractMatrix, y::AbstractVector)
return logpdf(MvNormal(m, S), y)
end

## KALMAN SMOOTHER #########################################################################
Expand Down
106 changes: 78 additions & 28 deletions GeneralisedFilters/src/algorithms/particles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,34 @@ function predict(
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
proposed_particles = map(enumerate(state.particles)) do (i, particle)
if !isnothing(ref_state) && i == 1
ref_state[iter]
else
simulate(rng, model, filter.proposal, iter, particle, observation; kwargs...)
end
proposed_particles =
SSMProblems.simulate.(
Ref(rng),
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
Ref(observation),
kwargs...,
)
Comment on lines +125 to +134
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
proposed_particles =
SSMProblems.simulate.(
Ref(rng),
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
Ref(observation),
kwargs...,
)
proposed_particles = SSMProblems.simulate.(
Ref(rng),
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
Ref(observation),
kwargs...,
)

if !isnothing(ref_state)
proposed_particles[1] = ref_state[iter]
end

state.log_weights +=
map(zip(proposed_particles, state.particles)) do (new_state, prev_state)
log_f = SSMProblems.logdensity(
model.dyn, iter, prev_state, new_state; kwargs...
)

log_q = SSMProblems.logdensity(
model, filter.proposal, iter, prev_state, new_state, observation; kwargs...
)

(log_f - log_q)
end
state.log_weights .+=
SSMProblems.logdensity.(
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs...
)
state.log_weights .-=
SSMProblems.logdensity.(
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
proposed_particles,
Ref(observation);
kwargs...,
)
Comment on lines +139 to +152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
state.log_weights .+=
SSMProblems.logdensity.(
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs...
)
state.log_weights .-=
SSMProblems.logdensity.(
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
proposed_particles,
Ref(observation);
kwargs...,
)
state.log_weights .+= SSMProblems.logdensity.(
Ref(model.dyn), Ref(iter), state.particles, proposed_particles, kwargs...
)
state.log_weights .-= SSMProblems.logdensity.(
Ref(model),
Ref(filter.proposal),
Ref(iter),
state.particles,
proposed_particles,
Ref(observation);
kwargs...,
)

Comment on lines +139 to +152
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there slowness induced by the need for an additional loop?

I liked the map block for a couple reasons (1) no need for an additional loop and (2) the code contains far fewer getindex calls, only relying on setindex for the RBPF.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there slowness induced by the need for an additional loop?

Yeah, and it's actually fairly substantial in the batched case. I think we can get around this by just having a function that does both log density calculations in one and broadcasting over that instead.


state.particles = proposed_particles

Expand All @@ -156,10 +164,10 @@ function update(
observation;
kwargs...,
) where {T}
log_increments = map(
x -> SSMProblems.logdensity(model.obs, iter, x, observation; kwargs...),
state.particles,
)
log_increments =
SSMProblems.logdensity.(
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs...
)
Comment on lines +167 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
log_increments =
SSMProblems.logdensity.(
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs...
)
log_increments = SSMProblems.logdensity.(
Ref(model.obs), Ref(iter), state.particles, Ref(observation); kwargs...
)


state.log_weights += log_increments

Expand Down Expand Up @@ -207,12 +215,12 @@ function predict(
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
state.particles = map(enumerate(state.particles)) do (i, particle)
if !isnothing(ref_state) && i == 1
ref_state[iter]
else
SSMProblems.simulate(rng, model.dyn, iter, particle; kwargs...)
end
state.particles =
SSMProblems.simulate.(
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs...
)
Comment on lines +218 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
state.particles =
SSMProblems.simulate.(
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs...
)
state.particles = SSMProblems.simulate.(
Ref(rng), Ref(model.dyn), Ref(iter), state.particles; kwargs...
)

if !isnothing(ref_state)
state.particles[1] = ref_state[iter]
end

return state
Expand All @@ -233,3 +241,45 @@ function filter(
)
return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...)
end

# Broadcast wrapper for batched types
# TODO: this can likely be replaced with a broadcast style
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to see this implemented before I go ahead and merge anything

function Base.Broadcast.broadcasted(
::typeof(SSMProblems.simulate),
rng_ref::Base.RefValue,
model_dyn_ref::Base.RefValue,
iter_ref::Base.RefValue,
particles::BatchedVector;
kwargs...,
)
# Extract values from Ref and call non-broadcasted version
return SSMProblems.simulate(
rng_ref[], model_dyn_ref[], iter_ref[], particles; kwargs...
)
end
function Base.Broadcast.broadcasted(
::typeof(SSMProblems.logdensity),
model_obs_ref::Base.RefValue,
iter_ref::Base.RefValue,
particles::BatchedVector,
observation::Base.RefValue;
kwargs...,
)
# Extract values from Ref and call non-broadcasted version
return SSMProblems.logdensity(
model_obs_ref[], iter_ref[], particles, observation[]; kwargs...
)
end
function Base.Broadcast.broadcasted(
::typeof(SSMProblems.logdensity),
model_dyn_ref::Base.RefValue,
iter_ref::Base.RefValue,
prev_particles::BatchedVector,
new_particles::BatchedVector;
kwargs...,
)
# Extract values from Ref and call non-broadcasted version
return SSMProblems.logdensity(
model_dyn_ref[], iter_ref[], prev_particles, new_particles; kwargs...
)
end
Loading
Loading