Skip to content

Commit a93e59d

Browse files
authored
[SYCL][Fusion] Kernel Fusion support for CUDA backend (#8747)
Extend kernel fusion for the CUDA backend. In contrast to the existing SPIR-V based backends, the default binary format for the CUDA backend (PTX or CUBIN) is not suitable as input for the kernel fusion JIT compiler. This PR therefore extends the driver to **additionally** embed LLVM IR in the fat binary if the user specifies the `-fsycl-embed-ir` during compilation, by taking the output of the `sycl-post-link` step for the CUDA backend. The JIT compiler has been extended to handle LLVM IR as input format and PTX assembly as output format (including translation via the NVPTX backend). Target-specific parts of the fusion process have been refactored to `TargetFusionInformation`. The connecting logic to the JIT compiler in the SYCL RT has been extended to produce valid PI device binaries for the CUDA backend/PI. Heterogeneous ND ranges are not yet supported for the CUDA backend. --------- Signed-off-by: Lukas Sommer <[email protected]>
1 parent 6953c46 commit a93e59d

File tree

83 files changed

+1805
-712
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+1805
-712
lines changed

clang/include/clang/Driver/Action.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -660,9 +660,14 @@ class OffloadUnbundlingJobAction final : public JobAction {
660660
class OffloadWrapperJobAction : public JobAction {
661661
void anchor() override;
662662

663+
bool EmbedIR;
664+
663665
public:
664666
OffloadWrapperJobAction(ActionList &Inputs, types::ID Type);
665-
OffloadWrapperJobAction(Action *Input, types::ID OutputType);
667+
OffloadWrapperJobAction(Action *Input, types::ID OutputType,
668+
bool EmbedIR = false);
669+
670+
bool isEmbeddedIR() const { return EmbedIR; }
666671

667672
static bool classof(const Action *A) {
668673
return A->getKind() == OffloadWrapperJobClass;

clang/include/clang/Driver/Options.td

+2
Original file line numberDiff line numberDiff line change
@@ -2973,6 +2973,8 @@ def fintelfpga : Flag<["-"], "fintelfpga">, Group<f_Group>,
29732973
HelpText<"Perform ahead-of-time compilation for FPGA">;
29742974
def fsycl_device_only : Flag<["-"], "fsycl-device-only">, Flags<[CoreOption]>,
29752975
HelpText<"Compile SYCL kernels for device">;
2976+
def fsycl_embed_ir : Flag<["-"], "fsycl-embed-ir">, Flags<[CoreOption]>,
2977+
HelpText<"Embed LLVM IR for runtime kernel fusion">;
29762978
defm sycl_esimd_force_stateless_mem : BoolFOption<"sycl-esimd-force-stateless-mem",
29772979
LangOpts<"SYCLESIMDForceStatelessMem">, DefaultFalse,
29782980
PosFlag<SetTrue, [], "Enforce using stateless memory accesses. "

clang/lib/Driver/Action.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -478,11 +478,11 @@ void OffloadWrapperJobAction::anchor() {}
478478

479479
OffloadWrapperJobAction::OffloadWrapperJobAction(ActionList &Inputs,
480480
types::ID Type)
481-
: JobAction(OffloadWrapperJobClass, Inputs, Type) {}
481+
: JobAction(OffloadWrapperJobClass, Inputs, Type), EmbedIR(false) {}
482482

483-
OffloadWrapperJobAction::OffloadWrapperJobAction(Action *Input,
484-
types::ID Type)
485-
: JobAction(OffloadWrapperJobClass, Input, Type) {}
483+
OffloadWrapperJobAction::OffloadWrapperJobAction(Action *Input, types::ID Type,
484+
bool IsEmbeddedIR)
485+
: JobAction(OffloadWrapperJobClass, Input, Type), EmbedIR(IsEmbeddedIR) {}
486486

487487
void OffloadPackagerJobAction::anchor() {}
488488

clang/lib/Driver/Driver.cpp

+64-52
Original file line numberDiff line numberDiff line change
@@ -5516,6 +5516,8 @@ class OffloadingActionBuilder final {
55165516
// s - device code split requested
55175517
// r - relocatable device code is requested
55185518
// f - link object output type is TY_Tempfilelist (fat archive)
5519+
// e - Embedded IR for fusion (-fsycl-embed-ir) was requested
5520+
// and target is NVPTX.
55195521
// * - "all other cases"
55205522
// - no condition means output/input is "always" present
55215523
// First symbol indicates output/input type
@@ -5535,58 +5537,58 @@ class OffloadingActionBuilder final {
55355537
// | |
55365538
// | |
55375539
// .---------------------------------------.
5538-
// | PostLink |
5539-
// .---------------------------------------.
5540-
// [+*] [+]
5541-
// | |
5542-
// | |
5543-
// |--------- |
5544-
// | | |
5545-
// | | |
5546-
// | [+!rf] |
5547-
// | .-------------. |
5548-
// | | llvm-foreach| |
5549-
// | .-------------. |
5550-
// | | |
5551-
// [+*] [+!rf] |
5552-
// .-----------------. |
5553-
// | FileTableTform | |
5554-
// | (extract "Code")| |
5555-
// .-----------------. |
5556-
// [-] |-----------
5557-
// --------------------| |
5558-
// | | |
5559-
// | |----------------- |
5560-
// | | | |
5561-
// | | [-!rf] |
5562-
// | | .--------------. |
5563-
// | | |FileTableTform| |
5564-
// | | | (merge) | |
5565-
// | | .--------------. |
5566-
// | | [-] |-------
5567-
// | | | | |
5568-
// | | | ------| |
5569-
// | | --------| | |
5570-
// [.] [-*] [-!rf] [+!rf] |
5571-
// .---------------. .-------------------. .--------------. |
5572-
// | finalizeNVPTX | | SPIRVTranslator | |FileTableTform| |
5573-
// | finalizeAMDGCN | | | | (merge) | |
5574-
// .---------------. .-------------------. . -------------. |
5575-
// [.] [-as] [-!a] | |
5576-
// | | | | |
5577-
// | [-s] | | |
5578-
// | .----------------. | | |
5579-
// | | BackendCompile | | | |
5580-
// | .----------------. | ------| |
5581-
// | [-s] | | |
5582-
// | | | | |
5583-
// | [-a] [-!a] [-!rf] |
5584-
// | .--------------------. |
5585-
// -----------[-n]| FileTableTform |[+*]--------------|
5586-
// | (replace "Code") |
5587-
// .--------------------.
5588-
// |
5589-
// [+*]
5540+
// | PostLink |[+e]----------------
5541+
// .---------------------------------------. |
5542+
// [+*] [+] |
5543+
// | | |
5544+
// | | |
5545+
// |--------- | |
5546+
// | | | |
5547+
// | | | |
5548+
// | [+!rf] | |
5549+
// | .-------------. | |
5550+
// | | llvm-foreach| | |
5551+
// | .-------------. | |
5552+
// | | | |
5553+
// [+*] [+!rf] | |
5554+
// .-----------------. | |
5555+
// | FileTableTform | | |
5556+
// | (extract "Code")| | |
5557+
// .-----------------. | |
5558+
// [-] |----------- |
5559+
// --------------------| | |
5560+
// | | | |
5561+
// | |----------------- | |
5562+
// | | | | |
5563+
// | | [-!rf] | |
5564+
// | | .--------------. | |
5565+
// | | |FileTableTform| | |
5566+
// | | | (merge) | | |
5567+
// | | .--------------. | |
5568+
// | | [-] |------- |
5569+
// | | | | | |
5570+
// | | | ------| | |
5571+
// | | --------| | | |
5572+
// [.] [-*] [-!rf] [+!rf] | |
5573+
// .---------------. .-------------------. .--------------. | |
5574+
// | finalizeNVPTX | | SPIRVTranslator | |FileTableTform| | |
5575+
// | finalizeAMDGCN | | | | (merge) | | |
5576+
// .---------------. .-------------------. . -------------. | |
5577+
// [.] [-as] [-!a] | | |
5578+
// | | | | | |
5579+
// | [-s] | | | |
5580+
// | .----------------. | | | |
5581+
// | | BackendCompile | | | | |
5582+
// | .----------------. | ------| | |
5583+
// | [-s] | | | |
5584+
// | | | | | |
5585+
// | [-a] [-!a] [-!rf] | |
5586+
// | .--------------------. | |
5587+
// -----------[-n]| FileTableTform |[+*]--------------| |
5588+
// | (replace "Code") | |
5589+
// .--------------------. |
5590+
// | -------------------------
5591+
// [+*] | [+e]
55905592
// .--------------------------------------.
55915593
// | OffloadWrapper |
55925594
// .--------------------------------------.
@@ -5693,6 +5695,16 @@ class OffloadingActionBuilder final {
56935695
return TypedPostLinkAction;
56945696
};
56955697
Action *PostLinkAction = createPostLinkAction();
5698+
if (isNVPTX && Args.hasArg(options::OPT_fsycl_embed_ir)) {
5699+
// When compiling for Nvidia/CUDA devices and the user requested the
5700+
// IR to be embedded in the application (via option), run the output
5701+
// of sycl-post-link (filetable referencing LLVM Bitcode + symbols)
5702+
// through the offload wrapper and link the resulting object to the
5703+
// application.
5704+
auto *WrapBitcodeAction = C.MakeAction<OffloadWrapperJobAction>(
5705+
PostLinkAction, types::TY_Object, true);
5706+
DA.add(*WrapBitcodeAction, *TC, BoundArch, Action::OFK_SYCL);
5707+
}
56965708
bool NoRDCFatStaticArchive =
56975709
!IsRDC &&
56985710
FullDeviceLinkAction->getType() == types::TY_Tempfilelist;

clang/lib/Driver/ToolChains/Clang.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -9272,6 +9272,14 @@ void OffloadWrapper::ConstructJob(Compilation &C, const JobAction &JA,
92729272
createArgString("-link-opts=");
92739273
}
92749274

9275+
bool IsEmbeddedIR = cast<OffloadWrapperJobAction>(JA).isEmbeddedIR();
9276+
if (IsEmbeddedIR) {
9277+
// When the offload-wrapper is called to embed LLVM IR, add a prefix to
9278+
// the target triple to distinguish the LLVM IR from the actual device
9279+
// binary for that target.
9280+
TargetTripleOpt = ("llvm_" + TargetTripleOpt).str();
9281+
}
9282+
92759283
WrapperArgs.push_back(
92769284
C.getArgs().MakeArgString(Twine("-target=") + TargetTripleOpt));
92779285

@@ -9293,7 +9301,7 @@ void OffloadWrapper::ConstructJob(Compilation &C, const JobAction &JA,
92939301
assert(I.isFilename() && "Invalid input.");
92949302

92959303
if (I.getType() == types::TY_Tempfiletable ||
9296-
I.getType() == types::TY_Tempfilelist)
9304+
I.getType() == types::TY_Tempfilelist || IsEmbeddedIR)
92979305
// wrapper actual input files are passed via the batch job file table:
92989306
WrapperArgs.push_back(C.getArgs().MakeArgString("-batch"));
92999307
WrapperArgs.push_back(C.getArgs().MakeArgString(I.getFilename()));

sycl-fusion/common/include/Kernel.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ enum class ParameterKind : uint32_t {
3434
};
3535

3636
/// Different binary formats supported as input to the JIT compiler.
37-
enum class BinaryFormat : uint32_t { INVALID, LLVM, SPIRV };
37+
enum class BinaryFormat : uint32_t { INVALID, LLVM, SPIRV, PTX };
3838

3939
/// Information about a device intermediate representation module (e.g., SPIR-V,
4040
/// LLVM IR) from DPC++.

sycl-fusion/common/lib/KernelIO.h

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ template <> struct ScalarEnumerationTraits<jit_compiler::BinaryFormat> {
4747
static void enumeration(IO &IO, jit_compiler::BinaryFormat &BF) {
4848
IO.enumCase(BF, "LLVM", jit_compiler::BinaryFormat::LLVM);
4949
IO.enumCase(BF, "SPIRV", jit_compiler::BinaryFormat::SPIRV);
50+
IO.enumCase(BF, "PTX", jit_compiler::BinaryFormat::PTX);
5051
IO.enumCase(BF, "INVALID", jit_compiler::BinaryFormat::INVALID);
5152
}
5253
};

sycl-fusion/jit-compiler/CMakeLists.txt

+11-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
add_llvm_library(sycl-fusion
33
lib/KernelFusion.cpp
44
lib/JITContext.cpp
5+
lib/translation/KernelTranslation.cpp
56
lib/translation/SPIRVLLVMTranslation.cpp
67
lib/fusion/FusionPipeline.cpp
78
lib/fusion/FusionHelper.cpp
89
lib/fusion/ModuleHelper.cpp
910
lib/helper/ConfigHelper.cpp
1011

11-
LINK_COMPONENTS
12+
LINK_COMPONENTS
13+
BitReader
1214
Core
1315
Support
1416
Analysis
@@ -18,6 +20,10 @@ add_llvm_library(sycl-fusion
1820
Linker
1921
ScalarOpts
2022
InstCombine
23+
Target
24+
TargetParser
25+
MC
26+
${LLVM_TARGETS_TO_BUILD}
2127
)
2228

2329
target_include_directories(sycl-fusion
@@ -40,6 +46,10 @@ target_link_libraries(sycl-fusion
4046
${CMAKE_THREAD_LIBS_INIT}
4147
)
4248

49+
if("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
50+
target_compile_definitions(sycl-fusion PRIVATE FUSION_JIT_SUPPORT_PTX)
51+
endif()
52+
4353
if (BUILD_SHARED_LIBS)
4454
if(NOT MSVC AND NOT APPLE)
4555
# Manage symbol visibility through the linker to make sure no LLVM symbols

sycl-fusion/jit-compiler/include/JITContext.h

+12-5
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,21 @@ using CacheKeyT =
3636
std::optional<std::vector<NDRange>>>;
3737

3838
///
39-
/// Wrapper around a SPIR-V binary.
40-
class SPIRVBinary {
39+
/// Wrapper around a kernel binary.
40+
class KernelBinary {
4141
public:
42-
explicit SPIRVBinary(std::string Binary);
42+
explicit KernelBinary(std::string &&Binary, BinaryFormat Format);
4343

4444
jit_compiler::BinaryAddress address() const;
4545

4646
size_t size() const;
4747

48+
BinaryFormat format() const;
49+
4850
private:
4951
std::string Blob;
52+
53+
BinaryFormat Format;
5054
};
5155

5256
///
@@ -61,7 +65,10 @@ class JITContext {
6165

6266
llvm::LLVMContext *getLLVMContext();
6367

64-
SPIRVBinary &emplaceSPIRVBinary(std::string Binary);
68+
template <typename... Ts> KernelBinary &emplaceKernelBinary(Ts &&...Args) {
69+
WriteLockT WriteLock{BinariesMutex};
70+
return Binaries.emplace_back(std::forward<Ts>(Args)...);
71+
}
6572

6673
std::optional<SYCLKernelInfo> getCacheEntry(CacheKeyT &Identifier) const;
6774

@@ -79,7 +86,7 @@ class JITContext {
7986

8087
MutexT BinariesMutex;
8188

82-
std::vector<SPIRVBinary> Binaries;
89+
std::vector<KernelBinary> Binaries;
8390

8491
mutable MutexT CacheMutex;
8592

sycl-fusion/jit-compiler/include/Options.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#ifndef SYCL_FUSION_JIT_COMPILER_OPTIONS_H
1010
#define SYCL_FUSION_JIT_COMPILER_OPTIONS_H
1111

12+
#include "Kernel.h"
13+
1214
#include <memory>
1315
#include <unordered_map>
1416

1517
namespace jit_compiler {
1618

17-
enum OptionID { VerboseOutput, EnableCaching };
19+
enum OptionID { VerboseOutput, EnableCaching, TargetFormat };
1820

1921
class OptionPtrBase {};
2022

@@ -78,6 +80,9 @@ struct JITEnableVerbose : public OptionBase<OptionID::VerboseOutput, bool> {};
7880

7981
struct JITEnableCaching : public OptionBase<OptionID::EnableCaching, bool> {};
8082

83+
struct JITTargetFormat
84+
: public OptionBase<OptionID::TargetFormat, BinaryFormat> {};
85+
8186
} // namespace option
8287
} // namespace jit_compiler
8388

sycl-fusion/jit-compiler/lib/JITContext.cpp

+6-11
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,24 @@
1111

1212
using namespace jit_compiler;
1313

14-
SPIRVBinary::SPIRVBinary(std::string Binary) : Blob{std::move(Binary)} {}
14+
KernelBinary::KernelBinary(std::string &&Binary, BinaryFormat Fmt)
15+
: Blob{std::move(Binary)}, Format{Fmt} {}
1516

16-
jit_compiler::BinaryAddress SPIRVBinary::address() const {
17+
jit_compiler::BinaryAddress KernelBinary::address() const {
1718
// FIXME: Verify it's a good idea to perform this reinterpret_cast here.
1819
return reinterpret_cast<jit_compiler::BinaryAddress>(Blob.c_str());
1920
}
2021

21-
size_t SPIRVBinary::size() const { return Blob.size(); }
22+
size_t KernelBinary::size() const { return Blob.size(); }
23+
24+
BinaryFormat KernelBinary::format() const { return Format; }
2225

2326
JITContext::JITContext() : LLVMCtx{new llvm::LLVMContext}, Binaries{} {}
2427

2528
JITContext::~JITContext() = default;
2629

2730
llvm::LLVMContext *JITContext::getLLVMContext() { return LLVMCtx.get(); }
2831

29-
SPIRVBinary &JITContext::emplaceSPIRVBinary(std::string Binary) {
30-
WriteLockT WriteLock{BinariesMutex};
31-
// NOTE: With C++17, which returns a reference from emplace_back, the
32-
// following code would be even simpler.
33-
Binaries.emplace_back(std::move(Binary));
34-
return Binaries.back();
35-
}
36-
3732
std::optional<SYCLKernelInfo>
3833
JITContext::getCacheEntry(CacheKeyT &Identifier) const {
3934
ReadLockT ReadLock{CacheMutex};

0 commit comments

Comments
 (0)