Skip to content

Commit f8c057c

Browse files
authored
[SYCL-MLIR] Extend and fix SYCL inline pass options (#8724)
We add an `-inline-sycl-method-ops` option both to the pass itself and to `cgeist` to decide whether to inline `SYCLMethodOps` (on by default). The pipeline needs further modification as `-no-mangled-function-name` is incompatible with inlining `SYCLMethodOps`, so we have to convert to `sycl.call` beforehand. Now the default `-remove-dead-callees` will be used if no argument is provided. Signed-off-by: Victor Perez <[email protected]>
1 parent 9703fc5 commit f8c057c

File tree

6 files changed

+57
-27
lines changed

6 files changed

+57
-27
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/Transforms/Passes.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ enum InlineMode { AlwaysInline, Simple, Aggressive, Ludicrous };
3333
//===----------------------------------------------------------------------===//
3434

3535
std::unique_ptr<Pass> createInlinePass();
36-
std::unique_ptr<Pass> createInlinePass(enum InlineMode InlineMode,
37-
bool RemoveDeadCallees);
36+
std::unique_ptr<Pass> createInlinePass(const InlinePassOptions &options);
3837

3938
std::unique_ptr<Pass> createSYCLMethodToSYCLCallPass();
4039

mlir-sycl/include/mlir/Dialect/SYCL/Transforms/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def InlinePass : Pass<"inliner"> {
3131
"Remove callees that become unreachable after inlining them">,
3232
Option<"MaxIterationCount", "max-num-iters", "unsigned", /*default=*/"3",
3333
"Maximum number of inlining iterations for each SCC">,
34+
Option<"InlineSYCLMethodOps", "inline-sycl-method-ops",
35+
"bool", /*default=*/"true",
36+
"Whether to inline SYCLMethodOp operations">,
3437
];
3538
let dependentDialects = ["memref::MemRefDialect"];
3639

mlir-sycl/lib/Transforms/Inliner.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,9 @@ operator<<(llvm::raw_ostream &OS, const InlineHeuristic &Heuristic) {
291291
class Inliner : public InlinerInterface {
292292
public:
293293
Inliner(MLIRContext *Ctx, CallGraph &CG, SymbolTableCollection &SymTable,
294-
const InlineHeuristic &Heuristic)
294+
const InlineHeuristic &Heuristic, bool InlineSYCLMethodOps)
295295
: InlinerInterface(Ctx), CG(CG), SymbolTable(SymTable),
296-
Heuristic(Heuristic) {}
296+
Heuristic(Heuristic), InlineSYCLMethodOps(InlineSYCLMethodOps) {}
297297

298298
ResolvedCall &getCall(unsigned Index) {
299299
assert(Index < Calls.size() && "Out of bound index");
@@ -360,6 +360,9 @@ class Inliner : public InlinerInterface {
360360

361361
/// The inline heuristic controlling when to inline a call edge.
362362
const InlineHeuristic &Heuristic;
363+
364+
/// Whether to inline SYCLMethodOp operations.
365+
const bool InlineSYCLMethodOps;
363366
};
364367

365368
class InlinePass : public sycl::impl::InlinePassBase<InlinePass> {
@@ -762,6 +765,10 @@ void Inliner::collectCallOps(CallGraphNode &SrcNode, CallGraph &CG,
762765
if (CGN->isExternal())
763766
return false;
764767

768+
const auto *Op = Call.getOperation();
769+
if (!InlineSYCLMethodOps && isa<sycl::SYCLMethodOpInterface>(Op))
770+
return false;
771+
765772
// Always inline calls to "alwaysinline" functions.
766773
if (const auto PassThroughAttrs = getPassThroughAttrs(Call);
767774
PassThroughAttrs &&
@@ -770,7 +777,6 @@ void Inliner::collectCallOps(CallGraphNode &SrcNode, CallGraph &CG,
770777
return true;
771778

772779
// Select which call operations to collect based on heuristics.
773-
const auto *Op = Call.getOperation();
774780
switch (Heuristic.InlineMode) {
775781
case sycl::InlineMode::Ludicrous:
776782
return true;
@@ -811,7 +817,7 @@ void InlinePass::runOnOperation() {
811817
SymbolTableCollection SymTable;
812818
CGUseList UseList(getOperation(), CG, SymTable);
813819
InlineHeuristic Heuristic(InlineMode);
814-
Inliner Inliner(Ctx, CG, SymTable, Heuristic);
820+
Inliner Inliner(Ctx, CG, SymTable, Heuristic, InlineSYCLMethodOps);
815821

816822
LLVM_DEBUG(llvm::dbgs() << "Inline Heuristic: " << Heuristic << "\n");
817823

@@ -865,13 +871,9 @@ bool InlinePass::checkForSymbolTable(Operation &Op) {
865871
}
866872

867873
std::unique_ptr<Pass> sycl::createInlinePass() {
868-
const sycl::InlinePassOptions &Options = {InlineMode::Simple,
869-
/* RemoveDeadCallees */ false};
870-
return std::make_unique<InlinePass>(Options);
874+
return createInlinePass(InlinePassOptions{});
871875
}
872876

873-
std::unique_ptr<Pass> sycl::createInlinePass(enum InlineMode InlineMode,
874-
bool RemoveDeadCallees) {
875-
const sycl::InlinePassOptions &Options = {InlineMode, RemoveDeadCallees};
877+
std::unique_ptr<Pass> sycl::createInlinePass(const InlinePassOptions &Options) {
876878
return std::make_unique<InlinePass>(Options);
877879
}

mlir-sycl/test/Transforms/inliner.mlir

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
// RUN: sycl-mlir-opt -split-input-file -inliner="mode=alwaysinline remove-dead-callees=false" -verify-diagnostics -mlir-pass-statistics %s 2>&1 | FileCheck --check-prefixes=ALWAYS-INLINE,CHECK-ALL %s
2-
// RUN: sycl-mlir-opt -split-input-file -inliner="mode=simple remove-dead-callees=true" -verify-diagnostics -mlir-pass-statistics %s 2>&1 | FileCheck --check-prefixes=INLINE,CHECK-ALL %s
3-
// RUN: sycl-mlir-opt -split-input-file -inliner="mode=aggressive remove-dead-callees=true" -verify-diagnostics -mlir-pass-statistics %s 2>&1 | FileCheck --check-prefixes=AGGRESSIVE,CHECK-ALL %s
1+
// RUN: sycl-mlir-opt -split-input-file -inliner="mode=alwaysinline remove-dead-callees=false inline-sycl-method-ops=false" -verify-diagnostics -mlir-pass-statistics %s 2>&1 | FileCheck --check-prefixes=ALWAYS-INLINE,CHECK-ALL %s
2+
// RUN: sycl-mlir-opt -split-input-file -inliner="mode=simple remove-dead-callees=true inline-sycl-method-ops=false" -verify-diagnostics -mlir-pass-statistics %s 2>&1 | FileCheck --check-prefixes=INLINE,CHECK-ALL %s
3+
// RUN: sycl-mlir-opt -split-input-file -inliner="mode=aggressive remove-dead-callees=true inline-sycl-method-ops=false" -verify-diagnostics -mlir-pass-statistics %s 2>&1 | FileCheck --check-prefixes=AGGRESSIVE,CHECK-ALL %s
4+
// RUN: sycl-mlir-opt -split-input-file -inliner="mode=aggressive remove-dead-callees=true inline-sycl-method-ops=true" -verify-diagnostics -mlir-pass-statistics %s 2>&1 | FileCheck --check-prefixes=AGGRESSIVE-METHODOPS,CHECK-ALL %s
45

56
// COM: Ensure a func.func can be inlined in a func.func caller iff the callee is 'alwaysinline'.
67
// COM: Ensure a gpu.func cannot be inlined in a func.func caller (even if it has the 'alwaysinline' attribute).
@@ -171,6 +172,13 @@ gpu.func @gpu_func_callee() -> i32 attributes {passthrough = ["alwaysinline"]} {
171172

172173
// -----
173174

175+
// AGGRESSIVE-METHODOPS-NOT: func.func private @inline_hint_callee
176+
// AGGRESSIVE-METHODOPS-NOT: func.func private @private_callee
177+
// AGGRESSIVE-METHODOPS-NOT: func.func private @get
178+
179+
// AGGRESSIVE-NOT: func.func private @inline_hint_callee
180+
// AGGRESSIVE-NOT: func.func private @private_callee
181+
174182
// COM: Ensure functions in a SCC are fully inlined (requires multiple inlining iterations).
175183
// CHECK-ALL-LABEL: func.func @main(
176184
// CHECK-ALL-SAME: %[[VAL_0:.*]]: memref<?x!sycl_id_1_>) -> (i32, i64) {
@@ -196,20 +204,27 @@ gpu.func @gpu_func_callee() -> i32 attributes {passthrough = ["alwaysinline"]} {
196204
// INLINE-NOT: func.func private @inline_hint_callee
197205
// INLINE-NOT: func.func private @private_callee
198206

199-
// AGGRESSIVE-NOT: func.func private @inline_hint_callee
200-
// AGGRESSIVE-NOT: func.func private @private_callee
201-
// AGGRESSIVE-NOT: func.func private @get
202-
203-
// AGGRESSIVE-DAG: %[[VAL_1:.*]] = arith.constant 1 : i32
204-
// AGGRESSIVE-DAG: %[[VAL_2:.*]] = arith.constant 1 : i32
205-
// AGGRESSIVE-DAG: %[[VAL_3:.*]] = arith.constant 2 : i32
207+
// AGGRESSIVE-METHODOPS-DAG: %[[VAL_1:.*]] = arith.constant 1 : i32
208+
// AGGRESSIVE-METHODOPS-DAG: %[[VAL_2:.*]] = arith.constant 1 : i32
209+
// AGGRESSIVE-METHODOPS-DAG: %[[VAL_3:.*]] = arith.constant 2 : i32
210+
// AGGRESSIVE-METHODOPS: %[[VAL_4:.*]] = sycl.call @main_() {MangledFunctionName = @main, TypeName = @A} : () -> i32
211+
// AGGRESSIVE-METHODOPS: %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : i32
212+
// AGGRESSIVE-METHODOPS: %[[VAL_6:.*]] = arith.addi %[[VAL_2]], %[[VAL_5]] : i32
213+
// AGGRESSIVE-METHODOPS: call @foo(%[[VAL_0]], %[[VAL_1]]) : (memref<?x!sycl_id_1_>, i32) -> ()
214+
// AGGRESSIVE-METHODOPS: %[[VAL_7:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?x!sycl_id_1_> to memref<?x!sycl_id_1_, 4>
215+
// AGGRESSIVE-METHODOPS: %[[VAL_8:.*]] = arith.constant 2 : i64
216+
// AGGRESSIVE-METHODOPS: return %[[VAL_6]], %[[VAL_8]] : i32, i64
217+
// AGGRESSIVE-METHODOPS: }
218+
219+
// AGGRESSIVE: %[[VAL_1:.*]] = arith.constant 1 : i32
220+
// AGGRESSIVE: %[[VAL_2:.*]] = arith.constant 1 : i32
221+
// AGGRESSIVE: %[[VAL_3:.*]] = arith.constant 2 : i32
206222
// AGGRESSIVE: %[[VAL_4:.*]] = sycl.call @main_() {MangledFunctionName = @main, TypeName = @A} : () -> i32
207223
// AGGRESSIVE: %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : i32
208224
// AGGRESSIVE: %[[VAL_6:.*]] = arith.addi %[[VAL_2]], %[[VAL_5]] : i32
209225
// AGGRESSIVE: call @foo(%[[VAL_0]], %[[VAL_1]]) : (memref<?x!sycl_id_1_>, i32) -> ()
210-
// AGGRESSIVE: %[[VAL_7:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<?x!sycl_id_1_> to memref<?x!sycl_id_1_, 4>
211-
// AGGRESSIVE: %[[VAL_8:.*]] = arith.constant 2 : i64
212-
// AGGRESSIVE: return %[[VAL_6]], %[[VAL_8]] : i32, i64
226+
// AGGRESSIVE: %[[VAL_7:.*]] = sycl.id.get %[[VAL_0]]{{\[}}%[[VAL_1]]] {ArgumentTypes = [memref<?x!sycl_id_1_, 4>, i32], FunctionName = @get, MangledFunctionName = @get, TypeName = @id} : (memref<?x!sycl_id_1_>, i32) -> i64
227+
// AGGRESSIVE: return %[[VAL_6]], %[[VAL_7]] : i32, i64
213228
// AGGRESSIVE: }
214229

215230
!sycl_array_1_ = !sycl.array<[1], (memref<1xi64, 4>)>

polygeist/tools/cgeist/Options.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ static llvm::cl::opt<std::string> McpuOpt("mcpu", llvm::cl::init(""),
222222
llvm::cl::desc("Target CPU"),
223223
llvm::cl::cat(ToolOptions));
224224

225+
static llvm::cl::opt<bool> InlineSYCLMethodOps(
226+
"inline-sycl-method-ops", llvm::cl::init(true),
227+
llvm::cl::desc("Whether to inline SYCLMethodOp operations"),
228+
llvm::cl::cat(ToolOptions));
229+
225230
llvm::cl::opt<bool> OmitOptionalMangledFunctionName(
226231
"no-mangled-function-name", llvm::cl::init(false),
227232
llvm::cl::desc("Whether to omit optional \"MangledFunctionName\" fields"));

polygeist/tools/cgeist/driver.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,14 @@ static int optimize(mlir::MLIRContext &Ctx,
374374
// operations to be inlined.
375375
if (RaiseToAffine)
376376
OptPM.addPass(mlir::createLowerAffinePass());
377-
PM.addPass(sycl::createInlinePass(sycl::InlineMode::Simple,
378-
/* RemoveDeadCallees */ true));
377+
if (OmitOptionalMangledFunctionName) {
378+
// Needed as the inliner pass needs the `MangledFunctionName` attribute to
379+
// build the call graph.
380+
PM.addPass(mlir::sycl::createSYCLMethodToSYCLCallPass());
381+
}
382+
PM.addPass(sycl::createInlinePass({sycl::InlineMode::Simple,
383+
/* RemoveDeadCallees */ true,
384+
InlineSYCLMethodOps}));
379385

380386
mlir::OpPassManager &OptPM2 = PM.nestAny();
381387
OptPM2.addPass(mlir::createCanonicalizerPass(CanonicalizerConfig, {}, {}));

0 commit comments

Comments
 (0)