@@ -6,36 +6,41 @@ using Functors: fmap
6
6
7
7
import .. Flux. Optimise: train!, update! # during 0.13, we add methods to the old functions
8
8
9
- export setup, @train_autodiff
9
+ export setup, train!
10
10
11
11
using ProgressLogging: @progress , @withprogress , @logprogress
12
12
using Zygote: Zygote, Params
13
13
14
14
"""
15
15
opt = setup(rule, model)
16
16
17
- This is a version of `Optimisers.setup`, and is the first step before using `train!`.
17
+ This is a version of `Optimisers.setup`, and is the first step before using [ `train!`](@ref Flux.train!) .
18
18
It differs from `Optimisers.setup` in that it:
19
19
* has one extra check for mutability
20
20
* has methods which accept Flux's old optimisers, and convert them.
21
21
22
+ # Example
22
23
```jldoctest
23
24
julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32);
24
25
25
- julia> opt = Flux.setup(Momentum(0.11 ), model)
26
- (weight = Leaf(Momentum{Float64}(0.11 , 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.11 , 0.9), Float32[0.0]), σ = ())
26
+ julia> opt = Flux.setup(Momentum(0.1 ), model) # this encodes the optimiser and its state
27
+ (weight = Leaf(Momentum{Float64}(0.1 , 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1 , 0.9), Float32[0.0]), σ = ())
27
28
28
- julia> Flux.train!(model, opt) do m # 3-arg train!, for one data point (x = [0.2, -0.3], y = [0.4])
29
- sum(m([0.2, -0.3]) .- [0.4]) * 100
29
+ julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps:
30
+
31
+ julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y
32
+ sum(abs.(m(x) .- y)) * 100
30
33
end
31
- -40.1
34
+ 2-element Vector{Float32}:
35
+ 40.1
36
+ 38.7
32
37
33
38
julia> model.bias # was zero, mutated by Flux.train!
34
39
1-element Vector{Float32}:
35
- -0.11
40
+ 10.190001
36
41
37
42
julia> opt # mutated by Flux.train!
38
- (weight = Leaf(Momentum{Float64}(0.11 , 0.9), Float32[0.022 -0.033 ]), bias = Leaf(Momentum{Float64}(0.11 , 0.9), Float32[0.11 ]), σ = ())
43
+ (weight = Leaf(Momentum{Float64}(0.1 , 0.9), Float32[-2.018 3.027 ]), bias = Leaf(Momentum{Float64}(0.1 , 0.9), Float32[-10.09 ]), σ = ())
39
44
```
40
45
"""
41
46
function setup (rule:: Optimisers.AbstractRule , model)
51
56
train!(loss, model, data, opt)
52
57
53
58
Uses a `loss` function and training `data` to improve the `model`'s parameters
54
- according to a particular optimisation rule `opt`.
55
-
56
- !!! note
57
- This method has significant changes from the one in Flux ≤ 0.13:
58
- * It now takes the `model` itself, not the result of [`Flux.params`](@ref).
59
- (This is to move away from Zygote's implicit parameter handling.)
60
- * Instead of `loss` being a function which typically accepts two arguments
61
- (the input `x` and expected output `y` from each element of `data`)
62
- now it should typically accept three, the first of which is the `model` itself.
63
- * `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
64
- * `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not.
65
- * Callback functions are not supported.
59
+ according to a particular optimisation rule `opt`. Iterates through `data` once,
60
+ evaluating `loss(model, d...)` for each `d` in data.
66
61
67
62
For example, with these definitions...
68
63
```
@@ -72,15 +67,17 @@ loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
72
67
73
68
opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
74
69
```
75
- ...calling `train!(loss3, model, data, opt)` runs a loop much like this:
70
+ ...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this,
71
+ using Zygote's "explicit" mode for the gradient:
76
72
```
77
73
for d in data
78
- ∂L∂m = Zygote. gradient(loss3, model, d...)[1]
79
- Optimisers. update!(opt, model, ∂L∂m)
74
+ ∂L∂m = gradient(loss3, model, d...)[1]
75
+ update!(opt, model, ∂L∂m) # method for "explicit" gradient
80
76
end
81
77
```
82
78
You can also write this loop yourself, if you need more flexibility.
83
- Besides the loop, `train!` will:
79
+ For this reason `train!` is not highly extensible.
80
+ It adds only a few featurs to the loop above:
84
81
85
82
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
86
83
@@ -91,20 +88,36 @@ Besides the loop, `train!` will:
91
88
Note that the built-in loss functions accept 3 arguments, allowing for instance
92
89
`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
93
90
94
- Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
91
+ !!! note
92
+ This method has significant changes from the one in Flux ≤ 0.13:
93
+ * It now takes the `model` itself, not the result of [`Flux.params`](@ref).
94
+ (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.)
95
+ * Instead of `loss` being a function which typically accepts two arguments
96
+ (the input `x` and expected output `y` from each element of `data`)
97
+ now it should typically accept three, the first of which is the `model` itself.
98
+ * `data` must iterate tuples, otherwise you get an error.
99
+ (Previously non-tuple types were not splatted into the loss.
100
+ Pass in `((d,) for d in data)` to simulate this.)
101
+ * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
102
+ such as `Adam()` without this step should give you a warning.
103
+ * Callback functions are not supported.
104
+ But any code can be included in the above `for` loop.
95
105
"""
96
- function train! (loss, model, data, opt)
106
+ function train! (loss, model, data, opt; cb = nothing )
107
+ isnothing (cb) || error (""" train! does not support callback functions.
108
+ For more control use a loop with `gradient` and `update!`.""" )
97
109
losses = Float32[]
98
110
@withprogress for (i,d) in enumerate (data)
99
111
d isa Tuple || error (""" train! expects as data an iterator producing tuples, but got $(typeof (d)) .
100
112
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""" )
101
- l, (g, _... ) = explicit_withgradient (loss, model, d... )
113
+ # l, (g, _...) = explicit_withgradient(loss, model, d...) # BTW this un-thunks gradient w.r.t. data. Could avoid that
114
+ l, (g, _... ) = explicit_withgradient (m -> loss (m, d... ), model)
102
115
isfinite (l) || throw (DomainError (" loss function returned $l , stopping training" ))
103
116
opt, model = Optimisers. update! (opt, model, g)
104
117
push! (losses, l)
105
118
@logprogress Base. haslength (data) ? i/ length (data) : nothing
106
119
end
107
- return losses # Not entirely sure returning losses is a good idea
120
+ return losses # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl
108
121
end
109
122
110
123
# This method let you use Optimisers.Descent() without setup, when there is no state
0 commit comments