Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] AST support for SYCL kernel entry point functions. #122379

Merged
merged 6 commits into from
Jan 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions clang/include/clang/AST/ASTNodeTraverser.h
Original file line number Diff line number Diff line change
@@ -158,8 +158,8 @@ class ASTNodeTraverser
ConstStmtVisitor<Derived>::Visit(S);

// Some statements have custom mechanisms for dumping their children.
if (isa<DeclStmt>(S) || isa<GenericSelectionExpr>(S) ||
isa<RequiresExpr>(S) || isa<OpenACCWaitConstruct>(S))
if (isa<DeclStmt, GenericSelectionExpr, RequiresExpr,
OpenACCWaitConstruct, SYCLKernelCallStmt>(S))
return;

if (Traversal == TK_IgnoreUnlessSpelledInSource &&
@@ -585,6 +585,12 @@ class ASTNodeTraverser

void VisitTopLevelStmtDecl(const TopLevelStmtDecl *D) { Visit(D->getStmt()); }

void VisitOutlinedFunctionDecl(const OutlinedFunctionDecl *D) {
for (const ImplicitParamDecl *Parameter : D->parameters())
Visit(Parameter);
Visit(D->getBody());
}

void VisitCapturedDecl(const CapturedDecl *D) { Visit(D->getBody()); }

void VisitOMPThreadPrivateDecl(const OMPThreadPrivateDecl *D) {
@@ -815,6 +821,12 @@ class ASTNodeTraverser
Visit(Node->getCapturedDecl());
}

void VisitSYCLKernelCallStmt(const SYCLKernelCallStmt *Node) {
Visit(Node->getOriginalStmt());
if (Traversal != TK_IgnoreUnlessSpelledInSource)
Visit(Node->getOutlinedFunctionDecl());
}

void VisitOMPExecutableDirective(const OMPExecutableDirective *Node) {
for (const auto *C : Node->clauses())
Visit(C);
77 changes: 77 additions & 0 deletions clang/include/clang/AST/Decl.h
Original file line number Diff line number Diff line change
@@ -4688,6 +4688,83 @@ class BlockDecl : public Decl, public DeclContext {
}
};

/// Represents a partial function definition.
///
/// An outlined function declaration contains the parameters and body of
/// a function independent of other function definition concerns such
/// as function name, type, and calling convention. Such declarations may
/// be used to hold a parameterized and transformed sequence of statements
/// used to generate a target dependent function definition without losing
/// association with the original statements. See SYCLKernelCallStmt as an
/// example.
class OutlinedFunctionDecl final
: public Decl,
public DeclContext,
private llvm::TrailingObjects<OutlinedFunctionDecl, ImplicitParamDecl *> {
private:
/// The number of parameters to the outlined function.
unsigned NumParams;

/// The body of the outlined function.
llvm::PointerIntPair<Stmt *, 1, bool> BodyAndNothrow;

explicit OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams);

ImplicitParamDecl *const *getParams() const {
return getTrailingObjects<ImplicitParamDecl *>();
}

ImplicitParamDecl **getParams() {
return getTrailingObjects<ImplicitParamDecl *>();
}

public:
friend class ASTDeclReader;
friend class ASTDeclWriter;
friend TrailingObjects;

static OutlinedFunctionDecl *Create(ASTContext &C, DeclContext *DC,
unsigned NumParams);
static OutlinedFunctionDecl *
CreateDeserialized(ASTContext &C, GlobalDeclID ID, unsigned NumParams);

Stmt *getBody() const override;
void setBody(Stmt *B);

bool isNothrow() const;
void setNothrow(bool Nothrow = true);

unsigned getNumParams() const { return NumParams; }

ImplicitParamDecl *getParam(unsigned i) const {
assert(i < NumParams);
return getParams()[i];
}
void setParam(unsigned i, ImplicitParamDecl *P) {
assert(i < NumParams);
getParams()[i] = P;
}

// Range interface to parameters.
using parameter_const_iterator = const ImplicitParamDecl *const *;
using parameter_const_range = llvm::iterator_range<parameter_const_iterator>;
parameter_const_range parameters() const {
return {param_begin(), param_end()};
}
parameter_const_iterator param_begin() const { return getParams(); }
parameter_const_iterator param_end() const { return getParams() + NumParams; }

// Implement isa/cast/dyncast/etc.
static bool classof(const Decl *D) { return classofKind(D->getKind()); }
static bool classofKind(Kind K) { return K == OutlinedFunction; }
static DeclContext *castToDeclContext(const OutlinedFunctionDecl *D) {
return static_cast<DeclContext *>(const_cast<OutlinedFunctionDecl *>(D));
}
static OutlinedFunctionDecl *castFromDeclContext(const DeclContext *DC) {
return static_cast<OutlinedFunctionDecl *>(const_cast<DeclContext *>(DC));
}
};

/// Represents the body of a CapturedStmt, and serves as its DeclContext.
class CapturedDecl final
: public Decl,
14 changes: 14 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
@@ -1581,6 +1582,11 @@ DEF_TRAVERSE_DECL(BlockDecl, {
ShouldVisitChildren = false;
})

DEF_TRAVERSE_DECL(OutlinedFunctionDecl, {
TRY_TO(TraverseStmt(D->getBody()));
ShouldVisitChildren = false;
})

DEF_TRAVERSE_DECL(CapturedDecl, {
TRY_TO(TraverseStmt(D->getBody()));
ShouldVisitChildren = false;
@@ -2904,6 +2910,14 @@ DEF_TRAVERSE_STMT(SEHFinallyStmt, {})
DEF_TRAVERSE_STMT(SEHLeaveStmt, {})
DEF_TRAVERSE_STMT(CapturedStmt, { TRY_TO(TraverseDecl(S->getCapturedDecl())); })

DEF_TRAVERSE_STMT(SYCLKernelCallStmt, {
if (getDerived().shouldVisitImplicitCode()) {
TRY_TO(TraverseStmt(S->getOriginalStmt()));
TRY_TO(TraverseDecl(S->getOutlinedFunctionDecl()));
ShouldVisitChildren = false;
}
})

DEF_TRAVERSE_STMT(CXXOperatorCallExpr, {})
DEF_TRAVERSE_STMT(CXXRewrittenBinaryOperator, {
if (!getDerived().shouldVisitImplicitCode()) {
94 changes: 94 additions & 0 deletions clang/include/clang/AST/StmtSYCL.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
//===- StmtSYCL.h - Classes for SYCL kernel calls ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
/// \file
/// This file defines SYCL AST classes used to represent calls to SYCL kernels.
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_AST_STMTSYCL_H
#define LLVM_CLANG_AST_STMTSYCL_H

#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Stmt.h"
#include "clang/Basic/SourceLocation.h"

namespace clang {

//===----------------------------------------------------------------------===//
// AST classes for SYCL kernel calls.
//===----------------------------------------------------------------------===//

/// SYCLKernelCallStmt represents the transformation that is applied to the body
/// of a function declared with the sycl_kernel_entry_point attribute. The body
/// of such a function specifies the statements to be executed on a SYCL device
/// to invoke a SYCL kernel with a particular set of kernel arguments. The
/// SYCLKernelCallStmt associates an original statement (the compound statement
/// that is the function body) with an OutlinedFunctionDecl that holds the
/// kernel parameters and the transformed body. During code generation, the
/// OutlinedFunctionDecl is used to emit an offload kernel entry point suitable
/// for invocation from a SYCL library implementation. If executed, the
/// SYCLKernelCallStmt behaves as a no-op; no code generation is performed for
/// it.
class SYCLKernelCallStmt : public Stmt {
friend class ASTStmtReader;
friend class ASTStmtWriter;

private:
Stmt *OriginalStmt = nullptr;
OutlinedFunctionDecl *OFDecl = nullptr;

public:
/// Construct a SYCL kernel call statement.
SYCLKernelCallStmt(CompoundStmt *CS, OutlinedFunctionDecl *OFD)
: Stmt(SYCLKernelCallStmtClass), OriginalStmt(CS), OFDecl(OFD) {}

/// Construct an empty SYCL kernel call statement.
SYCLKernelCallStmt(EmptyShell Empty) : Stmt(SYCLKernelCallStmtClass, Empty) {}

/// Retrieve the model statement.
CompoundStmt *getOriginalStmt() { return cast<CompoundStmt>(OriginalStmt); }
const CompoundStmt *getOriginalStmt() const {
return cast<CompoundStmt>(OriginalStmt);
}
void setOriginalStmt(CompoundStmt *CS) { OriginalStmt = CS; }

/// Retrieve the outlined function declaration.
OutlinedFunctionDecl *getOutlinedFunctionDecl() { return OFDecl; }
const OutlinedFunctionDecl *getOutlinedFunctionDecl() const { return OFDecl; }

/// Set the outlined function declaration.
void setOutlinedFunctionDecl(OutlinedFunctionDecl *OFD) { OFDecl = OFD; }

SourceLocation getBeginLoc() const LLVM_READONLY {
return getOriginalStmt()->getBeginLoc();
}

SourceLocation getEndLoc() const LLVM_READONLY {
return getOriginalStmt()->getEndLoc();
}

SourceRange getSourceRange() const LLVM_READONLY {
return getOriginalStmt()->getSourceRange();
}

static bool classof(const Stmt *T) {
return T->getStmtClass() == SYCLKernelCallStmtClass;
}

child_range children() {
return child_range(&OriginalStmt, &OriginalStmt + 1);
}

const_child_range children() const {
return const_child_range(&OriginalStmt, &OriginalStmt + 1);
}
};

} // end namespace clang

#endif // LLVM_CLANG_AST_STMTSYCL_H
1 change: 1 addition & 0 deletions clang/include/clang/AST/StmtVisitor.h
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/AST/StmtSYCL.h"
#include "clang/Basic/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
1 change: 1 addition & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
@@ -487,6 +487,7 @@ following requirements.
* Is not a C variadic function.
* Is not a coroutine.
* Is not defined as deleted or as defaulted.
* Is not defined with a function try block.
* Is not declared with the ``constexpr`` or ``consteval`` specifiers.
* Is not declared with the ``[[noreturn]]`` attribute.

1 change: 1 addition & 0 deletions clang/include/clang/Basic/DeclNodes.td
Original file line number Diff line number Diff line change
@@ -101,6 +101,7 @@ def Friend : DeclNode<Decl>;
def FriendTemplate : DeclNode<Decl>;
def StaticAssert : DeclNode<Decl>;
def Block : DeclNode<Decl, "blocks">, DeclContext;
def OutlinedFunction : DeclNode<Decl>, DeclContext;
def Captured : DeclNode<Decl>, DeclContext;
def Import : DeclNode<Decl>;
def OMPThreadPrivate : DeclNode<Decl>;
3 changes: 2 additions & 1 deletion clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
@@ -12457,7 +12457,8 @@ def err_sycl_entry_point_invalid : Error<
"'sycl_kernel_entry_point' attribute cannot be applied to a"
" %select{non-static member function|variadic function|deleted function|"
"defaulted function|constexpr function|consteval function|"
"function declared with the 'noreturn' attribute|coroutine}0">;
"function declared with the 'noreturn' attribute|coroutine|"
"function defined with a function try block}0">;
def err_sycl_entry_point_invalid_redeclaration : Error<
"'sycl_kernel_entry_point' kernel name argument does not match prior"
" declaration%diff{: $ vs $|}0,1">;
1 change: 1 addition & 0 deletions clang/include/clang/Basic/StmtNodes.td
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ def SwitchCase : StmtNode<Stmt, 1>;
def CaseStmt : StmtNode<SwitchCase>;
def DefaultStmt : StmtNode<SwitchCase>;
def CapturedStmt : StmtNode<Stmt>;
def SYCLKernelCallStmt : StmtNode<Stmt>;

// Statements that might produce a value (for example, as the last non-null
// statement in a GNU statement-expression).
1 change: 1 addition & 0 deletions clang/include/clang/Sema/SemaSYCL.h
Original file line number Diff line number Diff line change
@@ -65,6 +65,7 @@ class SemaSYCL : public SemaBase {
void handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL);

void CheckSYCLEntryPointFunctionDecl(FunctionDecl *FD);
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body);
};

} // namespace clang
5 changes: 4 additions & 1 deletion clang/include/clang/Sema/Template.h
Original file line number Diff line number Diff line change
@@ -627,7 +627,10 @@ enum class TemplateSubstitutionKind : char {
#define EMPTY(DERIVED, BASE)
#define LIFETIMEEXTENDEDTEMPORARY(DERIVED, BASE)

// Decls which use special-case instantiation code.
// Decls which never appear inside a template.
#define OUTLINEDFUNCTION(DERIVED, BASE)

// Decls which use special-case instantiation code.
#define BLOCK(DERIVED, BASE)
#define CAPTURED(DERIVED, BASE)
#define IMPLICITPARAM(DERIVED, BASE)
6 changes: 6 additions & 0 deletions clang/include/clang/Serialization/ASTBitCodes.h
Original file line number Diff line number Diff line change
@@ -1316,6 +1316,9 @@ enum DeclCode {
/// A BlockDecl record.
DECL_BLOCK,

/// A OutlinedFunctionDecl record.
DECL_OUTLINEDFUNCTION,

/// A CapturedDecl record.
DECL_CAPTURED,

@@ -1600,6 +1603,9 @@ enum StmtCode {
/// A CapturedStmt record.
STMT_CAPTURED,

/// A SYCLKernelCallStmt record.
STMT_SYCLKERNELCALL,

/// A GCC-style AsmStmt record.
STMT_GCCASM,

1 change: 1 addition & 0 deletions clang/lib/AST/ASTStructuralEquivalence.cpp
Original file line number Diff line number Diff line change
@@ -76,6 +76,7 @@
#include "clang/AST/StmtObjC.h"
#include "clang/AST/StmtOpenACC.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/AST/StmtSYCL.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TemplateName.h"
#include "clang/AST/Type.h"
29 changes: 29 additions & 0 deletions clang/lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
@@ -5448,6 +5448,35 @@ BlockDecl *BlockDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) {
return new (C, ID) BlockDecl(nullptr, SourceLocation());
}

OutlinedFunctionDecl::OutlinedFunctionDecl(DeclContext *DC, unsigned NumParams)
: Decl(OutlinedFunction, DC, SourceLocation()),
DeclContext(OutlinedFunction), NumParams(NumParams),
BodyAndNothrow(nullptr, false) {}

OutlinedFunctionDecl *OutlinedFunctionDecl::Create(ASTContext &C,
DeclContext *DC,
unsigned NumParams) {
return new (C, DC, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
OutlinedFunctionDecl(DC, NumParams);
}

OutlinedFunctionDecl *
OutlinedFunctionDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID,
unsigned NumParams) {
return new (C, ID, additionalSizeToAlloc<ImplicitParamDecl *>(NumParams))
OutlinedFunctionDecl(nullptr, NumParams);
}

Stmt *OutlinedFunctionDecl::getBody() const {
return BodyAndNothrow.getPointer();
}
void OutlinedFunctionDecl::setBody(Stmt *B) { BodyAndNothrow.setPointer(B); }

bool OutlinedFunctionDecl::isNothrow() const { return BodyAndNothrow.getInt(); }
void OutlinedFunctionDecl::setNothrow(bool Nothrow) {
BodyAndNothrow.setInt(Nothrow);
}

CapturedDecl::CapturedDecl(DeclContext *DC, unsigned NumParams)
: Decl(Captured, DC, SourceLocation()), DeclContext(Captured),
NumParams(NumParams), ContextParam(0), BodyAndNothrow(nullptr, false) {}
Loading