Skip to content

Commit

Permalink
CBLAS matmul backend
Browse files Browse the repository at this point in the history
  • Loading branch information
MarioSieg committed Feb 14, 2025
1 parent 3872f2e commit 735bc9e
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 30 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ option(MAGNETRON_DEBUG "Enable debug mode" OFF)
option(MAGNETRON_CPU_APPROX_MATH "Trade precision for performance" ON) # (CPU only) Enable SIMD math function approximations. Greatly increases performance. Try disabling if you encounter numerical instability. Does NOT enable -ffast-math or similar compiler flags.
option(MAGNETRON_ENABLE_CUDA "Enable CUDA support" ON) # Enable CUDA support
option(MAGNETRON_ENABLE_ACCELERATE "Use Apple's Accelerate framework" ON) # Use Apple's Accelerate framework for optimized math functions (only on Apple platforms)
option(MAGNETRON_ENABLE_OPENBLAS "Use OpenBLAS" ON) # Use OpenBLAS as BLAS backend for matmul
option(MAGNETRON_ENABLE_MIMALLOC "Use mimalloc as memory allocator" ON) # Use mimalloc as memory allocator for faster memory allocation

set(MAGNETRON_CUDA_COMPILER "/usr/local/cuda-12.6/bin/nvcc" CACHE STRING "Path to the CUDA compiler") # Set to your CUDA compiler path
Expand All @@ -43,6 +44,10 @@ if (${MAGNETRON_ENABLE_CUDA})
include(cmake/cuda.cmake)
endif()

if (${MAGNETRON_ENABLE_OPENBLAS})
include(cmake/openblas.cmake)
endif ()

if (${MAGNETRON_ENABLE_ACCELERATE} AND APPLE)
include(cmake/accelerate.cmake)
endif()
Expand Down
23 changes: 23 additions & 0 deletions cmake/openblas.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# (c) 2025 Mario "Neo" Sieg. <[email protected]>

set(OPENBLAS_INCLUDE_SEARCH_PATHS
/usr/include
/usr/include/openblas
/usr/include/openblas-base
/usr/local/include
/usr/local/include/openblas
/usr/local/include/openblas-base
/opt/OpenBLAS/include
$ENV{OpenBLAS_HOME}
$ENV{OpenBLAS_HOME}/include
)
find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
find_library(OPENBLAS_LIB NAMES openblas libopenblas)
if (OPENBLAS_INC AND OPENBLAS_LIB)
message(STATUS "Found OpenBLAS: ${OPENBLAS_LIB}")
include_directories(${OPENBLAS_INC})
target_link_libraries(magnetron ${OPENBLAS_LIB})
target_compile_definitions(magnetron PRIVATE MAG_OPENBLAS)
else()
message(WARNING "OpenBLAS not found, using fallback")
endif()
67 changes: 44 additions & 23 deletions magnetron/magnetron_cpu_blas.inl
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,10 @@ mag_cpu_blas_impl_binary(f32, sub, -)
mag_cpu_blas_impl_binary(f32, mul, *)
mag_cpu_blas_impl_binary(f32, div, /)

#ifdef MAG_OPENBLAS

#include <cblas.h>

static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) {
mag_tensor_t* r = payload->node;
const mag_tensor_t* x = r->op_inputs[0];
Expand All @@ -1399,35 +1403,52 @@ static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload
mag_load_local_storage_group(x, xs, strides);
mag_load_local_storage_group(y, yd, shape);
mag_load_local_storage_group(y, ys, strides);
mag_assert2(xd2 == 1 && xd3 == 1 && xd4 == 1&& xd5 == 1);
mag_assert2(yd2 == 1 && yd3 == 1 && yd4 == 1&& yd5 == 1);
int64_t tc = payload->thread_num;
int64_t ti = payload->thread_idx;
int64_t numel = xd0;
int64_t chunk = (numel + tc - 1)/tc;
int64_t ra = chunk*ti;
int64_t rb = mag_xmin(ra+chunk, numel);
bool tx = mag_tensor_is_transposed(x);
for (int64_t i = ra; i < rb; ++i) { /* Rows */
for (int64_t j = 0; j < yd1; ++j) {
float* xo = br + rd1*i + j;
mag_bnd_chk(xo, br, mag_tensor_data_size(r));
*xo = 0.0f;
}
for (int64_t k = 0; k < xd1; ++k) { /* Inner dim */
const mag_f32_t* px = bx + (tx ? k*xd0 + i : xd1*i + k);
mag_bnd_chk(px, bx, mag_tensor_data_size(x));
for (int64_t j = 0; j < yd1; ++j) { /* Columns */
mag_f32_t* pr = br + rd1*i + j;
const mag_f32_t* py = by + yd1*k + j;
mag_bnd_chk(pr, br, mag_tensor_data_size(r));
mag_bnd_chk(py, by, mag_tensor_data_size(y));
*pr += (*px) * (*py);
if (ti != 0) return;
mag_assert2(mag_tensor_is_contiguous(x) && mag_tensor_is_contiguous(y) && mag_tensor_is_contiguous(r));
bool trans_a = mag_tensor_is_transposed(x);
if (x->op == MAG_OP_CLONE && x->op_inputs[0]) trans_a |= mag_tensor_is_transposed(x->op_inputs[0]);
bool trans_b = mag_tensor_is_transposed(y);
if (y->op == MAG_OP_CLONE && y->op_inputs[0]) trans_a |= mag_tensor_is_transposed(y->op_inputs[0]);
int64_t b2 = yd2/xd2;
int64_t b3 = yd3/xd3;
int64_t b4 = yd4/xd4;
int64_t b5 = yd5/xd5;
for (int64_t i5=0; i5 < xd5; ++i5) {
for (int64_t i4=0; i4 < xd4; ++i4) {
for (int64_t i3=0; i3 < xd3; ++i3) {
for (int64_t i2=0; i2 < xd2; ++i2) {
int64_t xi5 = i5/b5;
int64_t xi4 = i4/b4;
int64_t xi3 = i3/b3;
int64_t xi2 = i2/b2;
const mag_f32_t* px = bx + xi5*xs5 + xi4*xs4 + xi3*xs3 + xi2*xs2;
const mag_f32_t* py = by + i5*ys5 + i4*ys4 + i3*ys3 + i2*ys2;
mag_f32_t* pr = br + i5*rs5 + i4*rs4 + i3*rs3 + i2*rs2;
mag_bnd_chk(pr, br, mag_tensor_data_size(r));
mag_bnd_chk(px, bx, mag_tensor_data_size(x));
mag_bnd_chk(py, by, mag_tensor_data_size(y));
cblas_sgemm(
CblasRowMajor,
trans_a ? CblasTrans : CblasNoTrans,
trans_b ? CblasTrans : CblasNoTrans,
rd0,
yd1,
xd1,
1.0f,
px, xd1,
py, yd1,
0.0f,
pr, yd1
);
}
}
}
}
}

#endif

#ifndef MAG_BLAS_SPECIALIZATION
#error "BLAS specialization undefined"
#endif
Expand Down
1 change: 0 additions & 1 deletion python/examples/xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def forward(self, x: Tensor) -> Tensor:
for epoch in range(epochs):
y_hat = model(x)
loss = mse_loss(y_hat, y)
loss.export_graphviz('loss.dot')
loss.backward()
optim.step()
optim.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion python/magnetron_framework/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def build_extension(self, ext):
cmake_args = [
'-DMAGNETRON_ENABLE_CUDA=OFF', # TODO: Fix cuda compilation
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={os.path.abspath(os.path.join(self.build_lib, "magnetron"))}',
'-DCMAKE_BUILD_TYPE=Release',
'-DCMAKE_BUILD_TYPE=Debug',
]
build_args = [
'--target magnetron', # Only build the magnetron library
Expand Down
2 changes: 1 addition & 1 deletion python/tests/tensor_ops3.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_matmul_x_transposed():
mag_b = Tensor.uniform(shape_b)
np_a = tonumpy(mag_a)
np_b = tonumpy(mag_b)
mag_result = mag_a.T @ mag_b
mag_result = mag_a.T.clone() @ mag_b
np_result = np.matmul(np_a.T, np_b)
assert mag_result.shape == np_result.shape
assert mag_result.shape == (2, 4)
Expand Down
4 changes: 2 additions & 2 deletions test/unit/cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ TEST(cuda, simple_add) {

mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_GPU_CUDA);

mag_tensor_t* a = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 4096, 4096, 16);
mag_tensor_t* a = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 512, 512, 16);
mag_tensor_fill(a, 1.0f);

mag_tensor_t* b = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 4096, 4096, 16);
mag_tensor_t* b = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 512, 512, 16);
mag_tensor_fill(b, 1.0f);

printf("Computing...\n");
Expand Down
2 changes: 1 addition & 1 deletion test/unit/tensor_ops_1_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl_test_unary_op(tanh, 1e-3, tanh, [](float x) -> float {
return std::tanh(x);
})
impl_test_unary_op(tanh_dv, 1e-9, tanh_dv, [](float x) -> float {
return 1.0f / (std::cosh(x)*std::cosh(x));
return 1.0f - (std::tanh(x) * std::tanh(x));
})

impl_test_unary_op(relu, 1e-9, relu, [](float x) -> float {
Expand Down
2 changes: 1 addition & 1 deletion test/unit/tensor_ops_mm_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ TEST(compute_cpu, mm_square_2x2_transpose_x) {
mag_tensor_copy_buffer_from(A, A_data, sizeof(A_data));
mag_tensor_t* B = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, M, N);
mag_tensor_copy_buffer_from(B, B_data, sizeof(B_data));
mag_tensor_t* R = mag_matmul(mag_transpose(A), B);
mag_tensor_t* R = mag_matmul(mag_clone(mag_transpose(A)), B);
ASSERT_EQ(R->rank, 2);
ASSERT_EQ(R->shape[0], 2);
ASSERT_EQ(R->shape[1], 2);
Expand Down

0 comments on commit 735bc9e

Please sign in to comment.