|
| 1 | +// RUN: iree-opt %s --split-input-file --verify-diagnostics \ |
| 2 | +// RUN: --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, \ |
| 3 | +// RUN: iree-dispatch-creation-clone-producers-into-dispatch-regions), cse, canonicalize, cse)" \ |
| 4 | +// RUN: | FileCheck %s |
| 5 | + |
| 6 | +// Check that a simple elementwise bit extend producer is assigned to a separate dispatch |
| 7 | +// (until fusion is supported) |
| 8 | +#map = affine_map<(d0, d1) -> (d0, d1)> |
| 9 | +util.func public @linalgext_scan_inclusive_dispatch_non_fusable(%arg0: tensor<8x32xi32>) -> tensor<8x32xi64> { |
| 10 | + %c0_i64 = arith.constant 0 : i64 |
| 11 | + %0 = tensor.empty() : tensor<8x32xi64> |
| 12 | + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<8x32xi32>) outs(%0 : tensor<8x32xi64>) { |
| 13 | + ^bb0(%in: i32, %out: i64): |
| 14 | + %6 = arith.extsi %in : i32 to i64 |
| 15 | + linalg.yield %6 : i64 |
| 16 | + } -> tensor<8x32xi64> |
| 17 | + %2 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<8x32xi64>) -> tensor<8x32xi64> |
| 18 | + %3 = tensor.empty() : tensor<8xi64> |
| 19 | + %4 = linalg.fill ins(%c0_i64 : i64) outs(%3 : tensor<8xi64>) -> tensor<8xi64> |
| 20 | + %5:2 = iree_linalg_ext.scan dimension(1) inclusive(true) ins(%1 : tensor<8x32xi64>) outs(%2, %4 : tensor<8x32xi64>, tensor<8xi64>) { |
| 21 | + ^bb0(%arg3: i64, %arg4: i64): |
| 22 | + %6 = arith.addi %arg3, %arg4 : i64 |
| 23 | + iree_linalg_ext.yield %6 : i64 |
| 24 | + } -> tensor<8x32xi64>, tensor<8xi64> |
| 25 | + util.return %5#0 : tensor<8x32xi64> |
| 26 | +} |
| 27 | + |
| 28 | +// CHECK-LABEL: util.func public @linalgext_scan_inclusive_dispatch_non_fusable( |
| 29 | +// CHECK-SAME: %[[ARG:.+]]: tensor<8x32xi32>) -> tensor<8x32xi64> |
| 30 | +// CHECK: %[[ZERO_CONST:.+]] = arith.constant 0 : i64 |
| 31 | +// CHECK: %[[PRODUCER_REGION:.+]] = flow.dispatch.region -> (tensor<8x32xi64>) |
| 32 | +// CHECK: %[[EMPTY:.+]] = tensor.empty() |
| 33 | +// CHECK: %[[PRODUCER:.+]] = linalg.generic {{.+}} ins(%[[ARG]] : tensor<8x32xi32>) outs(%[[EMPTY]] : tensor<8x32xi64>) |
| 34 | +// CHECK: flow.return %[[PRODUCER]] |
| 35 | +// CHECK: %[[LINALGEXT_REGION:.+]] = flow.dispatch.region -> (tensor<8x32xi64>) |
| 36 | +// CHECK: %[[CUMULATIVE_FILL:.+]] = linalg.fill ins(%[[ZERO_CONST]] : i64) outs(%{{.+}} : tensor<8x32xi64>) |
| 37 | +// CHECK: %[[REDUCED_FILL:.+]] = linalg.fill ins(%[[ZERO_CONST]] : i64) outs(%{{.+}} : tensor<8xi64>) |
| 38 | +// CHECK: %[[SCAN_RESULT:.+]]:2 = iree_linalg_ext.scan dimension(1) inclusive(true) |
| 39 | +// CHECK-SAME: ins(%[[PRODUCER_REGION]] : tensor<8x32xi64>) outs(%[[CUMULATIVE_FILL]], %[[REDUCED_FILL]] : {{.+}}) { |
| 40 | +// CHECK: flow.return %[[SCAN_RESULT]]#0 |
| 41 | +// CHECK: util.return %[[LINALGEXT_REGION]] : tensor<8x32xi64> |
0 commit comments