Skip to content

Update references to JAX's GitHub repo #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ One important difference between `tree_math` and `jax.numpy` is that dot
products in `tree_math` default to full precision on all platforms, rather
than defaulting to bfloat16 precision on TPUs. This is useful for writing most
numerical algorithms, and will likely be JAX's default behavior
[in the future](https://github.com/google/jax/pull/7859).
[in the future](https://github.com/jax-ml/jax/pull/7859).

It would be nice to have a `Matrix` class to make it possible to use tree-math
for numerical algorithms such as
Expand All @@ -86,7 +86,7 @@ feature, please comment on [this GitHub issue](https://github.com/google/tree-ma
Here is how we could write the preconditioned conjugate gradient
method. Notice how similar the implementation is to the [pseudocode from
Wikipedia](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method),
unlike the [implementation in JAX](https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121).
unlike the [implementation in JAX](https://github.com/jax-ml/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121).
Both versions support arbitrary pytrees as input:

```python
Expand Down
2 changes: 1 addition & 1 deletion tree_math/_src/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def dot(left, right, *, precision="highest"):
Note that unlike jax.numpy.dot, tree_math.dot defaults to full (highest)
precision. This is more useful for numerical algorithms and will be the
default for jax.numpy in the future:
https://github.com/google/jax/pull/7859
https://github.com/jax-ml/jax/pull/7859

Args:
left: left argument.
Expand Down
Loading