Skip to content

Commit 269cb22

Browse files
[mlir][transform] extract a minimal DomainAndOperandsAffineMapT… (#145034)
…ransferInterface out of LinalgStructuredInterface and use that for PadTilingInterface Along the way, a bug was found on the handling of scalar values, fix it and add a test.
1 parent 376b714 commit 269cb22

File tree

5 files changed

+112
-54
lines changed

5 files changed

+112
-54
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,59 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
222222
];
223223
}
224224

225+
def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
226+
let description = [{
227+
Interface for operations that connect an iteration domain to operands via
228+
affine maps. Provides methods to access indexing maps between iteration
229+
domain and operand index spaces.
230+
}];
231+
let cppNamespace = "::mlir::linalg";
232+
let methods = [
233+
InterfaceMethod<
234+
/*desc=*/[{
235+
Return the indexing maps attribute within the current operation.
236+
}],
237+
/*retTy=*/"ArrayAttr",
238+
/*methodName=*/"getIndexingMaps"
239+
>,
240+
InterfaceMethod<
241+
/*desc=*/[{
242+
Return the indexing maps within the current operation.
243+
}],
244+
/*retTy=*/"SmallVector<AffineMap>",
245+
/*methodName=*/"getIndexingMapsArray",
246+
/*args=*/(ins),
247+
/*methodBody=*/"",
248+
/*defaultImplementation=*/[{
249+
auto range = $_op.getIndexingMaps()
250+
.template getAsValueRange<AffineMapAttr>();
251+
return {range.begin(), range.end()};
252+
}]
253+
>,
254+
InterfaceMethod<
255+
/*desc=*/[{
256+
Return the input or output indexing map for `opOperand`.
257+
}],
258+
/*retTy=*/"AffineMap",
259+
/*methodName=*/"getMatchingIndexingMap",
260+
/*args=*/(ins "OpOperand*":$opOperand),
261+
/*methodBody=*/"",
262+
/*defaultImplementation=*/[{
263+
assert(opOperand->getOwner() == this->getOperation());
264+
auto indexingMaps =
265+
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
266+
return *(indexingMaps.begin() + opOperand->getOperandNumber());
267+
}]
268+
>,
269+
];
270+
}
271+
225272
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
226273
def LinalgStructuredInterface
227-
: OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
274+
: OpInterface<"LinalgOp", [
275+
DestinationStyleOpInterface,
276+
IndexingMapOpInterface
277+
]> {
228278
let cppNamespace = "::mlir::linalg";
229279
let methods = [
230280
//===------------------------------------------------------------------===//
@@ -465,21 +515,6 @@ def LinalgStructuredInterface
465515
blockArgument.getArgNumber());
466516
}]
467517
>,
468-
InterfaceMethod<
469-
/*desc=*/[{
470-
Return the input or output indexing map for `opOperand`.
471-
}],
472-
/*retTy=*/"AffineMap",
473-
/*methodName=*/"getMatchingIndexingMap",
474-
/*args=*/(ins "OpOperand*":$opOperand),
475-
/*methodBody=*/"",
476-
/*defaultImplementation=*/[{
477-
assert(opOperand->getOwner() == this->getOperation());
478-
auto indexingMaps =
479-
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
480-
return *(indexingMaps.begin() + opOperand->getOperandNumber());
481-
}]
482-
>,
483518
InterfaceMethod<
484519
/*desc=*/[{
485520
Return the indexing map for a `result`.
@@ -576,27 +611,6 @@ def LinalgStructuredInterface
576611
/*methodBody=*/"",
577612
/*defaultImplementation=*/[{ return success(); }]
578613
>,
579-
InterfaceMethod<
580-
/*desc=*/[{
581-
Return the indexing maps attribute within the current operation.
582-
}],
583-
/*retTy=*/"ArrayAttr",
584-
/*methodName=*/"getIndexingMaps"
585-
>,
586-
InterfaceMethod<
587-
/*desc=*/[{
588-
Return the indexing maps within the current operation.
589-
}],
590-
/*retTy=*/"SmallVector<AffineMap>",
591-
/*methodName=*/"getIndexingMapsArray",
592-
/*args=*/(ins),
593-
/*methodBody=*/"",
594-
/*defaultImplementation=*/[{
595-
auto range = $_op.getIndexingMaps()
596-
.template getAsValueRange<AffineMapAttr>();
597-
return {range.begin(), range.end()};
598-
}]
599-
>,
600614
InterfaceMethod<
601615
/*desc=*/[{
602616
Return true if any of the operands has a dynamic shape.

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,9 @@ using PadSizeComputationFunction =
612612
const PadTilingInterfaceOptions &)>;
613613

614614
/// Specific helper for Linalg ops.
615-
FailureOr<SmallVector<OpFoldResult>>
616-
computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
617-
ArrayRef<Range> iterationDomain,
618-
const PadTilingInterfaceOptions &options);
615+
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
616+
RewriterBase &rewriter, OpOperand &operandToPad,
617+
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
619618

620619
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
621620
///
@@ -632,7 +631,7 @@ rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
632631
const PadTilingInterfaceOptions &constOptions,
633632
SmallVector<tensor::PadOp> &padOps,
634633
PadSizeComputationFunction computePaddingSizeFun =
635-
&computeLinalgPaddedShape);
634+
&computeIndexingMapOpInterfacePaddedShape);
636635

637636
namespace detail {
638637

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,10 +2229,12 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
22292229
return diag;
22302230
}
22312231

2232-
// Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
2233-
// map / C++ APIs to compute the effect of padding on operands.
2234-
if (!isa<LinalgOp>(targetOp.getOperation())) {
2235-
auto diag = emitSilenceableError() << "only LinalgOp supported atm";
2232+
// Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2233+
// loopsToOperand map / C++ APIs to compute the effect of padding on
2234+
// operands.
2235+
if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2236+
auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2237+
"supported atm";
22362238
diag.attachNote(target->getLoc()) << "target op";
22372239
return diag;
22382240
}

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,13 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
155155
return paddedShape;
156156
}
157157

158-
FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
158+
FailureOr<SmallVector<OpFoldResult>>
159+
linalg::computeIndexingMapOpInterfacePaddedShape(
159160
RewriterBase &rewriter, OpOperand &operandToPad,
160161
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
161-
auto linalgOp = llvm::dyn_cast<LinalgOp>(operandToPad.getOwner());
162-
if (!linalgOp)
162+
auto transferOp =
163+
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
164+
if (!transferOp)
163165
return failure();
164166

165167
// clang-format off
@@ -173,7 +175,7 @@ FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
173175
for (const Range &range : iterationDomain)
174176
loopUpperBounds.push_back(range.size);
175177

176-
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&operandToPad);
178+
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
177179
return computePaddedShape(
178180
rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
179181
indexingMap, loopUpperBounds, options);
@@ -255,7 +257,18 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
255257
SmallVector<Value> newOperands;
256258
newOperands.reserve(opToPad->getNumOperands());
257259
for (OpOperand &opOperand : opToPad->getOpOperands()) {
258-
LLVM_DEBUG(DBGS() << "--start padding oprd: " << opOperand.get() << "\n");
260+
Value operand = opOperand.get();
261+
LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
262+
263+
// 2.a. Skip scalar-like operands.
264+
Type operandType = operand.getType();
265+
if (!isa<RankedTensorType>(operandType)) {
266+
assert(!isa<ShapedType>(operandType) ||
267+
isa<VectorType>(operandType) &&
268+
"Unexpected non-vector ShapedType");
269+
newOperands.push_back(operand);
270+
continue;
271+
}
259272
// 2.a. Compute padded shape.
260273
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
261274
computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
@@ -266,14 +279,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
266279
// 2.b. Expect proper `paddingValues`.
267280
// TODO: we may want to allow garbage padding in the future, in which case
268281
// we would just not assert.
269-
assert(opOperand.getOperandNumber() < options.paddingValues.size() &&
270-
"--no padding value specified");
282+
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
283+
return rewriter.notifyMatchFailure(opToPad,
284+
"--no padding value specified");
285+
}
271286
Attribute paddingValueAttr =
272287
options.paddingValues[opOperand.getOperandNumber()];
273288

274289
// 2.c. Perform actual padding.
275290
Value paddedOperand = padOperand(
276-
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(opOperand.get()),
291+
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
277292
*maybePaddedShape, paddingValueAttr);
278293
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
279294

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
11
// RUN: mlir-opt --transform-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s
22

3+
// CHECK-LABEL: pad_fill
4+
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
5+
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
6+
{
7+
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
8+
func.return %0 : tensor<24x25xf32>
9+
}
10+
11+
module attributes {transform.with_named_sequence} {
12+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
13+
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1
14+
: (!transform.any_op) -> !transform.any_op
15+
16+
// Tile to 5 then pad to 8
17+
%fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
18+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
19+
20+
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
21+
padding_values=[0.0 : f32, 0.0 : f32],
22+
padding_dimensions=[0]
23+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
24+
25+
transform.yield
26+
}
27+
}
28+
29+
// -----
30+
331
// CHECK-LABEL: pad_lhs
432
func.func @pad_lhs(
533
%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)

0 commit comments

Comments
 (0)