Skip to content

Commit

Permalink
Added support for 20 nvSHMEM API migration
Browse files Browse the repository at this point in the history
  • Loading branch information
TejaX-Alaghari committed Jan 28, 2025
1 parent 1d65b38 commit 41862ac
Show file tree
Hide file tree
Showing 15 changed files with 544 additions and 37 deletions.
46 changes: 28 additions & 18 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,28 @@
//===----------------------------------------------------------------------===//

#include "ASTTraversal.h"
#include "RulesLang/RulesLang.h"
#include "AnalysisInfo.h"
#include "CodePin/GenCodePinHeader.h"
#include "MigrationRuleManager.h"
#include "RulesAsm/AsmMigration.h"
#include "RulesCCL/NCCLAPIMigration.h"
#include "RulesDNN/DNNAPIMigration.h"
#include "RulesLang/OptimizeMigration.h"
#include "RulesLang/RulesLang.h"
#include "RulesLang/WMMAAPIMigration.h"
#include "RulesLangLib/LIBCUAPIMigration.h"
#include "RulesLangLib/NvtxAPIMigration.h"
#include "RulesLangLib/ThrustAPIMigration.h"
#include "RulesMathLib/BLASAPIMigration.h"
#include "RulesMathLib/FFTAPIMigration.h"
#include "RulesMathLib/RandomAPIMigration.h"
#include "RulesMathLib/SolverAPIMigration.h"
#include "RulesMathLib/BLASAPIMigration.h"
#include "CodePin/GenCodePinHeader.h"
#include "RulesMathLib/SpBLASAPIMigration.h"
#include "RulesSHMEM/NVSHMEMAPIMigration.h"
#include "RulesSecurity/Homoglyph.h"
#include "RulesLangLib/LIBCUAPIMigration.h"
#include "RulesLangLib/NvtxAPIMigration.h"
#include "MigrationRuleManager.h"
#include "RulesSecurity/MisleadingBidirectional.h"
#include "RulesCCL/NCCLAPIMigration.h"
#include "RulesLang/OptimizeMigration.h"
#include "RulesMathLib/SpBLASAPIMigration.h"
#include "TextModification.h"
#include "RulesLangLib/ThrustAPIMigration.h"
#include "Utility.h"
#include "RulesLang/WMMAAPIMigration.h"

#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -160,32 +161,41 @@ REGISTER_RULE(GraphicsInteropRule, PassKind::PK_Migration)
REGISTER_RULE(RulesLangAddrSpaceConvRule, PassKind::PK_Migration)

REGISTER_RULE(BLASEnumsRule, PassKind::PK_Migration, RuleGroupKind::RK_BLas)
REGISTER_RULE(BLASFunctionCallRule, PassKind::PK_Migration,RuleGroupKind::RK_BLas)
REGISTER_RULE(BLASFunctionCallRule, PassKind::PK_Migration,
RuleGroupKind::RK_BLas)

REGISTER_RULE(SPBLASEnumsRule, PassKind::PK_Migration, RuleGroupKind::RK_Sparse)
REGISTER_RULE(SPBLASFunctionCallRule, PassKind::PK_Migration,RuleGroupKind::RK_Sparse)
REGISTER_RULE(SPBLASFunctionCallRule, PassKind::PK_Migration,
RuleGroupKind::RK_Sparse)

REGISTER_RULE(RandomEnumsRule, PassKind::PK_Migration, RuleGroupKind::RK_Rng)
REGISTER_RULE(RandomFunctionCallRule, PassKind::PK_Migration,RuleGroupKind::RK_Rng)
REGISTER_RULE(DeviceRandomFunctionCallRule, PassKind::PK_Migration,RuleGroupKind::RK_Rng)
REGISTER_RULE(RandomFunctionCallRule, PassKind::PK_Migration,
RuleGroupKind::RK_Rng)
REGISTER_RULE(DeviceRandomFunctionCallRule, PassKind::PK_Migration,
RuleGroupKind::RK_Rng)

REGISTER_RULE(SOLVEREnumsRule, PassKind::PK_Migration, RuleGroupKind::RK_Solver)
REGISTER_RULE(SOLVERFunctionCallRule, PassKind::PK_Migration,RuleGroupKind::RK_Solver)
REGISTER_RULE(SOLVERFunctionCallRule, PassKind::PK_Migration,
RuleGroupKind::RK_Solver)

REGISTER_RULE(LIBCURule, PassKind::PK_Migration, RuleGroupKind::RK_Libcu)
REGISTER_RULE(NvtxRule, PassKind::PK_Migration)

REGISTER_RULE(ThrustAPIRule, PassKind::PK_Migration, RuleGroupKind::RK_Thrust)
REGISTER_RULE(ThrustTypeRule, PassKind::PK_Migration, RuleGroupKind::RK_Thrust)

REGISTER_RULE(ManualMigrateEnumsRule, PassKind::PK_Migration, RuleGroupKind::RK_NCCL)
REGISTER_RULE(ManualMigrateEnumsRule, PassKind::PK_Migration,
RuleGroupKind::RK_NCCL)
REGISTER_RULE(NCCLRule, PassKind::PK_Migration, RuleGroupKind::RK_NCCL)

REGISTER_RULE(FFTEnumsRule, PassKind::PK_Migration, RuleGroupKind::RK_FFT)
REGISTER_RULE(FFTFunctionCallRule, PassKind::PK_Migration,RuleGroupKind::RK_FFT)
REGISTER_RULE(FFTFunctionCallRule, PassKind::PK_Migration,
RuleGroupKind::RK_FFT)

REGISTER_RULE(CuDNNTypeRule, PassKind::PK_Migration, RuleGroupKind::RK_DNN)
REGISTER_RULE(CuDNNAPIRule, PassKind::PK_Migration, RuleGroupKind::RK_DNN)

REGISTER_RULE(NVSHMEMRule, PassKind::PK_Migration, RuleGroupKind::RK_NVSHMEM)

} // namespace dpct
} // namespace clang
2 changes: 2 additions & 0 deletions clang/lib/DPCT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ add_clang_library(DPCT
RulesLang/CallExprRewriterEvent.cpp
RulesLang/CallExprRewriterCG.cpp
RulesLang/CallExprRewriterWmma.cpp
RulesSHMEM/CallExprRewriterNvshmem.cpp
ErrorHandle/CrashRecovery.cpp
Diagnostics/Diagnostics.cpp
ErrorHandle/Error.cpp
Expand Down Expand Up @@ -240,6 +241,7 @@ add_clang_library(DPCT
RulesDNN/DNNAPIMigration.cpp
RulesCCL/NCCLAPIMigration.cpp
RuleInfra/TypeLocRewriters.cpp
RulesSHMEM/NVSHMEMAPIMigration.cpp
Linux/AutoComplete.cpp
RulesAsm/AsmMigration.cpp
QueryAPIMapping/QueryAPIMapping.cpp
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/DPCT/MigrationReport/Libraries.inc
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ LIBRARY(DNN, "oneAPI Deep Neural Network Library (oneDNN)",
LIBRARY(CCL, "oneAPI Collective Communications Library (oneCCL)",
RuleGroupKind::RK_NCCL)

LIBRARY(SHMEM, "Intel Shared Memory Library (iSHMEM)",
RuleGroupKind::RK_NVSHMEM)

#undef LIBRARY
1 change: 1 addition & 0 deletions clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ void CallExprRewriterFactoryBase::initRewriterMap() {
initRewriterMapMemory();
initRewriterMapMisc();
initRewriterMapNccl();
initRewriterMapNvshmem();
initRewriterMapStream();
initRewriterMapTexture();
initRewriterMapThrust();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/RuleInfra/CallExprRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CallExprRewriterFactoryBase {
static void initRewriterMapMemory();
static void initRewriterMapMisc();
static void initRewriterMapNccl();
static void initRewriterMapNvshmem();
static void initRewriterMapStream();
static void initRewriterMapTexture();
static void initRewriterMapThrust();
Expand Down
15 changes: 15 additions & 0 deletions clang/lib/DPCT/RuleInfra/MapNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,11 @@ void MapNames::setExplicitNamespaceMap(
{"cudaExternalMemoryHandleType",
std::make_shared<TypeNameRule>(getExpNamespace() +
"external_mem_handle_type")},
{"nvshmem_team_t", std::make_shared<TypeNameRule>("ishmem_team_t")},
{"nvshmem_team_config_t",
std::make_shared<TypeNameRule>("ishmem_team_config_t")},
{"nvshmemx_init_attr_t",
std::make_shared<TypeNameRule>("ishmemx_attr_t")},
// ...
};
// SYCLcompat unsupport types
Expand Down Expand Up @@ -1510,6 +1515,16 @@ void MapNames::setExplicitNamespaceMap(
? getExpNamespace() +
"external_mem_handle_type::win32_nt_dx12_resource"
: "cudaExternalMemoryHandleTypeD3D12Resource")},
{"NVSHMEM_TEAM_WORLD",
std::make_shared<EnumNameRule>("ISHMEM_TEAM_WORLD")},
{"NVSHMEM_TEAM_SHARED",
std::make_shared<EnumNameRule>("ISHMEM_TEAM_SHARED")},
{"NVSHMEM_TEAM_INVALID",
std::make_shared<EnumNameRule>("ISHMEM_TEAM_INVALID")},
{"NVSHMEMX_INIT_WITH_MPI_COMM",
std::make_shared<EnumNameRule>("ISHMEMX_RUNTIME_MPI")},
{"NVSHMEMX_INIT_WITH_SHMEM",
std::make_shared<EnumNameRule>("ISHMEMX_RUNTIME_OPENSHMEM")},
// ...
};

Expand Down
2 changes: 2 additions & 0 deletions clang/lib/DPCT/RulesInclude/HeaderTypes.inc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ STD_HEADER(DL, "<libloaderapi.h>")
#else
STD_HEADER(DL, "<dlfcn.h>")
#endif
STD_HEADER(SHMEM, "<ishmem.h>")
STD_HEADER(SHMEMX, "<ishmemx.h>")

ONEDPL_HEADER(Algorithm, "<oneapi/dpl/algorithm>")
ONEDPL_HEADER(Execution, "<oneapi/dpl/execution>")
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/RulesInclude/InclusionHeaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum class RuleGroupKind : uint8_t {
RK_Thrust,
RK_CUB,
RK_WMMA,
RK_NVSHMEM,
NUM
};

Expand Down
5 changes: 5 additions & 0 deletions clang/lib/DPCT/RulesInclude/InclusionHeaders.inc
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,8 @@ REGIST_INCLUSION("cuda_rutime.h", FullMatch, Libcu, Remove, true)

REGIST_INCLUSION("cublasLt.h", FullMatch, BLas, Replace, false,
HeaderType::HT_DPCT_BLAS_GEMM_Utils)

REGIST_INCLUSION("nvshmem.h", FullMatch, NVSHMEM, Replace, false,
HeaderType::HT_SHMEM)
REGIST_INCLUSION("nvshmemx.h", FullMatch, NVSHMEM, Replace, false,
HeaderType::HT_SHMEMX)
93 changes: 93 additions & 0 deletions clang/lib/DPCT/RulesSHMEM/APINamesNvshmem.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//===------------------------ APINamesNvshmem.inc -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// Library Setup, Exit & Query
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_init", CALL("ishmem_init")))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmemx_init_attr",
CALL("ishmemx_init_attr", ARG(1))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_my_pe",
CALL("ishmem_my_pe")))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_n_pes",
CALL("ishmem_n_pes")))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_finalize",
CALL("ishmem_finalize")))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_ptr",
CALL("ishmem_ptr", ARG(0), ARG(1))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_info_get_version",
CALL("ishmem_info_get_version",
ARG(0), ARG(1))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_info_get_name",
CALL("ishmem_info_get_name",
ARG(0))))

// Memory Management
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_malloc",
CALL("ishmem_malloc", ARG(0))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_free",
CALL("ishmem_free", ARG(0))))

FEATURE_REQUEST_FACTORY(
HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_align", CALL("ishmem_align", ARG(0), ARG(1))))

FEATURE_REQUEST_FACTORY(
HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_calloc", CALL("ishmem_calloc", ARG(0), ARG(1))))

// Team Management
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_team_my_pe",
CALL("ishmem_team_my_pe", ARG(0))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_team_n_pes",
CALL("ishmem_team_n_pes", ARG(0))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_team_get_config",
CALL("ishmem_team_get_config",
ARG(0), LITERAL("0"), ARG(1))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_team_translate_pe",
CALL("ishmem_team_translate_pe",
ARG(0), ARG(1), ARG(2))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_team_split_strided",
CALL("ishmem_team_split_strided",
ARG(0), ARG(1), ARG(2), ARG(3),
ARG(4), ARG(5), ARG(6))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_team_split_2d",
CALL("ishmem_team_split_2d", ARG(0),
ARG(1), ARG(2), ARG(3), ARG(4),
ARG(5), ARG(6), ARG(7))))

FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
CALL_FACTORY_ENTRY("nvshmem_team_destroy",
CALL("ishmem_team_destroy", ARG(0))))
23 changes: 23 additions & 0 deletions clang/lib/DPCT/RulesSHMEM/CallExprRewriterNvshmem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===-------------------- CallExprRewriterNvshmem.cpp ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "RuleInfra/CallExprRewriter.h"
#include "RuleInfra/CallExprRewriterCommon.h"

namespace clang {
namespace dpct {
void CallExprRewriterFactoryBase::initRewriterMapNvshmem() {
RewriterMap->merge(
std::unordered_map<std::string,
std::shared_ptr<CallExprRewriterFactoryBase>>({
#include "APINamesNvshmem.inc"
}));
}

} // namespace dpct
} // namespace clang
Loading

0 comments on commit 41862ac

Please sign in to comment.