Skip to content

Commit bbb57ea

Browse files
committed
Update
1 parent c8157f0 commit bbb57ea

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -907,17 +907,11 @@ def F8Types : AnyTypeOf<[
907907
F8E4M3B11FNUZ, // 4 exponent, 3 mantissa (with bias 11)
908908
F8E3M4 // 3 exponent, 4 mantissa
909909
]>;
910-
def F6Types : AnyTypeOf<[F6E2M3FN, F6E3M2FN]>;
911-
def TrLoadTypes : AnyTypeOf<[VectorOfLengthAndType<[4], [F16, AnyI<16>]>,
912-
VectorOfLengthAndType<[8], [F8Types, AnyI<8>]>,
913-
VectorOfLengthAndType<[16], [AnyI<4>, F6Types]>,
914-
VectorOfLengthAndType<[3], [I32]>,
915-
]>;
916910

917911
def AMDGPU_TransposeLoadOp :
918912
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
919913
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
920-
Results<(outs TrLoadTypes:$result)> {
914+
Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
921915
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
922916
let description = [{
923917
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -530,30 +530,29 @@ LogicalResult TransposeLoadOp::verify() {
530530
if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
531531
return emitOpError("source memory address space must be Workgroup");
532532

533-
// TODO: support 6-bit element type vectors.
534533
auto transferType = dyn_cast<VectorType>(getType());
535534
if (!transferType)
536535
return emitOpError("destination type must be a vector type");
537-
size_t transferSize =
538-
transferType.getNumElements() * transferType.getElementTypeBitWidth();
536+
size_t numElements = transferType.getNumElements();
539537
size_t elementTypeSize = srcType.getElementType().getIntOrFloatBitWidth();
540538

541-
// ElementSize -> LoadSize
539+
// ElementSize -> NumElements
542540
const std::map<size_t, size_t> KValidLoadSizeMap = {
543-
{4, 64},
544-
{32, 96}, // 6-bit element loads use casted vector<3xi32>
545-
{8, 64},
546-
{16, 64},
541+
{4, 16},
542+
{32, 3}, // 6-bit element loads use casted vector<3xi32>
543+
{8, 8},
544+
{16, 4},
547545
};
548546

549-
auto validLoadSize = KValidLoadSizeMap.find(elementTypeSize);
550-
if (validLoadSize == KValidLoadSizeMap.end()) {
547+
auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
548+
if (validNumElems == KValidLoadSizeMap.end()) {
551549
return emitOpError("Unsupported element type size for transpose load: ")
552550
<< elementTypeSize << " bits";
553551
}
554-
if (transferSize != validLoadSize->second) {
555-
return emitOpError("Transferring type size must be ")
556-
<< validLoadSize->second << " bits for element type size ";
552+
if (numElements != validNumElems->second) {
553+
return emitOpError(
554+
"Transferring type size mismatch: expected num of elements: ")
555+
<< validNumElems->second;
557556
}
558557

559558
return success();

0 commit comments

Comments
 (0)