Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2009ede

Browse files
committedJun 23, 2025·
[AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads.
* 1-to-1 mapping wrapper op. * Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
1 parent 1128a4f commit 2009ede

File tree

3 files changed

+84
-2
lines changed

3 files changed

+84
-2
lines changed
 

‎mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td‎

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,27 @@ def AMDGPU_GatherToLDSOp :
898898
let hasVerifier = 1;
899899
}
900900

901+
def AMDGPU_TransposeLoadOp :
902+
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
903+
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
904+
Results<(outs MFMAInTypes:$dst)> {
905+
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
906+
let description = [{
907+
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
908+
909+
Operands:
910+
* `$src`: LDS memref to read from.
911+
* `$srcIndices`: indices into `$src` to read from for this thread.
912+
* `$dst`: target register this transpose load instruction will write to.
913+
914+
Note: Lowering is only supported on gfx950 and up.
915+
}];
916+
let assemblyFormat = [{
917+
$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
918+
}];
919+
let hasVerifier = 1;
920+
}
921+
901922
def AMDGPU_ScaledMFMAOp :
902923
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
903924
Pure]>,

‎mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp‎

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
11001100
}
11011101
};
11021102

1103+
struct TransposeLoadOpLowering
1104+
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
1105+
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1106+
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1107+
1108+
Chipset chipset;
1109+
1110+
LogicalResult
1111+
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1112+
ConversionPatternRewriter &rewriter) const override {
1113+
if (chipset < kGfx950)
1114+
return op.emitOpError("Non-gfx950 chipset not supported");
1115+
1116+
Location loc = op.getLoc();
1117+
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1118+
Value srcPtr =
1119+
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1120+
(adaptor.getSrcIndices()));
1121+
auto elementTypeSize = cast<VectorType>(op.getDst().getType())
1122+
.getElementType()
1123+
.getIntOrFloatBitWidth();
1124+
1125+
// TODO: support ds_read_tr16_b64 intrinsic.
1126+
switch (elementTypeSize) {
1127+
case 4:
1128+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
1129+
op, op.getDst().getType(), srcPtr);
1130+
break;
1131+
case 8:
1132+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
1133+
op, op.getDst().getType(), srcPtr);
1134+
break;
1135+
case 16:
1136+
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
1137+
op, op.getDst().getType(), srcPtr);
1138+
break;
1139+
default:
1140+
return op.emitOpError("Unsupported element size for transpose load");
1141+
}
1142+
return success();
1143+
}
1144+
};
1145+
11031146
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
11041147
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
11051148
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1792,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
17491792
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
17501793
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
17511794
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1752-
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1753-
chipset);
1795+
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1796+
TransposeLoadOpLowering>(converter, chipset);
17541797
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
17551798
}

‎mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp‎

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,24 @@ LogicalResult GatherToLDSOp::verify() {
524524
return success();
525525
}
526526

527+
LogicalResult TransposeLoadOp::verify() {
528+
MemRefType srcType = cast<MemRefType>(getSrc().getType());
529+
530+
if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
531+
return emitOpError("source memory address space must be Workgroup");
532+
533+
// TODO: support 6-bit element type vectors.
534+
auto transferType = dyn_cast<VectorType>(getDst().getType());
535+
if (!transferType)
536+
return emitOpError("destination type must be a vector type");
537+
size_t transferSize =
538+
transferType.getNumElements() * transferType.getElementTypeBitWidth();
539+
if (transferSize != 64)
540+
return emitOpError("Transferring type size must be 64 bits");
541+
542+
return success();
543+
}
544+
527545
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
528546

529547
#define GET_ATTRDEF_CLASSES

0 commit comments

Comments
 (0)
Please sign in to comment.