Skip to content

Conversation

@danielpmorton
Copy link

If "jax_enable_x64" is True, these slice_indices are of mismatched dtype, causing an error in dynamic_update_slice. Wrapping the indices in an array fixes this

See associated traceback:

Traceback (most recent call last):
  File "/home/dmorton/gemma/scripts/demo.py", line 33, in <module>
    out = sampler.sample(prompt, max_new_tokens=1000, rng=key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/text/_sampler.py", line 311, in sample
    init_state = _prefill.prefill(
                 ^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/text/_prefill.py", line 110, in prefill
    out = model.apply(
          ^^^^^^^^^^^^
  File "/home/dmorton/.pyenv/versions/gemma/lib/python3.11/site-packages/kauldron/utils/train_property.py", line 141, in decorated
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/utils/_jax_utils.py", line 96, in decorated
    output = fn(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/.pyenv/versions/gemma/lib/python3.11/site-packages/kauldron/typing/type_check.py", line 270, in _reraise_with_shape_info
    retval = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/nn/_transformer.py", line 322, in __call__
    x, new_cache = self._apply_attention(inputs, cache)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/nn/_transformer.py", line 367, in _apply_attention
    layer_cache, x = block(
                     ^^^^^^
  File "/home/dmorton/gemma/gemma/gm/nn/_modules.py", line 468, in __call__
    cache, attn_output = self.attn(
                         ^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/nn/_modules.py", line 233, in __call__
    value_proj = jax.lax.dynamic_update_slice(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: index arguments to dynamic_update_slice must be integers of the same type, got int64, int32, int64, int64

If `"jax_enable_x64"` is `True`, these slice_indices are of mismatched dtype, causing an error in `dynamic_update_slice`. Wrapping the indices in an array fixes this

See associated traceback:
```
Traceback (most recent call last):
  File "/home/dmorton/gemma/scripts/demo.py", line 33, in <module>
    out = sampler.sample(prompt, max_new_tokens=1000, rng=key)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/text/_sampler.py", line 311, in sample
    init_state = _prefill.prefill(
                 ^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/text/_prefill.py", line 110, in prefill
    out = model.apply(
          ^^^^^^^^^^^^
  File "/home/dmorton/.pyenv/versions/gemma/lib/python3.11/site-packages/kauldron/utils/train_property.py", line 141, in decorated
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/utils/_jax_utils.py", line 96, in decorated
    output = fn(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/.pyenv/versions/gemma/lib/python3.11/site-packages/kauldron/typing/type_check.py", line 270, in _reraise_with_shape_info
    retval = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/nn/_transformer.py", line 322, in __call__
    x, new_cache = self._apply_attention(inputs, cache)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/nn/_transformer.py", line 367, in _apply_attention
    layer_cache, x = block(
                     ^^^^^^
  File "/home/dmorton/gemma/gemma/gm/nn/_modules.py", line 468, in __call__
    cache, attn_output = self.attn(
                         ^^^^^^^^^^
  File "/home/dmorton/gemma/gemma/gm/nn/_modules.py", line 233, in __call__
    value_proj = jax.lax.dynamic_update_slice(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: index arguments to dynamic_update_slice must be integers of the same type, got int64, int32, int64, int64
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant