Skip to content

Commit f8b8302

Browse files
authored
Merge pull request #115 from JuliaGaussianProcesses/fix-CI
Fix CI
2 parents d32aaa8 + 07588b1 commit f8b8302

16 files changed

+46
-38
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1717
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1818

1919
[compat]
20-
AbstractGPs = "0.5.15"
20+
AbstractGPs = "0.5.17"
2121
Bessels = "0.2.8"
2222
BlockDiagonals = "0.1.7"
2323
ChainRulesCore = "1"
2424
FillArrays = "0.13.0 - 0.13.7, 1"
2525
KernelFunctions = "0.9, 0.10.1"
2626
StaticArrays = "1"
2727
StructArrays = "0.5, 0.6"
28-
Zygote = "0.6"
28+
Zygote = "0.6.65"
2929
julia = "1.6"

README.md

-4
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@ TemporalGPs.jl is a tool to make Gaussian processes (GPs) defined using [Abstrac
88

99
[JuliaCon 2020 Talk](https://www.youtube.com/watch?v=dysmEpX1QoE)
1010

11-
# Dependency Status
12-
13-
In the interest of managing expectations, please note that TemporalGPs does not currently operate with the most current version of AbstractGPs / Zygote / ChainRules. I (Will) am aware of this problem, and will sort it out as soon as I have the time!
14-
1511
# Installation
1612

1713
TemporalGPs.jl is registered, so simply type the following at the REPL:

examples/approx_space_time_inference.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
6161
heatmap(reshape(σ_post_marginals, N_pr, T));
6262
layout=(1, 2),
6363
),
64-
"posterior.png",
64+
"approx_space_time_inference.png",
6565
);
6666
end

examples/approx_space_time_learning.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
100100
heatmap(reshape(σ_post_marginals, N_pr, T));
101101
layout=(1, 2),
102102
),
103-
"posterior.png",
103+
"approx_space_time_learning.png",
104104
);
105105
end

examples/augmented_inference.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using AbstractGPs
22
using TemporalGPs
3-
using Distributions
3+
using Distributions: Bernoulli
44
using StatsFuns: logistic
55

66
# In this example we are showing how to work with non-Gaussian likelihoods,
@@ -73,5 +73,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
7373
plot!(plt, x_pr, f_post_samples; color=:red, alpha=0.3, label="");
7474
plot!(plt, x, f_true; label="", lw=2.0, color=:blue); # Plot the true latent GP on top
7575
scatter!(plt, x, y; label="", markersize=1.0, alpha=1.0); # Plot the data
76-
savefig(plt, "posterior.png");
76+
savefig(plt, "augmented_inference.png");
7777
end

examples/exact_space_time_inference.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
5959
heatmap(reshape(σ_post_marginals, N, T_pr));
6060
layout=(1, 2),
6161
),
62-
"posterior.png",
62+
"exact_space_time_inference.png",
6363
);
6464
end

examples/exact_space_time_learning.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ end
3939
# Exact inference only works for such grids.
4040
# Times must be increasing, points in space can be anywhere.
4141
N = 50;
42-
T = 1_000;
42+
T = 500;
4343
points_in_space = collect(range(-3.0, 3.0; length=N));
4444
points_in_time = RegularSpacing(0.0, 0.01, T);
4545
x = RectilinearGrid(points_in_space, points_in_time);
@@ -73,7 +73,7 @@ final_params = unpack(training_results.minimizer)
7373
f_post = posterior(build_gp(final_params)(x, final_params.var_noise), y);
7474

7575
# Specify some locations at which to make predictions.
76-
T_pr = 1200;
76+
T_pr = 600;
7777
points_in_time_pr = RegularSpacing(0.0, 0.01, T_pr);
7878
x_pr = RectilinearGrid(points_in_space, points_in_time_pr);
7979

@@ -93,6 +93,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
9393
heatmap(reshape(σ_post_marginals, N, T_pr));
9494
layout=(1, 2),
9595
),
96-
"posterior.png",
96+
"exact_space_time_learning.png",
9797
);
9898
end

examples/exact_time_inference.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
4848
scatter!(plt, x, y; label="", markersize=0.1, alpha=0.1);
4949
plot!(plt, f_post(x_pr); ribbon_scale=3.0, label="");
5050
plot!(x_pr, f_post_samples; color=:red, label="");
51-
savefig(plt, "posterior.png");
51+
savefig(plt, "exact_time_inference.png");
5252
end

examples/exact_time_learning.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
8686
scatter!(plt, x, y; label="", markersize=0.1, alpha=0.1);
8787
plot!(plt, f_post(x_pr); ribbon_scale=3.0, label="");
8888
plot!(plt, x_pr, f_post_samples; color=:red, label="");
89-
savefig(plt, "posterior.png");
89+
savefig(plt, "exact_time_learning.png");
9090
end

src/models/linear_gaussian_conditionals.jl

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ be equivalent to
4646
function predict(x::Gaussian, f::AbstractLGC)
4747
A, a, Q = get_fields(f)
4848
m, P = get_fields(x)
49+
4950
# Symmetric wrapper needed for numerical stability. Do not unwrap.
5051
return Gaussian(A * m + a, (A * symmetric(P)) * A' + Q)
5152
end

src/space_time/pseudo_point.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,11 @@ function kernel_diagonals(k::DTCSeparable, x::RegularInTime)
103103
space_kernel = k.k.l
104104
time_kernel = k.k.r
105105
time_vars = kernelmatrix_diag(time_kernel, get_times(x))
106-
return map(
107-
(s_t, x_r) -> Diagonal(kernelmatrix_diag(space_kernel, x_r) * s_t),
108-
time_vars,
109-
x.vs,
106+
return Diagonal.(
107+
kernelmatrix_diag.(
108+
Ref(space_kernel),
109+
x.vs
110+
) .* time_vars
110111
)
111112
end
112113

@@ -185,7 +186,7 @@ function lgssm_components(k_dtc::DTCSeparable, x::RegularInTime, storage::Storag
185186
C = \(K_space_z_chol, C__)
186187
Cs = partition(ChainRulesCore.ignore_derivatives(map(length, x.vs)), C)
187188

188-
cs = _map((h, v) -> fill(h, length(v)), hs_t, x.vs) # This should currently be zero.
189+
cs = fill.(hs_t, length.(x.vs)) # This should currently be zero.
189190
Hs = _map(
190191
((I, H_t), ) -> kron(I, H_t),
191192
zip(Fill(ident_M, N), Hs_t),

src/space_time/regular_in_time.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,16 @@ function Base.collect(x::RegularInTime)
2424
return [(x, t) for (x, t) in zip(space_inputs, time_inputs)]
2525
end
2626

27-
Base.getindex(x::RegularInTime, n::Int) = collect(x)[n]
27+
function Base.getindex(x::RegularInTime, n::Int)
28+
n 0 && throw(BoundsError(x, n))
29+
sum_of_lengths = 0
30+
for (i, v) in enumerate(x.vs)
31+
temp = sum_of_lengths + length(v)
32+
temp n && return (v[n - sum_of_lengths], x.ts[i])
33+
sum_of_lengths = temp
34+
end
35+
throw(BoundsError(x, n))
36+
end
2837

2938
Base.show(io::IO, x::RegularInTime) = Base.show(io::IO, collect(x))
3039

src/util/chainrules.jl

-7
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,6 @@ Zygote.accum(a::Tuple, b::Tuple, c::Tuple) = map(Zygote.accum, a, b, c)
2828
# StaticArrays #
2929
# ---------------------------------------------------------------------------- #
3030

31-
function ProjectTo(x::SArray{S,T}) where {S, T}
32-
return ProjectTo{SArray}(; element=_eltype_projectto(T), axes=axes(x), static_size=S)
33-
end
34-
35-
(proj::ProjectTo{SArray})(dx::SArray) = SArray{proj.static_size}(dx.data)
36-
(proj::ProjectTo{SArray})(dx::AbstractArray) = SArray{proj.static_size}(Tuple(dx))
37-
3831
function rrule(::Type{T}, x::Tuple) where {T<:SArray}
3932
SArray_rrule(Δ) = begin
4033
(NoTangent(), Tangent{typeof(x)}(unthunk(Δ).data...))

test/runtests.jl

+12-7
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,16 @@ if GROUP == "examples"
110110
Pkg.resolve()
111111
Pkg.instantiate()
112112

113-
include(joinpath(pkgpath, "examples", "exact_time_inference.jl"))
114-
include(joinpath(pkgpath, "examples", "exact_time_learning.jl"))
115-
include(joinpath(pkgpath, "examples", "exact_space_time_inference.jl"))
116-
include(joinpath(pkgpath, "examples", "exact_space_time_learning.jl"))
117-
include(joinpath(pkgpath, "examples", "approx_space_time_inference.jl"))
118-
include(joinpath(pkgpath, "examples", "approx_space_time_learning.jl"))
119-
include(joinpath(pkgpath, "examples", "augmented_inference.jl"))
113+
function include_with_info(filename)
114+
@info "Running examples/$filename"
115+
include(joinpath(pkgpath, "examples", filename))
116+
end
117+
118+
include_with_info("exact_time_inference.jl")
119+
include_with_info("exact_time_learning.jl")
120+
include_with_info("exact_space_time_inference.jl")
121+
include_with_info("exact_space_time_learning.jl")
122+
include_with_info("approx_space_time_inference.jl")
123+
include_with_info("approx_space_time_learning.jl")
124+
include_with_info("augmented_inference.jl")
120125
end

test/space_time/pseudo_point.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ include("../models/model_test_utils.jl")
9696
validate_dims(lgssm)
9797

9898
# The two approaches to DTC computation should be equivalent up to roundoff error.
99-
dtc_naive = dtc(VFE(f_naive(z_naive)), fx_naive, y)
99+
dtc_naive = approx_log_evidence(DTC(f_naive(z_naive)), fx_naive, y)
100100
dtc_sde = dtc(fx, y, z_r)
101101
@test dtc_naive dtc_sde rtol=1e-6
102102

@@ -150,7 +150,7 @@ include("../models/model_test_utils.jl")
150150
fx_naive = f_naive(naive_inputs_missings, 0.1)
151151

152152
# Compute DTC using both approaches.
153-
dtc_naive = dtc(VFE(f_naive(z_naive)), fx_naive, naive_y_missings)
153+
dtc_naive = approx_log_evidence(DTC(f_naive(z_naive)), fx_naive, naive_y_missings)
154154
dtc_sde = dtc(fx, y_missing, z_r)
155155
@test dtc_naive dtc_sde rtol=1e-7 atol=1e-7
156156

test/space_time/regular_in_time.jl

+3
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,7 @@ using TemporalGPs: RegularInTime
1010
@test prod(size(x)) == length(collect(x))
1111

1212
@test all([getindex(x, n) for n in 1:length(x)] .== collect(x))
13+
@test_throws BoundsError x[0]
14+
@test_throws BoundsError x[-1]
15+
@test_throws BoundsError x[length(x) + 1]
1316
end

0 commit comments

Comments
 (0)