Skip to content

Add documentation on how to use PyType with Jax (and also common add-on libraries such as Flax) #8224

Open
@billmark

Description

@billmark

There are a bunch of tricks that one needs to know to use PyType with JAX (esp. also in combination with Flax). For example, a PyTree needs to be treated as "Any".

Since it's very common to want to use PyType with JAX, it would be useful to have a section of the Jax documentation summarizing these tricks and best practices. I'm not sure what the best way is to handle the Jax/Flax interactions but it's important for someone to figure out how to document those best practices too.

Metadata

Metadata

Assignees

No one assigned

    Labels

    contributions welcomeThe JAX team has not prioritized work on this. Community contributions are welcome.documentation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions