Open
Description
There are a bunch of tricks that one needs to know to use PyType with JAX (esp. also in combination with Flax). For example, a PyTree needs to be treated as "Any".
Since it's very common to want to use PyType with JAX, it would be useful to have a section of the Jax documentation summarizing these tricks and best practices. I'm not sure what the best way is to handle the Jax/Flax interactions but it's important for someone to figure out how to document those best practices too.