-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathreduce_blas.hpp
45 lines (39 loc) · 1.34 KB
/
reduce_blas.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/**
* @date 19/01/2025
* @file reduce_blas.gpp
* @brief BLAS-based reductions
* @author Ash Vardanian
*/
#pragma once
#include <cblas.h> // `cblas_sdot`
#include <limits> // `std::numeric_limits`
#include <stdexcept> // `std::length_error`
namespace ashvardanian::reduce {
/**
* @brief Using BLAS dot-product interface to accumulate a vector.
*
* BLAS interfaces have a convenient "stride" parameter that can be used to
* apply the kernel to various data layouts. Similarly, if we set the stride
* to @b zero, we can fool the kernels into thinking that a scalar is a vector.
*/
class blas_dot_t {
float const *const begin_ = nullptr;
float const *const end_ = nullptr;
#if defined(CBLAS_INDEX)
using blas_dim_t = CBLAS_INDEX;
#else
using blas_dim_t = blasint;
#endif
public:
blas_dot_t() = default;
blas_dot_t(float const *b, float const *e) : begin_(b), end_(e) {
constexpr std::size_t max_length_k = static_cast<std::size_t>(std::numeric_limits<blas_dim_t>::max());
if (end_ - begin_ > max_length_k) throw std::length_error("BLAS not configured for 64-bit sizes");
}
float operator()() const noexcept {
float repeated_ones[1];
repeated_ones[0] = 1.0f;
return cblas_sdot(end_ - begin_, begin_, 1, &repeated_ones[0], 0);
}
};
} // namespace ashvardanian::reduce