Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'arith.index_cast' op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32' #26034

Open
Akshat-Tripathi opened this issue Jan 22, 2025 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@Akshat-Tripathi
Copy link

Description

Hi, I was following the Pallas quickstart guide and tried modifying the iota kernel to produce fp32 results. But I got the following error:

Traceback (most recent call last):
  File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 732, in lower_jaxpr_to_func
    body.func_op.verify()
jaxlib.mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: "/swap"(callsite("grid_filler_kernel"("/mnt/ssd0/pallas_stuff/hello_grid.py":11:0) at callsite("grid_filler"("/mnt/ssd0/pallas_stuff/hello_grid.py":14:0) at "<module>"("/mnt/ssd0/
pallas_stuff/hello_grid.py":21:0)))): 'arith.index_cast' op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'                                          note: "/swap"(callsite("grid_filler_kernel"("/mnt/ssd0/pallas_stuff/hello_grid.py":11:0) at callsite("grid_filler"("/mnt/ssd0/pallas_stuff/hello_grid.py":14:0) at "<module>"("/mnt/ssd0/
pallas_stuff/hello_grid.py":21:0)))): see current operation: %1 = "arith.index_cast"(%0) : (f32) -> index                                                                                 
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 21, in <module>
    print(grid_filler(4))
  File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 14, in grid_filler
    return pl.pallas_call(
  File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 1882, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax._src.pallas.mosaic.error_handling.VerificationError: Pallas encountered an internal verification error.Please file a bug 
at https://github.com/jax-ml/jax/issues. Error details: 'arith.index_cast' op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'                        see current operation: %1 = "arith.index_cast"(%0) : (f32) -> index

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 21, in <module>
    print(grid_filler(4))
  File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 14, in grid_filler
    return pl.pallas_call(
  File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 1520, in _pallas_call_lowering
    return mlir.lower_per_platform(ctx, "pallas_call",
  File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 1493, in tpu_lowering
    return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
  File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 147, in pallas_call_tpu_lowering_rule
    mosaic_module, extra_args = lower_module(for_verification=False)
  File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 142, in lower_module
    return lowering.lower_jaxpr_to_module(
  File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 550, in lower_jaxpr_to_module
    func_op = lower_jaxpr_to_func(
  File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 734, in lower_jaxpr_to_func
    raise error_handling.mlir_error_to_verification_error(e) from e
  File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 21, in <module>
    print(grid_filler(4))
  File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 14, in grid_filler
    return pl.pallas_call(
  File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 11, in grid_filler_kernel
    o_ref[i] = i
jax._src.pallas.mosaic.error_handling.VerificationError: Pallas encountered an internal verification error.Please file a bug at https://github.com/jax-ml/jax/issues. Error details: 'arit
h.index_cast' op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'                                                                                     see current operation: %1 = "arith.index_cast"(%0) : (f32) -> index

I ran the following code.

from functools import partial

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np

def iota_kernel(o_ref):
    i = pl.program_id(0).astype(o_ref.dtype) # errors
    o_ref[i] = i # .astype(o_ref.dtype) - this works
    
def iota(i: int) -> jax.Array:
    return pl.pallas_call(
        iota_kernel,
        out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
        out_shape=jax.ShapeDtypeStruct((i,), jnp.float32),
        grid=(i,)
    )()
    
print(iota(4))

What's strange is that if I move the .astype(o_ref.dtype) down to the next line, the error goes away

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.36.dev20241122
jaxlib: 0.4.36.dev20241122
numpy:  1.26.4
python: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0]
device info: TPU v6 lite-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-8ff16d39-w-0', release='6.8.0-1015-gcp', version='#17~22.04.1-Ubuntu SMP Tue Sep  3 16:11:52 UTC 2024', machine='x86_64')
@Akshat-Tripathi Akshat-Tripathi added the bug Something isn't working label Jan 22, 2025
@justinjfu
Copy link
Collaborator

justinjfu commented Jan 22, 2025

The error here is on the index_cast (since you are storing to o_ref[i]), not on the cast produced by as_type. The error isn't very clear, but it's essentially saying you can't index into an array with a float.

That being said, this is an error that we should catch in Pallas itself and raise a better error instead of outputting an invalid MLIR program that fails during the verification check.

@Akshat-Tripathi
Copy link
Author

Oh yes right I didn't see that

@justinjfu justinjfu self-assigned this Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants