-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass value arg to optax, allowing use of reduce_on_plateau #1974
base: master
Are you sure you want to change the base?
Conversation
@fehiepsi Let me know what you think! I haven't tested this code, and I might need some help with adding types as well. |
numpyro/optim.py
Outdated
@@ -76,7 +76,7 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState: | |||
:return: new optimizer state after the update. | |||
""" | |||
i, opt_state = state | |||
opt_state = self.update_fn(i, g, opt_state) | |||
opt_state = self.update_fn(i, g, opt_state, value=value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can introduce an attribute update_with_value
(false by default) to control the behavior here
if self.update_with_value=True:
opt_state = self.update_fn(i, g, opt_state, value)
else:
opt_state = self.update_fn(i, g, opt_state)
you can then use the typing
self.update_fn: Union[Callable[[ArrayLike, _Params, _OptState], _OptState], Callable[[ArrayLike, _Params, _OptState, ArrayLike], _OptState]]
In optax optimizer, you can set it to True
numpyro_optim = _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)
numpyro_optim.update_with_value = True
return numpyro_optim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi I can't easily get this typing to pass mypy. I'd need to use some kind of custom TypeGuard
to narrow this Union type down to the specific Callable signature at runtime before it is called. Do you want me to do that? Or would you rather I changed this to simply Callable[..., _OptState]
?
@fehiepsi thanks for your comments, want to take another look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a simple test for this?
Addresses #1955 using the approach from this comment: #1955 (comment)