Skip to content

Commit 2e6ccab

Browse files
authored
Hacky fix for approx periodic (#122)
* Add failing test and fix formatting * Write to_sde for approx periodic * Bump patch * Restrict usage to Arrays * Remove known failure case * Fix ambiguity * Remove test for disallowed implementation
1 parent 3c56beb commit 2e6ccab

File tree

3 files changed

+44
-50
lines changed

3 files changed

+44
-50
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TemporalGPs"
22
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
33
authors = ["willtebbutt <[email protected]> and contributors"]
4-
version = "0.6.5"
4+
version = "0.6.6"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

Diff for: src/gp/lti_sde.jl

+24-39
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:R
256256
return Gaussian(m, P)
257257
end
258258

259-
# Approximate Periodic Kernel
259+
# ApproxPeriodicKernel
260+
260261
# The periodic kernel is approximated by a sum of cosine kernels with different frequencies.
261262
struct ApproxPeriodicKernel{N,K<:PeriodicKernel} <: KernelFunctions.SimpleKernel
262263
kernel::K
@@ -279,53 +280,37 @@ function Base.show(io::IO, κ::ApproxPeriodicKernel{N}) where {N}
279280
return print(io, "Approximate Periodic Kernel, (r = $(only.kernel.r))) approximated with $N cosine kernels")
280281
end
281282

282-
function lgssm_components(approx::ApproxPeriodicKernel{N}, t::Union{StepRangeLen, RegularSpacing}, storage::StorageType{T}) where {N,T<:Real}
283-
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
284-
nt = length(t)
285-
As = map(F -> Fill(time_exp(F, T(step(t))), nt), Fs)
286-
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
287-
end
288-
function lgssm_components(approx::ApproxPeriodicKernel{N}, t::AbstractVector{<:Real}, storage::StorageType{T}) where {N,T<:Real}
289-
Fs, Hs, ms, Ps = _init_periodic_kernel_lgssm(approx.kernel, storage, N)
290-
t = vcat([first(t) - 1], t)
291-
nt = length(diff(t))
292-
As = _map(F -> _map(Δt -> time_exp(F, T(Δt)), diff(t)), Fs)
293-
return _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
294-
end
283+
# Can't use approx periodic kernel with static arrays -- the dimensions become too large.
284+
_ap_error() = throw(error("Unable to construct an ApproxPeriodicKernel for SArrayStorage"))
285+
to_sde(::ApproxPeriodicKernel, ::SArrayStorage) = _ap_error()
286+
stationary_distribution(::ApproxPeriodicKernel, ::SArrayStorage) = _ap_error()
295287

296-
function _init_periodic_kernel_lgssm(kernel::PeriodicKernel, storage, N::Int=7)
297-
r = kernel.r
298-
l⁻² = inv(4 * only(r)^2)
299-
288+
function to_sde(::ApproxPeriodicKernel{N}, storage::ArrayStorage{T}) where {T<:Real, N}
289+
290+
# Compute F and H for component processes.
300291
F, _, H = to_sde(CosineKernel(), storage)
301292
Fs = ntuple(N) do i
302293
2π * (i - 1) * F
303294
end
304-
Hs = Fill(H, N)
305295

296+
# Combine component processes into a single whole.
297+
F = block_diagonal(collect.(Fs)...)
298+
q = zero(T)
299+
H = repeat(collect(H), N)
300+
return F, q, H
301+
end
302+
303+
function stationary_distribution(kernel::ApproxPeriodicKernel{N}, storage::ArrayStorage{<:Real}) where {N}
306304
x0 = stationary_distribution(CosineKernel(), storage)
307-
ms = Fill(x0.m, N)
308-
P = x0.P
305+
m = collect(repeat(x0.m, N))
306+
r = kernel.kernel.r
307+
l⁻² = inv(4 * only(r)^2)
309308
Ps = ntuple(N) do j
310309
qⱼ = (1 + (j !== 1) ) * besseli(j - 1, l⁻²) / exp(l⁻²)
311-
qⱼ * P
312-
end
313-
314-
Fs, Hs, ms, Ps
315-
end
316-
317-
function _reduce_sum_cosine_kernel_lgssm(As, Hs, ms, Ps, N, nt, T)
318-
as = Fill(Fill(Zeros{T}(size(first(first(As)), 1)), nt), N)
319-
Qs = _map((P, A) -> _map(A -> Symmetric(P) - A * Symmetric(P) * A', A), Ps, As)
320-
Hs = Fill(vcat(Hs...), nt)
321-
h = Fill(zero(T), nt)
322-
As = _map(block_diagonal, As...)
323-
as = -map(vcat, as...)
324-
Qs = _map(block_diagonal, Qs...)
325-
m = reduce(vcat, ms)
326-
P = block_diagonal(Ps...)
327-
x0 = Gaussian(m, P)
328-
return As, as, Qs, (Hs, h), x0
310+
return qⱼ * x0.P
311+
end
312+
P = collect(block_diagonal(Ps...))
313+
return Gaussian(m, P)
329314
end
330315

331316
# Constant

Diff for: test/gp/lti_sde.jl

+19-10
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end
2929
@testset "$(typeof(t)), $storage, $N" for t in (
3030
sort(rand(Nt)), RegularSpacing(0.0, 0.1, Nt)
3131
),
32-
storage in (SArrayStorage{Float64}(), ArrayStorage{Float64}()),
32+
storage in (ArrayStorage{Float64}(), ),
3333
N in (5, 8)
3434

3535
k = ApproxPeriodicKernel{N}()
@@ -131,6 +131,12 @@ println("lti_sde:")
131131
val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(),
132132
to_vec_grad=nothing,
133133
),
134+
# THIS IS KNOWN NOT TO WORK!
135+
# (
136+
# name="prod-(Matern32Kernel + ConstantKernel) * Matern52Kernel",
137+
# val=(Matern32Kernel() + ConstantKernel()) * Matern52Kernel(),
138+
# to_vec_grad=nothing,
139+
# ),
134140

135141
# Summed kernels.
136142
(
@@ -149,18 +155,21 @@ println("lti_sde:")
149155
)
150156

151157
# Construct a Gauss-Markov model with either dense storage or static storage.
152-
storages = ((name="dense storage Float64", val=ArrayStorage(Float64)),
153-
# (name="static storage Float64", val=SArrayStorage(Float64)),
154-
)
158+
storages = (
159+
(name="dense storage Float64", val=ArrayStorage(Float64)),
160+
# (name="static storage Float64", val=SArrayStorage(Float64)),
161+
)
155162

156163
# Either regular spacing or irregular spacing in time.
157-
ts = ((name="irregular spacing", val=collect(RegularSpacing(0.0, 0.3, N))),
158-
# (name="regular spacing", val=RegularSpacing(0.0, 0.3, N)),
159-
)
164+
ts = (
165+
(name="irregular spacing", val=collect(RegularSpacing(0.0, 0.3, N))),
166+
# (name="regular spacing", val=RegularSpacing(0.0, 0.3, N)),
167+
)
160168

161-
σ²s = ((name="homoscedastic noise", val=(0.1,)),
162-
# (name="heteroscedastic noise", val=(rand(rng, N) .+ 1e-1, )),
163-
)
169+
σ²s = (
170+
(name="homoscedastic noise", val=(0.1,)),
171+
# (name="heteroscedastic noise", val=(rand(rng, N) .+ 1e-1, )),
172+
)
164173

165174
means = (
166175
(name="Zero Mean", val=ZeroMean()),

0 commit comments

Comments
 (0)