@@ -2128,6 +2128,161 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2128
2128
return loopOp;
2129
2129
}
2130
2130
2131
+ static mlir::omp::CanonicalLoopOp
2132
+ genCanonicalLoopOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2133
+ semantics::SemanticsContext &semaCtx,
2134
+ lower::pft::Evaluation &eval, mlir::Location loc,
2135
+ const ConstructQueue &queue,
2136
+ ConstructQueue::const_iterator item,
2137
+ llvm::ArrayRef<const semantics::Symbol *> ivs,
2138
+ llvm::omp::Directive directive, DataSharingProcessor &dsp) {
2139
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2140
+
2141
+ assert (ivs.size () == 1 && " Nested loops not yet implemented" );
2142
+ const semantics::Symbol *iv = ivs[0 ];
2143
+
2144
+ auto &nestedEval = eval.getFirstNestedEvaluation ();
2145
+ if (nestedEval.getIf <parser::DoConstruct>()->IsDoConcurrent ()) {
2146
+ TODO (loc, " Do Concurrent in unroll construct" );
2147
+ }
2148
+
2149
+ // Get the loop bounds (and increment)
2150
+ auto &doLoopEval = nestedEval.getFirstNestedEvaluation ();
2151
+ auto *doStmt = doLoopEval.getIf <parser::NonLabelDoStmt>();
2152
+ assert (doStmt && " Expected do loop to be in the nested evaluation" );
2153
+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t );
2154
+ assert (loopControl.has_value ());
2155
+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u );
2156
+ assert (bounds && " Expected bounds for canonical loop" );
2157
+ lower::StatementContext stmtCtx;
2158
+ mlir::Value loopLBVar = fir::getBase (
2159
+ converter.genExprValue (*semantics::GetExpr (bounds->lower ), stmtCtx));
2160
+ mlir::Value loopUBVar = fir::getBase (
2161
+ converter.genExprValue (*semantics::GetExpr (bounds->upper ), stmtCtx));
2162
+ mlir::Value loopStepVar = [&]() {
2163
+ if (bounds->step ) {
2164
+ return fir::getBase (
2165
+ converter.genExprValue (*semantics::GetExpr (bounds->step ), stmtCtx));
2166
+ } else {
2167
+ // If `step` is not present, assume it is `1`.
2168
+ return firOpBuilder.createIntegerConstant (loc, firOpBuilder.getI32Type (),
2169
+ 1 );
2170
+ }
2171
+ }();
2172
+
2173
+ // Get the integer kind for the loop variable and cast the loop bounds
2174
+ size_t loopVarTypeSize = bounds->name .thing .symbol ->GetUltimate ().size ();
2175
+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
2176
+ loopLBVar = firOpBuilder.createConvert (loc, loopVarType, loopLBVar);
2177
+ loopUBVar = firOpBuilder.createConvert (loc, loopVarType, loopUBVar);
2178
+ loopStepVar = firOpBuilder.createConvert (loc, loopVarType, loopStepVar);
2179
+
2180
+ // Start lowering
2181
+ mlir::Value zero = firOpBuilder.createIntegerConstant (loc, loopVarType, 0 );
2182
+ mlir::Value one = firOpBuilder.createIntegerConstant (loc, loopVarType, 1 );
2183
+ mlir::Value isDownwards = firOpBuilder.create <mlir::arith::CmpIOp>(
2184
+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
2185
+
2186
+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
2187
+ mlir::Value negStep =
2188
+ firOpBuilder.create <mlir::arith::SubIOp>(loc, zero, loopStepVar);
2189
+ mlir::Value incr = firOpBuilder.create <mlir::arith::SelectOp>(
2190
+ loc, isDownwards, negStep, loopStepVar);
2191
+ mlir::Value lb = firOpBuilder.create <mlir::arith::SelectOp>(
2192
+ loc, isDownwards, loopUBVar, loopLBVar);
2193
+ mlir::Value ub = firOpBuilder.create <mlir::arith::SelectOp>(
2194
+ loc, isDownwards, loopLBVar, loopUBVar);
2195
+
2196
+ // Compute the trip count assuming lb <= ub. This guarantees that the result
2197
+ // is non-negative and we can use unsigned arithmetic.
2198
+ mlir::Value span = firOpBuilder.create <mlir::arith::SubIOp>(
2199
+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
2200
+ mlir::Value tcMinusOne =
2201
+ firOpBuilder.create <mlir::arith::DivUIOp>(loc, span, incr);
2202
+ mlir::Value tcIfLooping = firOpBuilder.create <mlir::arith::AddIOp>(
2203
+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
2204
+
2205
+ // Fall back to 0 if lb > ub
2206
+ mlir::Value isZeroTC = firOpBuilder.create <mlir::arith::CmpIOp>(
2207
+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
2208
+ mlir::Value tripcount = firOpBuilder.create <mlir::arith::SelectOp>(
2209
+ loc, isZeroTC, zero, tcIfLooping);
2210
+
2211
+ // Create the CLI handle.
2212
+ auto newcli = firOpBuilder.create <mlir::omp::NewCliOp>(loc);
2213
+ mlir::Value cli = newcli.getResult ();
2214
+
2215
+ auto ivCallback = [&](mlir::Operation *op)
2216
+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2217
+ mlir::Region ®ion = op->getRegion (0 );
2218
+
2219
+ // Create the op's region skeleton (BB taking the iv as argument)
2220
+ firOpBuilder.createBlock (®ion, {}, {loopVarType}, {loc});
2221
+
2222
+ // Compute the value of the loop variable from the logical iteration number.
2223
+ mlir::Value natIterNum = fir::getBase (region.front ().getArgument (0 ));
2224
+ mlir::Value scaled =
2225
+ firOpBuilder.create <mlir::arith::MulIOp>(loc, natIterNum, loopStepVar);
2226
+ mlir::Value userVal =
2227
+ firOpBuilder.create <mlir::arith::AddIOp>(loc, loopLBVar, scaled);
2228
+
2229
+ // The argument is not currently in memory, so make a temporary for the
2230
+ // argument, and store it there, then bind that location to the argument.
2231
+ mlir::Operation *storeOp =
2232
+ createAndSetPrivatizedLoopVar (converter, loc, userVal, iv);
2233
+
2234
+ firOpBuilder.setInsertionPointAfter (storeOp);
2235
+ return {iv};
2236
+ };
2237
+
2238
+ // Create the omp.canonical_loop operation
2239
+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
2240
+ OpWithBodyGenInfo (converter, symTable, semaCtx, loc, nestedEval,
2241
+ directive)
2242
+ .setClauses (&item->clauses )
2243
+ .setDataSharingProcessor (&dsp)
2244
+ .setGenRegionEntryCb (ivCallback),
2245
+ queue, item, tripcount, cli);
2246
+
2247
+ firOpBuilder.setInsertionPointAfter (canonLoop);
2248
+ return canonLoop;
2249
+ }
2250
+
2251
+ static void genUnrollOp (Fortran::lower::AbstractConverter &converter,
2252
+ Fortran::lower::SymMap &symTable,
2253
+ lower::StatementContext &stmtCtx,
2254
+ Fortran::semantics::SemanticsContext &semaCtx,
2255
+ Fortran::lower::pft::Evaluation &eval,
2256
+ mlir::Location loc, const ConstructQueue &queue,
2257
+ ConstructQueue::const_iterator item) {
2258
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2259
+
2260
+ mlir::omp::LoopRelatedClauseOps loopInfo;
2261
+ llvm::SmallVector<const semantics::Symbol *> iv;
2262
+ collectLoopRelatedInfo (converter, loc, eval, item->clauses , loopInfo, iv);
2263
+
2264
+ // Clauses for unrolling not yet implemnted
2265
+ ClauseProcessor cp (converter, semaCtx, item->clauses );
2266
+ cp.processTODO <clause::Partial, clause::Full>(
2267
+ loc, llvm::omp::Directive::OMPD_unroll);
2268
+
2269
+ // Even though unroll does not support data-sharing clauses, but this is
2270
+ // required to fill the symbol table.
2271
+ DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
2272
+ /* shouldCollectPreDeterminedSymbols=*/ true ,
2273
+ /* useDelayedPrivatization=*/ false , symTable);
2274
+ dsp.processStep1 ();
2275
+
2276
+ // Emit the associated loop
2277
+ auto canonLoop =
2278
+ genCanonicalLoopOp (converter, symTable, semaCtx, eval, loc, queue, item,
2279
+ iv, llvm::omp::Directive::OMPD_unroll, dsp);
2280
+
2281
+ // Apply unrolling to it
2282
+ auto cli = canonLoop.getCli ();
2283
+ firOpBuilder.create <mlir::omp::UnrollHeuristicOp>(loc, cli);
2284
+ }
2285
+
2131
2286
static mlir::omp::MaskedOp
2132
2287
genMaskedOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2133
2288
lower::StatementContext &stmtCtx,
@@ -3516,12 +3671,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
3516
3671
newOp = genTeamsOp (converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
3517
3672
item);
3518
3673
break ;
3519
- case llvm::omp::Directive::OMPD_tile:
3520
- case llvm::omp::Directive::OMPD_unroll: {
3521
- unsigned version = semaCtx.langOptions ().OpenMPVersion ;
3522
- TODO (loc, " Unhandled loop directive (" +
3523
- llvm::omp::getOpenMPDirectiveName (dir, version) + " )" );
3524
- }
3674
+ case llvm::omp::Directive::OMPD_unroll:
3675
+ genUnrollOp (converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
3676
+ break ;
3525
3677
// case llvm::omp::Directive::OMPD_workdistribute:
3526
3678
case llvm::omp::Directive::OMPD_workshare:
3527
3679
newOp = genWorkshareOp (converter, symTable, stmtCtx, semaCtx, eval, loc,
0 commit comments