Skip to content

Commit b8c6192

Browse files
committed
use _old_to_new in Optimisers.setup too
1 parent a52463f commit b8c6192

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

src/deprecations.jl

+26-15
Original file line numberDiff line numberDiff line change
@@ -86,29 +86,34 @@ Base.@deprecate_binding ADADelta AdaDelta
8686
#=
8787
# Valid method in Optimise, old implicit style, is:
8888
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
89+
8990
# Valid methods in Train, new explict style, are:
90-
train!(loss, model, data, opt)
91-
train!(loss, model, data, opt::Optimisers.AbstractRule)
91+
train!(loss, model, data, opt) # preferred
92+
train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup
93+
9294
# Provide friendly errors for what happens if you mix these up:
9395
=#
9496
import .Optimise: train!
95-
train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")
9697

97-
train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
98+
train!(loss, ps::Params, data, opt) = error(
99+
"""can't mix implict Params with explict state!
100+
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
101+
But better to use the new explicit style, in which `m` itself is the 2nd argument.
102+
""")
98103

99-
train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
104+
train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error(
105+
"""can't mix implict Params with explict rule from Optimisers.jl
106+
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-mo$
107+
But better to use the new explicit style, in which `m` itself is the 2nd argument.
108+
""")
100109

101-
# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
102-
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
103-
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
104-
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
105-
# where `loss_mxy` accepts the model as its first argument.
106-
# """
107-
# ))
110+
train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
108111

109-
# Next, to use the new `setup` with the still-exported old-style Adam etc:
112+
# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
110113
import .Train: setup
111114
setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
115+
# ... and allow accidental use of `Optimisers.setup` to do the same:
116+
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
112117

113118
for T in [:Descent, :Adam, :Momentum, :Nesterov,
114119
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
@@ -129,10 +134,16 @@ _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon
129134

130135
_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")
131136

132-
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = error("please use Flux.setup not Optimisers.setup, it may be able to translate this rule")
133-
134137
# v0.14 deprecations
135138

136139
# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
137140
# Base.@deprecate_binding Optimiser OptimiserChain
138141
# Base.@deprecate_binding ClipValue ClipGrad
142+
143+
# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
144+
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
145+
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
146+
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
147+
# where `loss_mxy` accepts the model as its first argument.
148+
# """
149+
# ))

0 commit comments

Comments
 (0)