diff --git a/magnetron/magnetron_cpu_blas.inl b/magnetron/magnetron_cpu_blas.inl index 6d9427b..40d8484 100644 --- a/magnetron/magnetron_cpu_blas.inl +++ b/magnetron/magnetron_cpu_blas.inl @@ -1408,42 +1408,24 @@ static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload int64_t ra = chunk*ti; int64_t rb = mag_xmin(ra+chunk, numel); bool tx = mag_tensor_is_transposed(x); - - // For each row index (i) in the result. - for (int64_t i = ra; i < rb; ++i) { - // Initialize row i in R to 0. + for (int64_t i = ra; i < rb; ++i) { /* Rows */ for (int64_t j = 0; j < yd1; ++j) { - /* Using the computed strides: - * rs[0] == 1 and rs[1] == M. - * Thus, element (i,j) is at offset: i + M*j. - */ - float* pr = br + rs0 * i + rs1 * j; - mag_bnd_chk(pr, br, mag_tensor_data_size(r)); - *pr = 0.0f; + float* xo = br + rd1*i + j; + mag_bnd_chk(xo, br, mag_tensor_data_size(r)); + *xo = 0.0f; } - // Multiply: R(i,j) += X(i,k) * Y(k,j) - for (int64_t k = 0; k < xd1; ++k) { - const mag_f32_t* px; - if (tx) { - // If X is transposed, treat it as X^T, - // so element (i,k) is stored at (k,i). - px = bx + xs0 * k + xs1 * i; - } else { - // Otherwise, access X(i,k) as: i + M*k. - px = bx + xs0 * i + xs1 * k; - } + for (int64_t k = 0; k < xd1; ++k) { /* Inner dim */ + const mag_f32_t* px = bx + (tx ? k*xd0 + i : xd1*i + k); mag_bnd_chk(px, bx, mag_tensor_data_size(x)); - for (int64_t j = 0; j < yd1; ++j) { - float* pr = br + rs0 * i + rs1 * j; // R(i,j) = i + M*j. - // Y is [K, N] stored with strides: ys0 == 1 and ys1 == K. - const mag_f32_t* py = by + ys0 * k + ys1 * j; // Y(k,j) = k + K*j. + for (int64_t j = 0; j < yd1; ++j) { /* Columns */ + mag_f32_t* pr = br + rd1*i + j; + const mag_f32_t* py = by + yd1*k + j; mag_bnd_chk(pr, br, mag_tensor_data_size(r)); mag_bnd_chk(py, by, mag_tensor_data_size(y)); *pr += (*px) * (*py); } } } - } #ifndef MAG_BLAS_SPECIALIZATION