Skip to content

Commit 3369f34

Browse files
committed
refactor and update wording
1 parent 09d2ef8 commit 3369f34

File tree

2 files changed

+22
-37
lines changed

2 files changed

+22
-37
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,8 +1333,7 @@ def GPU_ShuffleOp : GPU_Op<
13331333
```
13341334

13351335
For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
1336-
bigger than or equal to `width`, the value is unspecified and `valid` is
1337-
`false`.
1336+
bigger than or equal to `width`, the value is poison and `valid` is `false`.
13381337

13391338
`up` example:
13401339

@@ -1344,7 +1343,7 @@ def GPU_ShuffleOp : GPU_Op<
13441343
```
13451344

13461345
For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
1347-
smaller than `0`, the value is unspecified and `valid` is `false`.
1346+
smaller than `0`, the value is poison and `valid` is `false`.
13481347

13491348
`idx` example:
13501349

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,21 @@ Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc,
425425
return builder.create<arith::IndexCastOp>(loc, i32Type, dim);
426426
}
427427

428+
Value getLaneId(OpBuilder &rewriter, MLIRContext *ctx, Location loc) {
429+
Value dimX = getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
430+
Value dimY = getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
431+
Value tidX = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
432+
Value tidY = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
433+
Value tidZ = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
434+
auto i32Type = rewriter.getIntegerType(32);
435+
Value tmp1 = rewriter.create<arith::MulIOp>(loc, i32Type, tidZ, dimY);
436+
Value tmp2 = rewriter.create<arith::AddIOp>(loc, i32Type, tmp1, tidY);
437+
Value tmp3 = rewriter.create<arith::MulIOp>(loc, i32Type, tmp2, dimX);
438+
Value laneId = rewriter.create<arith::AddIOp>(loc, i32Type, tmp3, tidX);
439+
440+
return laneId;
441+
}
442+
428443
//===----------------------------------------------------------------------===//
429444
// Shuffle
430445
//===----------------------------------------------------------------------===//
@@ -464,24 +479,9 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
464479
loc, scope, adaptor.getValue(), adaptor.getOffset());
465480

466481
MLIRContext *ctx = shuffleOp.getContext();
467-
Value dimX =
468-
getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
469-
Value dimY =
470-
getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
471-
Value tidX =
472-
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
473-
Value tidY =
474-
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
475-
Value tidZ =
476-
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
477-
auto i32Type = rewriter.getIntegerType(32);
478-
Value tmp1 = rewriter.create<arith::MulIOp>(loc, i32Type, tidZ, dimY);
479-
Value tmp2 = rewriter.create<arith::AddIOp>(loc, i32Type, tmp1, tidY);
480-
Value tmp3 = rewriter.create<arith::MulIOp>(loc, i32Type, tmp2, dimX);
481-
Value landId = rewriter.create<arith::AddIOp>(loc, i32Type, tmp3, tidX);
482-
482+
Value laneId = getLaneId(rewriter, ctx, loc);
483483
Value resultLandId =
484-
rewriter.create<arith::AddIOp>(loc, landId, adaptor.getOffset());
484+
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
485485
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
486486
resultLandId, adaptor.getWidth());
487487
break;
@@ -491,24 +491,10 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
491491
loc, scope, adaptor.getValue(), adaptor.getOffset());
492492

493493
MLIRContext *ctx = shuffleOp.getContext();
494-
Value dimX =
495-
getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
496-
Value dimY =
497-
getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
498-
Value tidX =
499-
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
500-
Value tidY =
501-
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
502-
Value tidZ =
503-
getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
504-
auto i32Type = rewriter.getIntegerType(32);
505-
Value tmp1 = rewriter.create<arith::MulIOp>(loc, i32Type, tidZ, dimY);
506-
Value tmp2 = rewriter.create<arith::AddIOp>(loc, i32Type, tmp1, tidY);
507-
Value tmp3 = rewriter.create<arith::MulIOp>(loc, i32Type, tmp2, dimX);
508-
Value landId = rewriter.create<arith::AddIOp>(loc, i32Type, tmp3, tidX);
509-
494+
Value laneId = getLaneId(rewriter, ctx, loc);
510495
Value resultLandId =
511-
rewriter.create<arith::SubIOp>(loc, landId, adaptor.getOffset());
496+
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
497+
auto i32Type = rewriter.getIntegerType(32);
512498
validVal = rewriter.create<arith::CmpIOp>(
513499
loc, arith::CmpIPredicate::sge, resultLandId,
514500
rewriter.create<arith::ConstantOp>(

0 commit comments

Comments
 (0)