Skip to content
forked from iree-org/iree
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(iree,cpu): put linalg producers for iree_linalg_ext.scan into separate dispatch #40

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,23 @@ static bool isAttentionMaskGenerator(Operation *op) {
return false;
}

static bool hasExplicitNonFusableUsers(Operation *op) {
bool hasNonFusableUse = false;
for (Operation *user : op->getUsers()) {
if (isa<IREE::LinalgExt::LinalgFusionOpInterface>(user))
continue;
// TODO: The issue with iree_linalg_ext.scan compared to other non-fusable
// ops comes down to poor support in tiling configs, leading to overly large
// stack-bound allocations. In practice, backend compilers should often cope
// with fusing the resulting loops even without Linalg-level fusion.
// So long-term, we improve the tiling logic for ScanOp's, while also
// considering ways to express simple fusions within this cumulative
// reduction intrinsic.
hasNonFusableUse |= isa<IREE::LinalgExt::ScanOp>(user);
}
return hasNonFusableUse;
}

/// Operations that are cloned into dispatch regions formed with other
/// operations as roots.
bool isClonableIntoDispatchOp(Operation *op,
Expand All @@ -825,6 +842,11 @@ bool isClonableIntoDispatchOp(Operation *op,
complex::CreateOp>(op)) {
return true;
}
// TODO: Tune the cases excluded through hasFusableOrLowMemoryUsers
// condition in a more targeted manner, then remove the condition.
if (isa<linalg::LinalgOp>(op) && hasExplicitNonFusableUsers(op)) {
return false;
}
if (LinalgExt::isBitExtendOp(op)) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,13 @@ decideFusableLinalgOps(Region &region, DominanceInfo const &dominanceInfo,

// Once all root linalg ops have been tagged, put all remaining generic ops
// into their own dispatches.
// TODO: when identyfing root operations, there are multiple
// isFusableWith[Producer|Consumer] checks being invoked, however we only
// access the positive results of those checks (i.e. explicit assignment to
// the fusion group). We should consider assigning *all* operations to fusion
// groups during the fuseRootsWith...() phase, creating new groups for more
// negative fusion cases, and potentially eliding the second iteration over
// the region altogether.
for (Block &block : region) {
SmallVector<Operation *> roots;
for (Operation &op : llvm::reverse(block)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ iree_lit_test_suite(
"hoist_encoding_ops.mlir",
"dispatch_linalg_on_tensors_default.mlir",
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir",
"dispatch_non_fusable.mlir",
"form_scalar_dispatches.mlir",
"fuse_encoding_ops_into_dispatch_regions.mlir",
"fuse_horizontal_contractions.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_lit_test_suite(
"dispatch_linalg_on_tensors.mlir"
"dispatch_linalg_on_tensors_default.mlir"
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
"dispatch_non_fusable.mlir"
"dispatch_region_formation_preprocessing.mlir"
"elementwise_op_fusion.mlir"
"fold_unit_dims.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: iree-opt %s --split-input-file --verify-diagnostics \
// RUN: --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, \
// RUN: iree-dispatch-creation-clone-producers-into-dispatch-regions), cse, canonicalize, cse)" \
// RUN: | FileCheck %s

// Check that a simple elementwise bit extend producer is assigned to a separate dispatch
// (until fusion is supported)
#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @linalgext_scan_inclusive_dispatch_non_fusable(%arg0: tensor<8x32xi32>) -> tensor<8x32xi64> {
%c0_i64 = arith.constant 0 : i64
%0 = tensor.empty() : tensor<8x32xi64>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<8x32xi32>) outs(%0 : tensor<8x32xi64>) {
^bb0(%in: i32, %out: i64):
%6 = arith.extsi %in : i32 to i64
linalg.yield %6 : i64
} -> tensor<8x32xi64>
%2 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<8x32xi64>) -> tensor<8x32xi64>
%3 = tensor.empty() : tensor<8xi64>
%4 = linalg.fill ins(%c0_i64 : i64) outs(%3 : tensor<8xi64>) -> tensor<8xi64>
%5:2 = iree_linalg_ext.scan dimension(1) inclusive(true) ins(%1 : tensor<8x32xi64>) outs(%2, %4 : tensor<8x32xi64>, tensor<8xi64>) {
^bb0(%arg3: i64, %arg4: i64):
%6 = arith.addi %arg3, %arg4 : i64
iree_linalg_ext.yield %6 : i64
} -> tensor<8x32xi64>, tensor<8xi64>
util.return %5#0 : tensor<8x32xi64>
}

// CHECK-LABEL: util.func public @linalgext_scan_inclusive_dispatch_non_fusable(
// CHECK-SAME: %[[ARG:.+]]: tensor<8x32xi32>) -> tensor<8x32xi64>
// CHECK: %[[ZERO_CONST:.+]] = arith.constant 0 : i64
// CHECK: %[[PRODUCER_REGION:.+]] = flow.dispatch.region -> (tensor<8x32xi64>)
// CHECK: %[[EMPTY:.+]] = tensor.empty()
// CHECK: %[[PRODUCER:.+]] = linalg.generic {{.+}} ins(%[[ARG]] : tensor<8x32xi32>) outs(%[[EMPTY]] : tensor<8x32xi64>)
// CHECK: flow.return %[[PRODUCER]]
// CHECK: %[[LINALGEXT_REGION:.+]] = flow.dispatch.region -> (tensor<8x32xi64>)
// CHECK: %[[CUMULATIVE_FILL:.+]] = linalg.fill ins(%[[ZERO_CONST]] : i64) outs(%{{.+}} : tensor<8x32xi64>)
// CHECK: %[[REDUCED_FILL:.+]] = linalg.fill ins(%[[ZERO_CONST]] : i64) outs(%{{.+}} : tensor<8xi64>)
// CHECK: %[[SCAN_RESULT:.+]]:2 = iree_linalg_ext.scan dimension(1) inclusive(true)
// CHECK-SAME: ins(%[[PRODUCER_REGION]] : tensor<8x32xi64>) outs(%[[CUMULATIVE_FILL]], %[[REDUCED_FILL]] : {{.+}}) {
// CHECK: flow.return %[[SCAN_RESULT]]#0
// CHECK: util.return %[[LINALGEXT_REGION]] : tensor<8x32xi64>