Skip to content

Commit 6c7be29

Browse files
committed
Use gpu::LaneIdOp
1 parent 1ae89dc commit 6c7be29

File tree

2 files changed

+17
-66
lines changed

2 files changed

+17
-66
lines changed

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -416,30 +416,6 @@ LogicalResult GPUBarrierConversion::matchAndRewrite(
416416
return success();
417417
}
418418

419-
template <typename T>
420-
Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc,
421-
gpu::Dimension dimension) {
422-
Type indexType = IndexType::get(ctx);
423-
IntegerType i32Type = IntegerType::get(ctx, 32);
424-
Value dim = builder.create<T>(loc, indexType, dimension);
425-
return builder.create<arith::IndexCastOp>(loc, i32Type, dim);
426-
}
427-
428-
Value getLaneId(OpBuilder &rewriter, MLIRContext *ctx, Location loc) {
429-
Value dimX = getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
430-
Value dimY = getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
431-
Value tidX = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
432-
Value tidY = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
433-
Value tidZ = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
434-
auto i32Type = rewriter.getIntegerType(32);
435-
Value tmp1 = rewriter.create<arith::MulIOp>(loc, i32Type, tidZ, dimY);
436-
Value tmp2 = rewriter.create<arith::AddIOp>(loc, i32Type, tmp1, tidY);
437-
Value tmp3 = rewriter.create<arith::MulIOp>(loc, i32Type, tmp2, dimX);
438-
Value laneId = rewriter.create<arith::AddIOp>(loc, i32Type, tmp3, tidX);
439-
440-
return laneId;
441-
}
442-
443419
//===----------------------------------------------------------------------===//
444420
// Shuffle
445421
//===----------------------------------------------------------------------===//
@@ -460,26 +436,30 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
460436
shuffleOp, "shuffle width and target subgroup size mismatch");
461437

462438
Location loc = shuffleOp.getLoc();
463-
Value validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
464-
shuffleOp.getLoc(), rewriter);
465439
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
466440
Value result;
441+
Value validVal;
467442

468443
switch (shuffleOp.getMode()) {
469-
case gpu::ShuffleMode::XOR:
444+
case gpu::ShuffleMode::XOR: {
470445
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
471446
loc, scope, adaptor.getValue(), adaptor.getOffset());
447+
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
448+
shuffleOp.getLoc(), rewriter);
472449
break;
473-
case gpu::ShuffleMode::IDX:
450+
}
451+
case gpu::ShuffleMode::IDX: {
474452
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
475453
loc, scope, adaptor.getValue(), adaptor.getOffset());
454+
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
455+
shuffleOp.getLoc(), rewriter);
476456
break;
457+
}
477458
case gpu::ShuffleMode::DOWN: {
478459
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
479460
loc, scope, adaptor.getValue(), adaptor.getOffset());
480461

481-
MLIRContext *ctx = shuffleOp.getContext();
482-
Value laneId = getLaneId(rewriter, ctx, loc);
462+
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
483463
Value resultLandId =
484464
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
485465
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
@@ -490,8 +470,7 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
490470
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
491471
loc, scope, adaptor.getValue(), adaptor.getOffset());
492472

493-
MLIRContext *ctx = shuffleOp.getContext();
494-
Value laneId = getLaneId(rewriter, ctx, loc);
473+
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
495474
Value resultLandId =
496475
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
497476
auto i32Type = rewriter.getIntegerType(32);

mlir/test/Conversion/GPUToSPIRV/shuffle.mlir

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ gpu.module @kernels {
1515

1616
// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
1717
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
18-
// CHECK: %{{.+}} = spirv.Constant true
1918
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleXor <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
19+
// CHECK: %{{.+}} = spirv.Constant true
2020
%result, %valid = gpu.shuffle xor %val, %mask, %width : f32
2121
gpu.return
2222
}
@@ -64,8 +64,8 @@ gpu.module @kernels {
6464

6565
// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
6666
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
67-
// CHECK: %{{.+}} = spirv.Constant true
6867
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffle <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
68+
// CHECK: %{{.+}} = spirv.Constant true
6969
%result, %valid = gpu.shuffle idx %val, %mask, %width : f32
7070
gpu.return
7171
}
@@ -92,24 +92,10 @@ gpu.module @kernels {
9292
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
9393
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
9494
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
95-
// CHECK: %{{.+}} = spirv.Constant true
9695
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
9796

98-
// CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32
99-
// CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32
100-
// CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
101-
// CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32>
102-
// CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32>
103-
// CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
104-
// CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32>
105-
// CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32>
106-
// CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
107-
// CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32>
108-
// CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32>
109-
// CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32
110-
// CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32
111-
// CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32
112-
// CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32
97+
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
98+
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
11399
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32
114100
// CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32
115101

@@ -139,24 +125,10 @@ gpu.module @kernels {
139125
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
140126
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
141127
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
142-
// CHECK: %{{.+}} = spirv.Constant true
143128
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
144129

145-
// CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32
146-
// CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32
147-
// CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
148-
// CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32>
149-
// CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32>
150-
// CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
151-
// CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32>
152-
// CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32>
153-
// CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
154-
// CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32>
155-
// CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32>
156-
// CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32
157-
// CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32
158-
// CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32
159-
// CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32
130+
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
131+
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
160132
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32
161133
// CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32
162134
// CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32

0 commit comments

Comments
 (0)