@@ -27,28 +27,32 @@ struct WMMAOp{M, N, K} end
27
27
28
28
function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
29
29
conf = WMMA. Config{M, N, K, Float32}
30
- linear_index = linearise (tile. index, size (workspace))
30
+ ind = Tuple (tile. index) .+ 1
31
+ @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
31
32
ptr = pointer (workspace, linear_index)
32
33
return WMMA. load_a (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
33
34
end
34
35
35
36
function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
36
37
conf = WMMA. Config{M, N, K, Float32}
37
- linear_index = linearise (tile. index, size (workspace))
38
+ ind = Tuple (tile. index) .+ 1
39
+ @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
38
40
ptr = pointer (workspace, linear_index)
39
41
return WMMA. load_b (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
40
42
end
41
43
42
44
function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
43
45
conf = WMMA. Config{M, N, K, Float32}
44
- linear_index = linearise (tile. index, size (workspace))
46
+ ind = Tuple (tile. index) .+ 1
47
+ @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
45
48
ptr = pointer (workspace, linear_index)
46
49
return WMMA. load_c (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
47
50
end
48
51
49
52
function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
50
53
conf = WMMA. Config{M, N, K, Float32}
51
- linear_index = linearise (tile. index, size (workspace))
54
+ ind = Tuple (tile. index) .+ 1
55
+ @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
52
56
ptr = pointer (workspace, linear_index)
53
57
WMMA. store_d (ptr, frag, size (workspace, 1 ), WMMA. ColMajor, conf)
54
58
end
0 commit comments