Skip to content

Commit

Permalink
[Codegen][GPU] Fail vector distribution if any live conversion ops re…
Browse files Browse the repository at this point in the history
…main (iree-org#16487)

The `to_simt` and `to_simd` ops have no execution semantics (they are
effectively builtin.unrealized_conversion casts), so any live instances
of such ops should be treated as errors. We most likely should just
replace them with `builtin.unrealized_conversion_cast` but updating the
tests at this point is unwieldy.
  • Loading branch information
qedawkins authored Feb 20, 2024
1 parent cd18fdf commit e4ae2b7
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-gpu-vector-distribution"

Expand Down Expand Up @@ -220,15 +225,15 @@ static bool canDistribute(Operation *op, VectorLayoutAnalysis &analysis) {
});
}

void distributeVectorOps(Operation *root,
RewritePatternSet &distributionPatterns,
VectorLayoutOptions &options) {
LogicalResult distributeVectorOps(Operation *root,
RewritePatternSet &distributionPatterns,
VectorLayoutOptions &options) {
// Run the analysis and determine the layouts.
LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n");
VectorLayoutAnalysis analysis(root);
options.setAnchorOps(analysis);
if (failed(analysis.run()))
return;
return failure();
LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeded\n");
LLVM_DEBUG(llvm::dbgs() << "\n\n");

Expand All @@ -245,7 +250,38 @@ void distributeVectorOps(Operation *root,
LLVM_DEBUG(llvm::dbgs() << "\n\n");

FrozenRewritePatternSet frozenPatterns(std::move(distributionPatterns));
return applyVectorDistribution(root, frozenPatterns);
applyVectorDistribution(root, frozenPatterns);

RewritePatternSet patterns(root->getContext());
IREE::VectorExt::ToSIMDOp::getCanonicalizationPatterns(patterns,
root->getContext());
IREE::VectorExt::ToSIMTOp::getCanonicalizationPatterns(patterns,
root->getContext());
if (failed(applyPatternsAndFoldGreedily(root, std::move(patterns)))) {
return failure();
}

if (options.verifyConversion()) {
WalkResult hasConversionOp = root->walk([](Operation *op) {
if (isa<IREE::VectorExt::ToSIMDOp, IREE::VectorExt::ToSIMTOp>(op)) {
for (auto user : op->getUsers()) {
if (!isa<IREE::VectorExt::ToSIMDOp, IREE::VectorExt::ToSIMTOp>(
user)) {
LLVM_DEBUG({
llvm::dbgs() << "Found live cast op: " << *op << "\n";
llvm::dbgs() << "With live user: " << *user << "\n";
});
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (hasConversionOp.wasInterrupted()) {
return failure();
}
}
return success();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ class OpTraitDistributionPattern : public DistributionPattern {
/// distribution.
class VectorLayoutOptions {
public:
VectorLayoutOptions(Operation *root) : root(root) {
VectorLayoutOptions(Operation *root) : root(root), fullConversion(true) {
assert(root && "root operation must be non-null");
}
VectorLayoutOptions(Operation *root, bool fullConversion)
: root(root), fullConversion(fullConversion) {
assert(root && "root operation must be non-null");
}

Expand All @@ -96,8 +100,11 @@ class VectorLayoutOptions {
/// Set the anchor ops in the analysis rooted on the root operation.
virtual void setAnchorOps(VectorLayoutAnalysis &analysis) = 0;

bool verifyConversion() const { return fullConversion; }

protected:
Operation *root;
bool fullConversion = true;
}; // namespace iree_compiler

/// Distribute vector operations in the IR rooted at `root`.
Expand All @@ -112,9 +119,9 @@ class VectorLayoutOptions {
/// - Run a global analysis to determine how to distribute rest of the vector
/// values keeping the initial anchors in mind.
/// - Use the analysis information to distribute each operation.
void distributeVectorOps(Operation *root,
RewritePatternSet &distributionPatterns,
VectorLayoutOptions &options);
LogicalResult distributeVectorOps(Operation *root,
RewritePatternSet &distributionPatterns,
VectorLayoutOptions &options);

} // namespace mlir::iree_compiler

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,11 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
// CHECK: %[[B0_CAST:.+]] = vector.shape_cast %[[B_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
// CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %[[B0_CAST]] + %{{.+}}
// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<4x4xf32> to vector<4x1x1x4xf32>
// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][0, 1] : vector<1x1x1x4xf16> from vector<1x2x1x1x1x4xf16>
// CHECK: %[[B_SLICE1:.+]] = vector.extract %[[B_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
// CHECK: %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[R0_CAST]] : vector<4x1x1x4xf32> to vector<4x4xf32>
// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[CAST]]
// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[MFMA0]]
// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<4x4xf32> to vector<4x1x1x4xf32>
// CHECK: %[[INSERT:.+]] = vector.insert %[[R_CAST]], %{{.+}} [0, 0] : vector<4x1x1x4xf32> into vector<1x1x4x1x1x4xf32>
// CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[INSERT]] : vector<1x1x4x1x1x4xf32> -> vector<32x32xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,8 @@ void transform_dialect::TestVectorLayoutAnalysisOp::getEffects(

class TestVectorLayoutOptions : public VectorLayoutOptions {
public:
TestVectorLayoutOptions(Operation *root) : VectorLayoutOptions(root) {}
TestVectorLayoutOptions(Operation *root)
: VectorLayoutOptions(root, /*fullConversion=*/false) {}

void setAnchorOps(VectorLayoutAnalysis &analysis) override {
setAnchorOpsFromAttributes(analysis, root);
Expand Down Expand Up @@ -970,7 +971,9 @@ transform_dialect::TestGpuVectorDistribution::applyToOne(
populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
if (getExperimental())
populateGPULayoutResolutionDistributionPatterns(patterns);
distributeVectorOps(target, patterns, options);
if (failed(distributeVectorOps(target, patterns, options))) {
return emitDefaultDefiniteFailure(target);
}
return DiagnosedSilenceableFailure::success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,10 @@ struct LLVMGPUVectorDistributePass

ContractionVectorLayoutOptions options(func, *maybeSupportedTypes,
workgroupSize, laneVal);
// TODO: This should return failure when distribution fails for any op.
distributeVectorOps(func, options.getPatterns(), options);
if (failed(distributeVectorOps(func, options.getPatterns(), options))) {
func->emitOpError() << "failed to distribute";
return signalPassFailure();
}
}
};
} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1501,9 +1501,10 @@ void transform_dialect::PackSharedMemoryAllocOp::getEffects(
transform::modifiesPayload(effects);
}

class TestVectorLayoutOptions : public VectorLayoutOptions {
class TransformVectorLayoutOptions : public VectorLayoutOptions {
public:
TestVectorLayoutOptions(Operation *root) : VectorLayoutOptions(root) {}
TransformVectorLayoutOptions(Operation *root, bool fullConversion)
: VectorLayoutOptions(root, fullConversion) {}

void setAnchorOps(VectorLayoutAnalysis &analysis) override {
setAnchorOpsFromAttributes(analysis, root);
Expand All @@ -1515,7 +1516,7 @@ transform_dialect::AMDGPUDistributeVectorsOp::applyToOne(
transform::TransformRewriter &rewriter, mlir::FunctionOpInterface target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
TestVectorLayoutOptions options(target);
TransformVectorLayoutOptions options(target, !getTestConversion());
RewritePatternSet patterns(target.getContext());

rewriter.setInsertionPointToStart(&target.getFunctionBody().front());
Expand All @@ -1527,7 +1528,9 @@ transform_dialect::AMDGPUDistributeVectorsOp::applyToOne(
populateGPUReductionDistributionPatterns(patterns);
populateGPUDistributeNestedLayoutAttrPatterns(laneId, patterns);
populateAMDGPUDistributionPatterns(patterns);
distributeVectorOps(target, patterns, options);
if (failed(distributeVectorOps(target, patterns, options))) {
return emitDefaultSilenceableFailure(target);
}
return DiagnosedSilenceableFailure::success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,14 @@ def AMDGPUDistributeVectorsOp :
This transform does not consume the target handle and always return success.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$test_conversion);
let results = (outs);

let assemblyFormat = [{ $target attr-dict `:` type($target)}];
let assemblyFormat = [{
$target (`test_conversion` $test_conversion^)?
attr-dict `:` type($target)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";

let extraClassDeclaration = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ builtin.module attributes { transform.with_named_sequence } {
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func : !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : !transform.any_op
transform.yield
}
}
Expand Down Expand Up @@ -84,7 +84,7 @@ builtin.module attributes { transform.with_named_sequence } {
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func : !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : !transform.any_op
transform.yield
}
}
Expand Down Expand Up @@ -132,7 +132,7 @@ builtin.module attributes { transform.with_named_sequence } {
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func : !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : !transform.any_op
transform.yield
}
}
Expand Down Expand Up @@ -179,7 +179,7 @@ builtin.module attributes { transform.with_named_sequence } {
}
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func : !transform.any_op
transform.iree.amdgpu_distribute_vectors %top_level_func test_conversion : !transform.any_op
transform.yield
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ module attributes { transform.with_named_sequence } {
transform.iree.set_contraction_layout_attributes %contracts, %layout16x16x16 : !transform.any_op, !transform.any_param

%distribute_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
transform.iree.amdgpu_distribute_vectors %distribute_func : !transform.any_op
transform.iree.amdgpu_distribute_vectors %distribute_func test_conversion : !transform.any_op

transform.apply_patterns to %distribute_func {
transform.apply_patterns.canonicalization
Expand Down

0 comments on commit e4ae2b7

Please sign in to comment.