Skip to content

Commit

Permalink
Annotate cholesky for use with apex/other
Browse files Browse the repository at this point in the history
  • Loading branch information
biddisco committed Mar 30, 2023
1 parent 0f0ae83 commit fda4abb
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
23 changes: 15 additions & 8 deletions include/dlaf/factorization/cholesky/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,19 @@ namespace dlaf {
namespace factorization {
namespace internal {

#ifdef PIKA_HAVE_APEX
#define ANNOTATE(NAME) (priority == pika::execution::thread_priority::high ? "HP_" #NAME : #NAME)
#else
#define ANNOTATE(name) nullptr;
#endif

namespace cholesky_l {

template <Backend backend, class MatrixTileSender>
void potrfDiagTile(pika::execution::thread_priority priority, MatrixTileSender&& matrix_tile) {
pika::execution::experimental::start_detached(
dlaf::internal::whenAllLift(blas::Uplo::Lower, std::forward<MatrixTileSender>(matrix_tile)) |
tile::potrf(dlaf::internal::Policy<backend>(priority)));
tile::potrf(dlaf::internal::Policy<backend>(priority, ANNOTATE(potrfDiagTile))));
}

template <Backend backend, class KKTileSender, class MatrixTileSender>
Expand All @@ -52,7 +59,7 @@ void trsmPanelTile(pika::execution::thread_priority priority, KKTileSender&& kk_
blas::Diag::NonUnit, ElementType(1.0),
std::forward<KKTileSender>(kk_tile),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::trsm(dlaf::internal::Policy<backend>(priority)));
tile::trsm(dlaf::internal::Policy<backend>(priority, ANNOTATE(trsmPanelTile))));
}

template <Backend backend, class PanelTileSender, class MatrixTileSender>
Expand All @@ -64,7 +71,7 @@ void herkTrailingDiagTile(pika::execution::thread_priority priority, PanelTileSe
dlaf::internal::whenAllLift(blas::Uplo::Lower, blas::Op::NoTrans, BaseElementType(-1.0),
std::forward<PanelTileSender>(panel_tile), BaseElementType(1.0),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::herk(dlaf::internal::Policy<backend>(priority)));
tile::herk(dlaf::internal::Policy<backend>(priority, ANNOTATE(herkTrailingDiagTile))));
}

template <Backend backend, class PanelTileSender, class ColPanelSender, class MatrixTileSender>
Expand All @@ -77,7 +84,7 @@ void gemmTrailingMatrixTile(pika::execution::thread_priority priority, PanelTile
std::forward<PanelTileSender>(panel_tile),
std::forward<ColPanelSender>(col_panel), ElementType(1.0),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::gemm(dlaf::internal::Policy<backend>(priority)));
tile::gemm(dlaf::internal::Policy<backend>(priority, ANNOTATE(gemmTrailingMatrixTile))));
}
}

Expand All @@ -86,7 +93,7 @@ template <Backend backend, class MatrixTileSender>
void potrfDiagTile(pika::execution::thread_priority priority, MatrixTileSender&& matrix_tile) {
pika::execution::experimental::start_detached(
dlaf::internal::whenAllLift(blas::Uplo::Upper, std::forward<MatrixTileSender>(matrix_tile)) |
tile::potrf(dlaf::internal::Policy<backend>(priority)));
tile::potrf(dlaf::internal::Policy<backend>(priority, ANNOTATE(potrfDiagTile))));
}

template <Backend backend, class KKTileSender, class MatrixTileSender>
Expand All @@ -99,7 +106,7 @@ void trsmPanelTile(pika::execution::thread_priority priority, KKTileSender&& kk_
blas::Diag::NonUnit, ElementType(1.0),
std::forward<KKTileSender>(kk_tile),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::trsm(dlaf::internal::Policy<backend>(priority)));
tile::trsm(dlaf::internal::Policy<backend>(priority, ANNOTATE(trsmPanelTile))));
}

template <Backend backend, class PanelTileSender, class MatrixTileSender>
Expand All @@ -111,7 +118,7 @@ void herkTrailingDiagTile(pika::execution::thread_priority priority, PanelTileSe
dlaf::internal::whenAllLift(blas::Uplo::Upper, blas::Op::ConjTrans, base_element_type(-1.0),
std::forward<PanelTileSender>(panel_tile), base_element_type(1.0),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::herk(dlaf::internal::Policy<backend>(priority)));
tile::herk(dlaf::internal::Policy<backend>(priority, ANNOTATE(herkTrailingDiagTile))));
}

template <Backend backend, class PanelTileSender, class ColPanelSender, class MatrixTileSender>
Expand All @@ -124,7 +131,7 @@ void gemmTrailingMatrixTile(pika::execution::thread_priority priority, PanelTile
std::forward<PanelTileSender>(panel_tile),
std::forward<ColPanelSender>(col_panel), ElementType(1.0),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::gemm(dlaf::internal::Policy<backend>(priority)));
tile::gemm(dlaf::internal::Policy<backend>(priority, ANNOTATE(gemmTrailingMatrixTile))));
}
}

Expand Down
7 changes: 6 additions & 1 deletion include/dlaf/sender/policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ template <Backend B>
class Policy {
private:
const pika::execution::thread_priority priority_ = pika::execution::thread_priority::normal;
const char* annotation_ = nullptr;

public:
Policy() = default;
explicit Policy(pika::execution::thread_priority priority) : priority_(priority) {}
explicit Policy(pika::execution::thread_priority priority, const char* annotation = nullptr)
: priority_(priority), annotation_(annotation) {}
Policy(Policy&&) = default;
Policy(Policy const&) = default;
Policy& operator=(Policy&&) = default;
Expand All @@ -36,6 +38,9 @@ class Policy {
pika::execution::thread_priority priority() const noexcept {
return priority_;
}
const char* annotation() const noexcept {
return annotation_;
}
};
}
}
7 changes: 5 additions & 2 deletions include/dlaf/sender/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ template <TransformDispatchType Tag = TransformDispatchType::Plain, Backend B =
[[nodiscard]] decltype(auto) transform(const Policy<B> policy, F&& f, Sender&& sender) {
using pika::execution::experimental::then;
using pika::execution::experimental::transfer;
using pika::execution::experimental::with_annotation;

auto scheduler = getBackendScheduler<B>(policy.priority());
auto transfer_sender = transfer(std::forward<Sender>(sender), std::move(scheduler));

if constexpr (B == Backend::MC) {
if (policy.annotation()) {
scheduler = with_annotation(scheduler, policy.annotation());
}
auto transfer_sender = transfer(std::forward<Sender>(sender), std::move(scheduler));
return then(std::move(transfer_sender), dlaf::common::internal::Unwrapping{std::forward<F>(f)});
}
else if constexpr (B == Backend::GPU) {
Expand Down

0 comments on commit fda4abb

Please sign in to comment.