Skip to content

Commit b3405d0

Browse files
authored
Remove Zygote as a dep (#129)
* Remove literal_getfield usage * Improve perf * Progress * Lots of changes * Add Test as test dep * Fix typo * Add Pkg to examples * Add Pkg to test deps * Require Mooncake 0-4-3 * Import more names * Remove Mooncake as direct dep * Formatting * Formatting * Tidy up + enable all tests * Enable all tests * Add JET as test dep' * Tidy up and use JET rather than inferred * Some fixes * Discuss the changes in this release * Figure out how to avoid bad gradients
1 parent 784dbad commit b3405d0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+571
-2862
lines changed

.github/workflows/ci.yml

-2
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@ jobs:
2020
matrix:
2121
version:
2222
- '1'
23-
- '1.6'
2423
os:
2524
- ubuntu-latest
2625
arch:
2726
- x64
2827
group:
2928
- 'test util'
3029
- 'test models'
31-
- 'test models-lgssm'
3230
- 'test gp'
3331
- 'test space_time'
3432
steps:

NEWS.md

+15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# 0.7
2+
3+
Mooncake.jl (and probably Enzyme.jl) is now able to differentiate everything in
4+
TemporalGPs.jl _reasonably_ efficiently, and only requires a single rule (for time_exp).
5+
This is in stark contrast with Zygote.jl, which required roughly 2.5k lines to achieve
6+
reasonable performance. This code was not robust, required maintenance from time-to-time,
7+
and generally made making progress on improvements to this library hard to make.
8+
Consequently, in this version of TemporalGPs, we have removed all Zygote-related
9+
functionality, and now recommend that Mooncake.jl (or perhaps Enzyme.jl) is used to
10+
differentiate code in this package. In some places Mooncake.jl achieves worse performance
11+
than Zygote.jl, but it is worth it for the amount of code that has been removed.
12+
13+
If you wish to use Zygote + TemporalGPs, you should restrict yourself to the 0.6 series of
14+
this package.
15+
116
# 0.5.12
217

318
- A collection of examples of inference, and inference + learning, have been added.

Project.toml

+21-6
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,44 @@
11
name = "TemporalGPs"
22
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
3-
authors = ["willtebbutt <[email protected]> and contributors"]
4-
version = "0.6.8"
3+
authors = ["Will Tebbutt and contributors"]
4+
version = "0.7.0"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
88
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
99
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
10-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1110
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1211
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
1312
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1413
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1514
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1615
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
17-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
16+
17+
[weakdeps]
18+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
19+
20+
[extensions]
21+
TemporalGPsMooncakeExt = "Mooncake"
1822

1923
[compat]
2024
AbstractGPs = "0.5.17"
25+
BenchmarkTools = "1"
2126
Bessels = "0.2.8"
2227
BlockDiagonals = "0.1.7"
23-
ChainRulesCore = "1"
2428
FillArrays = "0.13.0 - 0.13.7, 1"
29+
JET = "0.9"
2530
KernelFunctions = "0.9, 0.10.1"
31+
Mooncake = "0.4.3"
2632
StaticArrays = "1"
2733
StructArrays = "0.5, 0.6"
28-
Zygote = "0.6.65"
2934
julia = "1.6"
35+
36+
[extras]
37+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
38+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
39+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
40+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
41+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
42+
43+
[targets]
44+
test = ["BenchmarkTools", "JET", "Mooncake", "Pkg", "Test"]

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ f = to_sde(f_naive, SArrayStorage(Float64))
4141

4242
# Project onto finite-dimensional distribution as usual.
4343
# x = range(-5.0; step=0.1, length=10_000)
44-
x = RegularSpacing(0.0, 0.1, 10_000) # Hack for Zygote.
44+
x = RegularSpacing(0.0, 0.1, 10_000) # Hack for AD.
4545
fx = f(x, 0.1)
4646

4747
# Sample from the prior as usual.
@@ -63,7 +63,7 @@ rand(f_post(x))
6363
logpdf(f_post(x), y)
6464
```
6565

66-
## Learning kernel parameters with [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl), and [Zygote.jl](https://github.com/FluxML/Zygote.jl/)
66+
## Learning kernel parameters with [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl), and [Mooncake.jl](https://github.com/compintell/Mooncake.jl/)
6767

6868
TemporalGPs.jl doesn't provide scikit-learn-like functionality to train your model (find good kernel parameter settings).
6969
Instead, we offer the functionality needed to easily implement your own training functionality using standard tools from the Julia ecosystem, as shown below.
@@ -76,7 +76,7 @@ using TemporalGPs
7676
# Load standard packages from the Julia ecosystem
7777
using Optim # Standard optimisation algorithms.
7878
using ParameterHandling # Helper functionality for dealing with model parameters.
79-
using Zygote # Algorithmic Differentiation
79+
using Mooncake # Algorithmic Differentiation
8080

8181
using ParameterHandling: flatten
8282

@@ -115,7 +115,7 @@ objective(params)
115115
# Optim.jl for more info on available optimisers and their properties.
116116
training_results = Optim.optimize(
117117
objective unpack,
118-
θ -> only(Zygote.gradient(objective unpack, θ)),
118+
θ -> only(Mooncake.gradient(objective unpack, θ)),
119119
flat_initial_params + randn(3), # Add some noise to make learning non-trivial
120120
BFGS(
121121
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
@@ -152,7 +152,7 @@ This tells TemporalGPs that you want all parameters of `f` and anything derived
152152

153153
"naive" timings are with the usual [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.
154154

155-
Gradient computations use Zygote. Custom adjoints have been implemented to achieve this level of performance.
155+
Gradient computations use Mooncake. Custom adjoints have been implemented to achieve this level of performance.
156156

157157

158158

0 commit comments

Comments
 (0)