Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 6ea499d

Browse files
Revert "Use linearise(...) in operator"
This reverts commit 11edd5e.
1 parent 3f64767 commit 6ea499d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/device/matmul_kernels/operator.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,32 @@ struct WMMAOp{M, N, K} end
2727

2828
function load_a(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
2929
conf = WMMA.Config{M, N, K, Float32}
30-
linear_index = linearise(tile.index, size(workspace))
30+
ind = Tuple(tile.index) .+ 1
31+
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
3132
ptr = pointer(workspace, linear_index)
3233
return WMMA.load_a(ptr, size(workspace, 1), WMMA.ColMajor, conf)
3334
end
3435

3536
function load_b(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
3637
conf = WMMA.Config{M, N, K, Float32}
37-
linear_index = linearise(tile.index, size(workspace))
38+
ind = Tuple(tile.index) .+ 1
39+
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
3840
ptr = pointer(workspace, linear_index)
3941
return WMMA.load_b(ptr, size(workspace, 1), WMMA.ColMajor, conf)
4042
end
4143

4244
function load_c(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, tile::Tile) where {M, N, K}
4345
conf = WMMA.Config{M, N, K, Float32}
44-
linear_index = linearise(tile.index, size(workspace))
46+
ind = Tuple(tile.index) .+ 1
47+
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
4548
ptr = pointer(workspace, linear_index)
4649
return WMMA.load_c(ptr, size(workspace, 1), WMMA.ColMajor, conf)
4750
end
4851

4952
function store_d(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
5053
conf = WMMA.Config{M, N, K, Float32}
51-
linear_index = linearise(tile.index, size(workspace))
54+
ind = Tuple(tile.index) .+ 1
55+
@inbounds linear_index = LinearIndices(size(workspace))[ind...]
5256
ptr = pointer(workspace, linear_index)
5357
WMMA.store_d(ptr, frag, size(workspace, 1), WMMA.ColMajor, conf)
5458
end

0 commit comments

Comments
 (0)