From a30b3c9872fa809ecc0708905892673782f664a7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 23 May 2022 12:53:06 -0400 Subject: [PATCH] add a d to keyword --- src/rules.jl | 18 +++++++++--------- test/rules.jl | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 872f1819..0ec856f0 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -84,7 +84,7 @@ function apply!(o::Nesterov, state, x, dx) end """ - RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centre = false) + RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred = false) Optimizer using the [RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) @@ -101,26 +101,26 @@ gradients by an estimate their variance, instead of their second moment. prominent direction, in effect dampening oscillations. - Machine epsilon (`ϵ`): Constant to prevent division by zero (no need to change default) -- Keyword `centre` (or `center`): Indicates whether to use centred variant - of the algorithm. +- Keyword `centred` (or `centered`): Indicates whether to use centred variant + of the algorithm. """ struct RMSProp{T} eta::T rho::T epsilon::T - centre::Bool + centred::Bool end -RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centre = false, center = false) = - RMSProp{typeof(η)}(η, ρ, ϵ, centre | center) +RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred::Bool = false, centered::Bool = false) = + RMSProp{typeof(η)}(η, ρ, ϵ, centred | centered) -init(o::RMSProp, x::AbstractArray) = (zero(x), o.centre ? zero(x) : false) +init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false) function apply!(o::RMSProp, state, x, dx) η, ρ, ϵ = o.eta, o.rho, o.epsilon quad, lin = state @.. quad = ρ * quad + (1 - ρ) * abs2(dx) - if o.centre + if o.centred @.. lin = ρ * lin + (1 - ρ) * dx end dx′ = @lazy dx * η / (sqrt(quad - abs2(lin)) + ϵ) @@ -132,7 +132,7 @@ function Base.show(io::IO, o::RMSProp) show(io, typeof(o)) print(io, "(") join(io, [o.eta, o.rho, o.epsilon], ", ") - print(io, "; centre = ", o.centre, ")") + print(io, "; centred = ", o.centred, ")") end """ diff --git a/test/rules.jl b/test/rules.jl index f9652a31..1a3835d1 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -22,7 +22,7 @@ RULES = [ name(o) = typeof(o).name.name # just for printing testset headings name(o::OptimiserChain) = join(name.(o.opts), " → ") -name(o::RMSProp) = o.centre ? "RMSProp(centre = true)" : :RMSProp +name(o::RMSProp) = o.centred ? "RMSProp(centred = true)" : "RMSProp" LOG = Dict() # for debugging these testsets, this makes it easy to plot each optimiser's loss