From fda4abbe370f77e39413268473274ba605bbb43b Mon Sep 17 00:00:00 2001 From: John Biddiscombe Date: Thu, 30 Mar 2023 14:00:02 +0200 Subject: [PATCH] Annotate cholesky for use with apex/other --- include/dlaf/factorization/cholesky/impl.h | 23 ++++++++++++++-------- include/dlaf/sender/policy.h | 7 ++++++- include/dlaf/sender/transform.h | 7 +++++-- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/include/dlaf/factorization/cholesky/impl.h b/include/dlaf/factorization/cholesky/impl.h index 92426cb6a9..d499ac309e 100644 --- a/include/dlaf/factorization/cholesky/impl.h +++ b/include/dlaf/factorization/cholesky/impl.h @@ -34,12 +34,19 @@ namespace dlaf { namespace factorization { namespace internal { +#ifdef PIKA_HAVE_APEX +#define ANNOTATE(NAME) (priority == pika::execution::thread_priority::high ? "HP_" #NAME : #NAME) +#else +#define ANNOTATE(name) nullptr; +#endif + namespace cholesky_l { + template void potrfDiagTile(pika::execution::thread_priority priority, MatrixTileSender&& matrix_tile) { pika::execution::experimental::start_detached( dlaf::internal::whenAllLift(blas::Uplo::Lower, std::forward(matrix_tile)) | - tile::potrf(dlaf::internal::Policy(priority))); + tile::potrf(dlaf::internal::Policy(priority, ANNOTATE(potrfDiagTile)))); } template @@ -52,7 +59,7 @@ void trsmPanelTile(pika::execution::thread_priority priority, KKTileSender&& kk_ blas::Diag::NonUnit, ElementType(1.0), std::forward(kk_tile), std::forward(matrix_tile)) | - tile::trsm(dlaf::internal::Policy(priority))); + tile::trsm(dlaf::internal::Policy(priority, ANNOTATE(trsmPanelTile)))); } template @@ -64,7 +71,7 @@ void herkTrailingDiagTile(pika::execution::thread_priority priority, PanelTileSe dlaf::internal::whenAllLift(blas::Uplo::Lower, blas::Op::NoTrans, BaseElementType(-1.0), std::forward(panel_tile), BaseElementType(1.0), std::forward(matrix_tile)) | - tile::herk(dlaf::internal::Policy(priority))); + tile::herk(dlaf::internal::Policy(priority, ANNOTATE(herkTrailingDiagTile)))); } template @@ -77,7 +84,7 @@ void gemmTrailingMatrixTile(pika::execution::thread_priority priority, PanelTile std::forward(panel_tile), std::forward(col_panel), ElementType(1.0), std::forward(matrix_tile)) | - tile::gemm(dlaf::internal::Policy(priority))); + tile::gemm(dlaf::internal::Policy(priority, ANNOTATE(gemmTrailingMatrixTile)))); } } @@ -86,7 +93,7 @@ template void potrfDiagTile(pika::execution::thread_priority priority, MatrixTileSender&& matrix_tile) { pika::execution::experimental::start_detached( dlaf::internal::whenAllLift(blas::Uplo::Upper, std::forward(matrix_tile)) | - tile::potrf(dlaf::internal::Policy(priority))); + tile::potrf(dlaf::internal::Policy(priority, ANNOTATE(potrfDiagTile)))); } template @@ -99,7 +106,7 @@ void trsmPanelTile(pika::execution::thread_priority priority, KKTileSender&& kk_ blas::Diag::NonUnit, ElementType(1.0), std::forward(kk_tile), std::forward(matrix_tile)) | - tile::trsm(dlaf::internal::Policy(priority))); + tile::trsm(dlaf::internal::Policy(priority, ANNOTATE(trsmPanelTile)))); } template @@ -111,7 +118,7 @@ void herkTrailingDiagTile(pika::execution::thread_priority priority, PanelTileSe dlaf::internal::whenAllLift(blas::Uplo::Upper, blas::Op::ConjTrans, base_element_type(-1.0), std::forward(panel_tile), base_element_type(1.0), std::forward(matrix_tile)) | - tile::herk(dlaf::internal::Policy(priority))); + tile::herk(dlaf::internal::Policy(priority, ANNOTATE(herkTrailingDiagTile)))); } template @@ -124,7 +131,7 @@ void gemmTrailingMatrixTile(pika::execution::thread_priority priority, PanelTile std::forward(panel_tile), std::forward(col_panel), ElementType(1.0), std::forward(matrix_tile)) | - tile::gemm(dlaf::internal::Policy(priority))); + tile::gemm(dlaf::internal::Policy(priority, ANNOTATE(gemmTrailingMatrixTile)))); } } diff --git a/include/dlaf/sender/policy.h b/include/dlaf/sender/policy.h index 3378ccb7ce..859a74250d 100644 --- a/include/dlaf/sender/policy.h +++ b/include/dlaf/sender/policy.h @@ -24,10 +24,12 @@ template class Policy { private: const pika::execution::thread_priority priority_ = pika::execution::thread_priority::normal; + const char* annotation_ = nullptr; public: Policy() = default; - explicit Policy(pika::execution::thread_priority priority) : priority_(priority) {} + explicit Policy(pika::execution::thread_priority priority, const char* annotation = nullptr) + : priority_(priority), annotation_(annotation) {} Policy(Policy&&) = default; Policy(Policy const&) = default; Policy& operator=(Policy&&) = default; @@ -36,6 +38,9 @@ class Policy { pika::execution::thread_priority priority() const noexcept { return priority_; } + const char* annotation() const noexcept { + return annotation_; + } }; } } diff --git a/include/dlaf/sender/transform.h b/include/dlaf/sender/transform.h index 8d41626802..7838ca3bb8 100644 --- a/include/dlaf/sender/transform.h +++ b/include/dlaf/sender/transform.h @@ -52,11 +52,14 @@ template policy, F&& f, Sender&& sender) { using pika::execution::experimental::then; using pika::execution::experimental::transfer; + using pika::execution::experimental::with_annotation; auto scheduler = getBackendScheduler(policy.priority()); - auto transfer_sender = transfer(std::forward(sender), std::move(scheduler)); - if constexpr (B == Backend::MC) { + if (policy.annotation()) { + scheduler = with_annotation(scheduler, policy.annotation()); + } + auto transfer_sender = transfer(std::forward(sender), std::move(scheduler)); return then(std::move(transfer_sender), dlaf::common::internal::Unwrapping{std::forward(f)}); } else if constexpr (B == Backend::GPU) {