Skip to content

Commit c86e95b

Browse files
[mlir][TilingInterface] Handle multi operand consumer fusion.
For consumer fusion cases of this form ``` %0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) { tensor.parallel_insert_slice ... into %arg0 tensor.parallel_insert_slice ... into %arg1 } %1 = linalg.generic ... ins(%0#0, %0#1) ``` the current consumer fusion that handles one slice at a time cannot fuse the consumer into the loop, since fusing along one slice will create and SSA violation on the other use from the `scf.forall`. The solution is to allow consumer fusion to allow considering multiple slices at once. This PR changes the `TilingInterface` methods related to consumer fusion, i.e. - `getTiledImplementationFromOperandTile` - `getIterationDomainFromOperandTile` to allow fusion while considering multiple operands. It is upto the `TilingInterface` implementation to return an error if a list of tiles of the operands cannot result in a consistent implementation of the tiled operation. The Linalg operation implementation of `TilingInterface` has been modified to account for these changes and allow cases where operand tiles that can result in a consistent tiling implementation are handled. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent f280d3b commit c86e95b

File tree

9 files changed

+667
-205
lines changed

9 files changed

+667
-205
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,19 +319,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
319319
TilingInterface consumer,
320320
const SCFTileAndFuseOptions &options);
321321

322-
/// Fuse the consumer of the source of `candidateSliceOp` by computing the
323-
/// required slice of the consumer in-place. Note that the method
324-
/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
325-
/// value but does not delete the slice operation.
322+
/// Fuse the consumer `candidateSlices` by computing the required slice of the
323+
/// consumer in-place. All the entries of `candidateSlices` are expected to map
324+
/// to the same consumer. The method returns an error if the consumer cannot be
325+
/// tiled in a manner that is consistent for all the passed slices. Note that
326+
/// the method replaces the uses of `candidateSlices` with the tiled and fused
327+
/// consumer value but does not delete the slice operations.
326328
struct SCFFuseConsumerOfSliceResult {
327-
OpOperand *origConsumerOperand; // Original untiled consumer's operand.
328-
OpOperand
329-
*tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
329+
// Original untiled consumer operands.
330+
SmallVector<OpOperand *> origConsumerOperands;
331+
// Tiled and fused consumer operands.
332+
SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
330333
SmallVector<Operation *> tiledOps;
331334
};
332335
FailureOr<scf::SCFFuseConsumerOfSliceResult>
333-
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
334-
MutableArrayRef<LoopLikeOpInterface> loops);
336+
tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
337+
ArrayRef<Operation *> candidateSlices,
338+
MutableArrayRef<LoopLikeOpInterface> loops);
335339

336340
/// Method to lower an `op` that implements the `TilingInterface` to
337341
/// loops/scalars.

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ namespace tensor {
3131
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
3232
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
3333

34-
/// Method to swap an `tensor.insert_slice` with its consumer when the
35-
/// consumer implements the `TilingInterface`.
34+
/// Method to swap an `tensor.insert_slice`s with its consumer when the
35+
/// consumer implements the `TilingInterface`. The size of `sliceOps` and
36+
/// `consumerOperands` is expected to be the same. Every entry in
37+
/// `consumerOperands` represents a use of the the corresponding
38+
/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is
39+
/// expected to be uses in the same consumer.
3640
FailureOr<TilingResult>
37-
replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
38-
OffsetSizeAndStrideOpInterface sliceOp,
39-
OpOperand &consumerOp);
41+
replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
42+
ArrayRef<tensor::InsertSliceOp> sliceOps,
43+
ArrayRef<OpOperand *> consumerOperands);
4044

4145
//===----------------------------------------------------------------------===//
4246
// Populate functions.

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
202202
InterfaceMethod<
203203
/*desc=*/[{
204204
Method to generate the tiled implementation of an operation that uses
205-
exactly a tile of the given operand.
205+
exactly tiles of the given operands.
206206

207207
This method is required to allow operations to be "tiled and fused"
208-
with an (already tiled) producer. Given a tile of the producer, this
209-
method generates the tile of the consumer that uses exactly this
210-
produced tile. In some sense it is the "reverse" of
208+
with an (already tiled) producer. Given tiles of the producer, this
209+
method generates the tile of the consumer that uses exactly these
210+
produced tiles. In some sense it is the "reverse" of
211211
`generateResultTileValue`.
212-
- `operandNumber` is the result of the producer used by the consumer.
213-
- `offsets` is the offset of the slice of the producer result used by
214-
the tiled implementation of the consumer.
215-
- `sizes` is the size of the slice of the producer result used by the
212+
- `operandNumbers` is the list of operands whose tiles are "producers".
213+
- `allOffsets` is the offset of the slice of the producer used by the
214+
tiled implementation of the consumer.
215+
- `allSizes` is the size of the slice of the producer used by the
216216
consumer.
217-
If it is illegal to fuse with a producer along the given operand for
217+
If it is illegal to fuse with a producer along the given operand tiles for
218218
an operation, the implementation should return a failure.
219219
}],
220220
/*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
221-
/*methodName=*/"getTiledImplementationFromOperandTile",
221+
/*methodName=*/"getTiledImplementationFromOperandTiles",
222222
/*args=*/(ins
223223
"::mlir::OpBuilder &":$b,
224-
"unsigned":$operandNumber,
225-
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
226-
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
224+
"::mlir::ArrayRef<unsigned>":$operandNumbers,
225+
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
226+
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
227227
/*methodBody=*/"",
228228
/*defaultImplementation=*/[{
229229
return failure();
@@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
235235
tile of the operand.
236236

237237
This method is required to allow operations to be "tiled and fused"
238-
with an (already tiled) producer. Given a tile of an operand,
239-
returns the tile of the iteration space that uses this tile.
240-
- `operandNumber` is the result of the producer used by the consumer.
241-
- `offsets` is the offset of the slice of the producer result used by
238+
with an (already tiled) producer. Given tiles of operands,
239+
returns the tile of the iteration space that uses these tiles.
240+
- `operandNumbers` is the list of operands whose tiles are "produced"
241+
by the producer(s).
242+
- `allOffsets` is the offset of the slice of the producers used by
242243
the tiled implementation of the consumer.
243-
- `sizes` is the size of the slice of the producer result used by the
244+
- `allSizes` is the size of the slice of the producers used by the
244245
consumer.
245-
If it is illegal to fuse with a producer along the given operand for
246-
an operation, or if this mapping cannot be computed, the
247-
implementation should return a failure.
246+
If it is illegal to fuse with the producer slices for an operation,
247+
or if this mapping cannot be computed, the implementation should
248+
return a failure.
248249

249250
Note that unlike the "tile consumer and fuse producer" case, the
250251
"tile producer and fuse consumer" requires an additional method to get
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
285286
transformation. It does not provide guarantees on whether such a
286287
transformation is profitable.
287288

288-
For most cases `getTiledImplementationFromOperandTile` could be a
289-
implemented using `getIterationDomainTileFromOperandTile` +
289+
For most cases `getTiledImplementationFromOperandTiles` could be a
290+
implemented using `getIterationDomainTileFromOperandTiles` +
290291
`getTiledImplementation` methods.
291292
}],
292293
/*retType=*/"::llvm::LogicalResult",
293-
/*methodName=*/"getIterationDomainTileFromOperandTile",
294+
/*methodName=*/"getIterationDomainTileFromOperandTiles",
294295
/*args=*/(ins
295296
"::mlir::OpBuilder &":$b,
296-
"unsigned":$operandNumber,
297-
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
298-
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
297+
"::mlir::ArrayRef<unsigned>":$operandNumbers,
298+
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
299+
"::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
299300
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
300301
"::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
301302
/*methodBody=*/"",

0 commit comments

Comments
 (0)