Skip to content

Commit

Permalink
add a d to keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed May 23, 2022
1 parent dbe3aa0 commit a30b3c9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
18 changes: 9 additions & 9 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function apply!(o::Nesterov, state, x, dx)
end

"""
RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centre = false)
RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred = false)
Optimizer using the
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Expand All @@ -101,26 +101,26 @@ gradients by an estimate their variance, instead of their second moment.
prominent direction, in effect dampening oscillations.
- Machine epsilon (`ϵ`): Constant to prevent division by zero
(no need to change default)
- Keyword `centre` (or `center`): Indicates whether to use centred variant
of the algorithm.
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
of the algorithm.
"""
struct RMSProp{T}
eta::T
rho::T
epsilon::T
centre::Bool
centred::Bool
end
RMSProp= 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centre = false, center = false) =
RMSProp{typeof(η)}(η, ρ, ϵ, centre | center)
RMSProp= 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred::Bool = false, centered::Bool = false) =
RMSProp{typeof(η)}(η, ρ, ϵ, centred | centered)

init(o::RMSProp, x::AbstractArray) = (zero(x), o.centre ? zero(x) : false)
init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false)

function apply!(o::RMSProp, state, x, dx)
η, ρ, ϵ = o.eta, o.rho, o.epsilon
quad, lin = state

@.. quad = ρ * quad + (1 - ρ) * abs2(dx)
if o.centre
if o.centred
@.. lin = ρ * lin + (1 - ρ) * dx
end
dx′ = @lazy dx * η / (sqrt(quad - abs2(lin)) + ϵ)
Expand All @@ -132,7 +132,7 @@ function Base.show(io::IO, o::RMSProp)
show(io, typeof(o))
print(io, "(")
join(io, [o.eta, o.rho, o.epsilon], ", ")
print(io, "; centre = ", o.centre, ")")
print(io, "; centred = ", o.centred, ")")
end

"""
Expand Down
2 changes: 1 addition & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ RULES = [

name(o) = typeof(o).name.name # just for printing testset headings
name(o::OptimiserChain) = join(name.(o.opts), "")
name(o::RMSProp) = o.centre ? "RMSProp(centre = true)" : :RMSProp
name(o::RMSProp) = o.centred ? "RMSProp(centred = true)" : "RMSProp"

LOG = Dict() # for debugging these testsets, this makes it easy to plot each optimiser's loss

Expand Down

0 comments on commit a30b3c9

Please sign in to comment.