Skip to content

Commit 03461c9

Browse files
authored
[mlir][gpu][spirv] Remove rotation semantics of gpu.shuffle up/down (#139105)
From the description of gpu.shuffle operation, shuffle up/down rotates values in the subgroup because it applies modulo on the shifted value to calculate the result lane ID. It is inconsistent with the definition of SPIR-V shuffle up/down and NVVM data movement definitions within subgroup. In NVVM, it says "If the computed source lane index j is in range, the returned i32 value will be the value of %a from lane j; otherwise, it will be the the value of %a from the current thread." It will keep the original value if the result land ID is out of range. In SPIR-V OpGroupNonUniformShuffleUp and OpGroupNonUniformShuffleDown, it says "The resulting value is undefined if Delta is greater than the current invocation’s id within the scope or if the identified invocation is not in scope restricted tangle." It's an undefined value if the result land ID is out of range. Anyway, there is no circular movement in shuffle up/down from these 2 specifications. This patch removes the circular movement in gpu.shuffle up/down and lower gpu.shuffle up/down to SPIR-V OpGroupNonUniformShuffleUp and OpGroupNonUniformShuffleDown directly. Reference: https://docs.nvidia.com/cuda/archive/12.2.1/nvvm-ir-spec/index.html#data-movement https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpGroupNonUniformShuffleUp https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpGroupNonUniformShuffleDown
1 parent 590066b commit 03461c9

File tree

3 files changed

+111
-11
lines changed

3 files changed

+111
-11
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,8 @@ def GPU_ShuffleOp : GPU_Op<
13321332
%3, %4 = gpu.shuffle down %0, %cst1, %width : f32
13331333
```
13341334

1335-
For lane `k`, returns the value from lane `(k + 1) % width`.
1335+
For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
1336+
bigger than or equal to `width`, the value is poison and `valid` is `false`.
13361337

13371338
`up` example:
13381339

@@ -1341,7 +1342,8 @@ def GPU_ShuffleOp : GPU_Op<
13411342
%5, %6 = gpu.shuffle up %0, %cst1, %width : f32
13421343
```
13431344

1344-
For lane `k`, returns the value from lane `(k - 1) % width`.
1345+
For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
1346+
smaller than `0`, the value is poison and `valid` is `false`.
13451347

13461348
`idx` example:
13471349

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -435,26 +435,57 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
435435
return rewriter.notifyMatchFailure(
436436
shuffleOp, "shuffle width and target subgroup size mismatch");
437437

438+
assert(!adaptor.getOffset().getType().isSignedInteger() &&
439+
"shuffle offset must be a signless/unsigned integer");
440+
438441
Location loc = shuffleOp.getLoc();
439-
Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
440-
shuffleOp.getLoc(), rewriter);
441442
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
442443
Value result;
444+
Value validVal;
443445

444446
switch (shuffleOp.getMode()) {
445-
case gpu::ShuffleMode::XOR:
447+
case gpu::ShuffleMode::XOR: {
446448
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
447449
loc, scope, adaptor.getValue(), adaptor.getOffset());
450+
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
451+
shuffleOp.getLoc(), rewriter);
448452
break;
449-
case gpu::ShuffleMode::IDX:
453+
}
454+
case gpu::ShuffleMode::IDX: {
450455
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
451456
loc, scope, adaptor.getValue(), adaptor.getOffset());
457+
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
458+
shuffleOp.getLoc(), rewriter);
459+
break;
460+
}
461+
case gpu::ShuffleMode::DOWN: {
462+
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
463+
loc, scope, adaptor.getValue(), adaptor.getOffset());
464+
465+
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
466+
Value resultLaneId =
467+
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
468+
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
469+
resultLaneId, adaptor.getWidth());
452470
break;
453-
default:
454-
return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
471+
}
472+
case gpu::ShuffleMode::UP: {
473+
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
474+
loc, scope, adaptor.getValue(), adaptor.getOffset());
475+
476+
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
477+
Value resultLaneId =
478+
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
479+
auto i32Type = rewriter.getIntegerType(32);
480+
validVal = rewriter.create<arith::CmpIOp>(
481+
loc, arith::CmpIPredicate::sge, resultLaneId,
482+
rewriter.create<arith::ConstantOp>(
483+
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
484+
break;
485+
}
455486
}
456487

457-
rewriter.replaceOp(shuffleOp, {result, trueVal});
488+
rewriter.replaceOp(shuffleOp, {result, validVal});
458489
return success();
459490
}
460491

mlir/test/Conversion/GPUToSPIRV/shuffle.mlir

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ gpu.module @kernels {
1515

1616
// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
1717
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
18-
// CHECK: %{{.+}} = spirv.Constant true
1918
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleXor <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
19+
// CHECK: %{{.+}} = spirv.Constant true
2020
%result, %valid = gpu.shuffle xor %val, %mask, %width : f32
2121
gpu.return
2222
}
@@ -64,11 +64,78 @@ gpu.module @kernels {
6464

6565
// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
6666
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
67-
// CHECK: %{{.+}} = spirv.Constant true
6867
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffle <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
68+
// CHECK: %{{.+}} = spirv.Constant true
6969
%result, %valid = gpu.shuffle idx %val, %mask, %width : f32
7070
gpu.return
7171
}
7272
}
7373

7474
}
75+
76+
// -----
77+
78+
module attributes {
79+
gpu.container_module,
80+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
81+
#spirv.resource_limits<subgroup_size = 16>>
82+
} {
83+
84+
gpu.module @kernels {
85+
// CHECK-LABEL: spirv.func @shuffle_down()
86+
gpu.func @shuffle_down() kernel
87+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
88+
%offset = arith.constant 4 : i32
89+
%width = arith.constant 16 : i32
90+
%val = arith.constant 42.0 : f32
91+
92+
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
93+
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
94+
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
95+
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
96+
97+
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
98+
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
99+
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32
100+
// CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32
101+
102+
%result, %valid = gpu.shuffle down %val, %offset, %width : f32
103+
gpu.return
104+
}
105+
}
106+
107+
}
108+
109+
// -----
110+
111+
module attributes {
112+
gpu.container_module,
113+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
114+
#spirv.resource_limits<subgroup_size = 16>>
115+
} {
116+
117+
gpu.module @kernels {
118+
// CHECK-LABEL: spirv.func @shuffle_up()
119+
gpu.func @shuffle_up() kernel
120+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
121+
%offset = arith.constant 4 : i32
122+
%width = arith.constant 16 : i32
123+
%val = arith.constant 42.0 : f32
124+
125+
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
126+
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
127+
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
128+
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
129+
130+
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
131+
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
132+
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32
133+
// CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32
134+
// CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32
135+
136+
%result, %valid = gpu.shuffle up %val, %offset, %width : f32
137+
gpu.return
138+
}
139+
}
140+
141+
}

0 commit comments

Comments
 (0)