Skip to content

Commit 00c18d0

Browse files
[mlir][Transforms] Add a PadTilingInterface transformation and hook i… (#144991)
…t up to the transform dialect This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface. Compared to structured.transform.pad it has the following differences: - no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps and have since then been separated out as independent transformations that compose. - no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling. - properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad - geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default) This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively. In the future this should be moved out of a specific Linalg implementation file and into a more general "Structured" file.
1 parent 7085065 commit 00c18d0

File tree

7 files changed

+906
-3
lines changed

7 files changed

+906
-3
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,85 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
11861186
}];
11871187
}
11881188

1189+
//===----------------------------------------------------------------------===//
1190+
// PadTilingInterfaceOp
1191+
//===----------------------------------------------------------------------===//
1192+
1193+
def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interface",
1194+
[FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1195+
TransformOpInterface,
1196+
ReportTrackingListenerFailuresOpTrait]> {
1197+
let description = [{
1198+
Pads the operations pointed to by the target handle using the options
1199+
provided as operation attributes. The operation returns a handle to the
1200+
padded operation and to the padding operation ("tensor.pad").
1201+
1202+
TODO: in the future this should be moved out of a specific Linalg
1203+
implementation file and into a more general "Structured" file.
1204+
1205+
#### Return modes
1206+
1207+
This operation ignores non-Linalg ops and drops them in the return.
1208+
In the future, this operation will support all TilingInterfaceOps.
1209+
1210+
This operation may produce a definite failure if the padding fails for any
1211+
reason.
1212+
1213+
If all the operations referred to by the `target` handle pad properly, the
1214+
transform succeeds. Otherwise the transform produces a silenceable failure.
1215+
The return handle points to only the subset of successfully produced
1216+
padded operations, which can be empty.
1217+
}];
1218+
1219+
let arguments =
1220+
(ins TransformHandleTypeInterface:$target,
1221+
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
1222+
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
1223+
Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
1224+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
1225+
$static_padding_sizes,
1226+
DefaultValuedAttr<UnitAttr, "false">:$pad_to_multiple_of);
1227+
let results = (outs TransformHandleTypeInterface:$padded,
1228+
TransformHandleTypeInterface:$pad);
1229+
1230+
let assemblyFormat = [{
1231+
$target
1232+
`to`
1233+
(`padding_sizes` custom<DynamicIndexList>($padding_sizes, $static_padding_sizes)^)?
1234+
(`pad_to_multiple_of` $pad_to_multiple_of^)?
1235+
attr-dict
1236+
`:` functional-type(operands, results)
1237+
}];
1238+
1239+
let hasVerifier = 1;
1240+
1241+
let builders = [
1242+
// Builder for a transform::PadOp with automatic inference of padding
1243+
// value. Warning: this will set the value 0 for the inferred elemental
1244+
// type without taking the op into account and thus only work for the
1245+
// add/mul ring at the moment.
1246+
// TODO: support other operations (e.g. min, max etc).
1247+
OpBuilder<(ins "Value":$target,
1248+
"ArrayRef<int64_t>":$paddingDimensions,
1249+
CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
1250+
CArg<"bool", "false">:$padToMultipleOf)>,
1251+
OpBuilder<(ins "Value":$target,
1252+
"ArrayRef<int64_t>":$paddingDimensions,
1253+
"ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
1254+
CArg<"bool", "false">:$usePrescribedTensorShapes)>
1255+
];
1256+
1257+
let extraClassDeclaration = [{
1258+
/// Returns a mix of dynamic `padding_sizes` and static `static_padding_sizes`.
1259+
SmallVector<OpFoldResult> getMixedPaddingSizes();
1260+
1261+
::mlir::DiagnosedSilenceableFailure apply(
1262+
::mlir::transform::TransformRewriter &rewriter,
1263+
::mlir::transform::TransformResults &results,
1264+
::mlir::transform::TransformState &state);
1265+
}];
1266+
}
1267+
11891268
//===----------------------------------------------------------------------===//
11901269
// HoistPadOp
11911270
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2121
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2222
#include "mlir/Dialect/X86Vector/Transforms.h"
23+
#include "mlir/IR/OpDefinition.h"
2324
#include "mlir/IR/PatternMatch.h"
2425
#include "mlir/Interfaces/TilingInterface.h"
2526
#include "mlir/Transforms/DialectConversion.h"
@@ -347,6 +348,34 @@ struct LinalgPaddingOptions {
347348
}
348349
};
349350

351+
struct PadTilingInterfaceOptions {
352+
/// A padding value for every operand.
353+
SmallVector<Attribute> paddingValues;
354+
PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
355+
paddingValues.assign(pv.begin(), pv.end());
356+
return *this;
357+
}
358+
/// A list of iterator dimensions to pad.
359+
SmallVector<int64_t> paddingDimensions;
360+
PadTilingInterfaceOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
361+
paddingDimensions.assign(pd.begin(), pd.end());
362+
return *this;
363+
}
364+
/// A list of iterator dimensions sizes to pad to.
365+
SmallVector<OpFoldResult> paddingSizes;
366+
PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
367+
paddingSizes.assign(m.begin(), m.end());
368+
return *this;
369+
}
370+
/// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
371+
/// if true. Otherwise pad to `paddingSizes[i]`.
372+
bool padToMultipleOf;
373+
PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
374+
padToMultipleOf = b;
375+
return *this;
376+
}
377+
};
378+
350379
/// Callback function type used to perform the allocation for the promoted
351380
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
352381
/// smallest constant value for the size of the buffer needed for each
@@ -542,9 +571,9 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
542571
/// where relevant.
543572
void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
544573

545-
/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
546-
/// to a static bounding box. The original `opToPad` is cloned and operates on
547-
/// the padded tensors.
574+
/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
575+
/// operands to a static bounding box. The original `opToPad` is cloned and
576+
/// operates on the padded tensors.
548577
///
549578
/// * "options.padToMultipleOf" indicates that each padding dimension should be
550579
/// padded to the specified multiple.
@@ -561,6 +590,50 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
561590
SmallVector<Value> &replacements,
562591
SmallVector<tensor::PadOp> &padOps);
563592

593+
/// Helper function to compute the padded shape of the given value `v` of
594+
/// `RankedTensorType` given:
595+
/// - the `indexingSizes` as a list of OpFoldResult.
596+
/// - an `indexingMap` that encodes how the padded shape varies with
597+
/// increases in `indexingSizes`.
598+
/// The implementation iteratively combines increases from contributing using
599+
/// affine.apply operations.
600+
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
601+
/// provides a gentle portability path for Linalg-like ops with affine maps.
602+
/// In the future, more general interfaces can be devised to encode similar
603+
/// shape evolutions and map between an op and its operands.
604+
SmallVector<OpFoldResult>
605+
computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
606+
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
607+
const PadTilingInterfaceOptions &options);
608+
609+
using PadSizeComputationFunction =
610+
std::function<FailureOr<SmallVector<OpFoldResult>>(
611+
RewriterBase &, OpOperand &, ArrayRef<Range>,
612+
const PadTilingInterfaceOptions &)>;
613+
614+
/// Specific helper for Linalg ops.
615+
FailureOr<SmallVector<OpFoldResult>>
616+
computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
617+
ArrayRef<Range> iterationDomain,
618+
const PadTilingInterfaceOptions &options);
619+
620+
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
621+
///
622+
/// * "options.paddingSizes" indicates that each padding dimension should be
623+
/// padded to the specified padding size.
624+
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
625+
// interpreted as the bounding box (dynamic) value to pad to.
626+
/// * Use "options.paddingValues" to set the padding value of the created
627+
// tensor::PadOp.
628+
/// * The tensor::PadOp is returned on success.
629+
630+
FailureOr<TilingInterface>
631+
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
632+
const PadTilingInterfaceOptions &constOptions,
633+
SmallVector<tensor::PadOp> &padOps,
634+
PadSizeComputationFunction computePaddingSizeFun =
635+
&computeLinalgPaddedShape);
636+
564637
namespace detail {
565638

566639
/// Helper struct to hold the results of building a packing loop nest.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "llvm/ADT/ScopeExit.h"
4646
#include "llvm/ADT/TypeSwitch.h"
4747
#include "llvm/Support/Debug.h"
48+
#include "llvm/Support/LogicalResult.h"
4849
#include <type_traits>
4950

5051
using namespace mlir;
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
21552156
return success();
21562157
}
21572158

2159+
//===---------------------------------------------------------------------===//
2160+
// PadTilingInterfaceOp
2161+
//===---------------------------------------------------------------------===//
2162+
2163+
void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2164+
OperationState &result,
2165+
Value target,
2166+
ArrayRef<int64_t> paddingDimensions,
2167+
ArrayRef<int64_t> paddingSizes,
2168+
bool padToMultipleOf) {
2169+
auto resultType = transform::AnyOpType::get(b.getContext());
2170+
return build(/*builder=*/b,
2171+
/*result=*/result,
2172+
/*types=*/TypeRange{resultType, resultType},
2173+
/*target=*/target,
2174+
/*paddingValues=*/ArrayAttr(), // let inference handle this
2175+
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
2176+
/*paddingSizes=*/ValueRange{},
2177+
/*paddingSizes=*/
2178+
(paddingSizes.empty() ? DenseI64ArrayAttr()
2179+
: b.getDenseI64ArrayAttr(paddingSizes)),
2180+
/*padToMultipleOf=*/
2181+
padToMultipleOf ? b.getUnitAttr() : nullptr);
2182+
}
2183+
2184+
void transform::PadTilingInterfaceOp::build(
2185+
OpBuilder &b, OperationState &result, Value target,
2186+
ArrayRef<int64_t> paddingDimensions,
2187+
ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2188+
auto resultType = transform::AnyOpType::get(b.getContext());
2189+
SmallVector<int64_t> staticPaddingSizes;
2190+
SmallVector<Value> dynamicPaddingSizes;
2191+
dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2192+
staticPaddingSizes);
2193+
return build(/*builder=*/b,
2194+
/*result=*/result,
2195+
/*types=*/TypeRange{resultType, resultType},
2196+
/*target=*/target,
2197+
/*paddingValues=*/ArrayAttr(), // let inference handle this
2198+
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
2199+
/*paddingSizes=*/dynamicPaddingSizes,
2200+
/*paddingSizes=*/staticPaddingSizes,
2201+
/*usePrescribedTensorShapes=*/padToMultipleOf);
2202+
}
2203+
2204+
void transform::PadTilingInterfaceOp::getEffects(
2205+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2206+
consumesHandle(getTargetMutable(), effects);
2207+
onlyReadsHandle(getPaddingSizesMutable(), effects);
2208+
producesHandle(getOperation()->getOpResults(), effects);
2209+
modifiesPayload(effects);
2210+
}
2211+
2212+
SmallVector<OpFoldResult>
2213+
transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2214+
Builder b(getContext());
2215+
return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2216+
}
2217+
2218+
DiagnosedSilenceableFailure
2219+
transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2220+
transform::TransformResults &results,
2221+
transform::TransformState &state) {
2222+
SmallVector<Operation *> paddedOps, padOps;
2223+
2224+
for (Operation *target : state.getPayloadOps(getTarget())) {
2225+
auto targetOp = dyn_cast<TilingInterface>(target);
2226+
if (!targetOp) {
2227+
auto diag = emitSilenceableError() << "expected TilingInterface target";
2228+
diag.attachNote(target->getLoc()) << "target op";
2229+
return diag;
2230+
}
2231+
2232+
// Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
2233+
// map / C++ APIs to compute the effect of padding on operands.
2234+
if (!isa<LinalgOp>(targetOp.getOperation())) {
2235+
auto diag = emitSilenceableError() << "only LinalgOp supported atm";
2236+
diag.attachNote(target->getLoc()) << "target op";
2237+
return diag;
2238+
}
2239+
2240+
// Convert the padding values to attributes.
2241+
SmallVector<Attribute> paddingValues;
2242+
for (auto const &[untypedAttr, elementOrTensorType] :
2243+
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2244+
auto attr = dyn_cast<TypedAttr>(untypedAttr);
2245+
Type elementType = getElementTypeOrSelf(elementOrTensorType);
2246+
if (!attr) {
2247+
emitOpError("expects padding values to be typed attributes");
2248+
return DiagnosedSilenceableFailure::definiteFailure();
2249+
}
2250+
// Try to parse string attributes to obtain an attribute of element type.
2251+
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2252+
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2253+
stringAttr, getContext(), elementType,
2254+
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2255+
if (!parsedAttr || parsedAttr.getType() != elementType) {
2256+
auto diag = this->emitOpError("expects a padding that parses to ")
2257+
<< elementType << ", got " << attr;
2258+
diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2259+
return DiagnosedSilenceableFailure::definiteFailure();
2260+
}
2261+
paddingValues.push_back(parsedAttr);
2262+
continue;
2263+
}
2264+
// Otherwise, add the attribute directly.
2265+
if (attr.getType() != elementType) {
2266+
auto diag = this->emitOpError("expects a padding value of type ")
2267+
<< elementType << ", got " << attr;
2268+
diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2269+
return DiagnosedSilenceableFailure::definiteFailure();
2270+
}
2271+
paddingValues.push_back(attr);
2272+
}
2273+
2274+
// Set options.
2275+
TilingInterface paddedOp;
2276+
PadTilingInterfaceOptions options;
2277+
options.setPaddingValues(paddingValues)
2278+
.setPaddingDimensions(
2279+
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
2280+
.setPaddingSizes(getMixedPaddingSizes())
2281+
.setPadToMultipleOf(getPadToMultipleOf());
2282+
2283+
// Apply padding.
2284+
SmallVector<tensor::PadOp> newPadOps;
2285+
FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2286+
rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
2287+
newPadOps);
2288+
if (failed(maybePaddedOp)) {
2289+
auto diag = emitSilenceableError() << "failed to pad op";
2290+
diag.attachNote(target->getLoc()) << "target op";
2291+
return diag;
2292+
}
2293+
2294+
// Set transform results.
2295+
paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2296+
padOps.append(newPadOps.begin(), newPadOps.end());
2297+
}
2298+
2299+
results.set(cast<OpResult>(getPadded()), paddedOps);
2300+
results.set(cast<OpResult>(getPad()), padOps);
2301+
return DiagnosedSilenceableFailure::success();
2302+
}
2303+
2304+
LogicalResult transform::PadTilingInterfaceOp::verify() {
2305+
SmallVector<int64_t> paddingDimensions =
2306+
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2307+
if (any_of(paddingDimensions,
2308+
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
2309+
return emitOpError() << "expects padding_dimensions to contain positive "
2310+
"integers, found "
2311+
<< getPaddingDimensions();
2312+
}
2313+
if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
2314+
return emitOpError() << "expects as many multiples as padding_dimensions";
2315+
}
2316+
return success();
2317+
}
2318+
21582319
//===---------------------------------------------------------------------===//
21592320
// HoistPadOp
21602321
//===---------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2929
BlockPackMatmul.cpp
3030
PackAndUnpackPatterns.cpp
3131
Padding.cpp
32+
PadTilingInterface.cpp
3233
Promotion.cpp
3334
RuntimeOpVerification.cpp
3435
Specialize.cpp

0 commit comments

Comments
 (0)