Skip to content

Commit 259194a

Browse files
justinjfuhawkinsp
authored andcommitted
[Pallas] Fix shard_axis in dma_start interpret mode rule.
PiperOrigin-RevId: 703192497
1 parent 7e6620a commit 259194a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/pallas/mosaic/primitives.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
615615
if device_id_len > 1 or len(nonempty_axes) > 1:
616616
raise NotImplementedError("Meshes with more than 1 named dimension not "
617617
"implemented in dma_start_p")
618-
shard_axis = nonempty_axes[0].name
618+
shard_axis = nonempty_axes[0]
619619
my_axis = jax.lax.axis_index(shard_axis)
620620
else:
621621
raise ValueError(f"Unknown device_id_type: {device_id_type}")

0 commit comments

Comments
 (0)