Skip to content

Commit

Permalink
Fixed init_attr bug
Browse files Browse the repository at this point in the history
  • Loading branch information
TejaX-Alaghari committed Jan 28, 2025
1 parent 41862ac commit 3c6f51a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@ REGISTER_RULE(CuDNNAPIRule, PassKind::PK_Migration, RuleGroupKind::RK_DNN)
REGISTER_RULE(NVSHMEMRule, PassKind::PK_Migration, RuleGroupKind::RK_NVSHMEM)

} // namespace dpct
} // namespace clang
} // namespace clang
2 changes: 1 addition & 1 deletion clang/lib/DPCT/MigrationReport/Libraries.inc
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ LIBRARY(CCL, "oneAPI Collective Communications Library (oneCCL)",
LIBRARY(SHMEM, "Intel Shared Memory Library (iSHMEM)",
RuleGroupKind::RK_NVSHMEM)

#undef LIBRARY
#undef LIBRARY
5 changes: 4 additions & 1 deletion clang/lib/DPCT/RulesSHMEM/NVSHMEMAPIMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ void clang::dpct::NVSHMEMRule::runRule(
}
}

llvm::outs() << "nvshmem_rt: " << nvshmem_rt << "\n";
if (nvshmem_rt == "0" || nvshmem_init_rt == "0") {
emplaceTransformation(new ReplaceStmt(CE, "ishmem_init()"));
return;
}

std::string ishmem_rt = "";
if (nvshmem_rt == "NVSHMEMX_INIT_WITH_MPI_COMM") {
Expand Down
8 changes: 8 additions & 0 deletions clang/test/dpct/nvshmem.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ int main() {
// CHECK-NEXT: ishmemx_init_attr(&attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr);

// CHECK: /*
// CHECK-NEXT: DPCT1007:{{[0-9]+}}: Migration of nvshmemx_init_attr is not supported.
// CHECK-NEXT: */
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM | NVSHMEMX_INIT_WITH_SHMEM, &attr);

// CHECK: unsigned int rt = ISHMEMX_RUNTIME_MPI;
// CHECK-NEXT: (&attr)->runtime = static_cast<ishmemx_runtime_type_t>(rt);
// CHECK-NEXT: ishmemx_init_attr(&attr);
Expand All @@ -74,6 +79,9 @@ int main() {
rt = NVSHMEMX_INIT_WITH_SHMEM;
nvshmemx_init_attr(rt, &attr);

// CHECK: ishmem_init();
nvshmemx_init_attr(0, &attr);

// CHECK: ishmem_init();
nvshmem_init();

Expand Down

0 comments on commit 3c6f51a

Please sign in to comment.