Skip to content

Commit a15b685

Browse files
authored
[OpenACC] Implement 'reduction' sema for compute constructs (llvm#92808)
'reduction' has a few restrictions over normal 'var-list' clauses: 1- On parallel, a num_gangs can only have 1 argument when combined with reduction. These two aren't able to be combined on any other of the compute constructs however. 2- The vars all must be 'numerical data types' types of some sort, or a 'composite of numerical data types'. A list of types is given in the standard as a minimum, so we choose 'isScalar', which covers all of these types and keeps types that are actually numeric. Other compilers don't seem to implement the 'composite of numerical data types', though we do. 3- Because of the above restrictions, member-of-composite is not allowed, so any access via a memberexpr is disallowed. Array-element and sub-arrays (aka array sections) are both permitted, so long as they meet the requirements of #2. This patch implements all of these for compute constructs.
1 parent fbc798e commit a15b685

39 files changed

+1005
-157
lines changed

clang/include/clang/AST/OpenACCClause.h

+29
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,35 @@ class OpenACCCreateClause final
677677
ArrayRef<Expr *> VarList, SourceLocation EndLoc);
678678
};
679679

680+
class OpenACCReductionClause final
681+
: public OpenACCClauseWithVarList,
682+
public llvm::TrailingObjects<OpenACCReductionClause, Expr *> {
683+
OpenACCReductionOperator Op;
684+
685+
OpenACCReductionClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
686+
OpenACCReductionOperator Operator,
687+
ArrayRef<Expr *> VarList, SourceLocation EndLoc)
688+
: OpenACCClauseWithVarList(OpenACCClauseKind::Reduction, BeginLoc,
689+
LParenLoc, EndLoc),
690+
Op(Operator) {
691+
std::uninitialized_copy(VarList.begin(), VarList.end(),
692+
getTrailingObjects<Expr *>());
693+
setExprs(MutableArrayRef(getTrailingObjects<Expr *>(), VarList.size()));
694+
}
695+
696+
public:
697+
static bool classof(const OpenACCClause *C) {
698+
return C->getClauseKind() == OpenACCClauseKind::Reduction;
699+
}
700+
701+
static OpenACCReductionClause *
702+
Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
703+
OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
704+
SourceLocation EndLoc);
705+
706+
OpenACCReductionOperator getReductionOp() const { return Op; }
707+
};
708+
680709
template <class Impl> class OpenACCClauseVisitor {
681710
Impl &getDerived() { return static_cast<Impl &>(*this); }
682711

clang/include/clang/Basic/DiagnosticSemaKinds.td

+16-2
Original file line numberDiff line numberDiff line change
@@ -12343,7 +12343,8 @@ def err_acc_num_gangs_num_args
1234312343
"provided}0">;
1234412344
def err_acc_not_a_var_ref
1234512345
: Error<"OpenACC variable is not a valid variable name, sub-array, array "
12346-
"element, or composite variable member">;
12346+
"element,%select{| member of a composite variable,}0 or composite "
12347+
"variable member">;
1234712348
def err_acc_typecheck_subarray_value
1234812349
: Error<"OpenACC sub-array subscripted value is not an array or pointer">;
1234912350
def err_acc_subarray_function_type
@@ -12374,5 +12375,18 @@ def note_acc_expected_pointer_var : Note<"expected variable of pointer type">;
1237412375
def err_acc_clause_after_device_type
1237512376
: Error<"OpenACC clause '%0' may not follow a '%1' clause in a "
1237612377
"compute construct">;
12377-
12378+
def err_acc_reduction_num_gangs_conflict
12379+
: Error<
12380+
"OpenACC 'reduction' clause may not appear on a 'parallel' construct "
12381+
"with a 'num_gangs' clause with more than 1 argument, have %0">;
12382+
def err_acc_reduction_type
12383+
: Error<"OpenACC 'reduction' variable must be of scalar type, sub-array, or a "
12384+
"composite of scalar types;%select{| sub-array base}1 type is %0">;
12385+
def err_acc_reduction_composite_type
12386+
: Error<"OpenACC 'reduction' variable must be a composite of scalar types; "
12387+
"%1 %select{is not a class or struct|is incomplete|is not an "
12388+
"aggregate}0">;
12389+
def err_acc_reduction_composite_member_type :Error<
12390+
"OpenACC 'reduction' composite variable must not have non-scalar field">;
12391+
def note_acc_reduction_composite_member_loc : Note<"invalid field is here">;
1237812392
} // end of sema component.

clang/include/clang/Basic/OpenACCClauses.def

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ VISIT_CLAUSE(NumGangs)
4646
VISIT_CLAUSE(NumWorkers)
4747
VISIT_CLAUSE(Present)
4848
VISIT_CLAUSE(Private)
49+
VISIT_CLAUSE(Reduction)
4950
VISIT_CLAUSE(Self)
5051
VISIT_CLAUSE(VectorLength)
5152
VISIT_CLAUSE(Wait)

clang/include/clang/Basic/OpenACCKinds.h

+36
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,42 @@ enum class OpenACCReductionOperator {
514514
/// Invalid Reduction Clause Kind.
515515
Invalid,
516516
};
517+
518+
template <typename StreamTy>
519+
inline StreamTy &printOpenACCReductionOperator(StreamTy &Out,
520+
OpenACCReductionOperator Op) {
521+
switch (Op) {
522+
case OpenACCReductionOperator::Addition:
523+
return Out << "+";
524+
case OpenACCReductionOperator::Multiplication:
525+
return Out << "*";
526+
case OpenACCReductionOperator::Max:
527+
return Out << "max";
528+
case OpenACCReductionOperator::Min:
529+
return Out << "min";
530+
case OpenACCReductionOperator::BitwiseAnd:
531+
return Out << "&";
532+
case OpenACCReductionOperator::BitwiseOr:
533+
return Out << "|";
534+
case OpenACCReductionOperator::BitwiseXOr:
535+
return Out << "^";
536+
case OpenACCReductionOperator::And:
537+
return Out << "&&";
538+
case OpenACCReductionOperator::Or:
539+
return Out << "||";
540+
case OpenACCReductionOperator::Invalid:
541+
return Out << "<invalid>";
542+
}
543+
llvm_unreachable("Unknown reduction operator kind");
544+
}
545+
inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &Out,
546+
OpenACCReductionOperator Op) {
547+
return printOpenACCReductionOperator(Out, Op);
548+
}
549+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &Out,
550+
OpenACCReductionOperator Op) {
551+
return printOpenACCReductionOperator(Out, Op);
552+
}
517553
} // namespace clang
518554

519555
#endif // LLVM_CLANG_BASIC_OPENACCKINDS_H

clang/include/clang/Parse/Parser.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -3686,9 +3686,9 @@ class Parser : public CodeCompletionHandler {
36863686

36873687
using OpenACCVarParseResult = std::pair<ExprResult, OpenACCParseCanContinue>;
36883688
/// Parses a single variable in a variable list for OpenACC.
3689-
OpenACCVarParseResult ParseOpenACCVar();
3689+
OpenACCVarParseResult ParseOpenACCVar(OpenACCClauseKind CK);
36903690
/// Parses the variable list for the variety of places that take a var-list.
3691-
llvm::SmallVector<Expr *> ParseOpenACCVarList();
3691+
llvm::SmallVector<Expr *> ParseOpenACCVarList(OpenACCClauseKind CK);
36923692
/// Parses any parameters for an OpenACC Clause, including required/optional
36933693
/// parens.
36943694
OpenACCClauseParseResult

clang/include/clang/Sema/SemaOpenACC.h

+27-2
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,14 @@ class SemaOpenACC : public SemaBase {
6666
struct DeviceTypeDetails {
6767
SmallVector<DeviceTypeArgument> Archs;
6868
};
69+
struct ReductionDetails {
70+
OpenACCReductionOperator Op;
71+
SmallVector<Expr *> VarList;
72+
};
6973

7074
std::variant<std::monostate, DefaultDetails, ConditionDetails,
71-
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails>
75+
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
76+
ReductionDetails>
7277
Details = std::monostate{};
7378

7479
public:
@@ -170,6 +175,10 @@ class SemaOpenACC : public SemaBase {
170175
return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
171176
}
172177

178+
OpenACCReductionOperator getReductionOp() const {
179+
return std::get<ReductionDetails>(Details).Op;
180+
}
181+
173182
ArrayRef<Expr *> getVarList() {
174183
assert((ClauseKind == OpenACCClauseKind::Private ||
175184
ClauseKind == OpenACCClauseKind::NoCreate ||
@@ -188,8 +197,13 @@ class SemaOpenACC : public SemaBase {
188197
ClauseKind == OpenACCClauseKind::PresentOrCreate ||
189198
ClauseKind == OpenACCClauseKind::Attach ||
190199
ClauseKind == OpenACCClauseKind::DevicePtr ||
200+
ClauseKind == OpenACCClauseKind::Reduction ||
191201
ClauseKind == OpenACCClauseKind::FirstPrivate) &&
192202
"Parsed clause kind does not have a var-list");
203+
204+
if (ClauseKind == OpenACCClauseKind::Reduction)
205+
return std::get<ReductionDetails>(Details).VarList;
206+
193207
return std::get<VarListDetails>(Details).VarList;
194208
}
195209

@@ -334,6 +348,13 @@ class SemaOpenACC : public SemaBase {
334348
Details = VarListDetails{std::move(VarList), IsReadOnly, IsZero};
335349
}
336350

351+
void setReductionDetails(OpenACCReductionOperator Op,
352+
llvm::SmallVector<Expr *> &&VarList) {
353+
assert(ClauseKind == OpenACCClauseKind::Reduction &&
354+
"reduction details only valid on reduction");
355+
Details = ReductionDetails{Op, std::move(VarList)};
356+
}
357+
337358
void setWaitDetails(Expr *DevNum, SourceLocation QueuesLoc,
338359
llvm::SmallVector<Expr *> &&IntExprs) {
339360
assert(ClauseKind == OpenACCClauseKind::Wait &&
@@ -394,7 +415,11 @@ class SemaOpenACC : public SemaBase {
394415

395416
/// Called when encountering a 'var' for OpenACC, ensures it is actually a
396417
/// declaration reference to a variable of the correct type.
397-
ExprResult ActOnVar(Expr *VarExpr);
418+
ExprResult ActOnVar(OpenACCClauseKind CK, Expr *VarExpr);
419+
420+
/// Called while semantically analyzing the reduction clause, ensuring the var
421+
/// is the correct kind of reference.
422+
ExprResult CheckReductionVar(Expr *VarExpr);
398423

399424
/// Called to check the 'var' type is a variable of pointer type, necessary
400425
/// for 'deviceptr' and 'attach' clauses. Returns true on success.

clang/lib/AST/OpenACCClause.cpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ bool OpenACCClauseWithVarList::classof(const OpenACCClause *C) {
3535
OpenACCAttachClause::classof(C) || OpenACCNoCreateClause::classof(C) ||
3636
OpenACCPresentClause::classof(C) || OpenACCCopyClause::classof(C) ||
3737
OpenACCCopyInClause::classof(C) || OpenACCCopyOutClause::classof(C) ||
38-
OpenACCCreateClause::classof(C);
38+
OpenACCReductionClause::classof(C) || OpenACCCreateClause::classof(C);
3939
}
4040
bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
4141
return OpenACCIfClause::classof(C) || OpenACCSelfClause::classof(C);
@@ -310,6 +310,16 @@ OpenACCDeviceTypeClause *OpenACCDeviceTypeClause::Create(
310310
OpenACCDeviceTypeClause(K, BeginLoc, LParenLoc, Archs, EndLoc);
311311
}
312312

313+
OpenACCReductionClause *OpenACCReductionClause::Create(
314+
const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
315+
OpenACCReductionOperator Operator, ArrayRef<Expr *> VarList,
316+
SourceLocation EndLoc) {
317+
void *Mem = C.Allocate(
318+
OpenACCReductionClause::totalSizeToAlloc<Expr *>(VarList.size()));
319+
return new (Mem)
320+
OpenACCReductionClause(BeginLoc, LParenLoc, Operator, VarList, EndLoc);
321+
}
322+
313323
//===----------------------------------------------------------------------===//
314324
// OpenACC clauses printing methods
315325
//===----------------------------------------------------------------------===//
@@ -445,6 +455,14 @@ void OpenACCClausePrinter::VisitCreateClause(const OpenACCCreateClause &C) {
445455
OS << ")";
446456
}
447457

458+
void OpenACCClausePrinter::VisitReductionClause(
459+
const OpenACCReductionClause &C) {
460+
OS << "reduction(" << C.getReductionOp() << ": ";
461+
llvm::interleaveComma(C.getVarList(), OS,
462+
[&](const Expr *E) { printExpr(E); });
463+
OS << ")";
464+
}
465+
448466
void OpenACCClausePrinter::VisitWaitClause(const OpenACCWaitClause &C) {
449467
OS << "wait";
450468
if (!C.getLParenLoc().isInvalid()) {

clang/lib/AST/StmtProfile.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -2588,6 +2588,12 @@ void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
25882588
/// Nothing to do here, there are no sub-statements.
25892589
void OpenACCClauseProfiler::VisitDeviceTypeClause(
25902590
const OpenACCDeviceTypeClause &Clause) {}
2591+
2592+
void OpenACCClauseProfiler::VisitReductionClause(
2593+
const OpenACCReductionClause &Clause) {
2594+
for (auto *E : Clause.getVarList())
2595+
Profiler.VisitStmt(E);
2596+
}
25912597
} // namespace
25922598

25932599
void StmtProfiler::VisitOpenACCComputeConstruct(

clang/lib/AST/TextNodeDumper.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,10 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
457457
});
458458
OS << ")";
459459
break;
460+
case OpenACCClauseKind::Reduction:
461+
OS << " clause Operator: "
462+
<< cast<OpenACCReductionClause>(C)->getReductionOp();
463+
break;
460464
default:
461465
// Nothing to do here.
462466
break;

clang/lib/Parse/ParseOpenACC.cpp

+16-14
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
920920
case OpenACCClauseKind::PresentOrCopyIn: {
921921
bool IsReadOnly = tryParseAndConsumeSpecialTokenKind(
922922
*this, OpenACCSpecialTokenKind::ReadOnly, ClauseKind);
923-
ParsedClause.setVarListDetails(ParseOpenACCVarList(), IsReadOnly,
923+
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
924+
IsReadOnly,
924925
/*IsZero=*/false);
925926
break;
926927
}
@@ -932,16 +933,17 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
932933
case OpenACCClauseKind::PresentOrCopyOut: {
933934
bool IsZero = tryParseAndConsumeSpecialTokenKind(
934935
*this, OpenACCSpecialTokenKind::Zero, ClauseKind);
935-
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
936+
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
936937
/*IsReadOnly=*/false, IsZero);
937938
break;
938939
}
939-
case OpenACCClauseKind::Reduction:
940+
case OpenACCClauseKind::Reduction: {
940941
// If we're missing a clause-kind (or it is invalid), see if we can parse
941942
// the var-list anyway.
942-
ParseReductionOperator(*this);
943-
ParseOpenACCVarList();
943+
OpenACCReductionOperator Op = ParseReductionOperator(*this);
944+
ParsedClause.setReductionDetails(Op, ParseOpenACCVarList(ClauseKind));
944945
break;
946+
}
945947
case OpenACCClauseKind::Self:
946948
// The 'self' clause is a var-list instead of a 'condition' in the case of
947949
// the 'update' clause, so we have to handle it here. U se an assert to
@@ -955,11 +957,11 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
955957
case OpenACCClauseKind::Host:
956958
case OpenACCClauseKind::Link:
957959
case OpenACCClauseKind::UseDevice:
958-
ParseOpenACCVarList();
960+
ParseOpenACCVarList(ClauseKind);
959961
break;
960962
case OpenACCClauseKind::Attach:
961963
case OpenACCClauseKind::DevicePtr:
962-
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
964+
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
963965
/*IsReadOnly=*/false, /*IsZero=*/false);
964966
break;
965967
case OpenACCClauseKind::Copy:
@@ -969,7 +971,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
969971
case OpenACCClauseKind::NoCreate:
970972
case OpenACCClauseKind::Present:
971973
case OpenACCClauseKind::Private:
972-
ParsedClause.setVarListDetails(ParseOpenACCVarList(),
974+
ParsedClause.setVarListDetails(ParseOpenACCVarList(ClauseKind),
973975
/*IsReadOnly=*/false, /*IsZero=*/false);
974976
break;
975977
case OpenACCClauseKind::Collapse: {
@@ -1278,7 +1280,7 @@ ExprResult Parser::ParseOpenACCBindClauseArgument() {
12781280
/// - an array element
12791281
/// - a member of a composite variable
12801282
/// - a common block name between slashes (fortran only)
1281-
Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
1283+
Parser::OpenACCVarParseResult Parser::ParseOpenACCVar(OpenACCClauseKind CK) {
12821284
OpenACCArraySectionRAII ArraySections(*this);
12831285

12841286
ExprResult Res = ParseAssignmentExpression();
@@ -1289,15 +1291,15 @@ Parser::OpenACCVarParseResult Parser::ParseOpenACCVar() {
12891291
if (!Res.isUsable())
12901292
return {Res, OpenACCParseCanContinue::Can};
12911293

1292-
Res = getActions().OpenACC().ActOnVar(Res.get());
1294+
Res = getActions().OpenACC().ActOnVar(CK, Res.get());
12931295

12941296
return {Res, OpenACCParseCanContinue::Can};
12951297
}
12961298

1297-
llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
1299+
llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList(OpenACCClauseKind CK) {
12981300
llvm::SmallVector<Expr *> Vars;
12991301

1300-
auto [Res, CanContinue] = ParseOpenACCVar();
1302+
auto [Res, CanContinue] = ParseOpenACCVar(CK);
13011303
if (Res.isUsable()) {
13021304
Vars.push_back(Res.get());
13031305
} else if (CanContinue == OpenACCParseCanContinue::Cannot) {
@@ -1308,7 +1310,7 @@ llvm::SmallVector<Expr *> Parser::ParseOpenACCVarList() {
13081310
while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
13091311
ExpectAndConsume(tok::comma);
13101312

1311-
auto [Res, CanContinue] = ParseOpenACCVar();
1313+
auto [Res, CanContinue] = ParseOpenACCVar(CK);
13121314

13131315
if (Res.isUsable()) {
13141316
Vars.push_back(Res.get());
@@ -1342,7 +1344,7 @@ void Parser::ParseOpenACCCacheVarList() {
13421344

13431345
// ParseOpenACCVarList should leave us before a r-paren, so no need to skip
13441346
// anything here.
1345-
ParseOpenACCVarList();
1347+
ParseOpenACCVarList(OpenACCClauseKind::Invalid);
13461348
}
13471349

13481350
Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {

0 commit comments

Comments
 (0)