-
Notifications
You must be signed in to change notification settings - Fork 259
feat(gh-299): Type hints in distributions modules #2032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…tion - Added type annotations to various methods and properties in the Distribution class and its subclasses for better type checking and clarity. - Updated the `enable_validation` and `validation_enabled` functions to use type hints. - Enhanced the `arg_constraints`, `support`, and other attributes with appropriate types. - Modified the `__init__` methods of several distribution classes to include type hints for parameters. - Improved the `log_prob`, `sample`, and other methods to specify return types. - Refactored the `clamp_probs` function in the util module to include type annotations.
- Updated the `LeftTruncatedDistribution`, `RightTruncatedDistribution`, and `TwoSidedTruncatedDistribution` classes to include type hints for parameters and return types. - Introduced a new `ConstraintLike` protocol in `numpyro/typing.py` to standardize constraint types across distributions. - Added type hints for the `TruncatedDistribution`, `TruncatedCauchy`, `TruncatedNormal`, `TruncatedPolyaGamma`, `DoublyTruncatedPowerLaw`, and `LowerTruncatedPowerLaw` classes. - Improved type safety and clarity in the distribution methods by specifying expected types for inputs and outputs.
Just curious what jaxtyping offers? |
I prefer Jaxtyping over native types provided by JAX because of the array annotations. Jaxtyping documentation describes it as,
It also has a good systematic way of typing PyTrees, along with variety of annotated types. I have not utilized array annotations in this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Qazalbash Could we not use jaxtyping? I feel that it is unnecessary to depend on it. Also we might want to allow sample key to be None for some deterministic distributions like Delta or the default Distribution, TransformedDistribution ones.
Sure, we can avoid it! |
All test cases are passing except for those failing on the master branch. I have figured out the problem and reported it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful! Thanks for putting lots of efforts on this, @Qazalbash!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending the usage of Optional[jax.dtypes.prng_key] at some places.
…ition" This reverts commit 51e2569.
Hi,
I have added type hints using
jaxtyping
innumpyro.distributions.*.py
modules. I have accordingly updated thesetup.py
too.All types of protocols have been transferred to
numpyro._typing
. I have modified theDistributionLike
type along with two new types,TransformLike
andConstraintLike
.This PR is related to #299.