Skip to content

Commit c6d4697

Browse files
jakeharmon8tree-math authors
authored and
tree-math authors
committed
Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886847
1 parent f08636e commit c6d4697

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

Diff for: README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ One important difference between `tree_math` and `jax.numpy` is that dot
7373
products in `tree_math` default to full precision on all platforms, rather
7474
than defaulting to bfloat16 precision on TPUs. This is useful for writing most
7575
numerical algorithms, and will likely be JAX's default behavior
76-
[in the future](https://github.com/google/jax/pull/7859).
76+
[in the future](https://github.com/jax-ml/jax/pull/7859).
7777

7878
It would be nice to have a `Matrix` class to make it possible to use tree-math
7979
for numerical algorithms such as
@@ -86,7 +86,7 @@ feature, please comment on [this GitHub issue](https://github.com/google/tree-ma
8686
Here is how we could write the preconditioned conjugate gradient
8787
method. Notice how similar the implementation is to the [pseudocode from
8888
Wikipedia](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method),
89-
unlike the [implementation in JAX](https://github.com/google/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121).
89+
unlike the [implementation in JAX](https://github.com/jax-ml/jax/blob/b5aea7bc2da4fb5ef96c87a59bfd1486d8958dd7/jax/_src/scipy/sparse/linalg.py#L111-L121).
9090
Both versions support arbitrary pytrees as input:
9191

9292
```python

Diff for: tree_math/_src/vector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def dot(left, right, *, precision="highest"):
104104
Note that unlike jax.numpy.dot, tree_math.dot defaults to full (highest)
105105
precision. This is more useful for numerical algorithms and will be the
106106
default for jax.numpy in the future:
107-
https://github.com/google/jax/pull/7859
107+
https://github.com/jax-ml/jax/pull/7859
108108
109109
Args:
110110
left: left argument.

0 commit comments

Comments
 (0)