Skip to content

[Clang][OpenMP][LoopTransformations] Add support for "#pragma omp fuse" loop transformation directive and "looprange" clause #139293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/docs/OpenMPSupport.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ implementation.
+-------------------------------------------------------------+---------------------------+---------------------------+--------------------------------------------------------------------------+
| loop stripe transformation | :good:`done` | https://github.com/llvm/llvm-project/pull/119891 |
+-------------------------------------------------------------+---------------------------+---------------------------+--------------------------------------------------------------------------+
| loop fuse transformation | :good:`prototyped` | :none:`unclaimed` | |
+-------------------------------------------------------------+---------------------------+---------------------------+--------------------------------------------------------------------------+
| work distribute construct | :none:`unclaimed` | :none:`unclaimed` | |
+-------------------------------------------------------------+---------------------------+---------------------------+--------------------------------------------------------------------------+
| task_iteration | :none:`unclaimed` | :none:`unclaimed` | |
Expand Down
1 change: 1 addition & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,7 @@ OpenMP Support
- Fixed mapping of arrays of structs containing nested structs with user defined
mappers, by using compiler-generated default mappers for the outer structs for
such maps.
- Added support for 'omp fuse' directive.

Improvements
^^^^^^^^^^^^
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang-c/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,10 @@ enum CXCursorKind {
*/
CXCursor_OMPStripeDirective = 310,

/** OpenMP fuse directive
*/
CXCursor_OMPFuseDirective = 318,

/** OpenACC Compute Construct.
*/
CXCursor_OpenACCComputeConstruct = 320,
Expand Down
95 changes: 95 additions & 0 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,101 @@ class OMPFullClause final : public OMPNoChildClause<llvm::omp::OMPC_full> {
static OMPFullClause *CreateEmpty(const ASTContext &C);
};

/// This class represents the 'looprange' clause in the
/// '#pragma omp fuse' directive
///
/// \code {c}
/// #pragma omp fuse looprange(1,2)
/// {
/// for(int i = 0; i < 64; ++i)
/// for(int j = 0; j < 256; j+=2)
/// for(int k = 127; k >= 0; --k)
/// \endcode
class OMPLoopRangeClause final
: public OMPClause,
private llvm::TrailingObjects<OMPLoopRangeClause, Expr *> {
friend class OMPClauseReader;
friend class llvm::TrailingObjects<OMPLoopRangeClause, Expr *>;

/// Location of '('
SourceLocation LParenLoc;

/// Location of first and count expressions
SourceLocation FirstLoc, CountLoc;

/// Number of looprange arguments (always 2: first, count)
unsigned NumArgs = 2;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always constant? In this case you don't need to use tail allocation, use fixed array


/// Set the argument expressions.
void setArgs(ArrayRef<Expr *> Args) {
assert(Args.size() == NumArgs && "Expected exactly 2 looprange arguments");
std::copy(Args.begin(), Args.end(), getTrailingObjects<Expr *>());
}

/// Build an empty clause for deserialization.
explicit OMPLoopRangeClause()
: OMPClause(llvm::omp::OMPC_looprange, {}, {}), NumArgs(2) {}

public:
/// Build a 'looprange' clause AST node.
static OMPLoopRangeClause *
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation FirstLoc, SourceLocation CountLoc,
SourceLocation EndLoc, ArrayRef<Expr *> Args);

/// Build an empty 'looprange' clause node.
static OMPLoopRangeClause *CreateEmpty(const ASTContext &C);

// Location getters/setters
SourceLocation getLParenLoc() const { return LParenLoc; }
SourceLocation getFirstLoc() const { return FirstLoc; }
SourceLocation getCountLoc() const { return CountLoc; }

void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
void setFirstLoc(SourceLocation Loc) { FirstLoc = Loc; }
void setCountLoc(SourceLocation Loc) { CountLoc = Loc; }

/// Get looprange 'first' expression
Expr *getFirst() const { return getArgs()[0]; }

/// Get looprange 'count' expression
Expr *getCount() const { return getArgs()[1]; }

/// Set looprange 'first' expression
void setFirst(Expr *E) { getArgs()[0] = E; }

/// Set looprange 'count' expression
void setCount(Expr *E) { getArgs()[1] = E; }

Comment on lines +1206 to +1211
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions should be private

MutableArrayRef<Expr *> getArgs() {
return MutableArrayRef<Expr *>(getTrailingObjects<Expr *>(), NumArgs);
}
ArrayRef<Expr *> getArgs() const {
return ArrayRef<Expr *>(getTrailingObjects<Expr *>(), NumArgs);
}
Comment on lines +1212 to +1217
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doubt you need these functions


child_range children() {
return child_range(reinterpret_cast<Stmt **>(getArgs().begin()),
reinterpret_cast<Stmt **>(getArgs().end()));
}
const_child_range children() const {
auto AR = getArgs();
return const_child_range(reinterpret_cast<Stmt *const *>(AR.begin()),
reinterpret_cast<Stmt *const *>(AR.end()));
}

child_range used_children() {
return child_range(child_iterator(), child_iterator());
}
const_child_range used_children() const {
return const_child_range(const_child_iterator(), const_child_iterator());
}

static bool classof(const OMPClause *T) {
return T->getClauseKind() == llvm::omp::OMPC_looprange;
}
};

/// Representation of the 'partial' clause of the '#pragma omp unroll'
/// directive.
///
Expand Down
11 changes: 11 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3090,6 +3090,9 @@ DEF_TRAVERSE_STMT(OMPUnrollDirective,
DEF_TRAVERSE_STMT(OMPReverseDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })

DEF_TRAVERSE_STMT(OMPFuseDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })

DEF_TRAVERSE_STMT(OMPInterchangeDirective,
{ TRY_TO(TraverseOMPExecutableDirective(S)); })

Expand Down Expand Up @@ -3407,6 +3410,14 @@ bool RecursiveASTVisitor<Derived>::VisitOMPFullClause(OMPFullClause *C) {
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPLoopRangeClause(
OMPLoopRangeClause *C) {
TRY_TO(TraverseStmt(C->getFirst()));
TRY_TO(TraverseStmt(C->getCount()));
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPPartialClause(OMPPartialClause *C) {
TRY_TO(TraverseStmt(C->getFactor()));
Expand Down
110 changes: 108 additions & 2 deletions clang/include/clang/AST/StmtOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,9 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {

/// Number of loops generated by this loop transformation.
unsigned NumGeneratedLoops = 0;
/// Number of top level canonical loop nests generated by this loop
/// transformation
unsigned NumGeneratedLoopNests = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this new field?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the name is a bit unfortunate and could be improved, but they are 2 completely different fields conceptually. This top level loops are the ones actually managed by loop Sequence constructs like fuse and the upcoming split. A loop sequence contains loops which may contain several inner nestes loops, but these should not be taken into account for performing fusion or splitting. This was not taken into account originally due to all transformations having a fixed number of generated top level nests (1). However fuse or split may generate several loop nests with inner nested loops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that unroll is an exception, it could have 0 or 1 but it coincides perfectly with the original number of loops .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question is how it is used. I did not see it is being read anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This distinction is indeed important and actively used in SemaOpenMP.cpp file, particularly within the AnalyzeLoopSequence function (starting at line 14284). For example, it's referenced in lines 14344 and 14364 to differentiate between specific loop transformations.


protected:
explicit OMPLoopTransformationDirective(StmtClass SC,
Expand All @@ -974,13 +977,21 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
/// Set the number of loops generated by this loop transformation.
void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; }

/// Set the number of top level canonical loop nests generated by this loop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] empty line for separating functions

/// transformation
void setNumGeneratedLoopNests(unsigned Num) { NumGeneratedLoopNests = Num; }

public:
/// Return the number of associated (consumed) loops.
unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }

/// Return the number of loops generated by this loop transformation.
unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; }

/// Return the number of top level canonical loop nests generated by this loop
/// transformation
unsigned getNumGeneratedLoopNests() const { return NumGeneratedLoopNests; }

/// Get the de-sugared statements after the loop transformation.
///
/// Might be nullptr if either the directive generates no loops and is handled
Expand All @@ -995,7 +1006,7 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
Stmt::StmtClass C = T->getStmtClass();
return C == OMPTileDirectiveClass || C == OMPUnrollDirectiveClass ||
C == OMPReverseDirectiveClass || C == OMPInterchangeDirectiveClass ||
C == OMPStripeDirectiveClass;
C == OMPStripeDirectiveClass || C == OMPFuseDirectiveClass;
}
};

Expand Down Expand Up @@ -5561,7 +5572,10 @@ class OMPTileDirective final : public OMPLoopTransformationDirective {
: OMPLoopTransformationDirective(OMPTileDirectiveClass,
llvm::omp::OMPD_tile, StartLoc, EndLoc,
NumLoops) {
// Tiling doubles the original number of loops
setNumGeneratedLoops(2 * NumLoops);
// Produces a single top-level canonical loop nest
setNumGeneratedLoopNests(1);
}

void setPreInits(Stmt *PreInits) {
Expand Down Expand Up @@ -5639,6 +5653,8 @@ class OMPStripeDirective final : public OMPLoopTransformationDirective {
llvm::omp::OMPD_stripe, StartLoc, EndLoc,
NumLoops) {
setNumGeneratedLoops(2 * NumLoops);
// Similar to Tile, it only generates a single top level loop nest
setNumGeneratedLoopNests(1);
}

void setPreInits(Stmt *PreInits) {
Expand Down Expand Up @@ -5792,7 +5808,9 @@ class OMPReverseDirective final : public OMPLoopTransformationDirective {
: OMPLoopTransformationDirective(OMPReverseDirectiveClass,
llvm::omp::OMPD_reverse, StartLoc,
EndLoc, NumLoops) {
// Reverse produces a single top-level canonical loop nest
setNumGeneratedLoops(NumLoops);
setNumGeneratedLoopNests(1);
}

void setPreInits(Stmt *PreInits) {
Expand Down Expand Up @@ -5864,7 +5882,10 @@ class OMPInterchangeDirective final : public OMPLoopTransformationDirective {
: OMPLoopTransformationDirective(OMPInterchangeDirectiveClass,
llvm::omp::OMPD_interchange, StartLoc,
EndLoc, NumLoops) {
setNumGeneratedLoops(NumLoops);
// Interchange produces a single top-level canonical loop
// nest, with the exact same amount of total loops
setNumGeneratedLoops(3 * NumLoops);
setNumGeneratedLoopNests(1);
}

void setPreInits(Stmt *PreInits) {
Expand Down Expand Up @@ -5915,6 +5936,91 @@ class OMPInterchangeDirective final : public OMPLoopTransformationDirective {
}
};

/// Represents the '#pragma omp fuse' loop transformation directive
///
/// \code{c}
/// #pragma omp fuse
/// {
/// for(int i = 0; i < m1; ++i) {...}
/// for(int j = 0; j < m2; ++j) {...}
/// ...
/// }
/// \endcode

class OMPFuseDirective final : public OMPLoopTransformationDirective {
friend class ASTStmtReader;
friend class OMPExecutableDirective;

// Offsets of child members.
enum {
PreInitsOffset = 0,
TransformedStmtOffset,
};

explicit OMPFuseDirective(SourceLocation StartLoc, SourceLocation EndLoc,
unsigned NumLoops)
: OMPLoopTransformationDirective(OMPFuseDirectiveClass,
llvm::omp::OMPD_fuse, StartLoc, EndLoc,
NumLoops) {
// This default initialization assumes simple loop fusion.
// If a 'looprange' clause is specified, these values must be explicitly set
setNumGeneratedLoopNests(1);
setNumGeneratedLoops(NumLoops);
}

void setPreInits(Stmt *PreInits) {
Data->getChildren()[PreInitsOffset] = PreInits;
}

void setTransformedStmt(Stmt *S) {
Data->getChildren()[TransformedStmtOffset] = S;
}

public:
/// Create a new AST node representation for #pragma omp fuse'
///
/// \param C Context of the AST
/// \param StartLoc Location of the introducer (e.g the 'omp' token)
/// \param EndLoc Location of the directive's end (e.g the tok::eod)
/// \param Clauses The directive's clauses
/// \param NumLoops Number of total affected loops
/// \param NumLoopNests Number of affected top level canonical loops
/// (number of items in the 'looprange' clause if present)
/// \param AssociatedStmt The outermost associated loop
/// \param TransformedStmt The loop nest after fusion, or nullptr in
/// dependent
/// \param PreInits Helper preinits statements for the loop nest
static OMPFuseDirective *Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation EndLoc,
ArrayRef<OMPClause *> Clauses,
unsigned NumLoops, unsigned NumLoopNests,
Stmt *AssociatedStmt, Stmt *TransformedStmt,
Stmt *PreInits);

/// Build an empty '#pragma omp fuse' AST node for deserialization
///
/// \param C Context of the AST
/// \param NumClauses Number of clauses to allocate
/// \param NumLoops Number of associated loops to allocate
/// \param NumLoopNests Number of top level loops to allocate
static OMPFuseDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses,
unsigned NumLoops,
unsigned NumLoopNests);

/// Gets the associated loops after the transformation. This is the de-sugared
/// replacement or nulltpr in dependent contexts.
Stmt *getTransformedStmt() const {
return Data->getChildren()[TransformedStmtOffset];
}

/// Return preinits statement.
Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }

static bool classof(const Stmt *T) {
return T->getStmtClass() == OMPFuseDirectiveClass;
}
};

/// This represents '#pragma omp scan' directive.
///
/// \code
Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -11612,6 +11612,18 @@ def note_omp_implicit_dsa : Note<
"implicitly determined as %0">;
def err_omp_loop_var_dsa : Error<
"loop iteration variable in the associated loop of 'omp %1' directive may not be %0, predetermined as %2">;
def err_omp_not_canonical_loop : Error<
"loop after '#pragma omp %0' is not in canonical form">;
def err_omp_not_a_loop_sequence : Error<
"statement after '#pragma omp %0' must be a loop sequence containing canonical loops or loop-generating constructs">;
def err_omp_empty_loop_sequence : Error<
"loop sequence after '#pragma omp %0' must contain at least 1 canonical loop or loop-generating construct">;
def err_omp_invalid_looprange : Error<
"loop range in '#pragma omp %0' exceeds the number of available loops: "
"range end '%1' is greater than the total number of loops '%2'">;
def warn_omp_redundant_fusion : Warning<
"loop range in '#pragma omp %0' contains only a single loop, resulting in redundant fusion">,
InGroup<OpenMPClauses>;
def err_omp_not_for : Error<
"%select{statement after '#pragma omp %1' must be a for loop|"
"expected %2 for loops after '#pragma omp %1'%select{|, but found only %4}3}0">;
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/StmtNodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def OMPStripeDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPInterchangeDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPFuseDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPForDirective : StmtNode<OMPLoopDirective>;
def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -6735,6 +6735,9 @@ class Parser : public CodeCompletionHandler {
OpenMPClauseKind Kind,
bool ParseOnly);

/// Parses the 'looprange' clause of a '#pragma omp fuse' directive.
OMPClause *ParseOpenMPLoopRangeClause();

/// Parses the 'sizes' clause of a '#pragma omp tile' directive.
OMPClause *ParseOpenMPSizesClause();

Expand Down
Loading
Loading