Skip to content

[mlir][PartialReductionTilingInterface] Add support for ReductionTilingStrategy::PartialReductionOuterParallel in tileUsingSCF. #143988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -2019,6 +2019,7 @@ def TileReductionUsingForallOp :

// TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
@@ -2036,10 +2037,11 @@ def TileReductionUsingForallOp :

let assemblyFormat = [{
$target
(`reduction_dims` `=` $reduction_dims^)?
`by`
(`num_threads` `=` $num_threads^)?
(`,` `tile_sizes` `=` $tile_sizes^)?
(`,` `mapping` `=` $mapping^)?
(`tile_sizes` `=` $tile_sizes^)?
(`mapping` `=` $mapping^)?
attr-dict
`:` functional-type(operands, results)
}];
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
@@ -156,7 +156,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// corresponding pair of arrays. This is the inverse function of
/// `getMixedValues`.
std::pair<SmallVector<int64_t>, SmallVector<Value>>
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);

/// Helper to sort `values` according to matching `keys`.
SmallVector<Value>
28 changes: 23 additions & 5 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
@@ -367,23 +367,28 @@ def PartialReductionOpInterface :
OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
let description = [{
Interface for allowing operations to expose information needed to
tile reductions using partial reduction followed by merge. This is
complementary to TilingInterface to tile reductions.
tile reductions using partial reduction followed by merge. This
extends the `TilingInterface` to allow splitting a reduction
dimension into a parallel dimension and reduction dimension.
The materialized inter-tile loop could either be the reduction dimension
(i.e. `ReductionTilingStrategy::PartialReductionOuterReduction`) or
the parallel dimension (i.e
`ReductionTilingStrategy::PartialReductionOuterReduction`).
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Method to generate a tensor initalized with the identity value of the
operation reduction. The tensor shape is equal to operation result
reduction operator. The tensor shape is equal to operation result
shape with new dimension for each non zero tile size.
}],
/*retType=*/"::mlir::FailureOr<SmallVector<Value>>",
/*methodName=*/"generateInitialTensorForPartialReduction",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"Location":$loc,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$tileSizes,
"const ::mlir::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -396,6 +401,11 @@ def PartialReductionOpInterface :
reduction dimension are converted to parallel dimensions with a size
less or equal to the tile size. This is meant to be used with
`mergeReductions` method which will combine the partial reductions.
The method recieves the `offset` and `sizes` for all iteration space
dimensions, as well as the iteration number of the tiled reduction
dimensions (which is the induction variable of the inter-tile loop
for the reduction dimension divided by the step of the loop) in
`splitReductionIvs`.
}],
/*retType=*/"::mlir::FailureOr<TilingResult>",
/*methodName=*/"tileToPartialReduction",
@@ -406,7 +416,8 @@ def PartialReductionOpInterface :
"ValueRange":$init,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
"const ::llvm::SetVector<unsigned> &":$reductionDims),
"const ::llvm::SetVector<unsigned> &":$reductionDims,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -436,15 +447,22 @@ def PartialReductionOpInterface :
the tiled operation. This is same as
TilingInterface:::getResultTilePosition, but determines the result
tile position for partial reduction.
The method recieves the `offset` and `sizes` for all iteration space
dimensions, as well as the iteration number of the tiled reduction
dimensions (which is the induction variable of the inter-tile loop
for the reduction dimension divided by the tile size specified) in
`splitReductionIvs`.
}],
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"getPartialResultTilePosition",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"unsigned":$resultNumber,
"ReductionTilingStrategy":$tilingStrategy,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
"const ::mlir::SetVector<unsigned> &":$reductionDims,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
/*methodBody=*/"",
37 changes: 30 additions & 7 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
@@ -3022,6 +3022,7 @@ void transform::TileReductionUsingForallOp::build(
build(builder, result,
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
/*target=*/target,
/*reduction_dims=*/{},
/*num_threads=*/staticNumThreadsAttr,
/*tile_sizes=*/staticTileSizesAttr,
/*mapping=*/mapping);
@@ -3036,23 +3037,45 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
SmallVector<OpFoldResult> tileSizes =
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
FailureOr<linalg::ForallReductionTilingResult> result =
linalg::tileReductionUsingForall(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
numThreads, tileSizes, getMapping());

scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
options.setReductionTilingStrategy(
ReductionTilingStrategy::PartialReductionOuterParallel);
if (!getNumThreads().empty()) {
options.setNumThreads(numThreads);
} else {
options.setTileSizes(tileSizes);
}
if (auto mapping = getMapping()) {
options.setMapping(mapping.value().getValue());
}
SmallVector<unsigned> reductionDims =
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
if (reductionDims.empty()) {
for (auto [idx, iteratorType] :
llvm::enumerate(target.getIteratorTypesArray())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
}
options.setReductionDims(reductionDims);
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
rewriter, cast<TilingInterface>(target.getOperation()), options);

if (failed(result)) {
auto diag = emitSilenceableError() << "could not tile reduction";
diag.attachNote(target.getLoc()) << "target operation";
return diag;
}
rewriter.replaceOp(target, result->replacements);

for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
for (auto parallelTiledOp : result->parallelTiledOps)
for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp);
for (auto mergeOp : result->mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
}

196 changes: 133 additions & 63 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -328,6 +328,17 @@ struct LinalgOpTilingInterface
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
//===----------------------------------------------------------------------===//

/// In a given set vector, get the position of a particular element.
std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims,
unsigned value) {
for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
if (reductionDim == value) {
return index;
}
}
return std::nullopt;
}

/// Return an AffineMaps to use for the `outs` operands of the linalg op
/// generated for partial results. The new AffineMap is the AffineMap of the
/// untiled op with reduction dimensions appended at end in order in which they
@@ -348,28 +359,86 @@ getPartialResultAffineMaps(LinalgOp linalgOp,
return partialReductionMaps;
}

/// Return the slice of the `initValue` to use as input to the partial reduction
/// op generated.
static Operation *getInitSliceForOuterReduction(
OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets,
struct InitSliceInfo {
SmallVector<int64_t> resultShape;
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
};

/// Return the result shape, offsets, sizes and strides of the slice of the
/// `initValue` to use as the destination of the partial reduction op generated
/// with outer reduction strategy.
static InitSliceInfo getInitSliceInfoForOuterReduction(
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
AffineMap partialReductionMap) {
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1));
Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
for (AffineExpr dimExpr : partialReductionMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (reductionDims.contains(dim)) {
initOffsets.push_back(b.getIndexAttr(0));
initOffsets.push_back(zero);
} else {
initOffsets.push_back(offsets[dim]);
}
initSizes.push_back(sizes[dim]);
}
// TODO: Use SubsetExtractOpInterface here once available.
auto extractSlice = b.create<tensor::ExtractSliceOp>(
loc, initValue, initOffsets, initSizes, initStrides);
return extractSlice;
SmallVector<int64_t> resultShape;
std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
return {resultShape, initOffsets, initSizes, initStrides};
}

/// Return the result shape, offsets, sizes and strides of the slice of the
/// `initValue` to use as destination of the partial reduction op generated with
/// outer parallel strategy.
static InitSliceInfo getInitSliceInfoForOuterParallel(
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
SmallVector<OpFoldResult> resultShape;
for (AffineExpr dimExpr : partialReductionMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
initOffsets.push_back(splitReductionIvs[dimPos.value()]);
initSizes.push_back(one);
} else {
initOffsets.push_back(offsets[dim]);
initSizes.push_back(sizes[dim]);
resultShape.push_back(sizes[dim]);
}
}
SmallVector<int64_t> staticShapes;
std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape);
return {staticShapes, initOffsets, initSizes, initStrides};
}

/// Return the result shape, offsets, sizes and strides of the slice of the
/// `initValue` to use as destination of the partial reduction op.
static InitSliceInfo getInitSliceInfo(MLIRContext *context,
ReductionTilingStrategy strategy,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
const SetVector<unsigned> &reductionDims,
ArrayRef<OpFoldResult> splitReductionIvs,
AffineMap partialReductionMap) {
if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
return getInitSliceInfoForOuterReduction(context, offsets, sizes,
reductionDims, splitReductionIvs,
partialReductionMap);
}
assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
"unexpected ReductionTilingStrategy");
return getInitSliceInfoForOuterParallel(context, offsets, sizes,
reductionDims, splitReductionIvs,
partialReductionMap);
}

/// External model implementation of PartialReductionInterface for
@@ -390,21 +459,6 @@ struct LinalgOpPartialReductionInterface
SmallVector<AffineMap> partialResultMaps =
getPartialResultAffineMaps(linalgOp, reductionDims);

// LinalgOp implements TilingInterface.
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
SmallVector<OpFoldResult> shape =
llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
[](Range x) { return x.size; });

SmallVector<OpFoldResult> tiledShape;
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
if (isZeroInteger(tileSize)) {
tiledShape.push_back(dimSize);
} else {
tiledShape.push_back(tileSize);
}
}

SmallVector<Value> inits;
for (auto [initIdx, result, partialMap] :
llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
@@ -424,7 +478,7 @@ struct LinalgOpPartialReductionInterface
SmallVector<OpFoldResult> partialResultShape;
for (AffineExpr dimExpr : partialMap.getResults()) {
auto dim = cast<AffineDimExpr>(dimExpr);
partialResultShape.push_back(tiledShape[dim.getPosition()]);
partialResultShape.push_back(sizes[dim.getPosition()]);
}

Type elType = getElementTypeOrSelf(result.getType());
@@ -444,13 +498,8 @@ struct LinalgOpPartialReductionInterface
ReductionTilingStrategy tilingStrategy,
ValueRange init, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
const SetVector<unsigned> &reductionDims) const {
if (tilingStrategy !=
ReductionTilingStrategy::PartialReductionOuterReduction) {
// TODO: Add support for `PartialReductionOuterParallel` strategy.
return op->emitOpError("unsupported partial reduction tiling with "
"`PartialReductionOuterParallel` strategy");
}
const SetVector<unsigned> &reductionDims,
ArrayRef<OpFoldResult> splitReductionIvs) const {
OpBuilder::InsertionGuard guard(b);
auto linalgOp = cast<LinalgOp>(op);

@@ -459,7 +508,16 @@ struct LinalgOpPartialReductionInterface

// Step 1. Extend init maps to have reduction dimension dims, since we
// are converting them to parallel dimensions.
SmallVector<AffineMap> newInitMaps = partialReductionMaps;
SmallVector<AffineMap> newInitMaps;
if (tilingStrategy ==
ReductionTilingStrategy::PartialReductionOuterReduction) {
newInitMaps = llvm::to_vector(partialReductionMaps);
} else {
newInitMaps = llvm::map_to_vector(
linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
return linalgOp.getMatchingIndexingMap(&opOperand);
});
}

// Step 2a: Extract a slice of the input operands.
SmallVector<Value> tiledInputs = makeTiledShapes(
@@ -473,10 +531,17 @@ struct LinalgOpPartialReductionInterface
SmallVector<Value, 1> tiledInits;
for (auto [partialReductionMap, valueToTile] :
llvm::zip_equal(partialReductionMaps, init)) {
Operation *sliceOp =
getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes,
reductionDims, partialReductionMap);
tiledInits.push_back(sliceOp->getResult(0));
InitSliceInfo sliceInfo = getInitSliceInfo(
b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
splitReductionIvs, partialReductionMap);
auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
RankedTensorType sliceResultType = RankedTensorType::get(
sliceInfo.resultShape, valueToTileType.getElementType(),
valueToTileType.getEncoding());
auto sliceOp = b.create<tensor::ExtractSliceOp>(
loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes,
sliceInfo.strides);
tiledInits.push_back(sliceOp.getResult());
generatedSlices.push_back(sliceOp);
}

@@ -491,19 +556,31 @@ struct LinalgOpPartialReductionInterface
// Step 3. Change the reduction dim iterator types.
SmallVector<utils::IteratorType> newIteratorTypes =
linalgOp.getIteratorTypesArray();
for (int dim : reductionDims)
newIteratorTypes[dim] = utils::IteratorType::parallel;
if (tilingStrategy ==
ReductionTilingStrategy::PartialReductionOuterReduction) {
for (int dim : reductionDims)
newIteratorTypes[dim] = utils::IteratorType::parallel;
}

// Step 4. Create the new generic op.
Operation *partialReductionOp;
auto resultTypes = ValueRange(tiledInits).getTypes();
auto genericOp = b.create<GenericOp>(loc, resultTypes, tiledInputs,
tiledInits, newMaps, newIteratorTypes);
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
if (tilingStrategy ==
ReductionTilingStrategy::PartialReductionOuterReduction) {
auto genericOp = b.create<GenericOp>(
loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes);
IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
partialReductionOp = genericOp.getOperation();
} else {
SmallVector<Value> operands = std::move(tiledInputs);
llvm::append_range(operands, tiledInits);
partialReductionOp = mlir::clone(b, op, resultTypes, operands);
}
return TilingResult{
{genericOp.getOperation()},
llvm::map_to_vector(genericOp->getResults(),
{partialReductionOp},
llvm::map_to_vector(partialReductionOp->getResults(),
[](OpResult r) -> Value { return r; }),
generatedSlices};
}
@@ -558,26 +635,19 @@ struct LinalgOpPartialReductionInterface

LogicalResult getPartialResultTilePosition(
Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
const SetVector<unsigned> &reductionDims,
ReductionTilingStrategy tilingStrategy, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
ArrayRef<OpFoldResult> splitReductionIvs,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) const {
auto linalgOp = cast<LinalgOp>(op);
SmallVector<AffineMap> partialReductionMaps =
getPartialResultAffineMaps(linalgOp, reductionDims);

for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
resultSizes.push_back(sizes[dim]);

if (llvm::is_contained(reductionDims, dim)) {
// Reduction dims are reduced, and are always outputed in the same
// place. So use offset 0 for them.
resultOffsets.push_back(b.getIndexAttr(0));
} else {
resultOffsets.push_back(offsets[dim]);
}
}
InitSliceInfo sliceInfo = getInitSliceInfo(
b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
splitReductionIvs, partialReductionMaps[resultNumber]);
std::swap(resultOffsets, sliceInfo.offsets);
std::swap(resultSizes, sliceInfo.sizes);

return success();
}
229 changes: 169 additions & 60 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
@@ -2315,13 +2315,13 @@ RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
staticSizes, staticStrides);
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
sourceTensorType.getEncoding());
}

/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
@@ -208,7 +208,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
std::pair<SmallVector<int64_t>, SmallVector<Value>>
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues) {
SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
286 changes: 167 additions & 119 deletions mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Large diffs are not rendered by default.