Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallax dot_general lowering bug for vectors #26013

Open
LasseBlaauwbroek opened this issue Jan 21, 2025 · 1 comment
Open

Pallax dot_general lowering bug for vectors #26013

LasseBlaauwbroek opened this issue Jan 21, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@LasseBlaauwbroek
Copy link

Description

The following minimized script seems to happen whenever dot_general is invoked where the second argument is a vector:

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

def matmul_kernel2(U_ref, V_ref, O_ref):
  O_ref[0] = jnp.dot(U_ref[...], V_ref[...])
  
def indexmul(U: jax.Array, V: jax.Array):
  return pl.pallas_call(
    matmul_kernel2,
    grid=(),
    in_specs=[
        pl.BlockSpec((U.shape[0],), lambda : (0,)),
        pl.BlockSpec((V.shape[0],), lambda : (0,)),
    ],
    out_shape=jax.ShapeDtypeStruct((U.shape[0],), V.dtype),
    out_specs=pl.BlockSpec((U.shape[0],), lambda : (0)),
  )(U, V)

k1, k2 = jax.random.split(jax.random.key(0))
U = jax.random.normal(k1, (64,))
V = jax.random.normal(k2, (64,))

z = indexmul(U, V)
Traceback (most recent call last):
  File "/root/test2.py", line 24, in <module>
    z = indexmul(U, V)
  File "/root/test2.py", line 9, in indexmul
    return pl.pallas_call(
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1841, in wrapped
    jaxpr, consts = _trace_kernel_to_jaxpr(
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1415, in _trace_kernel_to_jaxpr
    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/primitives.py", line 836, in wrap_with_transforms
    return f(*new_args)
  File "/root/test2.py", line 6, in matmul_kernel2
    O_ref[0] = jnp.dot(U_ref[...], V_ref[...])
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: not enough values to unpack (expected 2, got 1)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 392, in lower_jaxpr_to_triton_ir
    outvals = rule(rule_ctx, *invals, **eqn.params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 2088, in _dot_general_lowering
    _, n = b_type.shape
    ^^^^
ValueError: not enough values to unpack (expected 2, got 1)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/test2.py", line 24, in <module>
    z = indexmul(U, V)
        ^^^^^^^^^^^^^^
  File "/root/test2.py", line 9, in indexmul
    return pl.pallas_call(
           ^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1518, in _pallas_call_lowering
    return mlir.lower_per_platform(ctx, "pallas_call",
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1514, in gpu_lowering
    return pallas_call_registration.pallas_call_lowering(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/pallas_call_registration.py", line 80, in pallas_call_lowering
    lowering_result = lowering.lower_jaxpr_to_triton_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 345, in lower_jaxpr_to_triton_module
    () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *entry.arguments)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 399, in lower_jaxpr_to_triton_ir
    raise LoweringError(
jax._src.pallas.triton.lowering.LoweringError: Exception while lowering eqn:
  a:f32[] = dot_general[
  dimension_numbers=(([0], [0]), ([], []))
  preferred_element_type=float32
] b c
With context:
  LoweringRuleContext(context=ModuleContext(name='matmul_kernel2', grid_mapping=GridMapping(grid=(), grid_names=None, block_mappings=(BlockMapping(block_shape=(64,), transformed_block_aval=MemRef<None>{float32[64]}, index_map_jaxpr={ lambda ; . let  in (0,) }, index_map_src_info=<lambda> at /root/test2.py:13, indexing_mode=Blocked, array_shape_dtype=ShapeDtypeStruct(shape=(64,), dtype=float32), origin='U_ref', transforms=()), BlockMapping(block_shape=(64,), transformed_block_aval=MemRef<None>{float32[64]}, index_map_jaxpr={ lambda ; . let  in (0,) }, index_map_src_info=<lambda> at /root/test2.py:14, indexing_mode=Blocked, array_shape_dtype=ShapeDtypeStruct(shape=(64,), dtype=float32), origin='V_ref', transforms=()), BlockMapping(block_shape=(64,), transformed_block_aval=MemRef<None>{float32[64]}, index_map_jaxpr={ lambda ; . let  in (0,) }, index_map_src_info=<lambda> at /root/test2.py:17, indexing_mode=Blocked, array_shape_dtype=ShapeDtypeStruct(shape=(64,), dtype=float32), origin='outputs', transforms=())), index_map_tree=PyTreeDef(((), {})), index_map_avals=(), vmapped_dims=(), num_index_operands=0, num_inputs=2, num_outputs=1, num_scratch_operands=0, get_grid_indices=None, local_grid_env=None), program_ids=[], platform='cuda'), avals_in=[ShapedArray(float32[64]), ShapedArray(float32[64])], avals_out=[ShapedArray(float32[])], block_infos=[None, None])
With inval types=[RankedTensorType(tensor<64xf32>), RankedTensorType(tensor<64xf32>)]
In jaxpr:
{ lambda ; a:MemRef<None>{float32[64]} b:MemRef<None>{float32[64]} c:MemRef<None>{float32[64]}. let
    d:f32[64] <- a[:]
    e:f32[64] <- b[:]
    f:f32[] = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] d e
    c[0] <- f
  in () }
msg=not enough values to unpack (expected 2, got 1)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

Python 3.11.11
Jaxlib 0.5.0
Accelerator: GPU

@LasseBlaauwbroek LasseBlaauwbroek added the bug Something isn't working label Jan 21, 2025
@justinjfu
Copy link
Collaborator

Triton itself doesn't support 1D arrays in dot-products (https://triton-lang.org/main/python-api/generated/triton.language.dot.html#triton.language.dot), so the solution here should probably be to disallow this and raise a better error.

For your use-case, would a reshape + dot or multiply + sum be sufficient to replicate what you are trying to do?

@justinjfu justinjfu self-assigned this Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants