forked from FluxML/Flux.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.jl
188 lines (154 loc) · 5.53 KB
/
train.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient, withgradient
# Add methods to Optimisers.jl's function, so that there is just one Flux.update!
# for both explicit and implicit parameters.
import Optimisers.update!
"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
The gradient could be mutated as well.
!!! note
This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14.
The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
"""
function update!(opt::AbstractOptimiser, x::AbstractArray, x̄)
x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not
# safe due to aliasing, nor guaranteed to be possible, e.g. Fill.
x .-= apply!(opt, x, x̄r)
end
function update!(opt::AbstractOptimiser, xs::Params, gs)
for x in xs
isnothing(gs[x]) && continue
update!(opt, x, gs[x])
end
end
# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
struct SkipException <: Exception end
"""
skip()
Call `Flux.skip()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to skip the current data point and not update with the calculated gradient.
!!! note
`Flux.skip()` will be removed from Flux 0.14
# Examples
```julia
cb = function ()
loss() > 1e7 && Flux.skip()
end
```
"""
function skip()
Base.depwarn("""Flux.skip() will be removed from Flux 0.14.
and should be replaced with `continue` in an ordinary `for` loop.""", :skip)
throw(SkipException())
end
struct StopException <: Exception end
"""
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to stop and exit.
!!! note
`Flux.stop()` will be removed from Flux 0.14. It should be replaced with `break` in an ordinary `for` loop.
# Examples
```julia
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```
"""
function stop()
Base.depwarn("""Flux.stop() will be removed from Flux 0.14.
It should be replaced with `break` in an ordinary `for` loop.""", :stop)
throw(StopException())
end
batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x
"""
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
Uses a `loss` function and training `data` to improve the
model's parameters according to a particular optimisation rule `opt`.
!!! note
This method with implicit `Params` will be removed from Flux 0.14.
It should be replaced with the explicit method `train!(loss, model, data, opt)`.
For each `d in data`, first the gradient of the `loss` is computed like this:
```
gradient(() -> loss(d...), pars) # if d isa Tuple
gradient(() -> loss(d), pars) # otherwise
```
Here `pars` is produced by calling [`Flux.params`](@ref) on your model.
(Or just on the layers you want to train, like `train!(loss, params(model[1:end-2]), data, opt)`.)
This is the "implicit" style of parameter handling.
This gradient is then used by optimizer `opt` to update the parameters:
```
update!(opt, pars, grads)
```
The optimiser should be from the `Flux.Optimise` module (see [Optimisers](@ref)).
Different optimisers can be combined using [`Flux.Optimise.Optimiser`](@ref Flux.Optimiser).
This training loop iterates through `data` once.
It will stop with a `DomainError` if the loss is `NaN` or infinite.
You can use [`@epochs`](@ref) to do this several times, or
use for instance `Itertools.ncycle` to make a longer `data` iterator.
## Callbacks
[Callbacks](@ref) are given with the keyword argument `cb`.
For example, this will print "training" every 10 seconds (using [`Flux.throttle`](@ref)):
```
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
```
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
Multiple callbacks can be passed to `cb` as array.
"""
function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
cb = runall(cb)
itrsz = Base.IteratorSize(typeof(data))
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
@withprogress for (i, d) in enumerate(data)
try
l, gs = withgradient(ps) do
loss(batchmemaybe(d)...)
end
if !isfinite(l)
throw(DomainError("Loss is $l on data item $i, stopping training"))
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
break
elseif ex isa SkipException
continue
else
rethrow(ex)
end
end
@logprogress iszero(n) ? nothing : i / n
end
end
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
!!! note
The macro `@epochs` will be removed from Flux 0.14. Please just write an ordinary `for` loop.
# Examples
```julia
julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
hello
[ Info: Epoch 2
hello
```
"""
macro epochs(n, ex)
Base.depwarn("""The macro `@epochs` will be removed from Flux 0.14.
As an alternative, you can write a simple `for i in 1:epochs` loop.""", Symbol("@epochs"), force=true)
:(@progress for i = 1:$(esc(n))
@info "Epoch $i"
$(esc(ex))
end)
end