From 2bf9e607b68e24b4e5dd842b2184893f3ff0ecc2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 23 May 2022 13:48:54 -0400 Subject: [PATCH] Centred RMSProp (#51) * Centred RMSProp * try making this a keyword * description, show * add a d to keyword * fixup * require recent Zygote --- Project.toml | 1 + src/rules.jl | 32 +++++++++++++++++++++++++------- test/rules.jl | 3 +++ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 73fe6fba..ca0011a4 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.2.8" +Zygote = "0.6.40" julia = "1.6" [extras] diff --git a/src/rules.jl b/src/rules.jl index 80bf1bb1..0ec856f0 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -84,13 +84,16 @@ function apply!(o::Nesterov, state, x, dx) end """ - RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η))) + RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred = false) Optimizer using the [RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) algorithm. Often a good choice for recurrent networks. Parameters other than learning rate generally don't need tuning. +[Centred RMSProp](http://arxiv.org/abs/1308.08500) is a variant which normalises +gradients by an estimate their variance, instead of their second moment. + # Parameters - Learning rate (`η`): Amount by which gradients are discounted before updating the weights. @@ -98,23 +101,38 @@ generally don't need tuning. prominent direction, in effect dampening oscillations. - Machine epsilon (`ϵ`): Constant to prevent division by zero (no need to change default) +- Keyword `centred` (or `centered`): Indicates whether to use centred variant + of the algorithm. """ struct RMSProp{T} eta::T rho::T epsilon::T + centred::Bool end -RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η))) = RMSProp{typeof(η)}(η, ρ, ϵ) +RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred::Bool = false, centered::Bool = false) = + RMSProp{typeof(η)}(η, ρ, ϵ, centred | centered) -init(o::RMSProp, x::AbstractArray) = zero(x) +init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false) function apply!(o::RMSProp, state, x, dx) - η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state + η, ρ, ϵ = o.eta, o.rho, o.epsilon + quad, lin = state - @.. acc = ρ * acc + (1 - ρ) * abs2(dx) - dx′ = @lazy dx * (η / (sqrt(acc) + ϵ)) + @.. quad = ρ * quad + (1 - ρ) * abs2(dx) + if o.centred + @.. lin = ρ * lin + (1 - ρ) * dx + end + dx′ = @lazy dx * η / (sqrt(quad - abs2(lin)) + ϵ) - return acc, dx′ + return (quad, lin), dx′ +end + +function Base.show(io::IO, o::RMSProp) + show(io, typeof(o)) + print(io, "(") + join(io, [o.eta, o.rho, o.epsilon], ", ") + print(io, "; centred = ", o.centred, ")") end """ diff --git a/test/rules.jl b/test/rules.jl index ae0e58e3..553f31e8 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -14,10 +14,13 @@ RULES = [ OptimiserChain(ClipNorm(), Adam(0.001)), OptimiserChain(ClipGrad(0.5), Momentum()), OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)), + # Not the default: + RMSProp(centred = true), ] name(o) = typeof(o).name.name # just for printing testset headings name(o::OptimiserChain) = join(name.(o.opts), " → ") +name(o::RMSProp) = o.centred ? "RMSProp(centred = true)" : :RMSProp LOG = Dict() # for debugging these testsets, this makes it easy to plot each optimiser's loss