Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Centred RMSProp #51

Merged
merged 6 commits into from
May 23, 2022
Merged

Centred RMSProp #51

merged 6 commits into from
May 23, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Feb 5, 2022

Parallel to FluxML/Flux.jl#1778.

But if "Parameters other than learning rate generally don't need tuning", then having to type them out to get to the boolean one seems awkward. Cleaner to call it a new optimiser?

@ToucheSir
Copy link
Member

I have a bit of a hard time believing the no tuning required part...we could always make a kwarg constructor if it gets annoying.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 5, 2022

Maybe a keyword is best, it's different from the others and 4 positional arguments is a lot. Probably FluxML/Flux.jl#1778 should match this.

Needs a few words before merging.

@cossio
Copy link
Contributor

cossio commented May 11, 2022

Why centre instead of centered?

FluxML/Flux.jl#1778 is using centered. PyTorch too, https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html.

Besides this point, is there anything else holding this from merge?

@mcabbott
Copy link
Member Author

centred would be fine.

All that holds this back is needing a sentence or two saying what this option actually does.

@cossio
Copy link
Contributor

cossio commented May 11, 2022

centred would be fine.

Honestly centered seems more common, see for instance https://proceedings.neurips.cc/paper/2021/hash/eddea82ad2755b24c4e168c5fc2ebd40-Abstract.html, http://arxiv.org/abs/2010.07468. And also the Flux PR.

All that holds this back is needing a sentence or two saying what this option actually does.

I could write it. I'd probably also change how the rule is implemented, since it should look very similar to AdaBelief. I think it's easier if I make another PR?

@mcabbott
Copy link
Member Author

Note the package name, with an "s" --- it does not follow US spellings, although the keyword here accepts both.

Maybe make suggestions if you have ideas for what to change.

@mcabbott
Copy link
Member Author

mcabbott commented May 11, 2022

Besides matters of taste on naming things, this scatter of suggestions makes some logic changes. Can you please write these (and these alone) clearly in one place with an explanation of what & why? Before/after. What source is this following? Etc.

Is there a paper with clear formulas?

Make it easy for when someone has bandwidth to look closely.

@cossio
Copy link
Contributor

cossio commented May 11, 2022

Besides matters of taste on naming things, this scatter of suggestions makes some logic changes. Can you please write these (and these alone) clearly in one place with an explanation of what & why? What source is this following? Etc.

Ok I agree, that's why I suggested a separate PR. The naming is of course not essential, but the logic changes amounts to just that now it's carrying an estimate of the variance of the gradient directly instead of the second moment and then subtracting the squared mean. As far as I understand that's how it's implemented by Jax, https://github.com/deepmind/optax/blob/b4aa6657bbf79985279dea76eaf6d53b25d7e8d9/optax/_src/transform.py#L247.

I can make a separate PR since I think that'll make it easier to compare.

@cossio
Copy link
Contributor

cossio commented May 11, 2022

Is there a paper we can follow with clear formulas?

I don't think so. Centered RMSprop was introduced by http://arxiv.org/abs/1308.0850 without discussing details and without giving the implementation. There he gives the formulas:

image

which coincide with the current status of this PR (notation: epsilon = minibatch gradient, n_i = gradient second moment, g_i gradient first moment).

However in Jax they implement it by estimating the variance of the gradient directly.

This is the same AdaBelief does:

image

So what I was proposing amounts to setting beta1 = 0 in this AdaBelief pseudo-code. You also need to put epsilon inside the square root because the variance can get numerically negative.

Sorry for the noise with all the scattered suggestions.

However I'm not sure if this is actually better in practice.

@cossio
Copy link
Contributor

cossio commented May 11, 2022

Actually, looking at Optax code more closely, I now think they do this difference of squares thing instead of estimating the variance directly. See https://github.com/deepmind/optax/blob/a124552d0fc9f81812cd82da0d22528b7a17a847/optax/_src/transform.py#L247.

So then that's probably the way to go here too. I have removed my previous suggestions.

src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
test/rules.jl Outdated Show resolved Hide resolved
test/rules.jl Outdated Show resolved Hide resolved
@cossio
Copy link
Contributor

cossio commented May 12, 2022

I re-added the suggestions for the name (centre -> centred), without the logic changes.

A suggestion for the docstring (I cannot add this as a Github suggestion bc it's not in the PR):

# Parameters
...
- Centred RMSProp (`centred`): if `false` (default), gradients are normalized by an estimation of their second moments; if `true`, normalizes by the gradient variance instead of the second moment (http://arxiv.org/abs/1308.08500).

Perhaps we can merge this?

@ToucheSir
Copy link
Member

@mcabbott saw you added a couple changes. Is there anything left on the docket or is this good to go?

@mcabbott
Copy link
Member Author

Maybe the only Q is whether you give the constructor a verb "centre this" or an adjective "make the centred version".

@ToucheSir
Copy link
Member

PyTorch and TF both use the adjective form, let's go with that.

@mcabbott
Copy link
Member Author

mcabbott commented May 23, 2022

Done. But what is wrong with the tests? (Locally fine now.)

It's getting IRTools v0.3.3, Zygote v0.4.20, maybe because of Compat v4.1.0

@mcabbott mcabbott merged commit 2bf9e60 into FluxML:master May 23, 2022
@mcabbott mcabbott deleted the centred branch May 23, 2022 17:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants