@@ -27,32 +27,28 @@ struct WMMAOp{M, N, K} end
27
27
28
28
function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
29
29
conf = WMMA. Config{M, N, K, Float32}
30
- ind = Tuple (tile. index) .+ 1
31
- @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
30
+ linear_index = linearise (tile. index, size (workspace))
32
31
ptr = pointer (workspace, linear_index)
33
32
return WMMA. load_a (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
34
33
end
35
34
36
35
function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
37
36
conf = WMMA. Config{M, N, K, Float32}
38
- ind = Tuple (tile. index) .+ 1
39
- @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
37
+ linear_index = linearise (tile. index, size (workspace))
40
38
ptr = pointer (workspace, linear_index)
41
39
return WMMA. load_b (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
42
40
end
43
41
44
42
function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
45
43
conf = WMMA. Config{M, N, K, Float32}
46
- ind = Tuple (tile. index) .+ 1
47
- @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
44
+ linear_index = linearise (tile. index, size (workspace))
48
45
ptr = pointer (workspace, linear_index)
49
46
return WMMA. load_c (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
50
47
end
51
48
52
49
function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
53
50
conf = WMMA. Config{M, N, K, Float32}
54
- ind = Tuple (tile. index) .+ 1
55
- @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
51
+ linear_index = linearise (tile. index, size (workspace))
56
52
ptr = pointer (workspace, linear_index)
57
53
WMMA. store_d (ptr, frag, size (workspace, 1 ), WMMA. ColMajor, conf)
58
54
end
0 commit comments