-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SYCL] AST support for SYCL kernel entry point functions. #122379
Conversation
@llvm/pr-subscribers-clang-static-analyzer-1 @llvm/pr-subscribers-clang Author: Tom Honermann (tahonermann) ChangesA SYCL kernel entry point function is a non-member function or a static member function declared with the The offload kernel entry point function (sometimes referred to as the SYCL kernel caller function) is generated from the SYCL kernel entry point function by a transformation of the function parameters followed by a transformation of the function body to replace references to the original parameters with references to the transformed ones. Exactly how parameters are transformed will be explained in a future change that implements non-trivial transformations. For now, it suffices to state that a given parameter of the SYCL kernel entry point function may be transformed to multiple parameters of the offload kernel entry point as needed to satisfy offload kernel argument passing requirements. Parameters that are decomposed in this way are reconstituted as local variables in the body of the generated offload kernel entry point function. For example, given the following SYCL kernel entry point function definition:
and the following call:
the corresponding offload kernel entry point function that is generated might look as follows (assuming
Other details of the generated offload kernel entry point function, such as its name and calling convention, are implementation details that need not be reflected in the AST and may differ across target devices. For that reason, only the transformation described above is represented in the AST; other details will be filled in during code generation. These transformations are represented using new AST nodes introduced with this change.
Any ODR-use of the SYCL kernel entry point function will (with future changes) suffice for the offload kernel entry point to be emitted. An actual call to the SYCL kernel entry point function will result in a call to the function. However, evaluation of a Patch is 52.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122379.diff 34 Files Affected:
diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h
index f5652b295de168..3bc0bdff2bdd12 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -158,8 +158,8 @@ class ASTNodeTraverser
ConstStmtVisitor<Derived>::Visit(S);
// Some statements have custom mechanisms for dumping their children.
- if (isa<DeclStmt>(S) || isa<GenericSelectionExpr>(S) ||
- isa<RequiresExpr>(S) || isa<OpenACCWaitConstruct>(S))
+ if (isa<DeclStmt, GenericSelectionExpr, RequiresExpr,
+ OpenACCWaitConstruct, SYCLKernelCallStmt>(S))
return;
if (Traversal == TK_IgnoreUnlessSpelledInSource &&
@@ -585,6 +585,12 @@ class ASTNodeTraverser
void VisitTopLevelStmtDecl(const TopLevelStmtDecl *D) { Visit(D->getStmt()); }
+ void VisitOutlinedFunctionDecl(const OutlinedFunctionDecl *D) {
+ for (const ImplicitParamDecl *Parameter : D->parameters())
+ Visit(Parameter);
+ Visit(D->getBody());
+ }
+
void VisitCapturedDecl(const CapturedDecl *D) { Visit(D->getBody()); }
void VisitOMPThreadPrivateDecl(const OMPThreadPrivateDecl *D) {
@@ -815,6 +821,12 @@ class ASTNodeTraverser
Visit(Node->getCapturedDecl());
}
+ void VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *Node) {
+ Visit(Node->getOriginalStmt());
+ if (Traversal != TK_IgnoreUnlessSpelledInSource)
+ Visit(Node->getOutlinedFunctionDecl());
+ }
+
void VisitOMPExecutableDirective(const OMPExecutableDirective *Node) {
for (const auto *C : Node->clauses())
Visit(C);
diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h
index 16fc98aa1a57f3..901ec1e48ca080 100644
--- a/clang/include/clang/AST/Decl.h
+++ b/clang/include/clang/AST/Decl.h
@@ -4678,6 +4678,96 @@ class BlockDecl : public Decl, public DeclContext {
}
};
+/// Represents a partial function definition.
+///
+/// An outlined function declaration contains the parameters and body of
+/// a function independent of other function definition concerns such
+/// as function name, type, and calling convention. Such declarations may
+/// be used to hold a parameterized and transformed sequence of statements
+/// used to generate a target dependent function definition without losing
+/// association with the original statements. See SYCLKernelCallStmt as an
+/// example.
+class OutlinedFunctionDecl final
+ : public Decl,
+ public DeclContext,
+ private llvm::TrailingObjects<OutlinedFunctionDecl, ImplicitParamDecl *> {
+protected:
+ size_t numTrailingObjects(OverloadToken<ImplicitParamDecl>) {
+ return NumParams;
+ }
+
+private:
+ /// The number of parameters to the outlined function.
+ unsigned NumParams;
+
+ /// The body of the outlined function.
+ llvm::PointerIntPair<Stmt *, 1, bool> BodyAndNothrow;
+
+ explicit OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams);
+
+ ImplicitParamDecl *const *getParams() const {
+ return getTrailingObjects<ImplicitParamDecl *>();
+ }
+
+ ImplicitParamDecl **getParams() {
+ return getTrailingObjects<ImplicitParamDecl *>();
+ }
+
+public:
+ friend class ASTDeclReader;
+ friend class ASTDeclWriter;
+ friend TrailingObjects;
+
+ static OutlinedFunctionDecl *Create(ASTContext &C, DeclContext *DC,
+ unsigned NumParams);
+ static OutlinedFunctionDecl *CreateDeserialized(ASTContext &C,
+ GlobalDeclID ID,
+ unsigned NumParams);
+
+ Stmt *getBody() const override;
+ void setBody(Stmt *B);
+
+ bool isNothrow() const;
+ void setNothrow(bool Nothrow = true);
+
+ unsigned getNumParams() const { return NumParams; }
+
+ ImplicitParamDecl *getParam(unsigned i) const {
+ assert(i < NumParams);
+ return getParams()[i];
+ }
+ void setParam(unsigned i, ImplicitParamDecl *P) {
+ assert(i < NumParams);
+ getParams()[i] = P;
+ }
+
+ // ArrayRef interface to parameters.
+ ArrayRef<ImplicitParamDecl *> parameters() const {
+ return {getParams(), getNumParams()};
+ }
+ MutableArrayRef<ImplicitParamDecl *> parameters() {
+ return {getParams(), getNumParams()};
+ }
+
+ using param_iterator = ImplicitParamDecl *const *;
+ using param_range = llvm::iterator_range<param_iterator>;
+
+ /// Retrieve an iterator pointing to the first parameter decl.
+ param_iterator param_begin() const { return getParams(); }
+ /// Retrieve an iterator one past the last parameter decl.
+ param_iterator param_end() const { return getParams() + NumParams; }
+
+ // Implement isa/cast/dyncast/etc.
+ static bool classof(const Decl *D) { return classofKind(D->getKind()); }
+ static bool classofKind(Kind K) { return K == OutlinedFunction; }
+ static DeclContext *castToDeclContext(const OutlinedFunctionDecl *D) {
+ return static_cast<DeclContext *>(const_cast<OutlinedFunctionDecl *>(D));
+ }
+ static OutlinedFunctionDecl *castFromDeclContext(const DeclContext *DC) {
+ return static_cast<OutlinedFunctionDecl *>(const_cast<DeclContext *>(DC));
+ }
+};
+
/// Represents the body of a CapturedStmt, and serves as its DeclContext.
class CapturedDecl final
: public Decl,
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index d500f4eadef757..c4a1d03f1b3d10 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -37,6 +37,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
@@ -1581,6 +1582,11 @@ DEF_TRAVERSE_DECL(BlockDecl, {
ShouldVisitChildren = false;
})
+DEF_TRAVERSE_DECL(OutlinedFunctionDecl, {
+ TRY_TO(TraverseStmt(D->getBody()));
+ ShouldVisitChildren = false;
+})
+
DEF_TRAVERSE_DECL(CapturedDecl, {
TRY_TO(TraverseStmt(D->getBody()));
ShouldVisitChildren = false;
@@ -2904,6 +2910,14 @@ DEF_TRAVERSE_STMT(SEHFinallyStmt, {})
DEF_TRAVERSE_STMT(SEHLeaveStmt, {})
DEF_TRAVERSE_STMT(CapturedStmt, { TRY_TO(TraverseDecl(S->getCapturedDecl())); })
+DEF_TRAVERSE_STMT(SYCLKernelCallStmt, {
+ if (getDerived().shouldVisitImplicitCode()) {
+ TRY_TO(TraverseStmt(S->getOriginalStmt()));
+ TRY_TO(TraverseDecl(S->getOutlinedFunctionDecl()));
+ ShouldVisitChildren = false;
+ }
+})
+
DEF_TRAVERSE_STMT(CXXOperatorCallExpr, {})
DEF_TRAVERSE_STMT(CXXRewrittenBinaryOperator, {
if (!getDerived().shouldVisitImplicitCode()) {
diff --git a/clang/include/clang/AST/StmtSYCL.h b/clang/include/clang/AST/StmtSYCL.h
new file mode 100644
index 00000000000000..ac356cb0bf3841
--- /dev/null
+++ b/clang/include/clang/AST/StmtSYCL.h
@@ -0,0 +1,95 @@
+//===- StmtSYCL.h - Classes for SYCL kernel calls ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file defines SYCL AST classes used to represent calls to SYCL kernels.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_STMTSYCL_H
+#define LLVM_CLANG_AST_STMTSYCL_H
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/Stmt.h"
+#include "clang/Basic/SourceLocation.h"
+
+namespace clang {
+
+//===----------------------------------------------------------------------===//
+// AST classes for SYCL kernel calls.
+//===----------------------------------------------------------------------===//
+
+/// SYCLKernelCallStmt represents the transformation that is applied to the body
+/// of a function declared with the sycl_kernel_entry_point attribute. The body
+/// of such a function specifies the statements to be executed on a SYCL device
+/// to invoke a SYCL kernel with a particular set of kernel arguments. The
+/// SYCLKernelCallStmt associates an original statement (the compound statement
+/// that is the function body) with an OutlinedFunctionDecl that holds the
+/// kernel parameters and the transformed body. During code generation, the
+/// OutlinedFunctionDecl is used to emit an offload kernel entry point suitable
+/// for invocation from a SYCL library implementation. If executed, the
+/// SYCLKernelCallStmt behaves as a no-op; no code generation is performed for
+/// it.
+class SYCLKernelCallStmt : public Stmt {
+ friend class ASTStmtReader;
+ friend class ASTStmtWriter;
+
+private:
+ Stmt *OriginalStmt = nullptr;
+ OutlinedFunctionDecl *OFDecl = nullptr;
+
+public:
+ /// Construct a SYCL kernel call statement.
+ SYCLKernelCallStmt(Stmt *OS, OutlinedFunctionDecl *OFD)
+ : Stmt(SYCLKernelCallStmtClass), OriginalStmt(OS), OFDecl(OFD) {}
+
+ /// Construct an empty SYCL kernel call statement.
+ SYCLKernelCallStmt(EmptyShell Empty)
+ : Stmt(SYCLKernelCallStmtClass, Empty) {}
+
+ /// Retrieve the model statement.
+ Stmt *getOriginalStmt() { return OriginalStmt; }
+ const Stmt *getOriginalStmt() const { return OriginalStmt; }
+ void setOriginalStmt(Stmt *S) { OriginalStmt = S; }
+
+ /// Retrieve the outlined function declaration.
+ OutlinedFunctionDecl *getOutlinedFunctionDecl() { return OFDecl; }
+ const OutlinedFunctionDecl *getOutlinedFunctionDecl() const { return OFDecl; }
+
+ /// Set the outlined function declaration.
+ void setOutlinedFunctionDecl(OutlinedFunctionDecl *OFD) {
+ OFDecl = OFD;
+ }
+
+ SourceLocation getBeginLoc() const LLVM_READONLY {
+ return getOriginalStmt()->getBeginLoc();
+ }
+
+ SourceLocation getEndLoc() const LLVM_READONLY {
+ return getOriginalStmt()->getEndLoc();
+ }
+
+ SourceRange getSourceRange() const LLVM_READONLY {
+ return getOriginalStmt()->getSourceRange();
+ }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == SYCLKernelCallStmtClass;
+ }
+
+ child_range children() {
+ return child_range(&OriginalStmt, &OriginalStmt + 1);
+ }
+
+ const_child_range children() const {
+ return const_child_range(&OriginalStmt, &OriginalStmt + 1);
+ }
+};
+
+} // end namespace clang
+
+#endif
diff --git a/clang/include/clang/AST/StmtVisitor.h b/clang/include/clang/AST/StmtVisitor.h
index 990aa2df180d43..8b7b728deaff27 100644
--- a/clang/include/clang/AST/StmtVisitor.h
+++ b/clang/include/clang/AST/StmtVisitor.h
@@ -22,6 +22,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/Basic/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
diff --git a/clang/include/clang/Basic/DeclNodes.td b/clang/include/clang/Basic/DeclNodes.td
index 48396e85c5adac..723113dc2486e0 100644
--- a/clang/include/clang/Basic/DeclNodes.td
+++ b/clang/include/clang/Basic/DeclNodes.td
@@ -101,6 +101,7 @@ def Friend : DeclNode<Decl>;
def FriendTemplate : DeclNode<Decl>;
def StaticAssert : DeclNode<Decl>;
def Block : DeclNode<Decl, "blocks">, DeclContext;
+def OutlinedFunction : DeclNode<Decl>, DeclContext;
def Captured : DeclNode<Decl>, DeclContext;
def Import : DeclNode<Decl>;
def OMPThreadPrivate : DeclNode<Decl>;
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index ce2c48bd3c84e9..53fc77bbbcecc1 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -24,6 +24,7 @@ def SwitchCase : StmtNode<Stmt, 1>;
def CaseStmt : StmtNode<SwitchCase>;
def DefaultStmt : StmtNode<SwitchCase>;
def CapturedStmt : StmtNode<Stmt>;
+def SYCLKernelCallStmt : StmtNode<Stmt>;
// Statements that might produce a value (for example, as the last non-null
// statement in a GNU statement-expression).
diff --git a/clang/include/clang/Sema/SemaSYCL.h b/clang/include/clang/Sema/SemaSYCL.h
index 5bb0de40c886c7..b4f607d1287bc7 100644
--- a/clang/include/clang/Sema/SemaSYCL.h
+++ b/clang/include/clang/Sema/SemaSYCL.h
@@ -65,6 +65,7 @@ class SemaSYCL : public SemaBase {
void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL);
void CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD);
+ StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, Stmt *Body);
};
} // namespace clang
diff --git a/clang/include/clang/Sema/Template.h b/clang/include/clang/Sema/Template.h
index 9800f75f676aaf..4206bd50b13dd6 100644
--- a/clang/include/clang/Sema/Template.h
+++ b/clang/include/clang/Sema/Template.h
@@ -627,7 +627,10 @@ enum class TemplateSubstitutionKind : char {
#define EMPTY(DERIVED, BASE)
#define LIFETIMEEXTENDEDTEMPORARY(DERIVED, BASE)
- // Decls which use special-case instantiation code.
+// Decls which never appear inside a template.
+#define OUTLINEDFUNCTION(DERIVED, BASE)
+
+// Decls which use special-case instantiation code.
#define BLOCK(DERIVED, BASE)
#define CAPTURED(DERIVED, BASE)
#define IMPLICITPARAM(DERIVED, BASE)
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index aac165130b7192..87ac62551a1424 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1312,6 +1312,9 @@ enum DeclCode {
/// A BlockDecl record.
DECL_BLOCK,
+ /// A OutlinedFunctionDecl record.
+ DECL_OUTLINEDFUNCTION,
+
/// A CapturedDecl record.
DECL_CAPTURED,
@@ -1588,6 +1591,9 @@ enum StmtCode {
/// A CapturedStmt record.
STMT_CAPTURED,
+ /// A SYCLKernelCallStmt record.
+ STMT_SYCLKERNELCALL,
+
/// A GCC-style AsmStmt record.
STMT_GCCASM,
diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index 308551c3061510..eaf0748395268b 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -76,6 +76,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp
index 97e23dd1aaa920..5bce2c37bf0584 100644
--- a/clang/lib/AST/Decl.cpp
+++ b/clang/lib/AST/Decl.cpp
@@ -5440,6 +5440,31 @@ BlockDecl *BlockDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) {
return new (C, ID) BlockDecl(nullptr, SourceLocation());
}
+
+OutlinedFunctionDecl::OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams)
+ : Decl(OutlinedFunction, DC, SourceLocation()), DeclContext(OutlinedFunction),
+ NumParams(NumParams), BodyAndNothrow(nullptr, false) {}
+
+OutlinedFunctionDecl *OutlinedFunctionDecl::Create(ASTContext &C, DeclContext *DC,
+ unsigned NumParams) {
+ return new (C, DC, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
+ OutlinedFunctionDecl(DC, NumParams);
+}
+
+OutlinedFunctionDecl *OutlinedFunctionDecl::CreateDeserialized(ASTContext &C,
+ GlobalDeclID ID,
+ unsigned NumParams) {
+ return new (C, ID, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
+ OutlinedFunctionDecl(nullptr, NumParams);
+}
+
+Stmt *OutlinedFunctionDecl::getBody() const { return BodyAndNothrow.getPointer(); }
+void OutlinedFunctionDecl::setBody(Stmt *B) { BodyAndNothrow.setPointer(B); }
+
+bool OutlinedFunctionDecl::isNothrow() const { return BodyAndNothrow.getInt(); }
+void OutlinedFunctionDecl::setNothrow(bool Nothrow) { BodyAndNothrow.setInt(Nothrow); }
+
+
CapturedDecl::CapturedDecl(DeclContext *DC, unsigned NumParams)
: Decl(Captured, DC, SourceLocation()), DeclContext(Captured),
NumParams(NumParams), ContextParam(0), BodyAndNothrow(nullptr, false) {}
diff --git a/clang/lib/AST/DeclBase.cpp b/clang/lib/AST/DeclBase.cpp
index fb701f76231bcd..77ca8c5c8accd0 100644
--- a/clang/lib/AST/DeclBase.cpp
+++ b/clang/lib/AST/DeclBase.cpp
@@ -958,6 +958,7 @@ unsigned Decl::getIdentifierNamespaceForKind(Kind DeclKind) {
case PragmaDetectMismatch:
case Block:
case Captured:
+ case OutlinedFunction:
case TranslationUnit:
case ExternCContext:
case Decomposition:
@@ -1237,6 +1238,8 @@ template <class T> static Decl *getNonClosureContext(T *D) {
return getNonClosureContext(BD->getParent());
if (auto *CD = dyn_cast<CapturedDecl>(D))
return getNonClosureContext(CD->getParent());
+ if (auto *OFD = dyn_cast<OutlinedFunctionDecl>(D))
+ return getNonClosureContext(OFD->getParent());
return nullptr;
}
@@ -1429,6 +1432,7 @@ DeclContext *DeclContext::getPrimaryContext() {
case Decl::TopLevelStmt:
case Decl::Block:
case Decl::Captured:
+ case Decl::OutlinedFunction:
case Decl::OMPDeclareReduction:
case Decl::OMPDeclareMapper:
case Decl::RequiresExprBody:
diff --git a/clang/lib/AST/Stmt.cpp b/clang/lib/AST/Stmt.cpp
index d6a351a78c7ba8..685c00d0cb44f8 100644
--- a/clang/lib/AST/Stmt.cpp
+++ b/clang/lib/AST/Stmt.cpp
@@ -25,6 +25,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/AST/Type.h"
#include "clang/Basic/CharInfo.h"
#include "clang/Basic/LLVM.h"
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 52bcb5135d3513..b5def6fbe525c3 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -30,6 +30,7 @@
#include "clang/AST/StmtCXX.h"
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/Type.h"
@@ -582,6 +583,10 @@ void StmtPrinter::VisitCapturedStmt(CapturedStmt *Node) {
PrintStmt(Node->getCapturedDecl()->getBody());
}
+void StmtPrinter::VisitSYCLKernelCallStmt(SYCLKernelCallStmt *Node) {
+ PrintStmt(Node->getOutlinedFunctionDecl()->getBody());
+}
+
void StmtPrinter::VisitObjCAtTryStmt(ObjCAtTryStmt *Node) {
Indent() << "@try";
if (auto *TS = dyn_cast<CompoundStmt>(Node->getTryBody())) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 0f1ebc68a4f762..85b59f714ba845 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -392,6 +392,10 @@ void StmtProfiler::VisitCapturedStmt(const CapturedStmt *S) {
VisitStmt(S);
}
+void StmtProfiler::VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *S) {
+ VisitStmt(S);
+}
+
void StmtProfiler::VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
VisitStmt(S);
}
diff --git a/clang/lib/CodeGen/CGDecl.cpp b/clang/lib/CodeGen/CGDecl.cpp
index 6f3ff050cb6978..cda1a15e92f622 100644
--- a/clang/lib/CodeGen/CGDecl.cpp
+++ b/clang/lib/CodeGen/CGDecl.cpp
@@ -97,6 +97,7 @@ void CodeGenFunction::EmitDecl(const Decl &D) {
case Decl::Friend:
case Decl::FriendTemplate:
case Decl::Block:
+ case Decl::OutlinedFunction:
case Decl::Captured:
case Decl::UsingShadow:
case Decl::ConstructorUsingShadow:
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index c87ec899798aa2..9d8baa85ba5923 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -114,6 +114,7 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
case Stmt::DefaultStmtClass:
case Stmt::CaseStmtClass:
case Stmt::SEHLeaveStmtClass:
+ case Stmt::SYCLKernelCallStmtClass:
llvm_unreachable("should have emitted these statements as simple");
#define STMT(Type, Base)
@@ -527,6 +528,23 @@ bool CodeGenFunction::EmitSimpleStmt(const Stmt *S,
case Stmt::SEHLeaveStmtClass:
EmitSEHLeaveStmt(cast<SEHLeaveStmt>(*S));
break;
+ case Stmt::SYCLKernelCallStmtClass:
+ // SYCL kernel call statements are generated as wrappers around the body
+ // of functions declared with the sycl_kernel_entry_point attribute. Such
+ // functions are used to specify how a SYCL kernel (a function object) is
+ // to be invoked; the SYCL kernel call statement contains a transformed
+ // variation of the function body and...
[truncated]
|
@llvm/pr-subscribers-clang-codegen Author: Tom Honermann (tahonermann) ChangesA SYCL kernel entry point function is a non-member function or a static member function declared with the The offload kernel entry point function (sometimes referred to as the SYCL kernel caller function) is generated from the SYCL kernel entry point function by a transformation of the function parameters followed by a transformation of the function body to replace references to the original parameters with references to the transformed ones. Exactly how parameters are transformed will be explained in a future change that implements non-trivial transformations. For now, it suffices to state that a given parameter of the SYCL kernel entry point function may be transformed to multiple parameters of the offload kernel entry point as needed to satisfy offload kernel argument passing requirements. Parameters that are decomposed in this way are reconstituted as local variables in the body of the generated offload kernel entry point function. For example, given the following SYCL kernel entry point function definition:
and the following call:
the corresponding offload kernel entry point function that is generated might look as follows (assuming
Other details of the generated offload kernel entry point function, such as its name and calling convention, are implementation details that need not be reflected in the AST and may differ across target devices. For that reason, only the transformation described above is represented in the AST; other details will be filled in during code generation. These transformations are represented using new AST nodes introduced with this change.
Any ODR-use of the SYCL kernel entry point function will (with future changes) suffice for the offload kernel entry point to be emitted. An actual call to the SYCL kernel entry point function will result in a call to the function. However, evaluation of a Patch is 52.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122379.diff 34 Files Affected:
diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h
index f5652b295de168..3bc0bdff2bdd12 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -158,8 +158,8 @@ class ASTNodeTraverser
ConstStmtVisitor<Derived>::Visit(S);
// Some statements have custom mechanisms for dumping their children.
- if (isa<DeclStmt>(S) || isa<GenericSelectionExpr>(S) ||
- isa<RequiresExpr>(S) || isa<OpenACCWaitConstruct>(S))
+ if (isa<DeclStmt, GenericSelectionExpr, RequiresExpr,
+ OpenACCWaitConstruct, SYCLKernelCallStmt>(S))
return;
if (Traversal == TK_IgnoreUnlessSpelledInSource &&
@@ -585,6 +585,12 @@ class ASTNodeTraverser
void VisitTopLevelStmtDecl(const TopLevelStmtDecl *D) { Visit(D->getStmt()); }
+ void VisitOutlinedFunctionDecl(const OutlinedFunctionDecl *D) {
+ for (const ImplicitParamDecl *Parameter : D->parameters())
+ Visit(Parameter);
+ Visit(D->getBody());
+ }
+
void VisitCapturedDecl(const CapturedDecl *D) { Visit(D->getBody()); }
void VisitOMPThreadPrivateDecl(const OMPThreadPrivateDecl *D) {
@@ -815,6 +821,12 @@ class ASTNodeTraverser
Visit(Node->getCapturedDecl());
}
+ void VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *Node) {
+ Visit(Node->getOriginalStmt());
+ if (Traversal != TK_IgnoreUnlessSpelledInSource)
+ Visit(Node->getOutlinedFunctionDecl());
+ }
+
void VisitOMPExecutableDirective(const OMPExecutableDirective *Node) {
for (const auto *C : Node->clauses())
Visit(C);
diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h
index 16fc98aa1a57f3..901ec1e48ca080 100644
--- a/clang/include/clang/AST/Decl.h
+++ b/clang/include/clang/AST/Decl.h
@@ -4678,6 +4678,96 @@ class BlockDecl : public Decl, public DeclContext {
}
};
+/// Represents a partial function definition.
+///
+/// An outlined function declaration contains the parameters and body of
+/// a function independent of other function definition concerns such
+/// as function name, type, and calling convention. Such declarations may
+/// be used to hold a parameterized and transformed sequence of statements
+/// used to generate a target dependent function definition without losing
+/// association with the original statements. See SYCLKernelCallStmt as an
+/// example.
+class OutlinedFunctionDecl final
+ : public Decl,
+ public DeclContext,
+ private llvm::TrailingObjects<OutlinedFunctionDecl, ImplicitParamDecl *> {
+protected:
+ size_t numTrailingObjects(OverloadToken<ImplicitParamDecl>) {
+ return NumParams;
+ }
+
+private:
+ /// The number of parameters to the outlined function.
+ unsigned NumParams;
+
+ /// The body of the outlined function.
+ llvm::PointerIntPair<Stmt *, 1, bool> BodyAndNothrow;
+
+ explicit OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams);
+
+ ImplicitParamDecl *const *getParams() const {
+ return getTrailingObjects<ImplicitParamDecl *>();
+ }
+
+ ImplicitParamDecl **getParams() {
+ return getTrailingObjects<ImplicitParamDecl *>();
+ }
+
+public:
+ friend class ASTDeclReader;
+ friend class ASTDeclWriter;
+ friend TrailingObjects;
+
+ static OutlinedFunctionDecl *Create(ASTContext &C, DeclContext *DC,
+ unsigned NumParams);
+ static OutlinedFunctionDecl *CreateDeserialized(ASTContext &C,
+ GlobalDeclID ID,
+ unsigned NumParams);
+
+ Stmt *getBody() const override;
+ void setBody(Stmt *B);
+
+ bool isNothrow() const;
+ void setNothrow(bool Nothrow = true);
+
+ unsigned getNumParams() const { return NumParams; }
+
+ ImplicitParamDecl *getParam(unsigned i) const {
+ assert(i < NumParams);
+ return getParams()[i];
+ }
+ void setParam(unsigned i, ImplicitParamDecl *P) {
+ assert(i < NumParams);
+ getParams()[i] = P;
+ }
+
+ // ArrayRef interface to parameters.
+ ArrayRef<ImplicitParamDecl *> parameters() const {
+ return {getParams(), getNumParams()};
+ }
+ MutableArrayRef<ImplicitParamDecl *> parameters() {
+ return {getParams(), getNumParams()};
+ }
+
+ using param_iterator = ImplicitParamDecl *const *;
+ using param_range = llvm::iterator_range<param_iterator>;
+
+ /// Retrieve an iterator pointing to the first parameter decl.
+ param_iterator param_begin() const { return getParams(); }
+ /// Retrieve an iterator one past the last parameter decl.
+ param_iterator param_end() const { return getParams() + NumParams; }
+
+ // Implement isa/cast/dyncast/etc.
+ static bool classof(const Decl *D) { return classofKind(D->getKind()); }
+ static bool classofKind(Kind K) { return K == OutlinedFunction; }
+ static DeclContext *castToDeclContext(const OutlinedFunctionDecl *D) {
+ return static_cast<DeclContext *>(const_cast<OutlinedFunctionDecl *>(D));
+ }
+ static OutlinedFunctionDecl *castFromDeclContext(const DeclContext *DC) {
+ return static_cast<OutlinedFunctionDecl *>(const_cast<DeclContext *>(DC));
+ }
+};
+
/// Represents the body of a CapturedStmt, and serves as its DeclContext.
class CapturedDecl final
: public Decl,
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index d500f4eadef757..c4a1d03f1b3d10 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -37,6 +37,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
@@ -1581,6 +1582,11 @@ DEF_TRAVERSE_DECL(BlockDecl, {
ShouldVisitChildren = false;
})
+DEF_TRAVERSE_DECL(OutlinedFunctionDecl, {
+ TRY_TO(TraverseStmt(D->getBody()));
+ ShouldVisitChildren = false;
+})
+
DEF_TRAVERSE_DECL(CapturedDecl, {
TRY_TO(TraverseStmt(D->getBody()));
ShouldVisitChildren = false;
@@ -2904,6 +2910,14 @@ DEF_TRAVERSE_STMT(SEHFinallyStmt, {})
DEF_TRAVERSE_STMT(SEHLeaveStmt, {})
DEF_TRAVERSE_STMT(CapturedStmt, { TRY_TO(TraverseDecl(S->getCapturedDecl())); })
+DEF_TRAVERSE_STMT(SYCLKernelCallStmt, {
+ if (getDerived().shouldVisitImplicitCode()) {
+ TRY_TO(TraverseStmt(S->getOriginalStmt()));
+ TRY_TO(TraverseDecl(S->getOutlinedFunctionDecl()));
+ ShouldVisitChildren = false;
+ }
+})
+
DEF_TRAVERSE_STMT(CXXOperatorCallExpr, {})
DEF_TRAVERSE_STMT(CXXRewrittenBinaryOperator, {
if (!getDerived().shouldVisitImplicitCode()) {
diff --git a/clang/include/clang/AST/StmtSYCL.h b/clang/include/clang/AST/StmtSYCL.h
new file mode 100644
index 00000000000000..ac356cb0bf3841
--- /dev/null
+++ b/clang/include/clang/AST/StmtSYCL.h
@@ -0,0 +1,95 @@
+//===- StmtSYCL.h - Classes for SYCL kernel calls ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file defines SYCL AST classes used to represent calls to SYCL kernels.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_AST_STMTSYCL_H
+#define LLVM_CLANG_AST_STMTSYCL_H
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/Stmt.h"
+#include "clang/Basic/SourceLocation.h"
+
+namespace clang {
+
+//===----------------------------------------------------------------------===//
+// AST classes for SYCL kernel calls.
+//===----------------------------------------------------------------------===//
+
+/// SYCLKernelCallStmt represents the transformation that is applied to the body
+/// of a function declared with the sycl_kernel_entry_point attribute. The body
+/// of such a function specifies the statements to be executed on a SYCL device
+/// to invoke a SYCL kernel with a particular set of kernel arguments. The
+/// SYCLKernelCallStmt associates an original statement (the compound statement
+/// that is the function body) with an OutlinedFunctionDecl that holds the
+/// kernel parameters and the transformed body. During code generation, the
+/// OutlinedFunctionDecl is used to emit an offload kernel entry point suitable
+/// for invocation from a SYCL library implementation. If executed, the
+/// SYCLKernelCallStmt behaves as a no-op; no code generation is performed for
+/// it.
+class SYCLKernelCallStmt : public Stmt {
+ friend class ASTStmtReader;
+ friend class ASTStmtWriter;
+
+private:
+ Stmt *OriginalStmt = nullptr;
+ OutlinedFunctionDecl *OFDecl = nullptr;
+
+public:
+ /// Construct a SYCL kernel call statement.
+ SYCLKernelCallStmt(Stmt *OS, OutlinedFunctionDecl *OFD)
+ : Stmt(SYCLKernelCallStmtClass), OriginalStmt(OS), OFDecl(OFD) {}
+
+ /// Construct an empty SYCL kernel call statement.
+ SYCLKernelCallStmt(EmptyShell Empty)
+ : Stmt(SYCLKernelCallStmtClass, Empty) {}
+
+ /// Retrieve the model statement.
+ Stmt *getOriginalStmt() { return OriginalStmt; }
+ const Stmt *getOriginalStmt() const { return OriginalStmt; }
+ void setOriginalStmt(Stmt *S) { OriginalStmt = S; }
+
+ /// Retrieve the outlined function declaration.
+ OutlinedFunctionDecl *getOutlinedFunctionDecl() { return OFDecl; }
+ const OutlinedFunctionDecl *getOutlinedFunctionDecl() const { return OFDecl; }
+
+ /// Set the outlined function declaration.
+ void setOutlinedFunctionDecl(OutlinedFunctionDecl *OFD) {
+ OFDecl = OFD;
+ }
+
+ SourceLocation getBeginLoc() const LLVM_READONLY {
+ return getOriginalStmt()->getBeginLoc();
+ }
+
+ SourceLocation getEndLoc() const LLVM_READONLY {
+ return getOriginalStmt()->getEndLoc();
+ }
+
+ SourceRange getSourceRange() const LLVM_READONLY {
+ return getOriginalStmt()->getSourceRange();
+ }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == SYCLKernelCallStmtClass;
+ }
+
+ child_range children() {
+ return child_range(&OriginalStmt, &OriginalStmt + 1);
+ }
+
+ const_child_range children() const {
+ return const_child_range(&OriginalStmt, &OriginalStmt + 1);
+ }
+};
+
+} // end namespace clang
+
+#endif
diff --git a/clang/include/clang/AST/StmtVisitor.h b/clang/include/clang/AST/StmtVisitor.h
index 990aa2df180d43..8b7b728deaff27 100644
--- a/clang/include/clang/AST/StmtVisitor.h
+++ b/clang/include/clang/AST/StmtVisitor.h
@@ -22,6 +22,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/Basic/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
diff --git a/clang/include/clang/Basic/DeclNodes.td b/clang/include/clang/Basic/DeclNodes.td
index 48396e85c5adac..723113dc2486e0 100644
--- a/clang/include/clang/Basic/DeclNodes.td
+++ b/clang/include/clang/Basic/DeclNodes.td
@@ -101,6 +101,7 @@ def Friend : DeclNode<Decl>;
def FriendTemplate : DeclNode<Decl>;
def StaticAssert : DeclNode<Decl>;
def Block : DeclNode<Decl, "blocks">, DeclContext;
+def OutlinedFunction : DeclNode<Decl>, DeclContext;
def Captured : DeclNode<Decl>, DeclContext;
def Import : DeclNode<Decl>;
def OMPThreadPrivate : DeclNode<Decl>;
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index ce2c48bd3c84e9..53fc77bbbcecc1 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -24,6 +24,7 @@ def SwitchCase : StmtNode<Stmt, 1>;
def CaseStmt : StmtNode<SwitchCase>;
def DefaultStmt : StmtNode<SwitchCase>;
def CapturedStmt : StmtNode<Stmt>;
+def SYCLKernelCallStmt : StmtNode<Stmt>;
// Statements that might produce a value (for example, as the last non-null
// statement in a GNU statement-expression).
diff --git a/clang/include/clang/Sema/SemaSYCL.h b/clang/include/clang/Sema/SemaSYCL.h
index 5bb0de40c886c7..b4f607d1287bc7 100644
--- a/clang/include/clang/Sema/SemaSYCL.h
+++ b/clang/include/clang/Sema/SemaSYCL.h
@@ -65,6 +65,7 @@ class SemaSYCL : public SemaBase {
void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL);
void CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD);
+ StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, Stmt *Body);
};
} // namespace clang
diff --git a/clang/include/clang/Sema/Template.h b/clang/include/clang/Sema/Template.h
index 9800f75f676aaf..4206bd50b13dd6 100644
--- a/clang/include/clang/Sema/Template.h
+++ b/clang/include/clang/Sema/Template.h
@@ -627,7 +627,10 @@ enum class TemplateSubstitutionKind : char {
#define EMPTY(DERIVED, BASE)
#define LIFETIMEEXTENDEDTEMPORARY(DERIVED, BASE)
- // Decls which use special-case instantiation code.
+// Decls which never appear inside a template.
+#define OUTLINEDFUNCTION(DERIVED, BASE)
+
+// Decls which use special-case instantiation code.
#define BLOCK(DERIVED, BASE)
#define CAPTURED(DERIVED, BASE)
#define IMPLICITPARAM(DERIVED, BASE)
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index aac165130b7192..87ac62551a1424 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1312,6 +1312,9 @@ enum DeclCode {
/// A BlockDecl record.
DECL_BLOCK,
+ /// A OutlinedFunctionDecl record.
+ DECL_OUTLINEDFUNCTION,
+
/// A CapturedDecl record.
DECL_CAPTURED,
@@ -1588,6 +1591,9 @@ enum StmtCode {
/// A CapturedStmt record.
STMT_CAPTURED,
+ /// A SYCLKernelCallStmt record.
+ STMT_SYCLKERNELCALL,
+
/// A GCC-style AsmStmt record.
STMT_GCCASM,
diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index 308551c3061510..eaf0748395268b 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -76,6 +76,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp
index 97e23dd1aaa920..5bce2c37bf0584 100644
--- a/clang/lib/AST/Decl.cpp
+++ b/clang/lib/AST/Decl.cpp
@@ -5440,6 +5440,31 @@ BlockDecl *BlockDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) {
return new (C, ID) BlockDecl(nullptr, SourceLocation());
}
+
+OutlinedFunctionDecl::OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams)
+ : Decl(OutlinedFunction, DC, SourceLocation()), DeclContext(OutlinedFunction),
+ NumParams(NumParams), BodyAndNothrow(nullptr, false) {}
+
+OutlinedFunctionDecl *OutlinedFunctionDecl::Create(ASTContext &C, DeclContext *DC,
+ unsigned NumParams) {
+ return new (C, DC, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
+ OutlinedFunctionDecl(DC, NumParams);
+}
+
+OutlinedFunctionDecl *OutlinedFunctionDecl::CreateDeserialized(ASTContext &C,
+ GlobalDeclID ID,
+ unsigned NumParams) {
+ return new (C, ID, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
+ OutlinedFunctionDecl(nullptr, NumParams);
+}
+
+Stmt *OutlinedFunctionDecl::getBody() const { return BodyAndNothrow.getPointer(); }
+void OutlinedFunctionDecl::setBody(Stmt *B) { BodyAndNothrow.setPointer(B); }
+
+bool OutlinedFunctionDecl::isNothrow() const { return BodyAndNothrow.getInt(); }
+void OutlinedFunctionDecl::setNothrow(bool Nothrow) { BodyAndNothrow.setInt(Nothrow); }
+
+
CapturedDecl::CapturedDecl(DeclContext *DC, unsigned NumParams)
: Decl(Captured, DC, SourceLocation()), DeclContext(Captured),
NumParams(NumParams), ContextParam(0), BodyAndNothrow(nullptr, false) {}
diff --git a/clang/lib/AST/DeclBase.cpp b/clang/lib/AST/DeclBase.cpp
index fb701f76231bcd..77ca8c5c8accd0 100644
--- a/clang/lib/AST/DeclBase.cpp
+++ b/clang/lib/AST/DeclBase.cpp
@@ -958,6 +958,7 @@ unsigned Decl::getIdentifierNamespaceForKind(Kind DeclKind) {
case PragmaDetectMismatch:
case Block:
case Captured:
+ case OutlinedFunction:
case TranslationUnit:
case ExternCContext:
case Decomposition:
@@ -1237,6 +1238,8 @@ template <class T> static Decl *getNonClosureContext(T *D) {
return getNonClosureContext(BD->getParent());
if (auto *CD = dyn_cast<CapturedDecl>(D))
return getNonClosureContext(CD->getParent());
+ if (auto *OFD = dyn_cast<OutlinedFunctionDecl>(D))
+ return getNonClosureContext(OFD->getParent());
return nullptr;
}
@@ -1429,6 +1432,7 @@ DeclContext *DeclContext::getPrimaryContext() {
case Decl::TopLevelStmt:
case Decl::Block:
case Decl::Captured:
+ case Decl::OutlinedFunction:
case Decl::OMPDeclareReduction:
case Decl::OMPDeclareMapper:
case Decl::RequiresExprBody:
diff --git a/clang/lib/AST/Stmt.cpp b/clang/lib/AST/Stmt.cpp
index d6a351a78c7ba8..685c00d0cb44f8 100644
--- a/clang/lib/AST/Stmt.cpp
+++ b/clang/lib/AST/Stmt.cpp
@@ -25,6 +25,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/AST/Type.h"
#include "clang/Basic/CharInfo.h"
#include "clang/Basic/LLVM.h"
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 52bcb5135d3513..b5def6fbe525c3 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -30,6 +30,7 @@
#include "clang/AST/StmtCXX.h"
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenMP.h"
+#include "clang/AST/StmtSYCL.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/Type.h"
@@ -582,6 +583,10 @@ void StmtPrinter::VisitCapturedStmt(CapturedStmt *Node) {
PrintStmt(Node->getCapturedDecl()->getBody());
}
+void StmtPrinter::VisitSYCLKernelCallStmt(SYCLKernelCallStmt *Node) {
+ PrintStmt(Node->getOutlinedFunctionDecl()->getBody());
+}
+
void StmtPrinter::VisitObjCAtTryStmt(ObjCAtTryStmt *Node) {
Indent() << "@try";
if (auto *TS = dyn_cast<CompoundStmt>(Node->getTryBody())) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 0f1ebc68a4f762..85b59f714ba845 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -392,6 +392,10 @@ void StmtProfiler::VisitCapturedStmt(const CapturedStmt *S) {
VisitStmt(S);
}
+void StmtProfiler::VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *S) {
+ VisitStmt(S);
+}
+
void StmtProfiler::VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
VisitStmt(S);
}
diff --git a/clang/lib/CodeGen/CGDecl.cpp b/clang/lib/CodeGen/CGDecl.cpp
index 6f3ff050cb6978..cda1a15e92f622 100644
--- a/clang/lib/CodeGen/CGDecl.cpp
+++ b/clang/lib/CodeGen/CGDecl.cpp
@@ -97,6 +97,7 @@ void CodeGenFunction::EmitDecl(const Decl &D) {
case Decl::Friend:
case Decl::FriendTemplate:
case Decl::Block:
+ case Decl::OutlinedFunction:
case Decl::Captured:
case Decl::UsingShadow:
case Decl::ConstructorUsingShadow:
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index c87ec899798aa2..9d8baa85ba5923 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -114,6 +114,7 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
case Stmt::DefaultStmtClass:
case Stmt::CaseStmtClass:
case Stmt::SEHLeaveStmtClass:
+ case Stmt::SYCLKernelCallStmtClass:
llvm_unreachable("should have emitted these statements as simple");
#define STMT(Type, Base)
@@ -527,6 +528,23 @@ bool CodeGenFunction::EmitSimpleStmt(const Stmt *S,
case Stmt::SEHLeaveStmtClass:
EmitSEHLeaveStmt(cast<SEHLeaveStmt>(*S));
break;
+ case Stmt::SYCLKernelCallStmtClass:
+ // SYCL kernel call statements are generated as wrappers around the body
+ // of functions declared with the sycl_kernel_entry_point attribute. Such
+ // functions are used to specify how a SYCL kernel (a function object) is
+ // to be invoked; the SYCL kernel call statement contains a transformed
+ // variation of the function body and...
[truncated]
|
You can test this locally with the following command:git-clang-format --diff 3630d9ef65b30af7e4ca78e668649bbc48b5be66 48b62dcd5acc8873a57886d98c6d2ce724dfd218 --extensions h,cpp -- clang/include/clang/AST/StmtSYCL.h clang/test/ASTSYCL/ast-dump-sycl-kernel-call-stmt.cpp clang/include/clang/AST/ASTNodeTraverser.h clang/include/clang/AST/Decl.h clang/include/clang/AST/RecursiveASTVisitor.h clang/include/clang/AST/StmtVisitor.h clang/include/clang/Sema/SemaSYCL.h clang/include/clang/Sema/Template.h clang/include/clang/Serialization/ASTBitCodes.h clang/lib/AST/ASTStructuralEquivalence.cpp clang/lib/AST/Decl.cpp clang/lib/AST/DeclBase.cpp clang/lib/AST/Stmt.cpp clang/lib/AST/StmtPrinter.cpp clang/lib/AST/StmtProfile.cpp clang/lib/CodeGen/CGDecl.cpp clang/lib/CodeGen/CGStmt.cpp clang/lib/CodeGen/CodeGenFunction.h clang/lib/Sema/SemaDecl.cpp clang/lib/Sema/SemaExceptionSpec.cpp clang/lib/Sema/SemaSYCL.cpp clang/lib/Sema/TreeTransform.h clang/lib/Serialization/ASTCommon.cpp clang/lib/Serialization/ASTReaderDecl.cpp clang/lib/Serialization/ASTReaderStmt.cpp clang/lib/Serialization/ASTWriterDecl.cpp clang/lib/Serialization/ASTWriterStmt.cpp clang/lib/StaticAnalyzer/Core/ExprEngine.cpp clang/test/ASTSYCL/ast-dump-sycl-kernel-entry-point.cpp clang/test/SemaSYCL/sycl-kernel-entry-point-attr-appertainment.cpp clang/tools/libclang/CIndex.cpp clang/tools/libclang/CXCursor.cpp View the diff from clang-format here.diff --git a/clang/include/clang/Sema/Template.h b/clang/include/clang/Sema/Template.h
index 4206bd50b1..070dba8a50 100644
--- a/clang/include/clang/Sema/Template.h
+++ b/clang/include/clang/Sema/Template.h
@@ -627,10 +627,10 @@ enum class TemplateSubstitutionKind : char {
#define EMPTY(DERIVED, BASE)
#define LIFETIMEEXTENDEDTEMPORARY(DERIVED, BASE)
-// Decls which never appear inside a template.
+ // Decls which never appear inside a template.
#define OUTLINEDFUNCTION(DERIVED, BASE)
-// Decls which use special-case instantiation code.
+ // Decls which use special-case instantiation code.
#define BLOCK(DERIVED, BASE)
#define CAPTURED(DERIVED, BASE)
#define IMPLICITPARAM(DERIVED, BASE)
|
199cbd1
to
20e04ac
Compare
@erichkeane, thank you for the prior code review. I believe I've addressed all the concerns you raised. There are two conversations I left unresolved pending any further response from you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only saw 1 unresolved, not 2? But I resolved it. Can you point out the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 comments are all 'nits', so once those conversations are resolved, this LGTM. Please resolve those 3 before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only saw 1 unresolved, not 2? But I resolved it. Can you point out the other?
The other one was this one; You gave it a thumbs up, so I went ahead and resolved it now.
A SYCL kernel entry point function is a non-member function or a static member function declared with the `sycl_kernel_entry_point` attribute. Such functions define a pattern for an offload kernel entry point function to be generated to enable execution of a SYCL kernel on a device. A SYCL library implementation orchestrates the invocation of these functions with corresponding SYCL kernel arguments in response to calls to SYCL kernel invocation functions specified by the SYCL 2020 specification. The offload kernel entry point function (sometimes referred to as the SYCL kernel caller function) is generated from the SYCL kernel entry point function by a transformation of the function parameters followed by a transformation of the function body to replace references to the original parameters with references to the transformed ones. Exactly how parameters are transformed will be explained in a future change that implements non-trivial transformations. For now, it suffices to state that a given parameter of the SYCL kernel entry point function may be transformed to multiple parameters of the offload kernel entry point as needed to satisfy offload kernel argument passing requirements. Parameters that are decomposed in this way are reconstituted as local variables in the body of the generated offload kernel entry point function. For example, given the following SYCL kernel entry point function definition: template<typename KernelNameType, typename KernelType> [[clang::sycl_kernel_entry_point(KernelNameType)]] void sycl_kernel_entry_point(KernelType kernel) { kernel(); } and the following call: struct Kernel { int dm1; int dm2; void operator()() const; }; Kernel k; sycl_kernel_entry_point<class kernel_name>(k); the corresponding offload kernel entry point function that is generated might look as follows (assuming `Kernel` is a type that requires decomposition): void offload_kernel_entry_point_for_kernel_name(int dm1, int dm2) { Kernel kernel{dm1, dm2}; kernel(); } Other details of the generated offload kernel entry point function, such as its name and callng convention, are implementation details that need not be reflected in the AST and may differ across target devices. For that reason, only the transformation described above is represented in the AST; other details will be filled in during code generation. These transformations are represented using new AST nodes introduced with this change. `OutlinedFunctionDecl` holds a sequence of `ImplicitParamDecl` nodes and a sequence of statement nodes that correspond to the transformed parameters and function body. `SYCLKernelCallStmt` wraps the original function body and associates it with an `OutlinedFunctionDecl` instance. For the example above, the AST generated for the `sycl_kernel_entry_point<kernel_name>` specialization would look as follows: FunctionDecl 'sycl_kernel_entry_point<kernel_name>(Kernel)' TemplateArgument type 'kernel_name' TemplateArgument type 'Kernel' ParmVarDecl kernel 'Kernel' SYCLKernelCallStmt CompoundStmt <original statements> OutlinedFunctionDecl ImplicitParamDecl 'dm1' 'int' ImplicitParamDecl 'dm2' 'int' CompoundStmt VarDecl 'kernel' 'Kernel' <initialization of 'kernel' with 'dm1' and 'dm2'> <transformed statements with redirected references of 'kernel'> Any ODR-use of the SYCL kernel entry point function will (with future changes) suffice for the offload kernel entry point to be emitted. An actual call to the SYCL kernel entry point function will result in a call to the function. However, evaluation of a `SYCLKernelCallStmt` statement is a no-op, so such calls will have no effect other than to trigger emission of the offload kernel entry point.
…with a function-try-block. The SYCL 2020 specification prohibits the use of C++ exceptions in device functions. Even if exceptions were not prohibited, it is unclear what the semantics would be for an exception that escapes the SYCL kernel entry point function; the boundary between host and device code could be an implicit noexcept boundary that results in program termination if violated, or the exception could perhaps be propagated to host code via the SYCL library. The semantics of function-try-blocks is that an exception raised in the function body that escapes to a function-try-block handler is implicitly rethrown. Pending support for C++ exceptions in device code and clear semantics for handling them at the host-device boundary, this change makes use of the sycl_kernel_entry_point attribute with a function defined with a function-try-block an error.
cbf49f7
to
dc93ffa
Compare
No idea why I can't respond to the individual threads but...
It kinda does? Clang-format is supposed to be adding those and coding standard is that we're supposed to be clang-format clean. That said, I've got no idea why it didn't try to add it here, I've definitely seen it do so before. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make format happy, then LGTM.
@erichkeane, in order to reduce notifications, I replied to your comments as a review. You should generally be able to reply to them individually or as part of a batched review from the "Files changed" tab, but GitHub doesn't permit such replies from the "Conversation" tab (which is weird and seems dumb). However, if the comments are considered "outdated" because of new commits that changed the code the thread was anchored too, then such comments appear to become orphaned; they still appear in the "Conversation" tab, but not in the "Filed changed" tab and therefore can't be replied to. GitHub is a less than perfect tool, IMHO of course.
Thank you. I addressed the non-dumb clang-format complaints. I'll merge once @bader is happy with the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tahonermann thanks a lot for the work. I really appreciate the detailed description and comments!
@erichkeane did a great job with the code review, so I have nothing to add.
I just have one small question that does not affect the merge.
// for SYCL kernel parameters or local variables that reconstitute a decomposed | ||
// SYCL kernel argument. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just want to confirm that we are not limited to adding new transforms (even ones that are not currently used in our downstream repository).
We may need new transforms to implement more efficient host <-> device SYCL kernel argument data transfer. We should be able to "compound" SYCL kernel arguments (i.e. the opposite of decomposing) if we find that more efficient. Right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bader, we don't have an upstream design/implementation yet for handling decomposition and reconstitution of kernel arguments. My expectation is that we'll do something similar to what DPC++ does; there will be default rules for types generally with support for some types to be tagged as special SYCL types (types that have to be passed as independent kernel arguments) that force decomposition for some types that otherwise wouldn't require it. We'll likewise need to add support for types that have custom decomposition behavior similar to what DPC++ provides. I'm not aware of prior discussion regarding implicitly aggregating multiple kernel arguments into a single bundle argument. I'm curious how often doing so would actually be advantageous.
I am uneasy with the current DPC++ implementation since it requires the SYCL library and the front end to separately implement the decomposition rules. That seems like a recipe for bugs and means the library and front end have to be changed in lock step. I've been thinking about ways we can have the front end expose an interface to the library to perform or guide the decomposition so that the library can remain mostly agnostic of such details, but I don't have a design to share yet. The more we can isolate the library from such details, the more options we have for making changes.
A SYCL kernel entry point function is a non-member function or a static member function declared with the
sycl_kernel_entry_point
attribute. Such functions define a pattern for an offload kernel entry point function to be generated to enable execution of a SYCL kernel on a device. A SYCL library implementation orchestrates the invocation of these functions with corresponding SYCL kernel arguments in response to calls to SYCL kernel invocation functions specified by the SYCL 2020 specification.The offload kernel entry point function (sometimes referred to as the SYCL kernel caller function) is generated from the SYCL kernel entry point function by a transformation of the function parameters followed by a transformation of the function body to replace references to the original parameters with references to the transformed ones. Exactly how parameters are transformed will be explained in a future change that implements non-trivial transformations. For now, it suffices to state that a given parameter of the SYCL kernel entry point function may be transformed to multiple parameters of the offload kernel entry point as needed to satisfy offload kernel argument passing requirements. Parameters that are decomposed in this way are reconstituted as local variables in the body of the generated offload kernel entry point function.
For example, given the following SYCL kernel entry point function definition:
and the following call:
the corresponding offload kernel entry point function that is generated might look as follows (assuming
Kernel
is a type that requires decomposition):Other details of the generated offload kernel entry point function, such as its name and calling convention, are implementation details that need not be reflected in the AST and may differ across target devices. For that reason, only the transformation described above is represented in the AST; other details will be filled in during code generation.
These transformations are represented using new AST nodes introduced with this change.
OutlinedFunctionDecl
holds a sequence ofImplicitParamDecl
nodes and a sequence of statement nodes that correspond to the transformed parameters and function body.SYCLKernelCallStmt
wraps the original function body and associates it with anOutlinedFunctionDecl
instance. For the example above, the AST generated for thesycl_kernel_entry_point<kernel_name>
specialization would look as follows:Any ODR-use of the SYCL kernel entry point function will (with future changes) suffice for the offload kernel entry point to be emitted. An actual call to the SYCL kernel entry point function will result in a call to the function. However, evaluation of a
SYCLKernelCallStmt
statement is a no-op, so such calls will have no effect other than to trigger emission of the offload kernel entry point.Additionally, as a related change inspired by code review feedback, these changes disallow use of the
sycl_kernel_entry_point
attribute with functions defined with a function-try-block. The SYCL 2020 specification prohibits the use of C++ exceptions in device functions. Even if exceptions were not prohibited, it is unclear what the semantics would be for an exception that escapes the SYCL kernel entry point function; the boundary between host and device code could be an implicit noexcept boundary that results in program termination if violated, or the exception could perhaps be propagated to host code via the SYCL library. Pending support for C++ exceptions in device code and clear semantics for handling them at the host-device boundary, this change makes use of thesycl_kernel_entry_point
attribute with a function defined with a function-try-block an error.