Skip to content

Commit 9f9997c

Browse files
committed
Fix 1D tiling
The original code from the blog looks wrong. The code in the repo has these checks and they make tests pass.
1 parent 7dbd06a commit 9f9997c

File tree

3 files changed

+56
-9
lines changed
  • blog/2024-11-21-optimizing-matrix-mul/code

3 files changed

+56
-9
lines changed

blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn main() {
3030
run_tests(matmul::naive::wgpu(), &sizes);
3131
run_tests(matmul::workgroup_256::wgpu(), &sizes);
3232
run_tests(matmul::workgroup_2d::wgpu(), &sizes);
33-
//run_tests(matmul::tiling_1d::wgpu(), &sizes);
33+
run_tests(matmul::tiling_1d::wgpu(), &sizes);
3434
run_tests(matmul::tiling_2d_simd::wgpu(), &sizes);
3535

3636
run_tests(matmul::isomorphic::wgpu(), &sizes);

blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/cpu.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,37 @@ mod tests {
171171
assert_eq!(result, expected);
172172
}
173173

174+
#[test]
175+
fn test_single_threaded_matmul_4x4() {
176+
let m = 4;
177+
let k = 4;
178+
let n = 4;
179+
180+
// Define matrix `a` (4x4) in row-major order
181+
let a = vec![
182+
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
183+
];
184+
185+
// Define matrix `b` (4x4) in row-major order
186+
let b = vec![
187+
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
188+
31.0, 32.0,
189+
];
190+
191+
// Expected result (4x4) after multiplying `a` and `b`
192+
let expected = vec![
193+
250.0, 260.0, 270.0, 280.0, 618.0, 644.0, 670.0, 696.0, 986.0, 1028.0, 1070.0, 1112.0,
194+
1354.0, 1412.0, 1470.0, 1528.0,
195+
];
196+
197+
let variant = crate::variants::Isomorphic;
198+
let matrix_multiplier = futures::executor::block_on(SingleThreadedMatMul::new(variant));
199+
200+
let result = matrix_multiplier.multiply(&a, &b, m, k, n);
201+
202+
assert_eq!(result, expected);
203+
}
204+
174205
#[test]
175206
fn test_multithreaded_matmul_2x1x1() {
176207
let m = 2;

blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_1d/src/lib.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,30 @@ pub fn matmul(
2727

2828
for i in 0..dimensions.k as usize {
2929
let a_elem = a[row * dimensions.k as usize + i];
30-
sum00 += a_elem * b[i * dimensions.n as usize + col];
31-
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
32-
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
33-
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
30+
if col < dimensions.n as usize {
31+
sum00 += a_elem * b[i * dimensions.n as usize + col];
32+
}
33+
if col + 1 < dimensions.n as usize {
34+
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
35+
}
36+
if col + 2 < dimensions.n as usize {
37+
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
38+
}
39+
if col + 3 < dimensions.n as usize {
40+
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
41+
}
3442
}
3543

36-
result[row * dimensions.n as usize + col] = sum00;
37-
result[row * dimensions.n as usize + col + 1] = sum01;
38-
result[row * dimensions.n as usize + col + 2] = sum02;
39-
result[row * dimensions.n as usize + col + 3] = sum03;
44+
if col < dimensions.n as usize {
45+
result[row * dimensions.n as usize + col] = sum00;
46+
}
47+
if col + 1 < dimensions.n as usize {
48+
result[row * dimensions.n as usize + col + 1] = sum01;
49+
}
50+
if col + 2 < dimensions.n as usize {
51+
result[row * dimensions.n as usize + col + 2] = sum02;
52+
}
53+
if col + 3 < dimensions.n as usize {
54+
result[row * dimensions.n as usize + col + 3] = sum03;
55+
}
4056
}

0 commit comments

Comments
 (0)