-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][TilingInterface] Handle multi operand consumer fusion. #145193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][TilingInterface] Handle multi operand consumer fusion. #145193
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-scf Author: None (MaheshRavishankar) ChangesFor consumer fusion cases of this form
the current consumer fusion that handles one slice at a time cannot fuse the consumer into the loop, since fusing along one slice will create and SSA violation on the other use from the
to allow fusion while considering multiple operands. It is upto the The Linalg operation implementation of Additional change : Add Patch is 60.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145193.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..7b6e3cba5723d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -319,19 +319,24 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);
-/// Fuse the consumer of the source of `candidateSliceOp` by computing the
-/// required slice of the consumer in-place. Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
-/// value but does not delete the slice operation.
+/// Fuse the consumer of the result of every element of `candidateSliceOp` by
+/// computing the required slice of the consumer in-place. All the entries of
+/// `candidateSlices` are expected to map to the same consumer. The method
+/// returns an error if the consumer cannot be tiled in a manner that is
+/// consistent for all the passed slices. Note that the method replaces the uses
+/// of `candidateSliceOp` with the tiled and fused consumer value but does not
+/// delete the slice operation.
struct SCFFuseConsumerOfSliceResult {
- OpOperand *origConsumerOperand; // Original untiled consumer's operand.
- OpOperand
- *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
+ // Original untiled consumer's operand.
+ SmallVector<OpOperand *> origConsumerOperands;
+ // Tiled and fused consumer's operand.
+ SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
- MutableArrayRef<LoopLikeOpInterface> loops);
+tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
+ ArrayRef<Operation *> candidateSlices,
+ MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 18981337742eb..8f6eb1bd47782 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,12 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
-/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`.
+/// Method to swap an `tensor.insert_slice`s with its consumer when the
+/// consumer implements the `TilingInterface`. The size of `sliceOps` and
+/// `consumerOperands` is expected to be the same. Every entry in
+/// `consumerOperands` represents the use of the result of the corresponding
+/// entry in `sliceOps`. All entries of `consumerOperands` is expected to be
+/// uses in the same consumer.
FailureOr<TilingResult>
-replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
- OffsetSizeAndStrideOpInterface sliceOp,
- OpOperand &consumerOp);
+replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
+ ArrayRef<tensor::InsertSliceOp> sliceOps,
+ ArrayRef<OpOperand *> consumerOperands);
//===----------------------------------------------------------------------===//
// Populate functions.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 31f54413a5ff0..663c256c848df 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
using PointerUnion<Attribute, Value>::PointerUnion;
public:
- void dump() const { llvm::errs() << *this << "\n"; }
+ LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
MLIRContext *getContext() const {
PointerUnion pu = *this;
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..7ebdd8907e964 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation that uses
- exactly a tile of the given operand.
+ exactly tiles of the given operands.
This method is required to allow operations to be "tiled and fused"
- with an (already tiled) producer. Given a tile of the producer, this
- method generates the tile of the consumer that uses exactly this
- produced tile. In some sense it is the "reverse" of
+ with an (already tiled) producer. Given a tiles of the producer, this
+ method generates the tile of the consumer that uses exactly these
+ produced tiles. In some sense it is the "reverse" of
`generateResultTileValue`.
- - `operandNumber` is the result of the producer used by the consumer.
- - `offsets` is the offset of the slice of the producer result used by
- the tiled implementation of the consumer.
- - `sizes` is the size of the slice of the producer result used by the
+ - `operandNumbers` is the list of operands whose tiles are "produced".
+ - `allOffsets` is the offset of the slice of the producer results used
+ by the tiled implementation of the consumer.
+ - `allSizes` is the size of the slice of the producer results used by the
consumer.
- If it is illegal to fuse with a producer along the given operand for
+ If it is illegal to fuse with a producer along the given operand tiles for
an operation, the implementation should return a failure.
}],
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
- /*methodName=*/"getTiledImplementationFromOperandTile",
+ /*methodName=*/"getTiledImplementationFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
- "unsigned":$operandNumber,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -235,13 +235,14 @@ def TilingInterface : OpInterface<"TilingInterface"> {
tile of the operand.
This method is required to allow operations to be "tiled and fused"
- with an (already tiled) producer. Given a tile of an operand,
- returns the tile of the iteration space that uses this tile.
- - `operandNumber` is the result of the producer used by the consumer.
- - `offsets` is the offset of the slice of the producer result used by
- the tiled implementation of the consumer.
- - `sizes` is the size of the slice of the producer result used by the
- consumer.
+ with an (already tiled) producer. Given tiles of an operand,
+ returns the tile of the iteration space that uses these tiles.
+ - `operandNumbers` is the list of operands whose tiles are "produced"
+ by the producer(s).
+ - `allOffsets` is the offset of the slice of the producer results
+ used by the tiled implementation of the consumer.
+ - `allSizes` is the size of the slice of the producer results used by
+ the consumer.
If it is illegal to fuse with a producer along the given operand for
an operation, or if this mapping cannot be computed, the
implementation should return a failure.
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformation. It does not provide guarantees on whether such a
transformation is profitable.
- For most cases `getTiledImplementationFromOperandTile` could be a
- implemented using `getIterationDomainTileFromOperandTile` +
+ For most cases `getTiledImplementationFromOperandTiles` could be a
+ implemented using `getIterationDomainTileFromOperandTiles` +
`getTiledImplementation` methods.
}],
/*retType=*/"::llvm::LogicalResult",
- /*methodName=*/"getIterationDomainTileFromOperandTile",
+ /*methodName=*/"getIterationDomainTileFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
- "unsigned":$operandNumber,
- "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
- "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..86045b54075bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -147,55 +147,80 @@ struct LinalgOpTilingInterface
/// Utility to fetch the offsets and sizes when applied as per the indexing
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
/// a given slice op.
- void
- getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<OpFoldResult> &mappedOffsets,
- SmallVectorImpl<OpFoldResult> &mappedSizes) const {
- unsigned numLoops = linalgOp.getNumLoops();
- auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
- mappedOffsets.resize(numLoops);
- mappedSizes.resize(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
- mappedOffsets[index] = value.offset;
- mappedSizes[index] = value.size;
+ static LogicalResult
+ getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+ ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
+ SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
+ SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
+ DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+
+ for (auto [indexingMap, offsets, sizes] :
+ llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+ for (auto [resultExpr, offset, size] :
+ llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+ if (!dimExpr)
+ continue;
+ unsigned position = dimExpr.getPosition();
+ auto it = mappedOffsets.find(position);
+ if (it != mappedOffsets.end()) {
+ OpFoldResult seenOffset = it->second;
+ OpFoldResult seenSize = mappedSizes.lookup(position);
+ if (seenOffset != offset || seenSize != size) {
+ return linalgOp->emitOpError(
+ "inconsistent iteration space mapping from offsets/sizes of "
+ "operands/results");
+ }
+ } else {
+ mappedOffsets[position] = offset;
+ mappedSizes[position] = size;
+ }
}
}
- for (const auto &&[index, value] :
- llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
- mappedOffsets[dimPosition] = offsets[index];
- mappedSizes[dimPosition] = sizes[index];
+
+ // Aggregate from the given operand offsets and sizes, or default to
+ // iteration space values.
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
+ mappedOffsetsVec.resize(iterationDomain.size());
+ mappedSizesVec.resize(iterationDomain.size());
+ for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
+ auto it = mappedOffsets.find(index);
+ if (it != mappedOffsets.end()) {
+ mappedOffsetsVec[index] = it->second;
+ mappedSizesVec[index] = mappedSizes.lookup(index);
+ continue;
+ }
+ mappedOffsetsVec[index] = domain.offset;
+ mappedOffsetsVec[index] = domain.size;
}
+ return success();
}
/// Method to return the position of the result tile computed by the tiled
/// operation.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
- // Check that the indexing map used for the operand is a projected
- // permutation. This could be relaxed with a more general approach that can
- // map the offsets and sizes from the operand to iteration space tiles
- // (filling in full extent for dimensions not used to access the result).
- AffineMap indexingMap =
- linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
- if (!indexingMap.isProjectedPermutation()) {
- return op->emitError()
- << "unhandled get iter domain position when operand is not "
- "accessed using a permuted projection";
+ std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
+ iterationSpaceSizes;
+ SmallVector<AffineMap> indexingMaps =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+ OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+ return linalgOp.getMatchingIndexingMap(&opOperand);
+ });
+ if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
+ allSizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return failure();
}
-
- getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- iterDomainOffsets, iterDomainSizes);
return success();
}
@@ -246,8 +271,13 @@ struct LinalgOpTilingInterface
"accessed using a permuted projection");
}
- getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- iterDomainOffsets, iterDomainSizes);
+ SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
+ SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
+ auto status =
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
+ {allSizes}, iterDomainOffsets, iterDomainSizes);
+ (void)status;
+ assert(succeeded(status) && "unexpected error in offset calculation");
return success();
}
@@ -278,12 +308,13 @@ struct LinalgOpTilingInterface
/// Method to generate the tiled implementation of an operation from the tile
/// of the operand.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, operandNumber, offsets, sizes, mappedOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
mappedSizes))) {
return failure();
}
@@ -750,13 +781,17 @@ struct PackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
/// `resultSizes` only cover outer dimensions.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
- if (operandNumber != 0)
- return failure();
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+ return op->emitOpError("unsupporeted operands for consumer fusion");
+
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
// It is not trivial to infer dest tile from source tile if `packOp` has
@@ -817,11 +852,15 @@ struct PackOpTiling
}
/// Method to return the tiled implementation of tensor.pack as a consumer.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
- if (operandNumber != 0)
- return failure();
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+ return op->emitOpError("unhandled operands for consumer fusion");
+
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
Location loc = packOp.getLoc();
@@ -836,8 +875,8 @@ struct PackOpTiling
tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
outerDimSizes)))
return failure();
@@ -1095,12 +1134,20 @@ struct UnPackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpB...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: None (MaheshRavishankar) ChangesFor consumer fusion cases of this form
the current consumer fusion that handles one slice at a time cannot fuse the consumer into the loop, since fusing along one slice will create and SSA violation on the other use from the
to allow fusion while considering multiple operands. It is upto the The Linalg operation implementation of Additional change : Add Patch is 60.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145193.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..7b6e3cba5723d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -319,19 +319,24 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);
-/// Fuse the consumer of the source of `candidateSliceOp` by computing the
-/// required slice of the consumer in-place. Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
-/// value but does not delete the slice operation.
+/// Fuse the consumer of the result of every element of `candidateSliceOp` by
+/// computing the required slice of the consumer in-place. All the entries of
+/// `candidateSlices` are expected to map to the same consumer. The method
+/// returns an error if the consumer cannot be tiled in a manner that is
+/// consistent for all the passed slices. Note that the method replaces the uses
+/// of `candidateSliceOp` with the tiled and fused consumer value but does not
+/// delete the slice operation.
struct SCFFuseConsumerOfSliceResult {
- OpOperand *origConsumerOperand; // Original untiled consumer's operand.
- OpOperand
- *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
+ // Original untiled consumer's operand.
+ SmallVector<OpOperand *> origConsumerOperands;
+ // Tiled and fused consumer's operand.
+ SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
- MutableArrayRef<LoopLikeOpInterface> loops);
+tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
+ ArrayRef<Operation *> candidateSlices,
+ MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 18981337742eb..8f6eb1bd47782 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,12 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
-/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`.
+/// Method to swap an `tensor.insert_slice`s with its consumer when the
+/// consumer implements the `TilingInterface`. The size of `sliceOps` and
+/// `consumerOperands` is expected to be the same. Every entry in
+/// `consumerOperands` represents the use of the result of the corresponding
+/// entry in `sliceOps`. All entries of `consumerOperands` is expected to be
+/// uses in the same consumer.
FailureOr<TilingResult>
-replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
- OffsetSizeAndStrideOpInterface sliceOp,
- OpOperand &consumerOp);
+replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
+ ArrayRef<tensor::InsertSliceOp> sliceOps,
+ ArrayRef<OpOperand *> consumerOperands);
//===----------------------------------------------------------------------===//
// Populate functions.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 31f54413a5ff0..663c256c848df 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
using PointerUnion<Attribute, Value>::PointerUnion;
public:
- void dump() const { llvm::errs() << *this << "\n"; }
+ LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
MLIRContext *getContext() const {
PointerUnion pu = *this;
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..7ebdd8907e964 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation that uses
- exactly a tile of the given operand.
+ exactly tiles of the given operands.
This method is required to allow operations to be "tiled and fused"
- with an (already tiled) producer. Given a tile of the producer, this
- method generates the tile of the consumer that uses exactly this
- produced tile. In some sense it is the "reverse" of
+ with an (already tiled) producer. Given a tiles of the producer, this
+ method generates the tile of the consumer that uses exactly these
+ produced tiles. In some sense it is the "reverse" of
`generateResultTileValue`.
- - `operandNumber` is the result of the producer used by the consumer.
- - `offsets` is the offset of the slice of the producer result used by
- the tiled implementation of the consumer.
- - `sizes` is the size of the slice of the producer result used by the
+ - `operandNumbers` is the list of operands whose tiles are "produced".
+ - `allOffsets` is the offset of the slice of the producer results used
+ by the tiled implementation of the consumer.
+ - `allSizes` is the size of the slice of the producer results used by the
consumer.
- If it is illegal to fuse with a producer along the given operand for
+ If it is illegal to fuse with a producer along the given operand tiles for
an operation, the implementation should return a failure.
}],
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
- /*methodName=*/"getTiledImplementationFromOperandTile",
+ /*methodName=*/"getTiledImplementationFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
- "unsigned":$operandNumber,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
- "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -235,13 +235,14 @@ def TilingInterface : OpInterface<"TilingInterface"> {
tile of the operand.
This method is required to allow operations to be "tiled and fused"
- with an (already tiled) producer. Given a tile of an operand,
- returns the tile of the iteration space that uses this tile.
- - `operandNumber` is the result of the producer used by the consumer.
- - `offsets` is the offset of the slice of the producer result used by
- the tiled implementation of the consumer.
- - `sizes` is the size of the slice of the producer result used by the
- consumer.
+ with an (already tiled) producer. Given tiles of an operand,
+ returns the tile of the iteration space that uses these tiles.
+ - `operandNumbers` is the list of operands whose tiles are "produced"
+ by the producer(s).
+ - `allOffsets` is the offset of the slice of the producer results
+ used by the tiled implementation of the consumer.
+ - `allSizes` is the size of the slice of the producer results used by
+ the consumer.
If it is illegal to fuse with a producer along the given operand for
an operation, or if this mapping cannot be computed, the
implementation should return a failure.
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformation. It does not provide guarantees on whether such a
transformation is profitable.
- For most cases `getTiledImplementationFromOperandTile` could be a
- implemented using `getIterationDomainTileFromOperandTile` +
+ For most cases `getTiledImplementationFromOperandTiles` could be a
+ implemented using `getIterationDomainTileFromOperandTiles` +
`getTiledImplementation` methods.
}],
/*retType=*/"::llvm::LogicalResult",
- /*methodName=*/"getIterationDomainTileFromOperandTile",
+ /*methodName=*/"getIterationDomainTileFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
- "unsigned":$operandNumber,
- "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
- "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+ "::mlir::ArrayRef<unsigned>":$operandNumbers,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
+ "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..86045b54075bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -147,55 +147,80 @@ struct LinalgOpTilingInterface
/// Utility to fetch the offsets and sizes when applied as per the indexing
/// map of the linalg op. This helps in fusing the linalg op as a consumer of
/// a given slice op.
- void
- getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- SmallVectorImpl<OpFoldResult> &mappedOffsets,
- SmallVectorImpl<OpFoldResult> &mappedSizes) const {
- unsigned numLoops = linalgOp.getNumLoops();
- auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
- mappedOffsets.resize(numLoops);
- mappedSizes.resize(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
- mappedOffsets[index] = value.offset;
- mappedSizes[index] = value.size;
+ static LogicalResult
+ getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+ ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
+ SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
+ SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
+ DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+
+ for (auto [indexingMap, offsets, sizes] :
+ llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+ for (auto [resultExpr, offset, size] :
+ llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+ if (!dimExpr)
+ continue;
+ unsigned position = dimExpr.getPosition();
+ auto it = mappedOffsets.find(position);
+ if (it != mappedOffsets.end()) {
+ OpFoldResult seenOffset = it->second;
+ OpFoldResult seenSize = mappedSizes.lookup(position);
+ if (seenOffset != offset || seenSize != size) {
+ return linalgOp->emitOpError(
+ "inconsistent iteration space mapping from offsets/sizes of "
+ "operands/results");
+ }
+ } else {
+ mappedOffsets[position] = offset;
+ mappedSizes[position] = size;
+ }
}
}
- for (const auto &&[index, value] :
- llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
- mappedOffsets[dimPosition] = offsets[index];
- mappedSizes[dimPosition] = sizes[index];
+
+ // Aggregate from the given operand offsets and sizes, or default to
+ // iteration space values.
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
+ mappedOffsetsVec.resize(iterationDomain.size());
+ mappedSizesVec.resize(iterationDomain.size());
+ for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
+ auto it = mappedOffsets.find(index);
+ if (it != mappedOffsets.end()) {
+ mappedOffsetsVec[index] = it->second;
+ mappedSizesVec[index] = mappedSizes.lookup(index);
+ continue;
+ }
+ mappedOffsetsVec[index] = domain.offset;
+ mappedOffsetsVec[index] = domain.size;
}
+ return success();
}
/// Method to return the position of the result tile computed by the tiled
/// operation.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
auto linalgOp = cast<LinalgOp>(op);
- // Check that the indexing map used for the operand is a projected
- // permutation. This could be relaxed with a more general approach that can
- // map the offsets and sizes from the operand to iteration space tiles
- // (filling in full extent for dimensions not used to access the result).
- AffineMap indexingMap =
- linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
- if (!indexingMap.isProjectedPermutation()) {
- return op->emitError()
- << "unhandled get iter domain position when operand is not "
- "accessed using a permuted projection";
+ std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
+ iterationSpaceSizes;
+ SmallVector<AffineMap> indexingMaps =
+ llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+ OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+ return linalgOp.getMatchingIndexingMap(&opOperand);
+ });
+ if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
+ allSizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ return failure();
}
-
- getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- iterDomainOffsets, iterDomainSizes);
return success();
}
@@ -246,8 +271,13 @@ struct LinalgOpTilingInterface
"accessed using a permuted projection");
}
- getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
- iterDomainOffsets, iterDomainSizes);
+ SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
+ SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
+ auto status =
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
+ {allSizes}, iterDomainOffsets, iterDomainSizes);
+ (void)status;
+ assert(succeeded(status) && "unexpected error in offset calculation");
return success();
}
@@ -278,12 +308,13 @@ struct LinalgOpTilingInterface
/// Method to generate the tiled implementation of an operation from the tile
/// of the operand.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, operandNumber, offsets, sizes, mappedOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
mappedSizes))) {
return failure();
}
@@ -750,13 +781,17 @@ struct PackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
/// `resultSizes` only cover outer dimensions.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
- if (operandNumber != 0)
- return failure();
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+ return op->emitOpError("unsupporeted operands for consumer fusion");
+
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
// It is not trivial to infer dest tile from source tile if `packOp` has
@@ -817,11 +852,15 @@ struct PackOpTiling
}
/// Method to return the tiled implementation of tensor.pack as a consumer.
- FailureOr<TilingResult> getTiledImplementationFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
- if (operandNumber != 0)
- return failure();
+ FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+ Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+ ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+ if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+ return op->emitOpError("unhandled operands for consumer fusion");
+
+ ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+ ArrayRef<OpFoldResult> sizes(allSizes[0]);
auto packOp = cast<PackOp>(op);
Location loc = packOp.getLoc();
@@ -836,8 +875,8 @@ struct PackOpTiling
tiledOperands.push_back(sourceSlice);
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
- if (failed(getIterationDomainTileFromOperandTile(
- op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+ if (failed(getIterationDomainTileFromOperandTiles(
+ op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
outerDimSizes)))
return failure();
@@ -1095,12 +1134,20 @@ struct UnPackOpTiling
/// Method to return the position of iteration domain tile computed by the
/// tiled operation.
- LogicalResult getIterationDomainTileFromOperandTile(
- Operation *op, OpBuilder &b, unsigned operandNumber,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ LogicalResult getIterationDomainTileFromOperandTiles(
+ Operation *op, OpB...
[truncated]
|
9eb6341
to
c86e95b
Compare
For consumer fusion cases of this form ``` %0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) { tensor.parallel_insert_slice ... into %arg0 tensor.parallel_insert_slice ... into %arg1 } %1 = linalg.generic ... ins(%0#0, %0#1) ``` the current consumer fusion that handles one slice at a time cannot fuse the consumer into the loop, since fusing along one slice will create and SSA violation on the other use from the `scf.forall`. The solution is to allow consumer fusion to allow considering multiple slices at once. This PR changes the `TilingInterface` methods related to consumer fusion, i.e. - `getTiledImplementationFromOperandTile` - `getIterationDomainFromOperandTile` to allow fusion while considering multiple operands. It is upto the `TilingInterface` implementation to return an error if a list of tiles of the operands cannot result in a consistent implementation of the tiled operation. The Linalg operation implementation of `TilingInterface` has been modified to account for these changes and allow cases where operand tiles that can result in a consistent tiling implementation are handled. Signed-off-by: MaheshRavishankar <[email protected]>
c86e95b
to
f6a0713
Compare
For consumer fusion cases of this form
the current consumer fusion that handles one slice at a time cannot fuse the consumer into the loop, since fusing along one slice will create and SSA violation on the other use from the
scf.forall
. The solution is to allow consumer fusion to allow considering multiple slices at once. This PR changes theTilingInterface
methods related to consumer fusion, i.e.getTiledImplementationFromOperandTile
getIterationDomainFromOperandTile
to allow fusion while considering multiple operands. It is upto the
TilingInterface
implementation to return an error if a list of tiles of the operands cannot result in a consistent implementation of the tiled operation.The Linalg operation implementation of
TilingInterface
has been modified to account for these changes and allow cases where operand tiles that can result in a consistent tiling implementation are handled.