@@ -14,8 +14,8 @@ pub fn matmul(
14
14
#[ spirv( storage_buffer, descriptor_set = 0 , binding = 2 ) ] b : & [ f32 ] ,
15
15
#[ spirv( storage_buffer, descriptor_set = 0 , binding = 3 ) ] result : & mut [ f32 ] ,
16
16
) {
17
- let row = ( global_id. y * TILE_M as u32 ) as usize ;
18
- let col = ( global_id. x * TILE_N as u32 ) as usize ;
17
+ let row = ( global_id. y * TILE_M ) as usize ;
18
+ let col = ( global_id. x * TILE_N ) as usize ;
19
19
20
20
// Initialize sums array to zeros
21
21
// Note: This is uglier than it needs to be to work around
@@ -33,7 +33,7 @@ pub fn matmul(
33
33
34
34
for j in 0 ..TILE_N as usize {
35
35
let b_element = if col + j < dimensions. n as usize {
36
- b[ k * dimensions. n as usize + ( col + j as usize ) ]
36
+ b[ k * dimensions. n as usize + ( col + j) ]
37
37
} else {
38
38
0.0
39
39
} ;
@@ -46,8 +46,8 @@ pub fn matmul(
46
46
// Write results
47
47
for i in 0 ..TILE_M as usize {
48
48
for j in 0 ..TILE_N as usize {
49
- let output_row = row + i as usize ;
50
- let output_col = col + j as usize ;
49
+ let output_row = row + i;
50
+ let output_col = col + j;
51
51
52
52
if output_row < dimensions. m as usize && output_col < dimensions. n as usize {
53
53
result[ output_row * dimensions. n as usize + output_col] = sums[ i] [ j] ;
0 commit comments