Skip to content

[flang][fir] Add MLIR op for do concurrent #130893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3446,4 +3446,109 @@ def fir_BoxTotalElementsOp
let hasCanonicalizer = 1;
}

def fir_DoConcurrentOp : fir_Op<"do_concurrent",
[SingleBlock, AutomaticAllocationScope]> {
let summary = "do concurrent loop wrapper";

let description = [{
A wrapper operation for the actual op modeling `do concurrent` loops:
`fir.do_concurrent.loop` (see op declaration below for more info about it).

The `fir.do_concurrent` wrapper op consists of one single-block region with
the following properties:
- The first ops in the region are responsible for allocating storage for the
loop's iteration variables. This is property is **not** enforced by the op
verifier, but expected to be respected when building the op.
- The terminator of the region is an instance of `fir.do_concurrent.loop`.

For example, a 2D loop nest would be represented as follows:
```
fir.do_concurrent {
%i = fir.alloca i32
%j = fir.alloca i32
fir.do_concurrent.loop ...
}
```
}];

let regions = (region SizedRegion<1>:$region);

let assemblyFormat = "$region attr-dict";
let hasVerifier = 1;
}

def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface>,
Terminator, NoTerminator, SingleBlock, ParentOneOf<["DoConcurrentOp"]>]> {
let summary = "do concurrent loop";

let description = [{
An operation that models a Fortran `do concurrent` loop's header and block.
This is a single-region single-block terminator op that is expected to
terminate the region of a `omp.do_concurrent` wrapper op.

This op borrows from both `scf.parallel` and `fir.do_loop` ops. Similar to
`scf.parallel`, a loop nest takes 3 groups of SSA values as operands that
represent the lower bounds, upper bounds, and steps. Similar to `fir.do_loop`
the op takes one additional group of SSA values to represent reductions.

The body region **does not** have a terminator.

For example, a 2D loop nest with 2 reductions (sum and max) would be
represented as follows:
```
// The wrapper of the loop
fir.do_concurrent {
%i = fir.alloca i32
%j = fir.alloca i32

// The actual `do concurrent` loop
fir.do_concurrent.loop
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>,
#fir.reduce_attr<max> -> %max : !fir.ref<f32>) {

%0 = fir.convert %i_iv : (index) -> i32
fir.store %0 to %i : !fir.ref<i32>

%1 = fir.convert %j_iv : (index) -> i32
fir.store %1 to %j : !fir.ref<i32>

// ... loop body goes here ...
}
}
```

Description of arguments:
- `lowerBound`: The group of SSA values for the nest's lower bounds.
- `upperBound`: The group of SSA values for the nest's upper bounds.
- `step`: The group of SSA values for the nest's steps.
- `reduceOperands`: The reduction SSA values, if any.
- `reduceAttrs`: Attributes to store reduction operations, if any.
- `loopAnnotation`: Loop metadata to be passed down the compiler pipeline to
LLVM.
}];

let arguments = (ins
Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound,
Variadic<Index>:$step,
Variadic<AnyType>:$reduceOperands,
OptionalAttr<ArrayAttr>:$reduceAttrs,
OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
);

let regions = (region SizedRegion<1>:$region);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
// Get Number of reduction operands
unsigned getNumReduceOperands() {
return getReduceOperands().size();
}
}];
}

#endif
161 changes: 161 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4748,6 +4748,167 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns(
patterns.add<SimplifyBoxTotalElementsOp>(context);
}

//===----------------------------------------------------------------------===//
// DoConcurrentOp
//===----------------------------------------------------------------------===//

llvm::LogicalResult fir::DoConcurrentOp::verify() {
mlir::Block *body = getBody();

if (body->empty())
return emitOpError("body cannot be empty");

if (!body->mightHaveTerminator() ||
!mlir::isa<fir::DoConcurrentLoopOp>(body->getTerminator()))
return emitOpError("must be terminated by 'fir.do_concurrent.loop'");

return mlir::success();
}

//===----------------------------------------------------------------------===//
// DoConcurrentLoopOp
//===----------------------------------------------------------------------===//

mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
// Parse an opening `(` followed by induction variables followed by `)`
llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs;
if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren))
return mlir::failure();

// Parse loop bounds.
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower;
if (parser.parseEqual() ||
parser.parseOperandList(lower, ivs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(lower, builder.getIndexType(), result.operands))
return mlir::failure();

llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper;
if (parser.parseKeyword("to") ||
parser.parseOperandList(upper, ivs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
return mlir::failure();

// Parse step values.
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps;
if (parser.parseKeyword("step") ||
parser.parseOperandList(steps, ivs.size(),
mlir::OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
return mlir::failure();

llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
llvm::SmallVector<mlir::Type> reduceArgTypes;
if (succeeded(parser.parseOptionalKeyword("reduce"))) {
// Parse reduction attributes and variables.
llvm::SmallVector<fir::ReduceAttr> attributes;
if (failed(parser.parseCommaSeparatedList(
mlir::AsmParser::Delimiter::Paren, [&]() {
if (parser.parseAttribute(attributes.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(reduceOperands.emplace_back()) ||
parser.parseColonType(reduceArgTypes.emplace_back()))
return mlir::failure();
return mlir::success();
})))
return mlir::failure();
// Resolve input operands.
for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
if (parser.resolveOperand(std::get<0>(operand_type),
std::get<1>(operand_type), result.operands))
return mlir::failure();
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
attributes.end());
result.addAttribute(getReduceAttrsAttrName(result.name),
builder.getArrayAttr(arrayAttr));
}

// Now parse the body.
mlir::Region *body = result.addRegion();
for (auto &iv : ivs)
iv.type = builder.getIndexType();
if (parser.parseRegion(*body, ivs))
return mlir::failure();

// Set `operandSegmentSizes` attribute.
result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr(
{static_cast<int32_t>(lower.size()),
static_cast<int32_t>(upper.size()),
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(reduceOperands.size())}));

// Parse attributes.
if (parser.parseOptionalAttrDict(result.attributes))
return mlir::failure();

return mlir::success();
}

void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
<< ") to (" << getUpperBound() << ") step (" << getStep() << ")";

if (!getReduceOperands().empty()) {
p << " reduce(";
auto attrs = getReduceAttrsAttr();
auto operands = getReduceOperands();
llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
<< std::get<1>(it).getType();
});
p << ')';
}

p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(
(*this)->getAttrs(),
/*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
DoConcurrentLoopOp::getReduceAttrsAttrName()});
}

llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() {
return {&getRegion()};
}

llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
mlir::Operation::operand_range lbValues = getLowerBound();
mlir::Operation::operand_range ubValues = getUpperBound();
mlir::Operation::operand_range stepValues = getStep();

if (lbValues.empty())
return emitOpError(
"needs at least one tuple element for lowerBound, upperBound and step");

if (lbValues.size() != ubValues.size() ||
ubValues.size() != stepValues.size())
return emitOpError("different number of tuple elements for lowerBound, "
"upperBound or step");

// Check that the body defines the same number of block arguments as the
// number of tuple elements in step.
mlir::Block *body = getBody();
if (body->getNumArguments() != stepValues.size())
return emitOpError() << "expects the same number of induction variables: "
<< body->getNumArguments()
<< " as bound and step values: " << stepValues.size();
for (auto arg : body->getArguments())
if (!arg.getType().isIndex())
return emitOpError(
"expects arguments for the induction variable to be of index type");

auto reduceAttrs = getReduceAttrsAttr();
if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
return emitOpError(
"mismatch in number of reduction variables and reduction attributes");

return mlir::success();
}

//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
Expand Down
92 changes: 92 additions & 0 deletions flang/test/Fir/do_concurrent.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Test fir.do_concurrent operation parse, verify (no errors), and unparse

// RUN: fir-opt %s | fir-opt | FileCheck %s

func.func @dc_1d(%i_lb: index, %i_ub: index, %i_st: index) {
fir.do_concurrent {
%i = fir.alloca i32
fir.do_concurrent.loop (%i_iv) = (%i_lb) to (%i_ub) step (%i_st) {
%0 = fir.convert %i_iv : (index) -> i32
fir.store %0 to %i : !fir.ref<i32>
}
}
return
}

// CHECK-LABEL: func.func @dc_1d
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index)
// CHECK: fir.do_concurrent {
// CHECK: %[[I:.*]] = fir.alloca i32
// CHECK: fir.do_concurrent.loop (%[[I_IV:.*]]) = (%[[I_LB]]) to (%[[I_UB]]) step (%[[I_ST]]) {
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
// CHECK: }
// CHECK: }

func.func @dc_2d(%i_lb: index, %i_ub: index, %i_st: index,
%j_lb: index, %j_ub: index, %j_st: index) {
fir.do_concurrent {
%i = fir.alloca i32
%j = fir.alloca i32
fir.do_concurrent.loop
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st) {
%0 = fir.convert %i_iv : (index) -> i32
fir.store %0 to %i : !fir.ref<i32>

%1 = fir.convert %j_iv : (index) -> i32
fir.store %1 to %j : !fir.ref<i32>
}
}
return
}

// CHECK-LABEL: func.func @dc_2d
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index, %[[J_LB:.*]]: index, %[[J_UB:.*]]: index, %[[J_ST:.*]]: index)
// CHECK: fir.do_concurrent {
// CHECK: %[[I:.*]] = fir.alloca i32
// CHECK: %[[J:.*]] = fir.alloca i32
// CHECK: fir.do_concurrent.loop
// CHECK-SAME: (%[[I_IV:.*]], %[[J_IV:.*]]) = (%[[I_LB]], %[[J_LB]]) to (%[[I_UB]], %[[J_UB]]) step (%[[I_ST]], %[[J_ST]]) {
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
// CHECK: }
// CHECK: }

func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
%j_lb: index, %j_ub: index, %j_st: index) {
%sum = fir.alloca i32

fir.do_concurrent {
%i = fir.alloca i32
%j = fir.alloca i32
fir.do_concurrent.loop
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
%0 = fir.convert %i_iv : (index) -> i32
fir.store %0 to %i : !fir.ref<i32>

%1 = fir.convert %j_iv : (index) -> i32
fir.store %1 to %j : !fir.ref<i32>
}
}
return
}

// CHECK-LABEL: func.func @dc_2d_reduction
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index, %[[J_LB:.*]]: index, %[[J_UB:.*]]: index, %[[J_ST:.*]]: index)

// CHECK: %[[SUM:.*]] = fir.alloca i32

// CHECK: fir.do_concurrent {
// CHECK: %[[I:.*]] = fir.alloca i32
// CHECK: %[[J:.*]] = fir.alloca i32
// CHECK: fir.do_concurrent.loop
// CHECK-SAME: (%[[I_IV:.*]], %[[J_IV:.*]]) = (%[[I_LB]], %[[J_LB]]) to (%[[I_UB]], %[[J_UB]]) step (%[[I_ST]], %[[J_ST]]) reduce(#fir.reduce_attr<add> -> %[[SUM]] : !fir.ref<i32>) {
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
// CHECK: }
// CHECK: }
Loading
Loading