Skip to content

Commit 8fb4230

Browse files
authored
[SYCL] AST support for SYCL kernel entry point functions. (#122379)
A SYCL kernel entry point function is a non-member function or a static member function declared with the `sycl_kernel_entry_point` attribute. Such functions define a pattern for an offload kernel entry point function to be generated to enable execution of a SYCL kernel on a device. A SYCL library implementation orchestrates the invocation of these functions with corresponding SYCL kernel arguments in response to calls to SYCL kernel invocation functions specified by the SYCL 2020 specification. The offload kernel entry point function (sometimes referred to as the SYCL kernel caller function) is generated from the SYCL kernel entry point function by a transformation of the function parameters followed by a transformation of the function body to replace references to the original parameters with references to the transformed ones. Exactly how parameters are transformed will be explained in a future change that implements non-trivial transformations. For now, it suffices to state that a given parameter of the SYCL kernel entry point function may be transformed to multiple parameters of the offload kernel entry point as needed to satisfy offload kernel argument passing requirements. Parameters that are decomposed in this way are reconstituted as local variables in the body of the generated offload kernel entry point function. For example, given the following SYCL kernel entry point function definition: ``` template<typename KernelNameType, typename KernelType> [[clang::sycl_kernel_entry_point(KernelNameType)]] void sycl_kernel_entry_point(KernelType kernel) { kernel(); } ``` and the following call: ``` struct Kernel { int dm1; int dm2; void operator()() const; }; Kernel k; sycl_kernel_entry_point<class kernel_name>(k); ``` the corresponding offload kernel entry point function that is generated might look as follows (assuming `Kernel` is a type that requires decomposition): ``` void offload_kernel_entry_point_for_kernel_name(int dm1, int dm2) { Kernel kernel{dm1, dm2}; kernel(); } ``` Other details of the generated offload kernel entry point function, such as its name and calling convention, are implementation details that need not be reflected in the AST and may differ across target devices. For that reason, only the transformation described above is represented in the AST; other details will be filled in during code generation. These transformations are represented using new AST nodes introduced with this change. `OutlinedFunctionDecl` holds a sequence of `ImplicitParamDecl` nodes and a sequence of statement nodes that correspond to the transformed parameters and function body. `SYCLKernelCallStmt` wraps the original function body and associates it with an `OutlinedFunctionDecl` instance. For the example above, the AST generated for the `sycl_kernel_entry_point<kernel_name>` specialization would look as follows: ``` FunctionDecl 'sycl_kernel_entry_point<kernel_name>(Kernel)' TemplateArgument type 'kernel_name' TemplateArgument type 'Kernel' ParmVarDecl kernel 'Kernel' SYCLKernelCallStmt CompoundStmt <original statements> OutlinedFunctionDecl ImplicitParamDecl 'dm1' 'int' ImplicitParamDecl 'dm2' 'int' CompoundStmt VarDecl 'kernel' 'Kernel' <initialization of 'kernel' with 'dm1' and 'dm2'> <transformed statements with redirected references of 'kernel'> ``` Any ODR-use of the SYCL kernel entry point function will (with future changes) suffice for the offload kernel entry point to be emitted. An actual call to the SYCL kernel entry point function will result in a call to the function. However, evaluation of a `SYCLKernelCallStmt` statement is a no-op, so such calls will have no effect other than to trigger emission of the offload kernel entry point. Additionally, as a related change inspired by code review feedback, these changes disallow use of the `sycl_kernel_entry_point` attribute with functions defined with a _function-try-block_. The SYCL 2020 specification prohibits the use of C++ exceptions in device functions. Even if exceptions were not prohibited, it is unclear what the semantics would be for an exception that escapes the SYCL kernel entry point function; the boundary between host and device code could be an implicit noexcept boundary that results in program termination if violated, or the exception could perhaps be propagated to host code via the SYCL library. Pending support for C++ exceptions in device code and clear semantics for handling them at the host-device boundary, this change makes use of the `sycl_kernel_entry_point` attribute with a function defined with a _function-try-block_ an error.
1 parent 2656928 commit 8fb4230

36 files changed

+737
-12
lines changed

clang/include/clang/AST/ASTNodeTraverser.h

+14-2
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ class ASTNodeTraverser
158158
ConstStmtVisitor<Derived>::Visit(S);
159159

160160
// Some statements have custom mechanisms for dumping their children.
161-
if (isa<DeclStmt>(S) || isa<GenericSelectionExpr>(S) ||
162-
isa<RequiresExpr>(S) || isa<OpenACCWaitConstruct>(S))
161+
if (isa<DeclStmt, GenericSelectionExpr, RequiresExpr,
162+
OpenACCWaitConstruct, SYCLKernelCallStmt>(S))
163163
return;
164164

165165
if (Traversal == TK_IgnoreUnlessSpelledInSource &&
@@ -585,6 +585,12 @@ class ASTNodeTraverser
585585

586586
void VisitTopLevelStmtDecl(const TopLevelStmtDecl *D) { Visit(D->getStmt()); }
587587

588+
void VisitOutlinedFunctionDecl(const OutlinedFunctionDecl *D) {
589+
for (const ImplicitParamDecl *Parameter : D->parameters())
590+
Visit(Parameter);
591+
Visit(D->getBody());
592+
}
593+
588594
void VisitCapturedDecl(const CapturedDecl *D) { Visit(D->getBody()); }
589595

590596
void VisitOMPThreadPrivateDecl(const OMPThreadPrivateDecl *D) {
@@ -815,6 +821,12 @@ class ASTNodeTraverser
815821
Visit(Node->getCapturedDecl());
816822
}
817823

824+
void VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *Node) {
825+
Visit(Node->getOriginalStmt());
826+
if (Traversal != TK_IgnoreUnlessSpelledInSource)
827+
Visit(Node->getOutlinedFunctionDecl());
828+
}
829+
818830
void VisitOMPExecutableDirective(const OMPExecutableDirective *Node) {
819831
for (const auto *C : Node->clauses())
820832
Visit(C);

clang/include/clang/AST/Decl.h

+77
Original file line numberDiff line numberDiff line change
@@ -4688,6 +4688,83 @@ class BlockDecl : public Decl, public DeclContext {
46884688
}
46894689
};
46904690

4691+
/// Represents a partial function definition.
4692+
///
4693+
/// An outlined function declaration contains the parameters and body of
4694+
/// a function independent of other function definition concerns such
4695+
/// as function name, type, and calling convention. Such declarations may
4696+
/// be used to hold a parameterized and transformed sequence of statements
4697+
/// used to generate a target dependent function definition without losing
4698+
/// association with the original statements. See SYCLKernelCallStmt as an
4699+
/// example.
4700+
class OutlinedFunctionDecl final
4701+
: public Decl,
4702+
public DeclContext,
4703+
private llvm::TrailingObjects<OutlinedFunctionDecl, ImplicitParamDecl *> {
4704+
private:
4705+
/// The number of parameters to the outlined function.
4706+
unsigned NumParams;
4707+
4708+
/// The body of the outlined function.
4709+
llvm::PointerIntPair<Stmt *, 1, bool> BodyAndNothrow;
4710+
4711+
explicit OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams);
4712+
4713+
ImplicitParamDecl *const *getParams() const {
4714+
return getTrailingObjects<ImplicitParamDecl *>();
4715+
}
4716+
4717+
ImplicitParamDecl **getParams() {
4718+
return getTrailingObjects<ImplicitParamDecl *>();
4719+
}
4720+
4721+
public:
4722+
friend class ASTDeclReader;
4723+
friend class ASTDeclWriter;
4724+
friend TrailingObjects;
4725+
4726+
static OutlinedFunctionDecl *Create(ASTContext &C, DeclContext *DC,
4727+
unsigned NumParams);
4728+
static OutlinedFunctionDecl *
4729+
CreateDeserialized(ASTContext &C, GlobalDeclID ID, unsigned NumParams);
4730+
4731+
Stmt *getBody() const override;
4732+
void setBody(Stmt *B);
4733+
4734+
bool isNothrow() const;
4735+
void setNothrow(bool Nothrow = true);
4736+
4737+
unsigned getNumParams() const { return NumParams; }
4738+
4739+
ImplicitParamDecl *getParam(unsigned i) const {
4740+
assert(i < NumParams);
4741+
return getParams()[i];
4742+
}
4743+
void setParam(unsigned i, ImplicitParamDecl *P) {
4744+
assert(i < NumParams);
4745+
getParams()[i] = P;
4746+
}
4747+
4748+
// Range interface to parameters.
4749+
using parameter_const_iterator = const ImplicitParamDecl *const *;
4750+
using parameter_const_range = llvm::iterator_range<parameter_const_iterator>;
4751+
parameter_const_range parameters() const {
4752+
return {param_begin(), param_end()};
4753+
}
4754+
parameter_const_iterator param_begin() const { return getParams(); }
4755+
parameter_const_iterator param_end() const { return getParams() + NumParams; }
4756+
4757+
// Implement isa/cast/dyncast/etc.
4758+
static bool classof(const Decl *D) { return classofKind(D->getKind()); }
4759+
static bool classofKind(Kind K) { return K == OutlinedFunction; }
4760+
static DeclContext *castToDeclContext(const OutlinedFunctionDecl *D) {
4761+
return static_cast<DeclContext *>(const_cast<OutlinedFunctionDecl *>(D));
4762+
}
4763+
static OutlinedFunctionDecl *castFromDeclContext(const DeclContext *DC) {
4764+
return static_cast<OutlinedFunctionDecl *>(const_cast<DeclContext *>(DC));
4765+
}
4766+
};
4767+
46914768
/// Represents the body of a CapturedStmt, and serves as its DeclContext.
46924769
class CapturedDecl final
46934770
: public Decl,

clang/include/clang/AST/RecursiveASTVisitor.h

+14
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "clang/AST/StmtObjC.h"
3838
#include "clang/AST/StmtOpenACC.h"
3939
#include "clang/AST/StmtOpenMP.h"
40+
#include "clang/AST/StmtSYCL.h"
4041
#include "clang/AST/TemplateBase.h"
4142
#include "clang/AST/TemplateName.h"
4243
#include "clang/AST/Type.h"
@@ -1581,6 +1582,11 @@ DEF_TRAVERSE_DECL(BlockDecl, {
15811582
ShouldVisitChildren = false;
15821583
})
15831584

1585+
DEF_TRAVERSE_DECL(OutlinedFunctionDecl, {
1586+
TRY_TO(TraverseStmt(D->getBody()));
1587+
ShouldVisitChildren = false;
1588+
})
1589+
15841590
DEF_TRAVERSE_DECL(CapturedDecl, {
15851591
TRY_TO(TraverseStmt(D->getBody()));
15861592
ShouldVisitChildren = false;
@@ -2904,6 +2910,14 @@ DEF_TRAVERSE_STMT(SEHFinallyStmt, {})
29042910
DEF_TRAVERSE_STMT(SEHLeaveStmt, {})
29052911
DEF_TRAVERSE_STMT(CapturedStmt, { TRY_TO(TraverseDecl(S->getCapturedDecl())); })
29062912

2913+
DEF_TRAVERSE_STMT(SYCLKernelCallStmt, {
2914+
if (getDerived().shouldVisitImplicitCode()) {
2915+
TRY_TO(TraverseStmt(S->getOriginalStmt()));
2916+
TRY_TO(TraverseDecl(S->getOutlinedFunctionDecl()));
2917+
ShouldVisitChildren = false;
2918+
}
2919+
})
2920+
29072921
DEF_TRAVERSE_STMT(CXXOperatorCallExpr, {})
29082922
DEF_TRAVERSE_STMT(CXXRewrittenBinaryOperator, {
29092923
if (!getDerived().shouldVisitImplicitCode()) {

clang/include/clang/AST/StmtSYCL.h

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
//===- StmtSYCL.h - Classes for SYCL kernel calls ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
/// \file
9+
/// This file defines SYCL AST classes used to represent calls to SYCL kernels.
10+
//===----------------------------------------------------------------------===//
11+
12+
#ifndef LLVM_CLANG_AST_STMTSYCL_H
13+
#define LLVM_CLANG_AST_STMTSYCL_H
14+
15+
#include "clang/AST/ASTContext.h"
16+
#include "clang/AST/Decl.h"
17+
#include "clang/AST/Stmt.h"
18+
#include "clang/Basic/SourceLocation.h"
19+
20+
namespace clang {
21+
22+
//===----------------------------------------------------------------------===//
23+
// AST classes for SYCL kernel calls.
24+
//===----------------------------------------------------------------------===//
25+
26+
/// SYCLKernelCallStmt represents the transformation that is applied to the body
27+
/// of a function declared with the sycl_kernel_entry_point attribute. The body
28+
/// of such a function specifies the statements to be executed on a SYCL device
29+
/// to invoke a SYCL kernel with a particular set of kernel arguments. The
30+
/// SYCLKernelCallStmt associates an original statement (the compound statement
31+
/// that is the function body) with an OutlinedFunctionDecl that holds the
32+
/// kernel parameters and the transformed body. During code generation, the
33+
/// OutlinedFunctionDecl is used to emit an offload kernel entry point suitable
34+
/// for invocation from a SYCL library implementation. If executed, the
35+
/// SYCLKernelCallStmt behaves as a no-op; no code generation is performed for
36+
/// it.
37+
class SYCLKernelCallStmt : public Stmt {
38+
friend class ASTStmtReader;
39+
friend class ASTStmtWriter;
40+
41+
private:
42+
Stmt *OriginalStmt = nullptr;
43+
OutlinedFunctionDecl *OFDecl = nullptr;
44+
45+
public:
46+
/// Construct a SYCL kernel call statement.
47+
SYCLKernelCallStmt(CompoundStmt *CS, OutlinedFunctionDecl *OFD)
48+
: Stmt(SYCLKernelCallStmtClass), OriginalStmt(CS), OFDecl(OFD) {}
49+
50+
/// Construct an empty SYCL kernel call statement.
51+
SYCLKernelCallStmt(EmptyShell Empty) : Stmt(SYCLKernelCallStmtClass, Empty) {}
52+
53+
/// Retrieve the model statement.
54+
CompoundStmt *getOriginalStmt() { return cast<CompoundStmt>(OriginalStmt); }
55+
const CompoundStmt *getOriginalStmt() const {
56+
return cast<CompoundStmt>(OriginalStmt);
57+
}
58+
void setOriginalStmt(CompoundStmt *CS) { OriginalStmt = CS; }
59+
60+
/// Retrieve the outlined function declaration.
61+
OutlinedFunctionDecl *getOutlinedFunctionDecl() { return OFDecl; }
62+
const OutlinedFunctionDecl *getOutlinedFunctionDecl() const { return OFDecl; }
63+
64+
/// Set the outlined function declaration.
65+
void setOutlinedFunctionDecl(OutlinedFunctionDecl *OFD) { OFDecl = OFD; }
66+
67+
SourceLocation getBeginLoc() const LLVM_READONLY {
68+
return getOriginalStmt()->getBeginLoc();
69+
}
70+
71+
SourceLocation getEndLoc() const LLVM_READONLY {
72+
return getOriginalStmt()->getEndLoc();
73+
}
74+
75+
SourceRange getSourceRange() const LLVM_READONLY {
76+
return getOriginalStmt()->getSourceRange();
77+
}
78+
79+
static bool classof(const Stmt *T) {
80+
return T->getStmtClass() == SYCLKernelCallStmtClass;
81+
}
82+
83+
child_range children() {
84+
return child_range(&OriginalStmt, &OriginalStmt + 1);
85+
}
86+
87+
const_child_range children() const {
88+
return const_child_range(&OriginalStmt, &OriginalStmt + 1);
89+
}
90+
};
91+
92+
} // end namespace clang
93+
94+
#endif // LLVM_CLANG_AST_STMTSYCL_H

clang/include/clang/AST/StmtVisitor.h

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "clang/AST/StmtObjC.h"
2323
#include "clang/AST/StmtOpenACC.h"
2424
#include "clang/AST/StmtOpenMP.h"
25+
#include "clang/AST/StmtSYCL.h"
2526
#include "clang/Basic/LLVM.h"
2627
#include "llvm/ADT/STLExtras.h"
2728
#include "llvm/Support/Casting.h"

clang/include/clang/Basic/AttrDocs.td

+1
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ following requirements.
487487
* Is not a C variadic function.
488488
* Is not a coroutine.
489489
* Is not defined as deleted or as defaulted.
490+
* Is not defined with a function try block.
490491
* Is not declared with the ``constexpr`` or ``consteval`` specifiers.
491492
* Is not declared with the ``[[noreturn]]`` attribute.
492493

clang/include/clang/Basic/DeclNodes.td

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def Friend : DeclNode<Decl>;
101101
def FriendTemplate : DeclNode<Decl>;
102102
def StaticAssert : DeclNode<Decl>;
103103
def Block : DeclNode<Decl, "blocks">, DeclContext;
104+
def OutlinedFunction : DeclNode<Decl>, DeclContext;
104105
def Captured : DeclNode<Decl>, DeclContext;
105106
def Import : DeclNode<Decl>;
106107
def OMPThreadPrivate : DeclNode<Decl>;

clang/include/clang/Basic/DiagnosticSemaKinds.td

+2-1
Original file line numberDiff line numberDiff line change
@@ -12457,7 +12457,8 @@ def err_sycl_entry_point_invalid : Error<
1245712457
"'sycl_kernel_entry_point' attribute cannot be applied to a"
1245812458
" %select{non-static member function|variadic function|deleted function|"
1245912459
"defaulted function|constexpr function|consteval function|"
12460-
"function declared with the 'noreturn' attribute|coroutine}0">;
12460+
"function declared with the 'noreturn' attribute|coroutine|"
12461+
"function defined with a function try block}0">;
1246112462
def err_sycl_entry_point_invalid_redeclaration : Error<
1246212463
"'sycl_kernel_entry_point' kernel name argument does not match prior"
1246312464
" declaration%diff{: $ vs $|}0,1">;

clang/include/clang/Basic/StmtNodes.td

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def SwitchCase : StmtNode<Stmt, 1>;
2424
def CaseStmt : StmtNode<SwitchCase>;
2525
def DefaultStmt : StmtNode<SwitchCase>;
2626
def CapturedStmt : StmtNode<Stmt>;
27+
def SYCLKernelCallStmt : StmtNode<Stmt>;
2728

2829
// Statements that might produce a value (for example, as the last non-null
2930
// statement in a GNU statement-expression).

clang/include/clang/Sema/SemaSYCL.h

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class SemaSYCL : public SemaBase {
6565
void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL);
6666

6767
void CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD);
68+
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body);
6869
};
6970

7071
} // namespace clang

clang/include/clang/Sema/Template.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,10 @@ enum class TemplateSubstitutionKind : char {
627627
#define EMPTY(DERIVED, BASE)
628628
#define LIFETIMEEXTENDEDTEMPORARY(DERIVED, BASE)
629629

630-
// Decls which use special-case instantiation code.
630+
// Decls which never appear inside a template.
631+
#define OUTLINEDFUNCTION(DERIVED, BASE)
632+
633+
// Decls which use special-case instantiation code.
631634
#define BLOCK(DERIVED, BASE)
632635
#define CAPTURED(DERIVED, BASE)
633636
#define IMPLICITPARAM(DERIVED, BASE)

clang/include/clang/Serialization/ASTBitCodes.h

+6
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,9 @@ enum DeclCode {
13161316
/// A BlockDecl record.
13171317
DECL_BLOCK,
13181318

1319+
/// A OutlinedFunctionDecl record.
1320+
DECL_OUTLINEDFUNCTION,
1321+
13191322
/// A CapturedDecl record.
13201323
DECL_CAPTURED,
13211324

@@ -1600,6 +1603,9 @@ enum StmtCode {
16001603
/// A CapturedStmt record.
16011604
STMT_CAPTURED,
16021605

1606+
/// A SYCLKernelCallStmt record.
1607+
STMT_SYCLKERNELCALL,
1608+
16031609
/// A GCC-style AsmStmt record.
16041610
STMT_GCCASM,
16051611

clang/lib/AST/ASTStructuralEquivalence.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
#include "clang/AST/StmtObjC.h"
7777
#include "clang/AST/StmtOpenACC.h"
7878
#include "clang/AST/StmtOpenMP.h"
79+
#include "clang/AST/StmtSYCL.h"
7980
#include "clang/AST/TemplateBase.h"
8081
#include "clang/AST/TemplateName.h"
8182
#include "clang/AST/Type.h"

clang/lib/AST/Decl.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -5459,6 +5459,35 @@ BlockDecl *BlockDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) {
54595459
return new (C, ID) BlockDecl(nullptr, SourceLocation());
54605460
}
54615461

5462+
OutlinedFunctionDecl::OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams)
5463+
: Decl(OutlinedFunction, DC, SourceLocation()),
5464+
DeclContext(OutlinedFunction), NumParams(NumParams),
5465+
BodyAndNothrow(nullptr, false) {}
5466+
5467+
OutlinedFunctionDecl *OutlinedFunctionDecl::Create(ASTContext &C,
5468+
DeclContext *DC,
5469+
unsigned NumParams) {
5470+
return new (C, DC, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
5471+
OutlinedFunctionDecl(DC, NumParams);
5472+
}
5473+
5474+
OutlinedFunctionDecl *
5475+
OutlinedFunctionDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID,
5476+
unsigned NumParams) {
5477+
return new (C, ID, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
5478+
OutlinedFunctionDecl(nullptr, NumParams);
5479+
}
5480+
5481+
Stmt *OutlinedFunctionDecl::getBody() const {
5482+
return BodyAndNothrow.getPointer();
5483+
}
5484+
void OutlinedFunctionDecl::setBody(Stmt *B) { BodyAndNothrow.setPointer(B); }
5485+
5486+
bool OutlinedFunctionDecl::isNothrow() const { return BodyAndNothrow.getInt(); }
5487+
void OutlinedFunctionDecl::setNothrow(bool Nothrow) {
5488+
BodyAndNothrow.setInt(Nothrow);
5489+
}
5490+
54625491
CapturedDecl::CapturedDecl(DeclContext *DC, unsigned NumParams)
54635492
: Decl(Captured, DC, SourceLocation()), DeclContext(Captured),
54645493
NumParams(NumParams), ContextParam(0), BodyAndNothrow(nullptr, false) {}

0 commit comments

Comments
 (0)