Skip to content

[mlir][TilingInterface] Handle multi operand consumer fusion. #145193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,19 +319,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
TilingInterface consumer,
const SCFTileAndFuseOptions &options);

/// Fuse the consumer of the source of `candidateSliceOp` by computing the
/// required slice of the consumer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
/// value but does not delete the slice operation.
/// Fuse the consumer `candidateSlices` by computing the required slice of the
/// consumer in-place. All the entries of `candidateSlices` are expected to map
/// to the same consumer. The method returns an error if the consumer cannot be
/// tiled in a manner that is consistent for all the passed slices. Note that
/// the method replaces the uses of `candidateSlices` with the tiled and fused
/// consumer value but does not delete the slice operations.
struct SCFFuseConsumerOfSliceResult {
OpOperand *origConsumerOperand; // Original untiled consumer's operand.
OpOperand
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
// Original untiled consumer operands.
SmallVector<OpOperand *> origConsumerOperands;
// Tiled and fused consumer operands.
SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops);
tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops);

/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
Expand Down
14 changes: 9 additions & 5 deletions mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ namespace tensor {
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);

/// Method to swap an `tensor.insert_slice` with its consumer when the
/// consumer implements the `TilingInterface`.
/// Method to swap an `tensor.insert_slice`s with its consumer when the
/// consumer implements the `TilingInterface`. The size of `sliceOps` and
/// `consumerOperands` is expected to be the same. Every entry in
/// `consumerOperands` represents a use of the the corresponding
/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
/// expected to be uses in the same consumer.
FailureOr<TilingResult>
replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
OffsetSizeAndStrideOpInterface sliceOp,
OpOperand &consumerOp);
replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
ArrayRef<tensor::InsertSliceOp> sliceOps,
ArrayRef<OpOperand *> consumerOperands);

//===----------------------------------------------------------------------===//
// Populate functions.
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
using PointerUnion<Attribute, Value>::PointerUnion;

public:
void dump() const { llvm::errs() << *this << "\n"; }
LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }

MLIRContext *getContext() const {
PointerUnion pu = *this;
Expand Down
55 changes: 28 additions & 27 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
InterfaceMethod<
/*desc=*/[{
Method to generate the tiled implementation of an operation that uses
exactly a tile of the given operand.
exactly tiles of the given operands.

This method is required to allow operations to be "tiled and fused"
with an (already tiled) producer. Given a tile of the producer, this
method generates the tile of the consumer that uses exactly this
produced tile. In some sense it is the "reverse" of
with an (already tiled) producer. Given tiles of the producer, this
method generates the tile of the consumer that uses exactly these
produced tiles. In some sense it is the "reverse" of
`generateResultTileValue`.
- `operandNumber` is the result of the producer used by the consumer.
- `offsets` is the offset of the slice of the producer result used by
the tiled implementation of the consumer.
- `sizes` is the size of the slice of the producer result used by the
- `operandNumbers` is the list of operands whose tiles are "producers".
- `allOffsets` is the offset of the slice of the producer used by the
tiled implementation of the consumer.
- `allSizes` is the size of the slice of the producer used by the
consumer.
If it is illegal to fuse with a producer along the given operand for
If it is illegal to fuse with a producer along the given operand tiles for
an operation, the implementation should return a failure.
}],
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementationFromOperandTile",
/*methodName=*/"getTiledImplementationFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"unsigned":$operandNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
"::mlir::ArrayRef<unsigned>":$operandNumbers,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
Expand All @@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
tile of the operand.

This method is required to allow operations to be "tiled and fused"
with an (already tiled) producer. Given a tile of an operand,
returns the tile of the iteration space that uses this tile.
- `operandNumber` is the result of the producer used by the consumer.
- `offsets` is the offset of the slice of the producer result used by
with an (already tiled) producer. Given tiles of operands,
returns the tile of the iteration space that uses these tiles.
- `operandNumbers` is the list of operands whose tiles are "produced"
by the producer(s).
- `allOffsets` is the offset of the slice of the producers used by
the tiled implementation of the consumer.
- `sizes` is the size of the slice of the producer result used by the
- `allSizes` is the size of the slice of the producers used by the
consumer.
If it is illegal to fuse with a producer along the given operand for
an operation, or if this mapping cannot be computed, the
implementation should return a failure.
If it is illegal to fuse with the producer slices for an operation,
or if this mapping cannot be computed, the implementation should
return a failure.

Note that unlike the "tile consumer and fuse producer" case, the
"tile producer and fuse consumer" requires an additional method to get
Expand Down Expand Up @@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformation. It does not provide guarantees on whether such a
transformation is profitable.

For most cases `getTiledImplementationFromOperandTile` could be a
implemented using `getIterationDomainTileFromOperandTile` +
For most cases `getTiledImplementationFromOperandTiles` could be a
implemented using `getIterationDomainTileFromOperandTiles` +
`getTiledImplementation` methods.
}],
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"getIterationDomainTileFromOperandTile",
/*methodName=*/"getIterationDomainTileFromOperandTiles",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"unsigned":$operandNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
"::mlir::ArrayRef<unsigned>":$operandNumbers,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
Expand Down
Loading