You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Running uv run --python 3.11 bug.py succeeds without error.
Running uv run --python 3.11 --with jax-metal bug.py produces:
tomasz@tomaszkalinows-WQVX deep_learning_with_r_3e % uv run --python 3.11 --with jax-metal bug2.py
Installed 9 packages in 28ms
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
WARNING:2025-03-12 07:33:22,783:jax._src.xla_bridge:997: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1741779202.784024 79609 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M4 Max
systemMemory: 128.00 GB
maxCacheSize: 48.00 GB
I0000 00:00:1741779202.802302 79609 service.cc:145] XLA service 0x600003880600 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741779202.802324 79609 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1741779202.803547 79609 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1741779202.803557 79609 mps_client.cc:384] XLA backend will use up to 103078739968 bytes on device 0 for SimpleAllocator.
Traceback (most recent call last):
File "/Users/tomasz/github/t-kalinowski/deep_learning_with_r_3e/bug2.py", line 22, in <module>
tokens = tokenizer.tokenize("The quick brown fox.")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 50, in wrapper
return convert_preprocessing_outputs(x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 191, in convert_preprocessing_outputs
return keras.tree.map_structure(convert, x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/tree/tree_api.py", line 192, in map_structure
return tree_impl.map_structure(func, *structures)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/tree/optree_impl.py", line 108, in map_structure
return optree.tree_map(
^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/optree/ops.py", line 766, in tree_map
return treespec.unflatten(map(func, *flat_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 189, in convert
return ops.convert_to_tensor(x, dtype=dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/ops/core.py", line 958, in convert_to_tensor
return backend.core.convert_to_tensor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/backend/jax/core.py", line 80, in convert_to_tensor
return jnp.asarray(x, dtype=dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5732, in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5566, in array
out_array: Array = lax_internal._convert_element_type(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1414, in _convert_element_type
return convert_element_type_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 502, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 520, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4371, in _convert_element_type_bind_with_trace
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 525, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 1024, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: default_memory_space is not supported.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I0000 00:00:1741779203.113335 79609 mps_client.h:209] MetalClient destroyed.
The text was updated successfully, but these errors were encountered:
Describe the bug
On an Arm Mac with a Jax backend, if
jax-metal
is installed,SentencePieceTokenizer
will throw exceptions.To Reproduce
Given a file
bug.py
Running
uv run --python 3.11 bug.py
succeeds without error.Running
uv run --python 3.11 --with jax-metal bug.py
produces:The text was updated successfully, but these errors were encountered: