Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass value arg to optax, allowing use of reduce_on_plateau #1974

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

zmbc
Copy link

@zmbc zmbc commented Feb 12, 2025

Addresses #1955 using the approach from this comment: #1955 (comment)

@zmbc
Copy link
Author

zmbc commented Feb 12, 2025

@fehiepsi Let me know what you think! I haven't tested this code, and I might need some help with adding types as well.

numpyro/optim.py Outdated
@@ -76,7 +76,7 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
:return: new optimizer state after the update.
"""
i, opt_state = state
opt_state = self.update_fn(i, g, opt_state)
opt_state = self.update_fn(i, g, opt_state, value=value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can introduce an attribute update_with_value (false by default) to control the behavior here

if self.update_with_value=True:
    opt_state = self.update_fn(i, g, opt_state, value)
else:
    opt_state = self.update_fn(i, g, opt_state)

you can then use the typing

self.update_fn: Union[Callable[[ArrayLike, _Params, _OptState], _OptState], Callable[[ArrayLike, _Params, _OptState, ArrayLike], _OptState]]

In optax optimizer, you can set it to True

numpyro_optim = _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)
numpyro_optim.update_with_value = True
return numpyro_optim

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I can't easily get this typing to pass mypy. I'd need to use some kind of custom TypeGuard to narrow this Union type down to the specific Callable signature at runtime before it is called. Do you want me to do that? Or would you rather I changed this to simply Callable[..., _OptState]?

@zmbc
Copy link
Author

zmbc commented Feb 20, 2025

@fehiepsi thanks for your comments, want to take another look?

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a simple test for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants