Skip to content

Commit 9bba79f

Browse files
committed
Update
1 parent 60e2c56 commit 9bba79f

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -898,16 +898,6 @@ def AMDGPU_GatherToLDSOp :
898898
let hasVerifier = 1;
899899
}
900900

901-
def F8Types : AnyTypeOf<[
902-
F8E8M0FNU, // 8 exponent, 0 mantissa
903-
F8E5M2, // 5 exponent, 2 mantissa
904-
F8E5M2FNUZ, // 5 exponent, 2 mantissa
905-
F8E4M3, // 4 exponent, 3 mantissa
906-
F8E4M3FN, // 4 exponent, 3 mantissa
907-
F8E4M3B11FNUZ, // 4 exponent, 3 mantissa (with bias 11)
908-
F8E3M4 // 3 exponent, 4 mantissa
909-
]>;
910-
911901
def AMDGPU_TransposeLoadOp :
912902
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
913903
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ struct TransposeLoadOpLowering
11101110
LogicalResult
11111111
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
11121112
ConversionPatternRewriter &rewriter) const override {
1113-
if (chipset < kGfx950)
1113+
if (chipset != kGfx950)
11141114
return op.emitOpError("Non-gfx950 chipset not supported");
11151115

11161116
Location loc = op.getLoc();

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/OpImplementation.h"
2525
#include "mlir/IR/PatternMatch.h"
2626
#include "mlir/IR/TypeUtilities.h"
27+
#include "llvm/ADT/DenseMap.h"
2728
#include "llvm/ADT/TypeSwitch.h"
2829

2930
#include <limits>
@@ -534,10 +535,11 @@ LogicalResult TransposeLoadOp::verify() {
534535
if (!transferType)
535536
return emitOpError("destination type must be a vector type");
536537
size_t numElements = transferType.getNumElements();
537-
size_t elementTypeSize = transferType.getElementType().getIntOrFloatBitWidth();
538+
size_t elementTypeSize =
539+
transferType.getElementType().getIntOrFloatBitWidth();
538540

539541
// ElementSize -> NumElements
540-
const std::map<size_t, size_t> KValidLoadSizeMap = {
542+
const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
541543
{4, 16},
542544
{32, 3}, // 6-bit element loads use casted vector<3xi32>
543545
{8, 8},

0 commit comments

Comments
 (0)