Skip to content

Commit aa2b629

Browse files
committed
[SYCL][Fusion] Rebase and address feedback
Signed-off-by: Lukas Sommer <[email protected]>
1 parent 26978b2 commit aa2b629

File tree

20 files changed

+412
-350
lines changed

20 files changed

+412
-350
lines changed

clang/include/clang/Driver/Action.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ class OffloadWrapperJobAction : public JobAction {
665665
public:
666666
OffloadWrapperJobAction(ActionList &Inputs, types::ID Type);
667667
OffloadWrapperJobAction(Action *Input, types::ID OutputType,
668-
bool IsEmbeddedIR = false);
668+
bool EmbedIR = false);
669669

670670
bool isEmbeddedIR() const { return EmbedIR; }
671671

clang/lib/Driver/Action.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ void OffloadWrapperJobAction::anchor() {}
478478

479479
OffloadWrapperJobAction::OffloadWrapperJobAction(ActionList &Inputs,
480480
types::ID Type)
481-
: JobAction(OffloadWrapperJobClass, Inputs, Type) {}
481+
: JobAction(OffloadWrapperJobClass, Inputs, Type), EmbedIR(false) {}
482482

483483
OffloadWrapperJobAction::OffloadWrapperJobAction(Action *Input, types::ID Type,
484484
bool IsEmbeddedIR)

sycl-fusion/jit-compiler/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ add_llvm_library(sycl-fusion
2222
InstCombine
2323
Target
2424
TargetParser
25-
NVPTX
26-
X86
2725
MC
26+
${LLVM_TARGETS_TO_BUILD}
2827
)
2928

3029
target_include_directories(sycl-fusion
@@ -47,6 +46,10 @@ target_link_libraries(sycl-fusion
4746
${CMAKE_THREAD_LIBS_INIT}
4847
)
4948

49+
if("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
50+
target_compile_definitions(sycl-fusion PRIVATE FUSION_JIT_SUPPORT_PTX)
51+
endif()
52+
5053
if (BUILD_SHARED_LIBS)
5154
if(NOT MSVC AND NOT APPLE)
5255
# Manage symbol visibility through the linker to make sure no LLVM symbols

sycl-fusion/jit-compiler/include/JITContext.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ using CacheKeyT =
3939
/// Wrapper around a kernel binary.
4040
class KernelBinary {
4141
public:
42-
explicit KernelBinary(std::string Binary, BinaryFormat Format);
42+
explicit KernelBinary(std::string &&Binary, BinaryFormat Format);
4343

4444
jit_compiler::BinaryAddress address() const;
4545

@@ -65,7 +65,10 @@ class JITContext {
6565

6666
llvm::LLVMContext *getLLVMContext();
6767

68-
KernelBinary &emplaceSPIRVBinary(std::string Binary, BinaryFormat Format);
68+
template <typename... Ts> KernelBinary &emplaceKernelBinary(Ts &&...Args) {
69+
WriteLockT WriteLock{BinariesMutex};
70+
return Binaries.emplace_back(std::forward<Ts>(Args)...);
71+
}
6972

7073
std::optional<SYCLKernelInfo> getCacheEntry(CacheKeyT &Identifier) const;
7174

sycl-fusion/jit-compiler/include/Options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define SYCL_FUSION_JIT_COMPILER_OPTIONS_H
1111

1212
#include "Kernel.h"
13+
1314
#include <memory>
1415
#include <unordered_map>
1516

sycl-fusion/jit-compiler/lib/JITContext.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
using namespace jit_compiler;
1313

14-
KernelBinary::KernelBinary(std::string Binary, BinaryFormat Fmt)
14+
KernelBinary::KernelBinary(std::string &&Binary, BinaryFormat Fmt)
1515
: Blob{std::move(Binary)}, Format{Fmt} {}
1616

1717
jit_compiler::BinaryAddress KernelBinary::address() const {
@@ -29,15 +29,6 @@ JITContext::~JITContext() = default;
2929

3030
llvm::LLVMContext *JITContext::getLLVMContext() { return LLVMCtx.get(); }
3131

32-
KernelBinary &JITContext::emplaceSPIRVBinary(std::string Binary,
33-
BinaryFormat Format) {
34-
WriteLockT WriteLock{BinariesMutex};
35-
// NOTE: With C++17, which returns a reference from emplace_back, the
36-
// following code would be even simpler.
37-
Binaries.emplace_back(std::move(Binary), Format);
38-
return Binaries.back();
39-
}
40-
4132
std::optional<SYCLKernelInfo>
4233
JITContext::getCacheEntry(CacheKeyT &Identifier) const {
4334
ReadLockT ReadLock{CacheMutex};

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ gatherNDRanges(llvm::ArrayRef<SYCLKernelInfo> KernelInformation) {
4848
return NDRanges;
4949
}
5050

51+
static bool isTargetFormatSupported(BinaryFormat TargetFormat) {
52+
switch (TargetFormat) {
53+
case BinaryFormat::SPIRV:
54+
return true;
55+
case BinaryFormat::PTX: {
56+
#ifdef FUSION_JIT_SUPPORT_PTX
57+
return true;
58+
#else // FUSION_JIT_SUPPORT_PTX
59+
return false;
60+
#endif // FUSION_JIT_SUPPORT_PTX
61+
}
62+
default:
63+
return false;
64+
}
65+
}
66+
5167
FusionResult KernelFusion::fuseKernels(
5268
JITContext &JITCtx, Config &&JITConfig,
5369
const std::vector<SYCLKernelInfo> &KernelInformation,
@@ -71,6 +87,12 @@ FusionResult KernelFusion::fuseKernels(
7187
bool IsHeterogeneousList = jit_compiler::isHeterogeneousList(NDRanges);
7288

7389
BinaryFormat TargetFormat = ConfigHelper::get<option::JITTargetFormat>();
90+
91+
if (!isTargetFormatSupported(TargetFormat)) {
92+
return FusionResult(
93+
"Fusion output target format not supported by this build");
94+
}
95+
7496
if (TargetFormat == BinaryFormat::PTX && IsHeterogeneousList) {
7597
return FusionResult{"Heterogeneous ND ranges not supported for CUDA"};
7698
}

sycl-fusion/jit-compiler/lib/fusion/FusionPipeline.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
#ifndef NDEBUG
2424
#include "llvm/IR/Verifier.h"
2525
#endif // NDEBUG
26-
#include "llvm/ADT/Triple.h"
2726
#include "llvm/Passes/PassBuilder.h"
27+
#include "llvm/TargetParser/Triple.h"
2828
#include "llvm/Transforms/InstCombine/InstCombine.h"
2929
#include "llvm/Transforms/Scalar/ADCE.h"
3030
#include "llvm/Transforms/Scalar/EarlyCSE.h"
@@ -103,7 +103,7 @@ FusionPipeline::runFusionPasses(Module &Mod, SYCLModuleInfo &InputInfo,
103103
// to/from generic address-space as possible, because these hinder
104104
// internalization.
105105
// Ideally, the static compiler should have performed that job.
106-
unsigned FlatAddressSpace = getFlatAddressSpace(Mod);
106+
const unsigned FlatAddressSpace = getFlatAddressSpace(Mod);
107107
FPM.addPass(InferAddressSpacesPass(FlatAddressSpace));
108108
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
109109
}

sycl-fusion/jit-compiler/lib/fusion/ModuleHelper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ helper::ModuleHelper::cloneAndPruneModule(Module *Mod,
2424
identifyUnusedFunctions(Mod, CGRoots, UnusedFunctions);
2525

2626
{
27-
auto TFI = llvm::TargetFusionInfo::getTargetFusionInfo(Mod);
27+
TargetFusionInfo TFI{Mod};
2828
SmallVector<Function *> Unused{UnusedFunctions.begin(),
2929
UnusedFunctions.end()};
3030
TFI.notifyFunctionsDelete(Unused);

sycl-fusion/jit-compiler/lib/translation/KernelTranslation.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "KernelTranslation.h"
10+
1011
#include "SPIRVLLVMTranslation.h"
1112
#include "llvm/Bitcode/BitcodeReader.h"
1213
#include "llvm/IR/Constants.h"
@@ -182,7 +183,8 @@ llvm::Error KernelTranslator::translateKernel(SYCLKernelInfo &Kernel,
182183
break;
183184
}
184185
case BinaryFormat::PTX: {
185-
llvm::Expected<KernelBinary *> BinaryOrError = translateToPTX(Mod, JITCtx);
186+
llvm::Expected<KernelBinary *> BinaryOrError =
187+
translateToPTX(Kernel, Mod, JITCtx);
186188
if (auto Error = BinaryOrError.takeError()) {
187189
return Error;
188190
}
@@ -215,12 +217,20 @@ KernelTranslator::translateToSPIRV(llvm::Module &Mod, JITContext &JITCtx) {
215217
}
216218

217219
llvm::Expected<KernelBinary *>
218-
KernelTranslator::translateToPTX(llvm::Module &Mod, JITContext &JITCtx) {
219-
// FIXME: Can we limit this to the NVPTX specific target?
220-
llvm::InitializeAllTargets();
221-
llvm::InitializeAllAsmParsers();
222-
llvm::InitializeAllAsmPrinters();
223-
llvm::InitializeAllTargetMCs();
220+
KernelTranslator::translateToPTX(SYCLKernelInfo &KernelInfo, llvm::Module &Mod,
221+
JITContext &JITCtx) {
222+
#ifndef FUSION_JIT_SUPPORT_PTX
223+
return createStringError(inconvertibleErrorCode(),
224+
"PTX translation not supported in this build");
225+
#else // FUSION_JIT_SUPPORT_PTX
226+
LLVMInitializeNVPTXTargetInfo();
227+
LLVMInitializeNVPTXTarget();
228+
LLVMInitializeNVPTXAsmPrinter();
229+
LLVMInitializeNVPTXTargetMC();
230+
#endif // FUSION_JIT_SUPPORT_PTX
231+
232+
static const char *TARGET_CPU_ATTRIBUTE = "target-cpu";
233+
static const char *TARGET_FEATURE_ATTRIBUTE = "target-features";
224234

225235
std::string TargetTriple{"nvptx64-nvidia-cuda"};
226236

@@ -231,13 +241,26 @@ KernelTranslator::translateToPTX(llvm::Module &Mod, JITContext &JITCtx) {
231241
if (!Target) {
232242
return createStringError(
233243
inconvertibleErrorCode(),
234-
"Failed to load and translate SPIR-V module with error %s",
244+
"Failed to load and translate PTX LLVM IR module with error %s",
235245
ErrorMessage.c_str());
236246
}
237247

248+
llvm::StringRef TargetCPU{"sm_50"};
249+
llvm::StringRef TargetFeatures{"+sm_50,+ptx76"};
250+
if (auto *KernelFunc = Mod.getFunction(KernelInfo.Name)) {
251+
if (KernelFunc->hasFnAttribute(TARGET_CPU_ATTRIBUTE)) {
252+
TargetCPU =
253+
KernelFunc->getFnAttribute(TARGET_CPU_ATTRIBUTE).getValueAsString();
254+
}
255+
if (KernelFunc->hasFnAttribute(TARGET_FEATURE_ATTRIBUTE)) {
256+
TargetFeatures = KernelFunc->getFnAttribute(TARGET_FEATURE_ATTRIBUTE)
257+
.getValueAsString();
258+
}
259+
}
260+
238261
// FIXME: Check whether we can provide more accurate target information here
239262
auto *TargetMachine = Target->createTargetMachine(
240-
TargetTriple, "sm_50", "+sm_50,+ptx76", {}, llvm::Reloc::PIC_,
263+
TargetTriple, TargetCPU, TargetFeatures, {}, llvm::Reloc::PIC_,
241264
std::nullopt, llvm::CodeGenOpt::Default);
242265

243266
llvm::legacy::PassManager PM;
@@ -259,5 +282,5 @@ KernelTranslator::translateToPTX(llvm::Module &Mod, JITContext &JITCtx) {
259282
ASMStream.flush();
260283
}
261284

262-
return &JITCtx.emplaceSPIRVBinary(PTXASM, BinaryFormat::PTX);
285+
return &JITCtx.emplaceKernelBinary(std::move(PTXASM), BinaryFormat::PTX);
263286
}

sycl-fusion/jit-compiler/lib/translation/KernelTranslation.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
#ifndef SYCL_FUSION_JIT_COMPILER_TRANSLATION_KERNELTRANSLATION_H
9+
#define SYCL_FUSION_JIT_COMPILER_TRANSLATION_KERNELTRANSLATION_H
810

911
#include "JITContext.h"
1012
#include "Kernel.h"
1113
#include "llvm/IR/LLVMContext.h"
1214
#include "llvm/IR/Module.h"
13-
#include <llvm/Support/Error.h>
15+
#include "llvm/Support/Error.h"
1416
#include <vector>
1517

1618
namespace jit_compiler {
@@ -39,8 +41,10 @@ class KernelTranslator {
3941
static llvm::Expected<KernelBinary *> translateToSPIRV(llvm::Module &Mod,
4042
JITContext &JITCtx);
4143

42-
static llvm::Expected<KernelBinary *> translateToPTX(llvm::Module &Mod,
43-
JITContext &JITCtx);
44+
static llvm::Expected<KernelBinary *>
45+
translateToPTX(SYCLKernelInfo &Kernel, llvm::Module &Mod, JITContext &JITCtx);
4446
};
4547
} // namespace translation
4648
} // namespace jit_compiler
49+
50+
#endif // SYCL_FUSION_JIT_COMPILER_TRANSLATION_KERNELTRANSLATION_H

sycl-fusion/jit-compiler/lib/translation/SPIRVLLVMTranslation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,5 @@ SPIRVLLVMTranslator::translateLLVMtoSPIRV(Module &Mod, JITContext &JITCtx) {
9494
"Translation of LLVM IR to SPIR-V failed with error %s",
9595
ErrMsg.c_str());
9696
}
97-
return &JITCtx.emplaceSPIRVBinary(BinaryStream.str(), BinaryFormat::SPIRV);
97+
return &JITCtx.emplaceKernelBinary(BinaryStream.str(), BinaryFormat::SPIRV);
9898
}

sycl-fusion/passes/internalization/Internalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ static void moduleCleanup(Module &M, ModuleAnalysisManager &AM,
631631

632632
PreservedAnalyses llvm::SYCLInternalizer::run(Module &M,
633633
ModuleAnalysisManager &AM) {
634-
auto TFI = TargetFusionInfo::getTargetFusionInfo(&M);
634+
TargetFusionInfo TFI{&M};
635635
// Private promotion
636636
const PreservedAnalyses Tmp = SYCLInternalizerImpl{
637637
TFI.getPrivateAddressSpace(), PrivatePromotion, true, TFI}(M, AM);

sycl-fusion/passes/kernel-fusion/SYCLKernelFusion.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ PreservedAnalyses SYCLKernelFusion::run(Module &M, ModuleAnalysisManager &AM) {
138138
AM.getResult<SYCLModuleInfoAnalysis>(M).ModuleInfo;
139139
assert(ModuleInfo && "No module information available");
140140

141-
auto TFI = TargetFusionInfo::getTargetFusionInfo(&M);
141+
TargetFusionInfo TFI{&M};
142142

143143
// Iterate over the functions in the module and locate all
144144
// stub functions identified by metadata.
@@ -456,11 +456,20 @@ Error SYCLKernelFusion::fuseKernel(
456456
FT, GlobalValue::LinkageTypes::ExternalLinkage,
457457
M.getDataLayout().getProgramAddressSpace(), KernelName->getString(), &M);
458458
{
459+
auto DefaultAttr = FusedFunction->getAttributes();
460+
// Add uniform function attributes, i.e., attributes with identical value on
461+
// each input function, to the fused function.
462+
auto *FirstFunction = InputFunctions.front().F;
463+
for (const auto &UniformKey : TargetInfo.getUniformKernelAttributes()) {
464+
if (FirstFunction->hasFnAttribute(UniformKey)) {
465+
DefaultAttr = DefaultAttr.addFnAttribute(
466+
LLVMCtx, FirstFunction->getFnAttribute(UniformKey));
467+
}
468+
}
459469
// Add the collected parameter attributes to the fused function.
460470
// Copying the parameter attributes from their original definition in the
461471
// input kernels should be safe and they most likely can't be deducted later
462472
// on, as no caller is present in the module.
463-
auto DefaultAttr = FusedFunction->getAttributes();
464473
auto FusedFnAttrs =
465474
AttributeList::get(LLVMCtx, DefaultAttr.getFnAttrs(),
466475
DefaultAttr.getRetAttrs(), FusedParamAttributes);

sycl-fusion/passes/syclcp/SYCLCP.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ PreservedAnalyses SYCLCP::run(Module &M, ModuleAnalysisManager &AM) {
249249
Changed = propagateConstants(F, *ConstantsOrErr) || Changed;
250250
}
251251

252-
auto TFI = TargetFusionInfo::getTargetFusionInfo(&M);
252+
TargetFusionInfo TFI{&M};
253253

254254
if (Changed) {
255255
moduleCleanup(M, AM, TFI);

0 commit comments

Comments
 (0)