Skip to content

Commit 8205148

Browse files
committed
[bench] add dense (add|sub)_diag to blas
1 parent 3be7b60 commit 8205148

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

benchmark/blas/blas.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ std::map<std::string, std::function<std::unique_ptr<BenchmarkOperation>(
7777
exec, Generator{}, dims.n, dims.k, dims.m, dims.stride_A,
7878
dims.stride_B, dims.stride_C);
7979
}},
80+
{"add_diag",
81+
[](std::shared_ptr<const gko::Executor> exec, dimensions dims) {
82+
return std::make_unique<AddDiagOperation<Generator>>(
83+
exec, Generator{}, dims.n, dims.stride_A);
84+
}},
85+
{"sub_diag",
86+
[](std::shared_ptr<const gko::Executor> exec, dimensions dims) {
87+
return std::make_unique<SubDiagOperation<Generator>>(
88+
exec, Generator{}, dims.n, dims.stride_A);
89+
}},
8090
{"prefix_sum32",
8191
[](std::shared_ptr<const gko::Executor> exec, dimensions dims) {
8292
return std::make_unique<PrefixSumOperation<gko::int32>>(exec,

benchmark/blas/blas_common.hpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ DEFINE_string(
3636
" norm (a = sqrt(x' * x)),\n"
3737
" mm (C = A * B),\n"
3838
" gemm (C = a * A * B + b * C)\n"
39+
" add_diag (A = A + a * D)\n"
40+
" sub_diag (A = A - a * D)\n"
3941
"Non-numerical algorithms:\n"
4042
" prefix_sum32 (x_i <- sum_{j=0}^{i-1} x_i, 32 bit indices)\n"
4143
" prefix_sum64 ( 64 bit indices)\n"
@@ -373,6 +375,78 @@ class AdvancedApplyOperation : public BenchmarkOperation {
373375
};
374376

375377

378+
template <typename Generator>
379+
class AddDiagOperation : public BenchmarkOperation {
380+
public:
381+
AddDiagOperation(std::shared_ptr<const gko::Executor> exec,
382+
const Generator& generator, gko::size_type n,
383+
gko::size_type stride)
384+
{
385+
// Since dense distributed matrices are not supported we can use
386+
// local_size == global_size
387+
y_ = generator.create_multi_vector_strided(exec, gko::dim<2>{n, n},
388+
gko::dim<2>{n, n}, stride);
389+
D_ = gko::matrix::Diagonal<etype>::create(exec, n);
390+
alpha_ = gko::matrix::Dense<etype>::create(exec, gko::dim<2>{1, 1});
391+
392+
as_vector<Generator>(y_)->fill(1);
393+
D_->read(gko::matrix_data<etype, itype>::diag(gko::dim<2>{n, n},
394+
etype{2.2}));
395+
alpha_->fill(1);
396+
}
397+
398+
gko::size_type get_flops() const override { return y_->get_size()[0] * 2; }
399+
400+
gko::size_type get_memory() const override
401+
{
402+
return y_->get_size()[0] * 3 * sizeof(etype);
403+
}
404+
405+
void run() override { as_vector<Generator>(y_)->add_scaled(alpha_, D_); }
406+
407+
private:
408+
std::unique_ptr<gko::matrix::Dense<etype>> alpha_;
409+
std::unique_ptr<gko::LinOp> y_;
410+
std::unique_ptr<gko::matrix::Diagonal<etype>> D_;
411+
};
412+
413+
414+
template <typename Generator>
415+
class SubDiagOperation : public BenchmarkOperation {
416+
public:
417+
SubDiagOperation(std::shared_ptr<const gko::Executor> exec,
418+
const Generator& generator, gko::size_type n,
419+
gko::size_type stride)
420+
{
421+
// Since dense distributed matrices are not supported we can use
422+
// local_size == global_size
423+
y_ = generator.create_multi_vector_strided(exec, gko::dim<2>{n, n},
424+
gko::dim<2>{n, n}, stride);
425+
D_ = gko::matrix::Diagonal<etype>::create(exec, n);
426+
alpha_ = gko::matrix::Dense<etype>::create(exec, gko::dim<2>{1, 1});
427+
428+
as_vector<Generator>(y_)->fill(1);
429+
D_->read(gko::matrix_data<etype, itype>::diag(gko::dim<2>{n, n},
430+
etype{2.2}));
431+
alpha_->fill(1);
432+
}
433+
434+
gko::size_type get_flops() const override { return y_->get_size()[0] * 2; }
435+
436+
gko::size_type get_memory() const override
437+
{
438+
return y_->get_size()[0] * 3 * sizeof(etype);
439+
}
440+
441+
void run() override { as_vector<Generator>(y_)->sub_scaled(alpha_, D_); }
442+
443+
private:
444+
std::unique_ptr<gko::matrix::Dense<etype>> alpha_;
445+
std::unique_ptr<gko::LinOp> y_;
446+
std::unique_ptr<gko::matrix::Diagonal<etype>> D_;
447+
};
448+
449+
376450
GKO_REGISTER_OPERATION(prefix_sum_nonnegative,
377451
components::prefix_sum_nonnegative);
378452

0 commit comments

Comments
 (0)