diff --git a/CMakeLists.txt b/CMakeLists.txt index 50ad65a..ae7632c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ option(MAGNETRON_DEBUG "Enable debug mode" OFF) option(MAGNETRON_CPU_APPROX_MATH "Trade precision for performance" ON) # (CPU only) Enable SIMD math function approximations. Greatly increases performance. Try disabling if you encounter numerical instability. Does NOT enable -ffast-math or similar compiler flags. option(MAGNETRON_ENABLE_CUDA "Enable CUDA support" ON) # Enable CUDA support option(MAGNETRON_ENABLE_ACCELERATE "Use Apple's Accelerate framework" ON) # Use Apple's Accelerate framework for optimized math functions (only on Apple platforms) +option(MAGNETRON_ENABLE_OPENBLAS "Use OpenBLAS" ON) # Use OpenBLAS as BLAS backend for matmul option(MAGNETRON_ENABLE_MIMALLOC "Use mimalloc as memory allocator" ON) # Use mimalloc as memory allocator for faster memory allocation set(MAGNETRON_CUDA_COMPILER "/usr/local/cuda-12.6/bin/nvcc" CACHE STRING "Path to the CUDA compiler") # Set to your CUDA compiler path @@ -43,6 +44,10 @@ if (${MAGNETRON_ENABLE_CUDA}) include(cmake/cuda.cmake) endif() +if (${MAGNETRON_ENABLE_OPENBLAS}) + include(cmake/openblas.cmake) +endif () + if (${MAGNETRON_ENABLE_ACCELERATE} AND APPLE) include(cmake/accelerate.cmake) endif() diff --git a/cmake/openblas.cmake b/cmake/openblas.cmake new file mode 100644 index 0000000..f2dbd85 --- /dev/null +++ b/cmake/openblas.cmake @@ -0,0 +1,23 @@ +# (c) 2025 Mario "Neo" Sieg. + +set(OPENBLAS_INCLUDE_SEARCH_PATHS + /usr/include + /usr/include/openblas + /usr/include/openblas-base + /usr/local/include + /usr/local/include/openblas + /usr/local/include/openblas-base + /opt/OpenBLAS/include + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/include +) +find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) +find_library(OPENBLAS_LIB NAMES openblas libopenblas) +if (OPENBLAS_INC AND OPENBLAS_LIB) + message(STATUS "Found OpenBLAS: ${OPENBLAS_LIB}") + include_directories(${OPENBLAS_INC}) + target_link_libraries(magnetron ${OPENBLAS_LIB}) + target_compile_definitions(magnetron PRIVATE MAG_OPENBLAS) +else() + message(WARNING "OpenBLAS not found, using fallback") +endif() \ No newline at end of file diff --git a/magnetron/magnetron_cpu_blas.inl b/magnetron/magnetron_cpu_blas.inl index 40d8484..d9c3650 100644 --- a/magnetron/magnetron_cpu_blas.inl +++ b/magnetron/magnetron_cpu_blas.inl @@ -1386,6 +1386,10 @@ mag_cpu_blas_impl_binary(f32, sub, -) mag_cpu_blas_impl_binary(f32, mul, *) mag_cpu_blas_impl_binary(f32, div, /) +#ifdef MAG_OPENBLAS + +#include + static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { mag_tensor_t* r = payload->node; const mag_tensor_t* x = r->op_inputs[0]; @@ -1399,35 +1403,52 @@ static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload mag_load_local_storage_group(x, xs, strides); mag_load_local_storage_group(y, yd, shape); mag_load_local_storage_group(y, ys, strides); - mag_assert2(xd2 == 1 && xd3 == 1 && xd4 == 1&& xd5 == 1); - mag_assert2(yd2 == 1 && yd3 == 1 && yd4 == 1&& yd5 == 1); - int64_t tc = payload->thread_num; int64_t ti = payload->thread_idx; - int64_t numel = xd0; - int64_t chunk = (numel + tc - 1)/tc; - int64_t ra = chunk*ti; - int64_t rb = mag_xmin(ra+chunk, numel); - bool tx = mag_tensor_is_transposed(x); - for (int64_t i = ra; i < rb; ++i) { /* Rows */ - for (int64_t j = 0; j < yd1; ++j) { - float* xo = br + rd1*i + j; - mag_bnd_chk(xo, br, mag_tensor_data_size(r)); - *xo = 0.0f; - } - for (int64_t k = 0; k < xd1; ++k) { /* Inner dim */ - const mag_f32_t* px = bx + (tx ? k*xd0 + i : xd1*i + k); - mag_bnd_chk(px, bx, mag_tensor_data_size(x)); - for (int64_t j = 0; j < yd1; ++j) { /* Columns */ - mag_f32_t* pr = br + rd1*i + j; - const mag_f32_t* py = by + yd1*k + j; - mag_bnd_chk(pr, br, mag_tensor_data_size(r)); - mag_bnd_chk(py, by, mag_tensor_data_size(y)); - *pr += (*px) * (*py); + if (ti != 0) return; + mag_assert2(mag_tensor_is_contiguous(x) && mag_tensor_is_contiguous(y) && mag_tensor_is_contiguous(r)); + bool trans_a = mag_tensor_is_transposed(x); + if (x->op == MAG_OP_CLONE && x->op_inputs[0]) trans_a |= mag_tensor_is_transposed(x->op_inputs[0]); + bool trans_b = mag_tensor_is_transposed(y); + if (y->op == MAG_OP_CLONE && y->op_inputs[0]) trans_a |= mag_tensor_is_transposed(y->op_inputs[0]); + int64_t b2 = yd2/xd2; + int64_t b3 = yd3/xd3; + int64_t b4 = yd4/xd4; + int64_t b5 = yd5/xd5; + for (int64_t i5=0; i5 < xd5; ++i5) { + for (int64_t i4=0; i4 < xd4; ++i4) { + for (int64_t i3=0; i3 < xd3; ++i3) { + for (int64_t i2=0; i2 < xd2; ++i2) { + int64_t xi5 = i5/b5; + int64_t xi4 = i4/b4; + int64_t xi3 = i3/b3; + int64_t xi2 = i2/b2; + const mag_f32_t* px = bx + xi5*xs5 + xi4*xs4 + xi3*xs3 + xi2*xs2; + const mag_f32_t* py = by + i5*ys5 + i4*ys4 + i3*ys3 + i2*ys2; + mag_f32_t* pr = br + i5*rs5 + i4*rs4 + i3*rs3 + i2*rs2; + mag_bnd_chk(pr, br, mag_tensor_data_size(r)); + mag_bnd_chk(px, bx, mag_tensor_data_size(x)); + mag_bnd_chk(py, by, mag_tensor_data_size(y)); + cblas_sgemm( + CblasRowMajor, + trans_a ? CblasTrans : CblasNoTrans, + trans_b ? CblasTrans : CblasNoTrans, + rd0, + yd1, + xd1, + 1.0f, + px, xd1, + py, yd1, + 0.0f, + pr, yd1 + ); + } } } } } +#endif + #ifndef MAG_BLAS_SPECIALIZATION #error "BLAS specialization undefined" #endif diff --git a/python/examples/xor.py b/python/examples/xor.py index e5d0341..4e3cfd9 100644 --- a/python/examples/xor.py +++ b/python/examples/xor.py @@ -39,7 +39,6 @@ def forward(self, x: Tensor) -> Tensor: for epoch in range(epochs): y_hat = model(x) loss = mse_loss(y_hat, y) - loss.export_graphviz('loss.dot') loss.backward() optim.step() optim.zero_grad() diff --git a/python/magnetron_framework/setup.py b/python/magnetron_framework/setup.py index 6f0354b..358ffd9 100644 --- a/python/magnetron_framework/setup.py +++ b/python/magnetron_framework/setup.py @@ -46,7 +46,7 @@ def build_extension(self, ext): cmake_args = [ '-DMAGNETRON_ENABLE_CUDA=OFF', # TODO: Fix cuda compilation f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={os.path.abspath(os.path.join(self.build_lib, "magnetron"))}', - '-DCMAKE_BUILD_TYPE=Release', + '-DCMAKE_BUILD_TYPE=Debug', ] build_args = [ '--target magnetron', # Only build the magnetron library diff --git a/python/tests/tensor_ops3.py b/python/tests/tensor_ops3.py index 671b9c2..c02aebf 100644 --- a/python/tests/tensor_ops3.py +++ b/python/tests/tensor_ops3.py @@ -115,7 +115,7 @@ def test_matmul_x_transposed(): mag_b = Tensor.uniform(shape_b) np_a = tonumpy(mag_a) np_b = tonumpy(mag_b) - mag_result = mag_a.T @ mag_b + mag_result = mag_a.T.clone() @ mag_b np_result = np.matmul(np_a.T, np_b) assert mag_result.shape == np_result.shape assert mag_result.shape == (2, 4) diff --git a/test/unit/cuda.cpp b/test/unit/cuda.cpp index dfc54ee..1265767 100644 --- a/test/unit/cuda.cpp +++ b/test/unit/cuda.cpp @@ -9,10 +9,10 @@ TEST(cuda, simple_add) { mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_GPU_CUDA); - mag_tensor_t* a = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 4096, 4096, 16); + mag_tensor_t* a = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 512, 512, 16); mag_tensor_fill(a, 1.0f); - mag_tensor_t* b = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 4096, 4096, 16); + mag_tensor_t* b = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 512, 512, 16); mag_tensor_fill(b, 1.0f); printf("Computing...\n"); diff --git a/test/unit/tensor_ops_1_cpu.cpp b/test/unit/tensor_ops_1_cpu.cpp index da65696..6a2824a 100644 --- a/test/unit/tensor_ops_1_cpu.cpp +++ b/test/unit/tensor_ops_1_cpu.cpp @@ -159,7 +159,7 @@ impl_test_unary_op(tanh, 1e-3, tanh, [](float x) -> float { return std::tanh(x); }) impl_test_unary_op(tanh_dv, 1e-9, tanh_dv, [](float x) -> float { - return 1.0f / (std::cosh(x)*std::cosh(x)); + return 1.0f - (std::tanh(x) * std::tanh(x)); }) impl_test_unary_op(relu, 1e-9, relu, [](float x) -> float { diff --git a/test/unit/tensor_ops_mm_cpu.cpp b/test/unit/tensor_ops_mm_cpu.cpp index 04a5b7f..00eedeb 100644 --- a/test/unit/tensor_ops_mm_cpu.cpp +++ b/test/unit/tensor_ops_mm_cpu.cpp @@ -95,7 +95,7 @@ TEST(compute_cpu, mm_square_2x2_transpose_x) { mag_tensor_copy_buffer_from(A, A_data, sizeof(A_data)); mag_tensor_t* B = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, M, N); mag_tensor_copy_buffer_from(B, B_data, sizeof(B_data)); - mag_tensor_t* R = mag_matmul(mag_transpose(A), B); + mag_tensor_t* R = mag_matmul(mag_clone(mag_transpose(A)), B); ASSERT_EQ(R->rank, 2); ASSERT_EQ(R->shape[0], 2); ASSERT_EQ(R->shape[1], 2);