Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SentencePieceTokenizer errors with jax-metal #2135

Open
t-kalinowski opened this issue Mar 12, 2025 · 0 comments
Open

SentencePieceTokenizer errors with jax-metal #2135

t-kalinowski opened this issue Mar 12, 2025 · 0 comments
Assignees
Labels
type:Bug Something isn't working

Comments

@t-kalinowski
Copy link

Describe the bug

On an Arm Mac with a Jax backend, if jax-metal is installed, SentencePieceTokenizer will throw exceptions.

To Reproduce

Given a file bug.py

# /// script
# dependencies = [
#   "keras",
#   "keras-hub",
#   "jax"
# ]
# ///

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
import keras_hub

vocabulary_file = keras.utils.get_file(
    origin="https://huggingface.co/mattdangerw/sentencepiece-example/resolve/main/vocabulary.proto"
)

tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(vocabulary_file)
tokens = tokenizer.tokenize("The quick brown fox.")
print(tokens)

Running uv run --python 3.11 bug.py succeeds without error.
Running uv run --python 3.11 --with jax-metal bug.py produces:

tomasz@tomaszkalinows-WQVX deep_learning_with_r_3e % uv run --python 3.11 --with jax-metal bug2.py
Installed 9 packages in 28ms
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
WARNING:2025-03-12 07:33:22,783:jax._src.xla_bridge:997: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1741779202.784024   79609 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M4 Max

systemMemory: 128.00 GB
maxCacheSize: 48.00 GB

I0000 00:00:1741779202.802302   79609 service.cc:145] XLA service 0x600003880600 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741779202.802324   79609 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1741779202.803547   79609 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1741779202.803557   79609 mps_client.cc:384] XLA backend will use up to 103078739968 bytes on device 0 for SimpleAllocator.
Traceback (most recent call last):
  File "/Users/tomasz/github/t-kalinowski/deep_learning_with_r_3e/bug2.py", line 22, in <module>
    tokens = tokenizer.tokenize("The quick brown fox.")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 50, in wrapper
    return convert_preprocessing_outputs(x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 191, in convert_preprocessing_outputs
    return keras.tree.map_structure(convert, x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/tree/tree_api.py", line 192, in map_structure
    return tree_impl.map_structure(func, *structures)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/tree/optree_impl.py", line 108, in map_structure
    return optree.tree_map(
           ^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/optree/ops.py", line 766, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 189, in convert
    return ops.convert_to_tensor(x, dtype=dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/ops/core.py", line 958, in convert_to_tensor
    return backend.core.convert_to_tensor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/environments-v2/bug2-977c5b9b87b33b67/lib/python3.11/site-packages/keras/src/backend/jax/core.py", line 80, in convert_to_tensor
    return jnp.asarray(x, dtype=dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5732, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5566, in array
    out_array: Array = lax_internal._convert_element_type(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1414, in _convert_element_type
    return convert_element_type_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 502, in bind
    return self._true_bind(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 520, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4371, in _convert_element_type_bind_with_trace
    operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 525, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/core.py", line 1024, in process_primitive
    return primitive.impl(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomasz/.cache/uv/archive-v0/Ux38duHD1OrEs5NMTltaO/lib/python3.11/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: default_memory_space is not supported.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
I0000 00:00:1741779203.113335   79609 mps_client.h:209] MetalClient destroyed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants