Skip to content

feat(gh-299): Type hints in distributions modules #2032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 3, 2025

Conversation

Qazalbash
Copy link
Contributor

@Qazalbash Qazalbash commented May 24, 2025

Hi,

I have added type hints using jaxtyping in numpyro.distributions.*.py modules. I have accordingly updated the setup.py too.

All types of protocols have been transferred to numpyro._typing. I have modified the DistributionLike type along with two new types, TransformLike and ConstraintLike.


This PR is related to #299.

Qazalbash added 4 commits May 23, 2025 22:33
…tion

- Added type annotations to various methods and properties in the Distribution class and its subclasses for better type checking and clarity.
- Updated the `enable_validation` and `validation_enabled` functions to use type hints.
- Enhanced the `arg_constraints`, `support`, and other attributes with appropriate types.
- Modified the `__init__` methods of several distribution classes to include type hints for parameters.
- Improved the `log_prob`, `sample`, and other methods to specify return types.
- Refactored the `clamp_probs` function in the util module to include type annotations.
- Updated the `LeftTruncatedDistribution`, `RightTruncatedDistribution`, and `TwoSidedTruncatedDistribution` classes to include type hints for parameters and return types.
- Introduced a new `ConstraintLike` protocol in `numpyro/typing.py` to standardize constraint types across distributions.
- Added type hints for the `TruncatedDistribution`, `TruncatedCauchy`, `TruncatedNormal`, `TruncatedPolyaGamma`, `DoublyTruncatedPowerLaw`, and `LowerTruncatedPowerLaw` classes.
- Improved type safety and clarity in the distribution methods by specifying expected types for inputs and outputs.
@fehiepsi
Copy link
Member

Just curious what jaxtyping offers?

@Qazalbash
Copy link
Contributor Author

Qazalbash commented May 26, 2025

I prefer Jaxtyping over native types provided by JAX because of the array annotations. Jaxtyping documentation describes it as,

The shape and dtypes of arrays can be annotated in the form dtype[array, shape], such as Float[Array, "batch channels"].

It also has a good systematic way of typing PyTrees, along with variety of annotated types.


I have not utilized array annotations in this PR.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Qazalbash Could we not use jaxtyping? I feel that it is unnecessary to depend on it. Also we might want to allow sample key to be None for some deterministic distributions like Delta or the default Distribution, TransformedDistribution ones.

@Qazalbash
Copy link
Contributor Author

Sure, we can avoid it!

@Qazalbash
Copy link
Contributor Author

All test cases are passing except for those failing on the master branch. I have figured out the problem and reported it.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful! Thanks for putting lots of efforts on this, @Qazalbash!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending the usage of Optional[jax.dtypes.prng_key] at some places.

@fehiepsi fehiepsi merged commit 4c505c1 into pyro-ppl:master Jun 3, 2025
9 checks passed
@Qazalbash Qazalbash deleted the type-hint-distribution branch June 3, 2025 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants