Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 11edd5e

Browse files
Use linearise(...) in operator
1 parent be21dfa commit 11edd5e

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

src/device/matmul_kernels/operator.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,32 +27,28 @@ struct WMMAOp{M, N, K} end
2727

2828
function load_a(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
2929
conf = WMMA.Config{M, N, K, Float32}
30-
ind = Tuple(tile.index) .+ 1
31-
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
30+
linear_index = linearise(tile.index, size(workspace))
3231
ptr = pointer(workspace, linear_index)
3332
return WMMA.load_a(ptr, size(workspace, 1), WMMA.ColMajor, conf)
3433
end
3534

3635
function load_b(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
3736
conf = WMMA.Config{M, N, K, Float32}
38-
ind = Tuple(tile.index) .+ 1
39-
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
37+
linear_index = linearise(tile.index, size(workspace))
4038
ptr = pointer(workspace, linear_index)
4139
return WMMA.load_b(ptr, size(workspace, 1), WMMA.ColMajor, conf)
4240
end
4341

4442
function load_c(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, tile::Tile) where {M, N, K}
4543
conf = WMMA.Config{M, N, K, Float32}
46-
ind = Tuple(tile.index) .+ 1
47-
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
44+
linear_index = linearise(tile.index, size(workspace))
4845
ptr = pointer(workspace, linear_index)
4946
return WMMA.load_c(ptr, size(workspace, 1), WMMA.ColMajor, conf)
5047
end
5148

5249
function store_d(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
5350
conf = WMMA.Config{M, N, K, Float32}
54-
ind = Tuple(tile.index) .+ 1
55-
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
51+
linear_index = linearise(tile.index, size(workspace))
5652
ptr = pointer(workspace, linear_index)
5753
WMMA.store_d(ptr, frag, size(workspace, 1), WMMA.ColMajor, conf)
5854
end

0 commit comments

Comments
 (0)