Skip to content

Commit f38628d

Browse files
committed
tidy up
1 parent 2994d74 commit f38628d

File tree

2 files changed

+43
-34
lines changed

2 files changed

+43
-34
lines changed

Diff for: Project.toml

+1-5
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ ProgressLogging = "0.1"
4141
Reexport = "0.2, 1.0"
4242
SpecialFunctions = "1.8.2, 2.1.2"
4343
StatsBase = "0.33"
44-
Tracker = "0.2.22"
45-
Yota = "0.8.1"
4644
Zygote = "0.6.34"
4745
julia = "1.6"
4846

@@ -52,9 +50,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
5250
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
5351
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
5452
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
55-
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
56-
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
5753
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5854

5955
[targets]
60-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker", "Yota"]
56+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]

Diff for: src/train.jl

+42-29
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,41 @@ using Functors: fmap
66

77
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions
88

9-
export setup, @train_autodiff
9+
export setup, train!
1010

1111
using ProgressLogging: @progress, @withprogress, @logprogress
1212
using Zygote: Zygote, Params
1313

1414
"""
1515
opt = setup(rule, model)
1616
17-
This is a version of `Optimisers.setup`, and is the first step before using `train!`.
17+
This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
1818
It differs from `Optimisers.setup` in that it:
1919
* has one extra check for mutability
2020
* has methods which accept Flux's old optimisers, and convert them.
2121
22+
# Example
2223
```jldoctest
2324
julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32);
2425
25-
julia> opt = Flux.setup(Momentum(0.11), model)
26-
(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0]), σ = ())
26+
julia> opt = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state
27+
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ())
2728
28-
julia> Flux.train!(model, opt) do m # 3-arg train!, for one data point (x = [0.2, -0.3], y = [0.4])
29-
sum(m([0.2, -0.3]) .- [0.4]) * 100
29+
julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:
30+
31+
julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y
32+
sum(abs.(m(x) .- y)) * 100
3033
end
31-
-40.1
34+
2-element Vector{Float32}:
35+
40.1
36+
38.7
3237
3338
julia> model.bias # was zero, mutated by Flux.train!
3439
1-element Vector{Float32}:
35-
-0.11
40+
10.190001
3641
3742
julia> opt # mutated by Flux.train!
38-
(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.022 -0.033]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.11]), σ = ())
43+
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ())
3944
```
4045
"""
4146
function setup(rule::Optimisers.AbstractRule, model)
@@ -51,18 +56,8 @@ end
5156
train!(loss, model, data, opt)
5257
5358
Uses a `loss` function and training `data` to improve the `model`'s parameters
54-
according to a particular optimisation rule `opt`.
55-
56-
!!! note
57-
This method has significant changes from the one in Flux ≤ 0.13:
58-
* It now takes the `model` itself, not the result of [`Flux.params`](@ref).
59-
(This is to move away from Zygote's implicit parameter handling.)
60-
* Instead of `loss` being a function which typically accepts two arguments
61-
(the input `x` and expected output `y` from each element of `data`)
62-
now it should typically accept three, the first of which is the `model` itself.
63-
* `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
64-
* `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not.
65-
* Callback functions are not supported.
59+
according to a particular optimisation rule `opt`. Iterates through `data` once,
60+
evaluating `loss(model, d...)` for each `d` in data.
6661
6762
For example, with these definitions...
6863
```
@@ -72,15 +67,17 @@ loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
7267
7368
opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
7469
```
75-
...calling `train!(loss3, model, data, opt)` runs a loop much like this:
70+
...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this,
71+
using Zygote's "explicit" mode for the gradient:
7672
```
7773
for d in data
78-
∂L∂m = Zygote.gradient(loss3, model, d...)[1]
79-
Optimisers.update!(opt, model, ∂L∂m)
74+
∂L∂m = gradient(loss3, model, d...)[1]
75+
update!(opt, model, ∂L∂m) # method for "explicit" gradient
8076
end
8177
```
8278
You can also write this loop yourself, if you need more flexibility.
83-
Besides the loop, `train!` will:
79+
For this reason `train!` is not highly extensible.
80+
It adds only a few featurs to the loop above:
8481
8582
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
8683
@@ -91,20 +88,36 @@ Besides the loop, `train!` will:
9188
Note that the built-in loss functions accept 3 arguments, allowing for instance
9289
`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
9390
94-
Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
91+
!!! note
92+
This method has significant changes from the one in Flux ≤ 0.13:
93+
* It now takes the `model` itself, not the result of [`Flux.params`](@ref).
94+
(This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
95+
* Instead of `loss` being a function which typically accepts two arguments
96+
(the input `x` and expected output `y` from each element of `data`)
97+
now it should typically accept three, the first of which is the `model` itself.
98+
* `data` must iterate tuples, otherwise you get an error.
99+
(Previously non-tuple types were not splatted into the loss.
100+
Pass in `((d,) for d in data)` to simulate this.)
101+
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
102+
such as `Adam()` without this step should give you a warning.
103+
* Callback functions are not supported.
104+
But any code can be included in the above `for` loop.
95105
"""
96-
function train!(loss, model, data, opt)
106+
function train!(loss, model, data, opt; cb = nothing)
107+
isnothing(cb) || error("""train! does not support callback functions.
108+
For more control use a loop with `gradient` and `update!`.""")
97109
losses = Float32[]
98110
@withprogress for (i,d) in enumerate(data)
99111
d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
100112
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
101-
l, (g, _...) = explicit_withgradient(loss, model, d...)
113+
# l, (g, _...) = explicit_withgradient(loss, model, d...) # BTW this un-thunks gradient w.r.t. data. Could avoid that
114+
l, (g, _...) = explicit_withgradient(m -> loss(m, d...), model)
102115
isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
103116
opt, model = Optimisers.update!(opt, model, g)
104117
push!(losses, l)
105118
@logprogress Base.haslength(data) ? i/length(data) : nothing
106119
end
107-
return losses # Not entirely sure returning losses is a good idea
120+
return losses # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl
108121
end
109122

110123
# This method let you use Optimisers.Descent() without setup, when there is no state

0 commit comments

Comments
 (0)