Skip to content

Commit

Permalink
[CPU] Do not fuse ukernel ops into tiling loops. (iree-org#16054)
Browse files Browse the repository at this point in the history
It is a step towards iree-org#16025
  • Loading branch information
hanhanW authored Jan 10, 2024
1 parent baa911e commit dc81beb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
Expand Down Expand Up @@ -99,7 +100,8 @@ LogicalResult applyTileAndFuse(RewriterBase &rewriter, Operation *rootOp,
llvm::SmallDenseSet<Operation *> origTiledAndFusedOps;
collectTiledAndFusedOps(rootOp, origTiledAndFusedOps);
auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
return origTiledAndFusedOps.count(user) || isa<tensor::DimOp>(user);
return origTiledAndFusedOps.count(user) ||
isa<tensor::DimOp, IREE::Codegen::UKernelGenericOp>(user);
};

// The rest of this method is similar to
Expand Down
36 changes: 36 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,39 @@ func.func @scalable_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<
// CHECK-SAME: step %[[SCALABLE_TILE_SIZE]]
// CHECK: scf.for
// CHECK-SAME: step %[[C1]]

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @ukernel_generic(%arg0: tensor<1x192x1x16xf32>, %arg1: tensor<1x768x1x1xf32>, %arg2: tensor<192x768x16x1xf32>, %arg3: tensor<1x192x1x16xf32>) -> tensor<1x192x1x16xf32> {
%c1 = arith.constant 1 : index
%c192 = arith.constant 192 : index
%c768 = arith.constant 768 : index
%c1_i32 = arith.constant 1 : i32
%c16_i32 = arith.constant 16 : i32
%c1025_i32 = arith.constant 1025 : i32
%0 = tensor.empty() : tensor<1x192x1x16xf32>
%1 = iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%arg1, %arg2 : tensor<1x768x1x1xf32>, tensor<192x768x16x1xf32>) outs(%0 : tensor<1x192x1x16xf32>) (%c1, %c192, %c768, %c1_i32, %c16_i32, %c1_i32, %c1025_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.cconv = 1 : i32, hal.import.fields = ["processor_data"]} strided_outer_dims(1) -> tensor<1x192x1x16xf32>
%2 = linalg.generic { indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
ins(%1, %arg3 : tensor<1x192x1x16xf32>, tensor<1x192x1x16xf32>)
outs(%arg0 : tensor<1x192x1x16xf32>)
attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 20, 0]]>} {
^bb0(%in: f32, %in_0: f32, %out: f32):
%3 = arith.addf %in, %in_0 : f32
linalg.yield %3 : f32
} -> tensor<1x192x1x16xf32>
return %2 : tensor<1x192x1x16xf32>
}
// CHECK-LABEL: func.func @ukernel_generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]
// CHECK: %[[UK:.+]] = iree_codegen.ukernel.generic "iree_uk_mmt4d"
// CHECK: scf.for {{.+}} iter_args(%[[ITER:.+]] = %[[ARG0]])
// CHECK: %[[UK_SLICE:.+]] = tensor.extract_slice %[[UK]]
// CHECK: %[[ARG3_SLICE:.+]] = tensor.extract_slice %[[ARG3]]
// CHECK: %[[ITER_SLICE:.+]] = tensor.extract_slice %[[ITER]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[UK_SLICE]], %[[ARG3_SLICE]]
// CHECK-SAME: outs(%[[ITER_SLICE]]

0 comments on commit dc81beb

Please sign in to comment.