@@ -425,6 +425,21 @@ Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc,
425
425
return builder.create <arith::IndexCastOp>(loc, i32Type, dim);
426
426
}
427
427
428
+ Value getLaneId (OpBuilder &rewriter, MLIRContext *ctx, Location loc) {
429
+ Value dimX = getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
430
+ Value dimY = getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
431
+ Value tidX = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
432
+ Value tidY = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
433
+ Value tidZ = getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
434
+ auto i32Type = rewriter.getIntegerType (32 );
435
+ Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
436
+ Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
437
+ Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
438
+ Value laneId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
439
+
440
+ return laneId;
441
+ }
442
+
428
443
// ===----------------------------------------------------------------------===//
429
444
// Shuffle
430
445
// ===----------------------------------------------------------------------===//
@@ -464,24 +479,9 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
464
479
loc, scope, adaptor.getValue (), adaptor.getOffset ());
465
480
466
481
MLIRContext *ctx = shuffleOp.getContext ();
467
- Value dimX =
468
- getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
469
- Value dimY =
470
- getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
471
- Value tidX =
472
- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
473
- Value tidY =
474
- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
475
- Value tidZ =
476
- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
477
- auto i32Type = rewriter.getIntegerType (32 );
478
- Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
479
- Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
480
- Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
481
- Value landId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
482
-
482
+ Value laneId = getLaneId (rewriter, ctx, loc);
483
483
Value resultLandId =
484
- rewriter.create <arith::AddIOp>(loc, landId , adaptor.getOffset ());
484
+ rewriter.create <arith::AddIOp>(loc, laneId , adaptor.getOffset ());
485
485
validVal = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
486
486
resultLandId, adaptor.getWidth ());
487
487
break ;
@@ -491,24 +491,10 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
491
491
loc, scope, adaptor.getValue (), adaptor.getOffset ());
492
492
493
493
MLIRContext *ctx = shuffleOp.getContext ();
494
- Value dimX =
495
- getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
496
- Value dimY =
497
- getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
498
- Value tidX =
499
- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
500
- Value tidY =
501
- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
502
- Value tidZ =
503
- getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
504
- auto i32Type = rewriter.getIntegerType (32 );
505
- Value tmp1 = rewriter.create <arith::MulIOp>(loc, i32Type, tidZ, dimY);
506
- Value tmp2 = rewriter.create <arith::AddIOp>(loc, i32Type, tmp1, tidY);
507
- Value tmp3 = rewriter.create <arith::MulIOp>(loc, i32Type, tmp2, dimX);
508
- Value landId = rewriter.create <arith::AddIOp>(loc, i32Type, tmp3, tidX);
509
-
494
+ Value laneId = getLaneId (rewriter, ctx, loc);
510
495
Value resultLandId =
511
- rewriter.create <arith::SubIOp>(loc, landId, adaptor.getOffset ());
496
+ rewriter.create <arith::SubIOp>(loc, laneId, adaptor.getOffset ());
497
+ auto i32Type = rewriter.getIntegerType (32 );
512
498
validVal = rewriter.create <arith::CmpIOp>(
513
499
loc, arith::CmpIPredicate::sge, resultLandId,
514
500
rewriter.create <arith::ConstantOp>(
0 commit comments