Skip to content

[AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. #145395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 26, 2025

Conversation

lialan
Copy link
Member

@lialan lialan commented Jun 23, 2025

  • 1-to-1 mapping wrapper op.
  • Direct lowering from AMDGPU wrapper to ROCDL intrinsics.

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-amdgpu

Author: Alan Li (lialan)

Changes
  • 1-to-1 mapping wrapper op.
  • Direct lowering from AMDGPU wrapper to ROCDL intrinsics.

Full diff: https://github.com/llvm/llvm-project/pull/145395.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+21)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+45-2)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+18)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d58558ac32884..003aff6d38da0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,6 +898,27 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_TransposeLoadOp :
+    AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
+    Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
+    Results<(outs MFMAInTypes:$dst)> {
+  let summary = "MLIR wrapper for CDNA Transpose Load instructions";
+  let description = [{
+    The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
+
+    Operands:
+    * `$src`: LDS memref to read from.
+    * `$srcIndices`: indices into `$src` to read from for this thread.
+    * `$dst`: target register this transpose load instruction will write to.
+
+    Note: Lowering is only supported on gfx950 and up.
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_ScaledMFMAOp :
     AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 700563460f525..62ed1d871bcfd 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,6 +1100,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct TransposeLoadOpLowering
+    : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+  TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx950)
+      return op.emitOpError("Non-gfx950 chipset not supported");
+
+    Location loc = op.getLoc();
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             (adaptor.getSrcIndices()));
+    auto elementTypeSize = cast<VectorType>(op.getDst().getType())
+                               .getElementType()
+                               .getIntOrFloatBitWidth();
+
+    // TODO: support ds_read_tr16_b64 intrinsic.
+    switch (elementTypeSize) {
+    case 4:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    case 8:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    case 16:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    default:
+      return op.emitOpError("Unsupported element size for transpose load");
+    }
+    return success();
+  }
+};
+
 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1792,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
            ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
            PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
-           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
-                                                                 chipset);
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+           TransposeLoadOpLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 0d0add3094666..00e9019b79647 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -524,6 +524,24 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+LogicalResult TransposeLoadOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+
+  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
+    return emitOpError("source memory address space must be Workgroup");
+
+  // TODO: support 6-bit element type vectors.
+  auto transferType = dyn_cast<VectorType>(getDst().getType());
+  if (!transferType)
+    return emitOpError("destination type must be a vector type");
+  size_t transferSize =
+      transferType.getNumElements() * transferType.getElementTypeBitWidth();
+  if (transferSize != 64)
+    return emitOpError("Transfering type size must be 64 bits");
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir

Author: Alan Li (lialan)

Changes
  • 1-to-1 mapping wrapper op.
  • Direct lowering from AMDGPU wrapper to ROCDL intrinsics.

Full diff: https://github.com/llvm/llvm-project/pull/145395.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+21)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+45-2)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+18)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d58558ac32884..003aff6d38da0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -898,6 +898,27 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_TransposeLoadOp :
+    AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
+    Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
+    Results<(outs MFMAInTypes:$dst)> {
+  let summary = "MLIR wrapper for CDNA Transpose Load instructions";
+  let description = [{
+    The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
+
+    Operands:
+    * `$src`: LDS memref to read from.
+    * `$srcIndices`: indices into `$src` to read from for this thread.
+    * `$dst`: target register this transpose load instruction will write to.
+
+    Note: Lowering is only supported on gfx950 and up.
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_ScaledMFMAOp :
     AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 700563460f525..62ed1d871bcfd 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1100,6 +1100,49 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct TransposeLoadOpLowering
+    : public ConvertOpToLLVMPattern<TransposeLoadOp> {
+  TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx950)
+      return op.emitOpError("Non-gfx950 chipset not supported");
+
+    Location loc = op.getLoc();
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    Value srcPtr =
+        getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
+                             (adaptor.getSrcIndices()));
+    auto elementTypeSize = cast<VectorType>(op.getDst().getType())
+                               .getElementType()
+                               .getIntOrFloatBitWidth();
+
+    // TODO: support ds_read_tr16_b64 intrinsic.
+    switch (elementTypeSize) {
+    case 4:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr4_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    case 8:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr8_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    case 16:
+      rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(
+          op, op.getDst().getType(), srcPtr);
+      break;
+    default:
+      return op.emitOpError("Unsupported element size for transpose load");
+    }
+    return success();
+  }
+};
+
 struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
   GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1792,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
            ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
            PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
-           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
-                                                                 chipset);
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
+           TransposeLoadOpLowering>(converter, chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 0d0add3094666..00e9019b79647 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -524,6 +524,24 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+LogicalResult TransposeLoadOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+
+  if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
+    return emitOpError("source memory address space must be Workgroup");
+
+  // TODO: support 6-bit element type vectors.
+  auto transferType = dyn_cast<VectorType>(getDst().getType());
+  if (!transferType)
+    return emitOpError("destination type must be a vector type");
+  size_t transferSize =
+      transferType.getNumElements() * transferType.getElementTypeBitWidth();
+  if (transferSize != 64)
+    return emitOpError("Transfering type size must be 64 bits");
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a new amdgpu.transpose_load wrapper operation with verification, TableGen definition, and direct lowering to ROCDL intrinsics.

  • Added TableGen op definition for TransposeLoadOp in the AMDGPU dialect.
  • Implemented TransposeLoadOp::verify() to enforce memory space and type constraints.
  • Created a conversion pattern to lower TransposeLoadOp to ROCDL ds_read_tr intrinsics.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp Added TransposeLoadOp::verify()
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td Defined AMDGPU_TransposeLoadOp in TableGen
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp Added TransposeLoadOpLowering and registered it
Comments suppressed due to low confidence (1)

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp:1796

  • There are no corresponding tests for the new TransposeLoadOp and its lowering. Consider adding unit tests to cover verification and lowering paths for different element sizes and unsupported cases.
           TransposeLoadOpLowering>(converter, chipset);

* 1-to-1 mapping wrapper op.
* Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
@lialan lialan requested review from krzysz00 and kuhar June 23, 2025 19:37
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing AMDGPU dialect tests to show the op

Missing tests for the lowering

Maybe missing a narrow type emulation pattern

@lialan
Copy link
Member Author

lialan commented Jun 23, 2025

@krzysz00 My bad, forgot to include the test file in the PR. updated.

But what do we need for emulating narrow types?

@krzysz00
Copy link
Contributor

We'll want to make a pattern on this op that's analogous to the ones in mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp ... there's a reason I was working on an interface for this sort of thing that I never got around to.

In short, this pass turns memref<...x[small type> to memref<N x i8> and rewrites the indexing accordingly. We'll want to do the indexing adjustments, but keep returning the <L x {i4,f4E2M1FN, i6, ...}> directly


Note: Lowering is only supported on gfx950 and up.
}];
let assemblyFormat = [{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know other ops here don't provide examples, but I think it would be worth adding going forward -- I rely on these all the time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your idea. So I tried to add a very simple example to show the format of the op. In terms of the semantics of the instruction, it is too hard to explain in a few sentences so I wrote that "please refer to the actual document for detailed explanation".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably call out that you mean the CDNA4 ISA manual

F8E3M4 // 3 exponent, 4 mantissa
]>;
def F6Types : AnyTypeOf<[F6E2M3FN, F6E3M2FN]>;
def TrLoadTypes : AnyTypeOf<[VectorOfLengthAndType<[4], [F16, AnyI<16>]>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BF16 exists ... and also, we can probably leave this open and rely on a getIntOrFloatBitWidth() check in the verifier?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, now it accepts any vectors and the verifier will serve as the checker.

Copy link

github-actions bot commented Jun 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking because the lowering's got some footguns in it we need to get rid of


Note: Lowering is only supported on gfx950 and up.
}];
let assemblyFormat = [{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably call out that you mean the CDNA4 ISA manual

@lialan lialan requested review from krzysz00, Copilot and kuhar June 24, 2025 17:37
Copilot

This comment was marked as outdated.

@lialan lialan requested a review from Copilot June 24, 2025 17:48
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds a 1:1 AMDGPU dialect wrapper for ROCDL transpose-load instructions, with direct lowering on gfx950+ and accompanying tests.

  • Introduce amdgpu.transpose_load op in the AMDGPU dialect (.td), implement semantic verification and conversion lowering to ROCDL intrinsics.
  • Add positive and negative MLIR tests for valid and unsupported element sizes in the conversion suite.
  • Wire up the new pattern in the AMDGPU→ROCDL conversion pipeline.

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td Define AMDGPU_TransposeLoadOp with assembly format and docs
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp Implement TransposeLoadOp::verify() for basic checks
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp Add TransposeLoadOpLowering to lower to ROCDL::ds_read_tr*
mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir Positive tests for various supported bit-width patterns
mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir Negative tests rejecting sub-byte element sizes
Comments suppressed due to low confidence (2)

mlir/test/Conversion/AMDGPUToROCDL/transpose_load_reject.mlir:1

  • No test covers the default fallback for unsupported element sizes (e.g., 12-bit); consider adding a case to exercise the Unsupported element size for transpose load path.
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp:27

  • SmallDenseMap is used below but the header for it isn’t included; add #include "llvm/ADT/SmallDenseMap.h" to avoid compilation errors.
#include "llvm/ADT/DenseMap.h"

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add tests to mlir/test/Dialect/AMDGPU/ops.mlir and invalid.mlir

Code itself LGTM

@lialan lialan requested a review from krzysz00 June 25, 2025 19:48
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, thanks for making this happen!

@lialan lialan merged commit 3f3282c into llvm:main Jun 26, 2025
7 checks passed
@lialan lialan deleted the lialan/tr_load branch June 26, 2025 02:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants