Skip to content

[mlir][gpu][spirv] Remove rotation semantics of gpu.shuffle up/down #139105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 19, 2025
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,8 @@ def GPU_ShuffleOp : GPU_Op<
%3, %4 = gpu.shuffle down %0, %cst1, %width : f32
```

For lane `k`, returns the value from lane `(k + 1) % width`.
For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
bigger than or equal to `width`, the value is poison and `valid` is `false`.

`up` example:

Expand All @@ -1341,7 +1342,8 @@ def GPU_ShuffleOp : GPU_Op<
%5, %6 = gpu.shuffle up %0, %cst1, %width : f32
```

For lane `k`, returns the value from lane `(k - 1) % width`.
For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
smaller than `0`, the value is poison and `valid` is `false`.

`idx` example:

Expand Down
45 changes: 38 additions & 7 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,26 +435,57 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return rewriter.notifyMatchFailure(
shuffleOp, "shuffle width and target subgroup size mismatch");

assert(!adaptor.getOffset().getType().isSignedInteger() &&
"shuffle offset must be a signless/unsigned integer");

Location loc = shuffleOp.getLoc();
Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
Value result;
Value validVal;

switch (shuffleOp.getMode()) {
case gpu::ShuffleMode::XOR:
case gpu::ShuffleMode::XOR: {
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
case gpu::ShuffleMode::IDX:
}
case gpu::ShuffleMode::IDX: {
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
}
case gpu::ShuffleMode::DOWN: {
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());

Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
Value resultLaneId =
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
resultLaneId, adaptor.getWidth());
break;
default:
return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
}
case gpu::ShuffleMode::UP: {
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());

Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
Value resultLaneId =
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
auto i32Type = rewriter.getIntegerType(32);
validVal = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, resultLaneId,
rewriter.create<arith::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
break;
}
}

rewriter.replaceOp(shuffleOp, {result, trueVal});
rewriter.replaceOp(shuffleOp, {result, validVal});
return success();
}

Expand Down
71 changes: 69 additions & 2 deletions mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ gpu.module @kernels {

// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.Constant true
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleXor <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
// CHECK: %{{.+}} = spirv.Constant true
%result, %valid = gpu.shuffle xor %val, %mask, %width : f32
gpu.return
}
Expand Down Expand Up @@ -64,11 +64,78 @@ gpu.module @kernels {

// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.Constant true
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffle <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
// CHECK: %{{.+}} = spirv.Constant true
%result, %valid = gpu.shuffle idx %val, %mask, %width : f32
gpu.return
}
}

}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @shuffle_down()
gpu.func @shuffle_down() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32

// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32

// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32
// CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32

%result, %valid = gpu.shuffle down %val, %offset, %width : f32
gpu.return
}
}

}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @shuffle_up()
gpu.func @shuffle_up() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32

// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32

// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32
// CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32
// CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32

%result, %valid = gpu.shuffle up %val, %offset, %width : f32
gpu.return
}
}

}