Skip to content
forked from iree-org/iree

Commit 8daa67e

Browse files
authored
fix(iree,cpu): put linalg producers for iree_linalg_ext.scan into separate dispatch (#40)
This fix is necessitated by lack of fusion support for LinalgExt Scan - since the default LLVMCPU tiling does not respect the fusability of producer ops, nor the limit for stack-bound allocations, we force this non-fusable "cumulative reduction"-style operation to be dispatched separately from non-trivial linalg operations. Further enhancements should include tiling config fine-tuning for LinalgExt operations, conscious restrictions of work-group level tiling depending on the predicted size of stack-bound allocas within a dispatch, and further adoption of `LinalgFusionOpInterface` for LinalgExt operations that cannot be expressed through simple reduction iterators. Possible adjustments to the FormDispatchRegion algorithm itself are noted as TODO items.
1 parent b06346c commit 8daa67e

File tree

5 files changed

+72
-0
lines changed

5 files changed

+72
-0
lines changed

compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,23 @@ static bool isAttentionMaskGenerator(Operation *op) {
809809
return false;
810810
}
811811

812+
static bool hasExplicitNonFusableUsers(Operation *op) {
813+
bool hasNonFusableUse = false;
814+
for (Operation *user : op->getUsers()) {
815+
if (isa<IREE::LinalgExt::LinalgFusionOpInterface>(user))
816+
continue;
817+
// TODO: The issue with iree_linalg_ext.scan compared to other non-fusable
818+
// ops comes down to poor support in tiling configs, leading to overly large
819+
// stack-bound allocations. In practice, backend compilers should often cope
820+
// with fusing the resulting loops even without Linalg-level fusion.
821+
// So long-term, we improve the tiling logic for ScanOp's, while also
822+
// considering ways to express simple fusions within this cumulative
823+
// reduction intrinsic.
824+
hasNonFusableUse |= isa<IREE::LinalgExt::ScanOp>(user);
825+
}
826+
return hasNonFusableUse;
827+
}
828+
812829
/// Operations that are cloned into dispatch regions formed with other
813830
/// operations as roots.
814831
bool isClonableIntoDispatchOp(Operation *op,
@@ -825,6 +842,11 @@ bool isClonableIntoDispatchOp(Operation *op,
825842
complex::CreateOp>(op)) {
826843
return true;
827844
}
845+
// TODO: Tune the cases excluded through hasFusableOrLowMemoryUsers
846+
// condition in a more targeted manner, then remove the condition.
847+
if (isa<linalg::LinalgOp>(op) && hasExplicitNonFusableUsers(op)) {
848+
return false;
849+
}
828850
if (LinalgExt::isBitExtendOp(op)) {
829851
return true;
830852
}

compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,13 @@ decideFusableLinalgOps(Region &region, DominanceInfo const &dominanceInfo,
766766

767767
// Once all root linalg ops have been tagged, put all remaining generic ops
768768
// into their own dispatches.
769+
// TODO: when identyfing root operations, there are multiple
770+
// isFusableWith[Producer|Consumer] checks being invoked, however we only
771+
// access the positive results of those checks (i.e. explicit assignment to
772+
// the fusion group). We should consider assigning *all* operations to fusion
773+
// groups during the fuseRootsWith...() phase, creating new groups for more
774+
// negative fusion cases, and potentially eliding the second iteration over
775+
// the region altogether.
769776
for (Block &block : region) {
770777
SmallVector<Operation *> roots;
771778
for (Operation &op : llvm::reverse(block)) {

compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ iree_lit_test_suite(
3434
"hoist_encoding_ops.mlir",
3535
"dispatch_linalg_on_tensors_default.mlir",
3636
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir",
37+
"dispatch_non_fusable.mlir",
3738
"form_scalar_dispatches.mlir",
3839
"fuse_encoding_ops_into_dispatch_regions.mlir",
3940
"fuse_horizontal_contractions.mlir",

compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ iree_lit_test_suite(
2626
"dispatch_linalg_on_tensors.mlir"
2727
"dispatch_linalg_on_tensors_default.mlir"
2828
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
29+
"dispatch_non_fusable.mlir"
2930
"dispatch_region_formation_preprocessing.mlir"
3031
"elementwise_op_fusion.mlir"
3132
"fold_unit_dims.mlir"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: iree-opt %s --split-input-file --verify-diagnostics \
2+
// RUN: --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, \
3+
// RUN: iree-dispatch-creation-clone-producers-into-dispatch-regions), cse, canonicalize, cse)" \
4+
// RUN: | FileCheck %s
5+
6+
// Check that a simple elementwise bit extend producer is assigned to a separate dispatch
7+
// (until fusion is supported)
8+
#map = affine_map<(d0, d1) -> (d0, d1)>
9+
util.func public @linalgext_scan_inclusive_dispatch_non_fusable(%arg0: tensor<8x32xi32>) -> tensor<8x32xi64> {
10+
%c0_i64 = arith.constant 0 : i64
11+
%0 = tensor.empty() : tensor<8x32xi64>
12+
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<8x32xi32>) outs(%0 : tensor<8x32xi64>) {
13+
^bb0(%in: i32, %out: i64):
14+
%6 = arith.extsi %in : i32 to i64
15+
linalg.yield %6 : i64
16+
} -> tensor<8x32xi64>
17+
%2 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<8x32xi64>) -> tensor<8x32xi64>
18+
%3 = tensor.empty() : tensor<8xi64>
19+
%4 = linalg.fill ins(%c0_i64 : i64) outs(%3 : tensor<8xi64>) -> tensor<8xi64>
20+
%5:2 = iree_linalg_ext.scan dimension(1) inclusive(true) ins(%1 : tensor<8x32xi64>) outs(%2, %4 : tensor<8x32xi64>, tensor<8xi64>) {
21+
^bb0(%arg3: i64, %arg4: i64):
22+
%6 = arith.addi %arg3, %arg4 : i64
23+
iree_linalg_ext.yield %6 : i64
24+
} -> tensor<8x32xi64>, tensor<8xi64>
25+
util.return %5#0 : tensor<8x32xi64>
26+
}
27+
28+
// CHECK-LABEL: util.func public @linalgext_scan_inclusive_dispatch_non_fusable(
29+
// CHECK-SAME: %[[ARG:.+]]: tensor<8x32xi32>) -> tensor<8x32xi64>
30+
// CHECK: %[[ZERO_CONST:.+]] = arith.constant 0 : i64
31+
// CHECK: %[[PRODUCER_REGION:.+]] = flow.dispatch.region -> (tensor<8x32xi64>)
32+
// CHECK: %[[EMPTY:.+]] = tensor.empty()
33+
// CHECK: %[[PRODUCER:.+]] = linalg.generic {{.+}} ins(%[[ARG]] : tensor<8x32xi32>) outs(%[[EMPTY]] : tensor<8x32xi64>)
34+
// CHECK: flow.return %[[PRODUCER]]
35+
// CHECK: %[[LINALGEXT_REGION:.+]] = flow.dispatch.region -> (tensor<8x32xi64>)
36+
// CHECK: %[[CUMULATIVE_FILL:.+]] = linalg.fill ins(%[[ZERO_CONST]] : i64) outs(%{{.+}} : tensor<8x32xi64>)
37+
// CHECK: %[[REDUCED_FILL:.+]] = linalg.fill ins(%[[ZERO_CONST]] : i64) outs(%{{.+}} : tensor<8xi64>)
38+
// CHECK: %[[SCAN_RESULT:.+]]:2 = iree_linalg_ext.scan dimension(1) inclusive(true)
39+
// CHECK-SAME: ins(%[[PRODUCER_REGION]] : tensor<8x32xi64>) outs(%[[CUMULATIVE_FILL]], %[[REDUCED_FILL]] : {{.+}}) {
40+
// CHECK: flow.return %[[SCAN_RESULT]]#0
41+
// CHECK: util.return %[[LINALGEXT_REGION]] : tensor<8x32xi64>

0 commit comments

Comments
 (0)