Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions tasks/guseva_crs/tbb/include/multiplier_tbb.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#pragma once
#include <cmath>
#include <cstddef>
#include <vector>

#include "guseva_crs/common/include/common.hpp"
#include "guseva_crs/common/include/multiplier.hpp"
#include "oneapi/tbb/blocked_range.h"
#include "oneapi/tbb/parallel_for.h"

namespace guseva_crs {

class MultiplierTbb : public Multiplier {
static void PerformCalculation(std::size_t k, std::size_t ind3, std::size_t ind4, const CRS &a, const CRS &bt,
double &sum, std::vector<int> &temp) {
for (k = ind3; k < ind4; k++) {
std::size_t bcol = bt.cols[k];
int aind = temp[bcol];
if (aind != -1) {
sum += a.values[aind] * bt.values[k];
}
}
}

static void ProcessRows(const tbb::blocked_range<std::size_t> &range, const CRS &a, const CRS &bt,
std::vector<std::vector<std::size_t>> &columns, std::vector<std::vector<double>> &values,
std::vector<std::size_t> &row_index) {
std::size_t n = a.nrows;
std::vector<int> temp(n);

for (std::size_t i = range.begin(); i != range.end(); ++i) {
for (int &l : temp) {
l = -1;
}
std::size_t ind1 = a.row_ptrs[i];
std::size_t ind2 = a.row_ptrs[i + 1];
for (std::size_t j = ind1; j < ind2; j++) {
std::size_t col = a.cols[j];
temp[col] = static_cast<int>(j);
}

for (std::size_t j = 0; j < n; j++) {
double sum = 0;
std::size_t ind3 = bt.row_ptrs[j];
std::size_t ind4 = bt.row_ptrs[j + 1];

PerformCalculation(0, ind3, ind4, a, bt, sum, temp);

if (std::fabs(sum) > kZERO) {
columns[i].push_back(j);
values[i].push_back(sum);
row_index[i]++;
}
}
}
}

public:
[[nodiscard]] CRS Multiply(const CRS &a, const CRS &b) const override {
std::size_t n = a.nrows;

auto bt = this->Transpose(b);

std::vector<std::vector<std::size_t>> columns(n);
std::vector<std::vector<double>> values(n);
std::vector<std::size_t> row_index(n + 1, 0);

tbb::parallel_for(tbb::blocked_range<std::size_t>(0, n),
[&a, &bt, &columns, &values, &row_index](const tbb::blocked_range<std::size_t> &range) {
guseva_crs::MultiplierTbb::ProcessRows(range, a, bt, columns, values, row_index);
});

std::size_t nz = 0;
for (std::size_t i = 0; i < n; i++) {
std::size_t tmp = row_index[i];
row_index[i] = nz;
nz += tmp;
}
row_index[n] = nz;

CRS result;
result.row_ptrs = row_index;
result.nrows = n;
result.ncols = n;

for (std::size_t i = 0; i < n; i++) {
result.cols.insert(result.cols.end(), columns[i].begin(), columns[i].end());
result.values.insert(result.values.end(), values[i].begin(), values[i].end());
}

result.nz = result.values.size();
return result;
}
};

} // namespace guseva_crs
22 changes: 22 additions & 0 deletions tasks/guseva_crs/tbb/include/ops_tbb.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include "guseva_crs/common/include/common.hpp"
#include "task/include/task.hpp"

namespace guseva_crs {

class GusevaCRSMatMulTbb : public BaseTask {
public:
static constexpr ppc::task::TypeOfTask GetStaticTypeOfTask() {
return ppc::task::TypeOfTask::kTBB;
}
explicit GusevaCRSMatMulTbb(const InType &in);

private:
bool ValidationImpl() override;
bool PreProcessingImpl() override;
bool RunImpl() override;
bool PostProcessingImpl() override;
};

} // namespace guseva_crs
34 changes: 34 additions & 0 deletions tasks/guseva_crs/tbb/src/ops_tbb.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "guseva_crs/tbb/include/ops_tbb.hpp"

#include "guseva_crs/common/include/common.hpp"
#include "guseva_crs/tbb/include/multiplier_tbb.hpp"

namespace guseva_crs {

GusevaCRSMatMulTbb::GusevaCRSMatMulTbb(const InType &in) {
SetTypeOfTask(GetStaticTypeOfTask());
GetInput() = in;
GetOutput();
}

bool GusevaCRSMatMulTbb::ValidationImpl() {
const auto &[a, b] = GetInput();
return a.ncols == b.nrows;
}

bool GusevaCRSMatMulTbb::PreProcessingImpl() {
return true;
}

bool GusevaCRSMatMulTbb::RunImpl() {
const auto &[a, b] = GetInput();
auto mult = MultiplierTbb();
GetOutput() = mult.Multiply(a, b);
return true;
}

bool GusevaCRSMatMulTbb::PostProcessingImpl() {
return true;
}

} // namespace guseva_crs
4 changes: 3 additions & 1 deletion tasks/guseva_crs/tests/functional/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "guseva_crs/common/include/test_reader.hpp"
#include "guseva_crs/omp/include/ops_omp.hpp"
#include "guseva_crs/seq/include/ops_seq.hpp"
#include "guseva_crs/tbb/include/ops_tbb.hpp"
#include "util/include/func_test_util.hpp"
#include "util/include/util.hpp"

Expand Down Expand Up @@ -55,7 +56,8 @@ const std::array<TestType, 6> kTestParam = {"sparse_dense", "dense_sparse",

const auto kTestTasksList =
std::tuple_cat(ppc::util::AddFuncTask<GusevaCRSMatMulSeq, InType>(kTestParam, PPC_SETTINGS_guseva_crs),
ppc::util::AddFuncTask<GusevaCRSMatMulOmp, InType>(kTestParam, PPC_SETTINGS_guseva_crs));
ppc::util::AddFuncTask<GusevaCRSMatMulOmp, InType>(kTestParam, PPC_SETTINGS_guseva_crs),
ppc::util::AddFuncTask<GusevaCRSMatMulTbb, InType>(kTestParam, PPC_SETTINGS_guseva_crs));

const auto kGtestValues = ppc::util::ExpandToValues(kTestTasksList);

Expand Down
4 changes: 3 additions & 1 deletion tasks/guseva_crs/tests/performance/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "guseva_crs/common/include/common.hpp"
#include "guseva_crs/omp/include/ops_omp.hpp"
#include "guseva_crs/seq/include/ops_seq.hpp"
#include "guseva_crs/tbb/include/ops_tbb.hpp"
#include "util/include/perf_test_util.hpp"

namespace guseva_crs {
Expand Down Expand Up @@ -57,7 +58,8 @@ TEST_P(GusevaMatMulCRSPerfTest, G) {
namespace {

const auto kAllPerfTasks =
ppc::util::MakeAllPerfTasks<InType, GusevaCRSMatMulSeq, GusevaCRSMatMulOmp>(PPC_SETTINGS_guseva_crs);
ppc::util::MakeAllPerfTasks<InType, GusevaCRSMatMulSeq, GusevaCRSMatMulOmp, GusevaCRSMatMulTbb>(
PPC_SETTINGS_guseva_crs);

const auto kGtestValues = ppc::util::TupleToGTestValues(kAllPerfTasks);

Expand Down
Loading