You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/root/test2.py", line 24, in <module>
z = indexmul(U, V)
File "/root/test2.py", line 9, in indexmul
return pl.pallas_call(
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1841, in wrapped
jaxpr, consts = _trace_kernel_to_jaxpr(
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1415, in _trace_kernel_to_jaxpr
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/primitives.py", line 836, in wrap_with_transforms
return f(*new_args)
File "/root/test2.py", line 6, in matmul_kernel2
O_ref[0] = jnp.dot(U_ref[...], V_ref[...])
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: not enough values to unpack (expected 2, got 1)
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 392, in lower_jaxpr_to_triton_ir
outvals = rule(rule_ctx, *invals, **eqn.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 2088, in _dot_general_lowering
_, n = b_type.shape
^^^^
ValueError: not enough values to unpack (expected 2, got 1)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/test2.py", line 24, in <module>
z = indexmul(U, V)
^^^^^^^^^^^^^^
File "/root/test2.py", line 9, in indexmul
return pl.pallas_call(
^^^^^^^^^^^^^^^
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1518, in _pallas_call_lowering
return mlir.lower_per_platform(ctx, "pallas_call",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 1514, in gpu_lowering
return pallas_call_registration.pallas_call_lowering(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/pallas_call_registration.py", line 80, in pallas_call_lowering
lowering_result = lowering.lower_jaxpr_to_triton_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 345, in lower_jaxpr_to_triton_module
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *entry.arguments)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/venv/lib/python3.11/site-packages/jax/_src/pallas/triton/lowering.py", line 399, in lower_jaxpr_to_triton_ir
raise LoweringError(
jax._src.pallas.triton.lowering.LoweringError: Exception while lowering eqn:
a:f32[] = dot_general[
dimension_numbers=(([0], [0]), ([], []))
preferred_element_type=float32
] b c
With context:
LoweringRuleContext(context=ModuleContext(name='matmul_kernel2', grid_mapping=GridMapping(grid=(), grid_names=None, block_mappings=(BlockMapping(block_shape=(64,), transformed_block_aval=MemRef<None>{float32[64]}, index_map_jaxpr={ lambda ; . let in (0,) }, index_map_src_info=<lambda> at /root/test2.py:13, indexing_mode=Blocked, array_shape_dtype=ShapeDtypeStruct(shape=(64,), dtype=float32), origin='U_ref', transforms=()), BlockMapping(block_shape=(64,), transformed_block_aval=MemRef<None>{float32[64]}, index_map_jaxpr={ lambda ; . let in (0,) }, index_map_src_info=<lambda> at /root/test2.py:14, indexing_mode=Blocked, array_shape_dtype=ShapeDtypeStruct(shape=(64,), dtype=float32), origin='V_ref', transforms=()), BlockMapping(block_shape=(64,), transformed_block_aval=MemRef<None>{float32[64]}, index_map_jaxpr={ lambda ; . let in (0,) }, index_map_src_info=<lambda> at /root/test2.py:17, indexing_mode=Blocked, array_shape_dtype=ShapeDtypeStruct(shape=(64,), dtype=float32), origin='outputs', transforms=())), index_map_tree=PyTreeDef(((), {})), index_map_avals=(), vmapped_dims=(), num_index_operands=0, num_inputs=2, num_outputs=1, num_scratch_operands=0, get_grid_indices=None, local_grid_env=None), program_ids=[], platform='cuda'), avals_in=[ShapedArray(float32[64]), ShapedArray(float32[64])], avals_out=[ShapedArray(float32[])], block_infos=[None, None])
With inval types=[RankedTensorType(tensor<64xf32>), RankedTensorType(tensor<64xf32>)]
In jaxpr:
{ lambda ; a:MemRef<None>{float32[64]} b:MemRef<None>{float32[64]} c:MemRef<None>{float32[64]}. let
d:f32[64] <- a[:]
e:f32[64] <- b[:]
f:f32[] = dot_general[
dimension_numbers=(([0], [0]), ([], []))
preferred_element_type=float32
] d e
c[0] <- f
in () }
msg=not enough values to unpack (expected 2, got 1)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System info (python version, jaxlib version, accelerator, etc.)
Python 3.11.11
Jaxlib 0.5.0
Accelerator: GPU
The text was updated successfully, but these errors were encountered:
Description
The following minimized script seems to happen whenever
dot_general
is invoked where the second argument is a vector:System info (python version, jaxlib version, accelerator, etc.)
Python 3.11.11
Jaxlib 0.5.0
Accelerator: GPU
The text was updated successfully, but these errors were encountered: