Skip to content

Commit 7f4f75c

Browse files
[MLIR][SCFToOpenMP] Add num-threads option (#74854)
Add `num-threads` option to the `-convert-scf-to-openmp` pass, allowing to set the number of threads to be used in the `omp.parallel` to a fixed value.
1 parent bf5d96c commit 7f4f75c

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,12 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
886886
let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
887887
"constructs.";
888888

889+
let options = [
890+
Option<"numThreads", "num-threads", "unsigned",
891+
/*default=kUseOpenMPDefaultNumThreads*/"0",
892+
"Number of threads to use">
893+
];
894+
889895
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
890896
"memref::MemRefDialect"];
891897
}

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
339339
namespace {
340340

341341
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
342+
static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
343+
unsigned numThreads;
342344

343-
ParallelOpLowering(MLIRContext *context)
344-
: OpRewritePattern<scf::ParallelOp>(context) {}
345+
ParallelOpLowering(MLIRContext *context,
346+
unsigned numThreads = kUseOpenMPDefaultNumThreads)
347+
: OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
345348

346349
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
347350
PatternRewriter &rewriter) const override {
@@ -388,8 +391,21 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
388391
reduceOp, reduceOp.getOperand(), std::get<1>(pair));
389392
}
390393

394+
Value numThreadsVar;
395+
if (numThreads > 0) {
396+
numThreadsVar = rewriter.create<LLVM::ConstantOp>(
397+
loc, rewriter.getI32IntegerAttr(numThreads));
398+
}
391399
// Create the parallel wrapper.
392-
auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
400+
auto ompParallel = rewriter.create<omp::ParallelOp>(
401+
loc,
402+
/* if_expr_var = */ Value{},
403+
/* num_threads_var = */ numThreadsVar,
404+
/* allocate_vars = */ llvm::SmallVector<Value>{},
405+
/* allocators_vars = */ llvm::SmallVector<Value>{},
406+
/* reduction_vars = */ llvm::SmallVector<Value>{},
407+
/* reductions = */ ArrayAttr{},
408+
/* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
393409
{
394410

395411
OpBuilder::InsertionGuard guard(rewriter);
@@ -443,14 +459,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
443459
};
444460

445461
/// Applies the conversion patterns in the given function.
446-
static LogicalResult applyPatterns(ModuleOp module) {
462+
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
447463
ConversionTarget target(*module.getContext());
448464
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
449465
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
450466
memref::MemRefDialect>();
451467

452468
RewritePatternSet patterns(module.getContext());
453-
patterns.add<ParallelOpLowering>(module.getContext());
469+
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
454470
FrozenRewritePatternSet frozen(std::move(patterns));
455471
return applyPartialConversion(module, target, frozen);
456472
}
@@ -463,7 +479,7 @@ struct SCFToOpenMPPass
463479

464480
/// Pass entry point.
465481
void runOnOperation() override {
466-
if (failed(applyPatterns(getOperation())))
482+
if (failed(applyPatterns(getOperation(), numThreads)))
467483
signalPassFailure();
468484
}
469485
};

mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
1+
// RUN: mlir-opt -convert-scf-to-openmp='num-threads=4' %s | FileCheck %s
22

33
// CHECK-LABEL: @parallel
44
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
55
%arg3: index, %arg4: index, %arg5: index) {
6-
// CHECK: omp.parallel {
6+
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
7+
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
78
// CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
89
// CHECK: memref.alloca_scope
910
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +21,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
2021
// CHECK-LABEL: @nested_loops
2122
func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
2223
%arg3: index, %arg4: index, %arg5: index) {
23-
// CHECK: omp.parallel {
24+
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
25+
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
2426
// CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
2527
// CHECK: memref.alloca_scope
2628
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +45,8 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
4345
// CHECK-LABEL: @adjacent_loops
4446
func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
4547
%arg3: index, %arg4: index, %arg5: index) {
46-
// CHECK: omp.parallel {
48+
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
49+
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
4750
// CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
4851
// CHECK: memref.alloca_scope
4952
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +58,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
5558
// CHECK: omp.terminator
5659
// CHECK: }
5760

58-
// CHECK: omp.parallel {
61+
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
62+
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
5963
// CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
6064
// CHECK: memref.alloca_scope
6165
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {

0 commit comments

Comments
 (0)