Skip to content

Commit 80e1c94

Browse files
committed
Prepare for v0.4.33 release.
This release is branched off the v0.4.32 release, with two changes: a) a fixed libtpu pin, and b) a patch to revert an F64 tanh issue on CPU.
1 parent 1594d2f commit 80e1c94

File tree

5 files changed

+43
-6
lines changed

5 files changed

+43
-6
lines changed

CHANGELOG.md

+22-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,24 @@ Remember to align the itemized text with the first line of an item within a list
1010
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1111
-->
1212

13-
## jax 0.4.32
13+
## jax 0.4.33
14+
15+
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
16+
release.
17+
18+
A TPU-only data corruption bug was found in the version of libtpu pinned by
19+
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
20+
same job, for example, if training on multiple v5e slices.
21+
This release fixes that issue by pinning a fixed version of `libtpu`.
22+
23+
## jaxlib 0.4.33
24+
25+
This release fixes an inaccurate result for F64 tanh on CPU (#23590).
26+
27+
## jax 0.4.32 (September 11, 2024)
28+
29+
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
30+
See the 0.4.33 release notes for more details.
1431

1532
* New Functionality
1633
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
@@ -65,7 +82,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
6582
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
6683
another framework that implements the ``__dlpack__`` protocol.
6784

68-
## jaxlib 0.4.32
85+
## jaxlib 0.4.32 (September 11, 2024)
86+
87+
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
88+
See the 0.4.33 release notes for more details.
6989

7090
* Breaking changes
7191
* Hermetic CUDA support is added.

jax/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pathlib
2222
import subprocess
2323

24-
_version = "0.4.32"
24+
_version = "0.4.33"
2525
# The following line is overwritten by build scripts in distributions &
2626
# releases. Do not modify this manually, or jax/jaxlib build will fail.
2727
_release_version: str | None = None
@@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files):
133133

134134

135135
__version__ = _get_version_string()
136-
_minimum_jaxlib_version = "0.4.32"
136+
_minimum_jaxlib_version = "0.4.33"
137137

138138
def _version_as_tuple(version_str):
139139
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
project_name = 'jax'
2121

22-
_current_jaxlib_version = '0.4.32'
22+
_current_jaxlib_version = '0.4.33'
2323
# The following should be updated after each new jaxlib release.
2424
_latest_jaxlib_version_on_pypi = '0.4.31'
25-
_libtpu_version = '0.1.dev20240911'
25+
_libtpu_version = '0.1.dev20240916'
2626

2727
def load_version_module(pkg_path):
2828
spec = importlib.util.spec_from_file_location(

third_party/xla/tanh.patch

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
diff --git a/xla/service/cpu/llvm_ir_runtime.cc b/xla/service/cpu/llvm_ir_runtime.cc
2+
index 89b40b915caa3..25541c16bfd61 100644
3+
--- a/xla/service/cpu/llvm_ir_runtime.cc
4+
+++ b/xla/service/cpu/llvm_ir_runtime.cc
5+
@@ -410,7 +410,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module,
6+
rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
7+
rewrite_calls(kTanhV16F32SymbolName, GenerateVF32Tanh, /*vector_width=*/16);
8+
9+
- rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1);
10+
+ // TODO(penporn): Re-enable after fixing JAX issue #23590.
11+
+ // rewrite_calls("tanh", GenerateVF64Tanh, /*vector_width=*/1);
12+
13+
rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
14+
rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);

third_party/xla/workspace.bzl

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def repo():
3030
sha256 = XLA_SHA256,
3131
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
3232
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
33+
patch_file = [
34+
"//third_party/xla:tanh.patch",
35+
],
3336
)
3437

3538
# For development, one often wants to make changes to the TF repository as well

0 commit comments

Comments
 (0)