@@ -339,9 +339,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
339
339
namespace {
340
340
341
341
struct ParallelOpLowering : public OpRewritePattern <scf::ParallelOp> {
342
+ static constexpr unsigned kUseOpenMPDefaultNumThreads = 0 ;
343
+ unsigned numThreads;
342
344
343
- ParallelOpLowering (MLIRContext *context)
344
- : OpRewritePattern<scf::ParallelOp>(context) {}
345
+ ParallelOpLowering (MLIRContext *context,
346
+ unsigned numThreads = kUseOpenMPDefaultNumThreads )
347
+ : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
345
348
346
349
LogicalResult matchAndRewrite (scf::ParallelOp parallelOp,
347
350
PatternRewriter &rewriter) const override {
@@ -388,8 +391,21 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
388
391
reduceOp, reduceOp.getOperand (), std::get<1 >(pair));
389
392
}
390
393
394
+ Value numThreadsVar;
395
+ if (numThreads > 0 ) {
396
+ numThreadsVar = rewriter.create <LLVM::ConstantOp>(
397
+ loc, rewriter.getI32IntegerAttr (numThreads));
398
+ }
391
399
// Create the parallel wrapper.
392
- auto ompParallel = rewriter.create <omp::ParallelOp>(loc);
400
+ auto ompParallel = rewriter.create <omp::ParallelOp>(
401
+ loc,
402
+ /* if_expr_var = */ Value{},
403
+ /* num_threads_var = */ numThreadsVar,
404
+ /* allocate_vars = */ llvm::SmallVector<Value>{},
405
+ /* allocators_vars = */ llvm::SmallVector<Value>{},
406
+ /* reduction_vars = */ llvm::SmallVector<Value>{},
407
+ /* reductions = */ ArrayAttr{},
408
+ /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
393
409
{
394
410
395
411
OpBuilder::InsertionGuard guard (rewriter);
@@ -443,14 +459,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
443
459
};
444
460
445
461
// / Applies the conversion patterns in the given function.
446
- static LogicalResult applyPatterns (ModuleOp module) {
462
+ static LogicalResult applyPatterns (ModuleOp module, unsigned numThreads ) {
447
463
ConversionTarget target (*module.getContext ());
448
464
target.addIllegalOp <scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
449
465
target.addLegalDialect <omp::OpenMPDialect, LLVM::LLVMDialect,
450
466
memref::MemRefDialect>();
451
467
452
468
RewritePatternSet patterns (module.getContext ());
453
- patterns.add <ParallelOpLowering>(module.getContext ());
469
+ patterns.add <ParallelOpLowering>(module.getContext (), numThreads );
454
470
FrozenRewritePatternSet frozen (std::move (patterns));
455
471
return applyPartialConversion (module, target, frozen);
456
472
}
@@ -463,7 +479,7 @@ struct SCFToOpenMPPass
463
479
464
480
// / Pass entry point.
465
481
void runOnOperation () override {
466
- if (failed (applyPatterns (getOperation ())))
482
+ if (failed (applyPatterns (getOperation (), numThreads )))
467
483
signalPassFailure ();
468
484
}
469
485
};
0 commit comments