You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I was following the Pallas quickstart guide and tried modifying the iota kernel to produce fp32 results. But I got the following error:
Traceback (most recent call last):
File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 732, in lower_jaxpr_to_func
body.func_op.verify()
jaxlib.mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: "/swap"(callsite("grid_filler_kernel"("/mnt/ssd0/pallas_stuff/hello_grid.py":11:0) at callsite("grid_filler"("/mnt/ssd0/pallas_stuff/hello_grid.py":14:0) at "<module>"("/mnt/ssd0/
pallas_stuff/hello_grid.py":21:0)))): 'arith.index_cast' op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32' note: "/swap"(callsite("grid_filler_kernel"("/mnt/ssd0/pallas_stuff/hello_grid.py":11:0) at callsite("grid_filler"("/mnt/ssd0/pallas_stuff/hello_grid.py":14:0) at "<module>"("/mnt/ssd0/
pallas_stuff/hello_grid.py":21:0)))): see current operation: %1 = "arith.index_cast"(%0) : (f32) -> index
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 21, in <module>
print(grid_filler(4))
File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 14, in grid_filler
return pl.pallas_call(
File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 1882, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax._src.pallas.mosaic.error_handling.VerificationError: Pallas encountered an internal verification error.Please file a bug
at https://github.com/jax-ml/jax/issues. Error details: 'arith.index_cast' op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32' see current operation: %1 = "arith.index_cast"(%0) : (f32) -> index
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 21, in <module>
print(grid_filler(4))
File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 14, in grid_filler
return pl.pallas_call(
File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 1520, in _pallas_call_lowering
return mlir.lower_per_platform(ctx, "pallas_call",
File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 1493, in tpu_lowering
return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 147, in pallas_call_tpu_lowering_rule
mosaic_module, extra_args = lower_module(for_verification=False)
File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 142, in lower_module
return lowering.lower_jaxpr_to_module(
File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 550, in lower_jaxpr_to_module
func_op = lower_jaxpr_to_func(
File "/mnt/ssd0/anaconda3/envs/vllm/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 734, in lower_jaxpr_to_func
raise error_handling.mlir_error_to_verification_error(e) from e
File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 21, in <module>
print(grid_filler(4))
File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 14, in grid_filler
return pl.pallas_call(
File "/mnt/ssd0/pallas_stuff/hello_grid.py", line 11, in grid_filler_kernel
o_ref[i] = i
jax._src.pallas.mosaic.error_handling.VerificationError: Pallas encountered an internal verification error.Please file a bug at https://github.com/jax-ml/jax/issues. Error details: 'arit
h.index_cast' op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32' see current operation: %1 = "arith.index_cast"(%0) : (f32) -> index
The error here is on the index_cast (since you are storing to o_ref[i]), not on the cast produced by as_type. The error isn't very clear, but it's essentially saying you can't index into an array with a float.
That being said, this is an error that we should catch in Pallas itself and raise a better error instead of outputting an invalid MLIR program that fails during the verification check.
Description
Hi, I was following the Pallas quickstart guide and tried modifying the
iota
kernel to produce fp32 results. But I got the following error:I ran the following code.
What's strange is that if I move the
.astype(o_ref.dtype)
down to the next line, the error goes awaySystem info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: