-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Centred RMSProp #51
Centred RMSProp #51
Conversation
I have a bit of a hard time believing the no tuning required part...we could always make a kwarg constructor if it gets annoying. |
Maybe a keyword is best, it's different from the others and 4 positional arguments is a lot. Probably FluxML/Flux.jl#1778 should match this. Needs a few words before merging. |
Why FluxML/Flux.jl#1778 is using Besides this point, is there anything else holding this from merge? |
All that holds this back is needing a sentence or two saying what this option actually does. |
Honestly
I could write it. I'd probably also change how the rule is implemented, since it should look very similar to AdaBelief. I think it's easier if I make another PR? |
Note the package name, with an "s" --- it does not follow US spellings, although the keyword here accepts both. Maybe make suggestions if you have ideas for what to change. |
Besides matters of taste on naming things, this scatter of suggestions makes some logic changes. Can you please write these (and these alone) clearly in one place with an explanation of what & why? Before/after. What source is this following? Etc. Is there a paper with clear formulas? Make it easy for when someone has bandwidth to look closely. |
Ok I agree, that's why I suggested a separate PR. The naming is of course not essential, but the logic changes amounts to just that now it's carrying an estimate of the variance of the gradient directly instead of the second moment and then subtracting the squared mean. As far as I understand that's how it's implemented by Jax, https://github.com/deepmind/optax/blob/b4aa6657bbf79985279dea76eaf6d53b25d7e8d9/optax/_src/transform.py#L247. I can make a separate PR since I think that'll make it easier to compare. |
I don't think so. Centered RMSprop was introduced by http://arxiv.org/abs/1308.0850 without discussing details and without giving the implementation. There he gives the formulas: which coincide with the current status of this PR (notation: epsilon = minibatch gradient, n_i = gradient second moment, g_i gradient first moment). However in Jax they implement it by estimating the variance of the gradient directly. This is the same AdaBelief does: So what I was proposing amounts to setting beta1 = 0 in this AdaBelief pseudo-code. You also need to put epsilon inside the square root because the variance can get numerically negative. Sorry for the noise with all the scattered suggestions. However I'm not sure if this is actually better in practice. |
Actually, looking at Optax code more closely, I now think they do this difference of squares thing instead of estimating the variance directly. See https://github.com/deepmind/optax/blob/a124552d0fc9f81812cd82da0d22528b7a17a847/optax/_src/transform.py#L247. So then that's probably the way to go here too. I have removed my previous suggestions. |
I re-added the suggestions for the name ( A suggestion for the docstring (I cannot add this as a Github suggestion bc it's not in the PR):
Perhaps we can merge this? |
@mcabbott saw you added a couple changes. Is there anything left on the docket or is this good to go? |
Maybe the only Q is whether you give the constructor a verb "centre this" or an adjective "make the centred version". |
PyTorch and TF both use the adjective form, let's go with that. |
Done. But what is wrong with the tests? (Locally fine now.) It's getting IRTools v0.3.3, Zygote v0.4.20, maybe because of Compat v4.1.0 |
Parallel to FluxML/Flux.jl#1778.
But if "Parameters other than learning rate generally don't need tuning", then having to type them out to get to the boolean one seems awkward. Cleaner to call it a new optimiser?