Skip to content

Commit

Permalink
[rocdl] Deduce and plumb MMA schedule in vector distribution (iree-or…
Browse files Browse the repository at this point in the history
…g#16482)

This commit reuses the basic heuristics for deducing cooperative matrix
configuration in SPIR-V for ROCDL. We deduce a full tiling schedule and
attach it as a newly introduced `iree_gpu.mma_schedule` attribute for
the vector distribution pipeline to pick up and convert into concrete
`iree_gpu.mfma_layout` to drive contraction vector distribution.
  • Loading branch information
antiagainst authored Feb 20, 2024
1 parent e4ae2b7 commit b42e627
Show file tree
Hide file tree
Showing 18 changed files with 442 additions and 292 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

#define DEBUG_TYPE "iree-amdgpu-distribute-contract"
#define DEBUG_TYPE "iree-codegen-amdgpu-distribute-contract"

namespace mlir::iree_compiler {
namespace {
Expand Down
16 changes: 8 additions & 8 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::optional<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds) {
for (const GPUMatmulShapeType &intrinsic : intrinsics) {
for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType ||
problem.cType != intrinsic.cType) {
continue; // Cannot use this intrinsic for mismatched types
Expand All @@ -35,8 +35,8 @@ deduceMMASchedule(const GPUMatmulShapeType &problem,
int64_t mTotalTileCount = problem.mSize / intrinsic.mSize;
int64_t nTotalTileCount = problem.nSize / intrinsic.nSize;

int64_t remainingWarps = seeds.numSubgroupsPerWorkgroup;
int64_t remainingTiles = seeds.numMNTilesPerSubgroup;
int64_t remainingWarps = seeds.bestSubgroupCountPerWorkgroup;
int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup;
// Assign more warps to the M dimension (used later) to balance thread
// counts along X and Y dimensions.
int64_t warpSqrt = 1ull
Expand Down Expand Up @@ -90,8 +90,8 @@ deduceMMASchedule(const GPUMatmulShapeType &problem,
}

const uint64_t kTotalTileCount = problem.kSize / intrinsic.kSize;
APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount),
APInt(64, seeds.numKTilesPerSubgroup));
APInt kGCD = GreatestCommonDivisor(
APInt(64, kTotalTileCount), APInt(64, seeds.bestKTileCountPerSubgroup));
int64_t kTileCount = kGCD.getSExtValue();

LLVM_DEBUG({
Expand All @@ -103,9 +103,9 @@ deduceMMASchedule(const GPUMatmulShapeType &problem,
llvm::dbgs() << " subgroup tile count (M, N, K) = (" << mTileCount
<< ", " << nTileCount << ", " << kTileCount << ")\n";
});
return GPUMMASchedule{intrinsic.mSize, intrinsic.nSize, intrinsic.kSize,
mWarpCount, nWarpCount, mTileCount,
nTileCount, kTileCount};
return GPUMMASchedule{index, intrinsic.mSize, intrinsic.nSize,
intrinsic.kSize, mWarpCount, nWarpCount,
mTileCount, nTileCount, kTileCount};
}
return std::nullopt;
}
Expand Down
14 changes: 8 additions & 6 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ struct GPUMatmulShapeType {

/// Struct containing seed tile sizes for GPU MMA heuristics deduction logic.
struct GPUMMAHeuristicSeeds {
// The default number of subgroups to use per workgroup
int64_t numSubgroupsPerWorkgroup;
// The default number of tiles along M/N dimension to use per workgroup
int64_t numMNTilesPerSubgroup;
// The default number of tiles along K dimension to use per subgroup
int64_t numKTilesPerSubgroup;
// The best number of subgroups to use per workgroup
int64_t bestSubgroupCountPerWorkgroup;
// The best number of total tiles along M*N dimensions per subgroup
int64_t bestMNTileCountPerSubgroup;
// The best number of tiles along K dimension per subgroup
int64_t bestKTileCountPerSubgroup;
};

struct GPUMMASchedule {
// Index of the chosen intrinsic into the list of given MMA intrinsics
uint64_t index;
int64_t mSize; // Native MMA size along M dimension
int64_t nSize; // Native MMA size along N dimension
int64_t kSize; // Native MMA size along K dimension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ iree_compiler_cc_library(
":IREEGPUInterfaces",
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
"//llvm-external-projects/iree-dialects:IREEVectorExtDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DialectUtils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_cc_library(
MLIRVectorDialect
iree::compiler::Codegen::Common
iree::compiler::Codegen::Utils
iree::compiler::Codegen::Utils::VectorOpUtils
PUBLIC
)

Expand Down
225 changes: 151 additions & 74 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"

#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -27,6 +29,7 @@ using VectorLayoutInterface =
mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterface;
using PerDimLayoutAttr = mlir::iree_compiler::IREE::VectorExt::PerDimLayoutAttr;
using LayoutAttr = mlir::iree_compiler::IREE::VectorExt::LayoutAttr;
using NestedLayoutAttr = mlir::iree_compiler::IREE::VectorExt::NestedLayoutAttr;

namespace mlir::iree_compiler::IREE::GPU {

Expand Down Expand Up @@ -324,7 +327,7 @@ MFMAAttr::getContractionLayout(vector::ContractionOp contract) const {
return IREE::GPU::getContractionLayout(contract, layout);
}

int64_t MFMAAttr::getBlockSize() {
int64_t MFMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MFMAIntrinsic::F16_16x16x16_F32: {
return 1;
Expand All @@ -337,95 +340,169 @@ int64_t MFMAAttr::getBlockSize() {
return 0;
}

//===----------------------------------------------------------------------===//
// Initialize attributes
//===----------------------------------------------------------------------===//

void IREEGPUDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp.inc" // IWYU pragma: keep
>();
MFMAAttr::SingleSubgroupLayout MFMAAttr::getASingleSubgroupLayoutCount() const {
switch (getIntrinsic().getValue()) {
case MFMAIntrinsic::F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*element=*/{1, 4}};
}
case MFMAIntrinsic::F16_32x32x8_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*element=*/{1, 4}};
}
}
return {};
}

} // namespace mlir::iree_compiler::IREE::GPU

namespace mlir::iree_compiler {

std::optional<IREE::GPU::MmaAttr>
getCompatibleMmaAttr(ArrayAttr mmaKinds, vector::ContractionOp contract) {
SmallVector<int64_t> iterationBounds;
contract.getIterationBounds(iterationBounds);
return getCompatibleMmaAttr(mmaKinds, contract.getIndexingMapsArray(),
iterationBounds, contract->getOperandTypes());
MFMAAttr::SingleSubgroupLayout MFMAAttr::getBSingleSubgroupLayoutCount() const {
switch (getIntrinsic().getValue()) {
case MFMAIntrinsic::F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*element=*/{4, 1}};
}
case MFMAIntrinsic::F16_32x32x8_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*element=*/{4, 1}};
}
}
return {};
}

std::optional<IREE::GPU::MmaAttr>
getCompatibleMmaAttr(ArrayAttr mmaKinds, linalg::LinalgOp linalgOp) {
return getCompatibleMmaAttr(mmaKinds, linalgOp.getIndexingMapsArray(),
linalgOp.getStaticLoopRanges(),
linalgOp->getOperandTypes());
MFMAAttr::SingleSubgroupLayout MFMAAttr::getCSingleSubgroupLayoutCount() const {
switch (getIntrinsic().getValue()) {
case MFMAIntrinsic::F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*element=*/{4, 1}};
}
case MFMAIntrinsic::F16_32x32x8_F32: {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*element=*/{4, 1}};
}
}
return {};
}

std::optional<IREE::GPU::MmaAttr>
getCompatibleMmaAttr(ArrayAttr mmaKinds, ArrayRef<AffineMap> indexingMaps,
ArrayRef<int64_t> iterationBounds, TypeRange inputTypes) {
FailureOr<linalg::ContractionDimensions> maybeContractionDims =
linalg::inferContractionDims(indexingMaps);
if (failed(maybeContractionDims)) {
return std::nullopt;
MFMAAttr::SingleSubgroupLayout MFMAAttr::getASingleSubgroupLayoutOrder() const {
switch (getIntrinsic().getValue()) {
case MFMAIntrinsic::F16_16x16x16_F32:
case MFMAIntrinsic::F16_32x32x8_F32: {
return {/*outer=*/{0, 1}, /*thread=*/{1, 0}, /*element=*/{0, 1}};
}
auto contractionDims = *maybeContractionDims;
}
return {};
}

// TODO: Relax this condition once distribution supports it.
if (contractionDims.k.size() != 1 || contractionDims.m.size() != 1 ||
contractionDims.n.size() != 1) {
return std::nullopt;
MFMAAttr::SingleSubgroupLayout MFMAAttr::getBSingleSubgroupLayoutOrder() const {
switch (getIntrinsic().getValue()) {
case MFMAIntrinsic::F16_16x16x16_F32:
case MFMAIntrinsic::F16_32x32x8_F32: {
return {/*outer=*/{0, 1}, /*thread=*/{0, 1}, /*element=*/{1, 0}};
}
}
return {};
}

unsigned mDim = contractionDims.m[0];
unsigned nDim = contractionDims.n[0];
unsigned kDim = contractionDims.k[0];
MFMAAttr::SingleSubgroupLayout MFMAAttr::getCSingleSubgroupLayoutOrder() const {
switch (getIntrinsic().getValue()) {
case MFMAIntrinsic::F16_16x16x16_F32:
case MFMAIntrinsic::F16_32x32x8_F32: {
return {/*outer=*/{0, 1}, /*thread=*/{0, 1}, /*element=*/{1, 0}};
}
}
return {};
}

int64_t problemMSize = iterationBounds[mDim];
int64_t problemNSize = iterationBounds[nDim];
int64_t problemKSize = iterationBounds[kDim];
//===----------------------------------------------------------------------===//
// MMA Schedule Attributes
//===----------------------------------------------------------------------===//

// Bail on dynamic shapes. Once better support for dynamic cases is in place,
// a separate helper should be added for dynamic and unaligned.
if (ShapedType::isDynamic(problemMSize) ||
ShapedType::isDynamic(problemNSize) ||
ShapedType::isDynamic(problemKSize)) {
std::optional<std::tuple<VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface,
VectorExt::VectorLayoutInterface>>
MMAScheduleAttr::getContractionLayout(vector::ContractionOp contractOp) const {
VectorContractOpInfo opInfo(contractOp);
if (opInfo.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN)
return std::nullopt;
}

if (inputTypes.size() != 3) {
return std::nullopt;
}
auto [aM, bN] = *opInfo.getOperandMNIndex();
auto [aK, bK] = *opInfo.getOperandKIndex();
auto [cM, cN] = *opInfo.getResultMNIndex();
SmallVector<int64_t, 2> aPermute = {aM, aK};
SmallVector<int64_t, 2> bPermute = {bK, bN};
SmallVector<int64_t, 2> cPermute = {cM, cN};

Type lhsType = getElementTypeOrSelf(inputTypes[0]);
Type rhsType = getElementTypeOrSelf(inputTypes[1]);
Type accType = getElementTypeOrSelf(inputTypes[2]);
// TODO: drop this and permute the following fields.
if (!isIdentityPermutation(aPermute) || !isIdentityPermutation(bPermute) ||
!isIdentityPermutation(cPermute))
return std::nullopt;

for (Attribute a : mmaKinds.getValue()) {
auto mmaKind = dyn_cast<IREE::GPU::MmaAttr>(a);
if (!mmaKind) {
return std::nullopt;
}
auto mfmaAttr = llvm::cast<MFMAAttr>(getIntrinsic());

// TODO: revisit the handling of subgroup/thread basis.
SmallVector<int64_t, 2> subgroupBasis = {getSubgroupMCount(),
getSubgroupNCount()};

// C matrix layout
MFMAAttr::SingleSubgroupLayout cCounts =
mfmaAttr.getCSingleSubgroupLayoutCount();
MFMAAttr::SingleSubgroupLayout cOrders =
mfmaAttr.getCSingleSubgroupLayoutOrder();

SmallVector<int64_t, 2> cSubgroupPerWorkgroup = {getSubgroupMCount(),
getSubgroupNCount()};
SmallVector<int64_t, 2> cBatchesPerSubgroup = {getSubgroupMTileCount(),
getSubgroupNTileCount()};
SmallVector<int64_t, 2> cSubgroupOrder = {0, 1};
SmallVector<int64_t, 2> cBatchOrder = {0, 1};
SmallVector<int64_t, 2> cThreadBasis = cCounts.thread;

auto cLayout = NestedLayoutAttr::get(
getContext(), cSubgroupPerWorkgroup, cSubgroupOrder, cBatchesPerSubgroup,
cBatchOrder, cCounts.outer, cOrders.outer, cCounts.thread, cOrders.thread,
cCounts.element, cOrders.element, subgroupBasis, cThreadBasis);

// A matrix layout
MFMAAttr::SingleSubgroupLayout aCounts =
mfmaAttr.getASingleSubgroupLayoutCount();
MFMAAttr::SingleSubgroupLayout aOrders =
mfmaAttr.getASingleSubgroupLayoutOrder();

SmallVector<int64_t, 2> aSubgroupPerWorkgroup = {getSubgroupMCount(), 1};
SmallVector<int64_t, 2> aBatchesPerSubgroup = {getSubgroupMTileCount(),
getSubgroupKTileCount()};
SmallVector<int64_t, 2> aSubgroupOrder = {0, 1};
SmallVector<int64_t, 2> aBatchOrder = {0, 1};
SmallVector<int64_t, 2> aThreadBasis = aCounts.thread;

auto aLayout = NestedLayoutAttr::get(
getContext(), aSubgroupPerWorkgroup, aSubgroupOrder, aBatchesPerSubgroup,
aBatchOrder, aCounts.outer, aOrders.outer, aCounts.thread, aOrders.thread,
aCounts.element, aOrders.element, subgroupBasis, aThreadBasis);

// B matrix layout
MFMAAttr::SingleSubgroupLayout bCounts =
mfmaAttr.getBSingleSubgroupLayoutCount();
MFMAAttr::SingleSubgroupLayout bOrders =
mfmaAttr.getBSingleSubgroupLayoutOrder();

SmallVector<int64_t, 2> bSubgroupPerWorkgroup = {1, getSubgroupNCount()};
SmallVector<int64_t, 2> bBatchesPerSubgroup = {getSubgroupKTileCount(),
getSubgroupNTileCount()};
SmallVector<int64_t, 2> bSubgroupOrder = {0, 1};
SmallVector<int64_t, 2> bBatchOrder = {0, 1};
SmallVector<int64_t, 2> bThreadBasis = bCounts.thread;

auto bLayout = NestedLayoutAttr::get(
getContext(), bSubgroupPerWorkgroup, bSubgroupOrder, bBatchesPerSubgroup,
bBatchOrder, bCounts.outer, bOrders.outer, bCounts.thread, bOrders.thread,
bCounts.element, bOrders.element, subgroupBasis, bThreadBasis);

return std::make_tuple(aLayout, bLayout, cLayout);
}

auto [typeA, typeB, typeC] = mmaKind.getABCElementTypes();
if (typeA != lhsType || typeB != rhsType || typeC != accType) {
continue;
}
//===----------------------------------------------------------------------===//
// Attribute Registration
//===----------------------------------------------------------------------===//

auto [sizeM, sizeN, sizeK] = mmaKind.getMNKShape();
if (problemMSize % sizeM != 0 || problemNSize % sizeN != 0 ||
problemKSize % sizeK != 0) {
continue;
}
return mmaKind;
}
return std::nullopt;
void IREEGPUDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp.inc" // IWYU pragma: keep
>();
}

} // namespace mlir::iree_compiler
} // namespace mlir::iree_compiler::IREE::GPU
Loading

0 comments on commit b42e627

Please sign in to comment.