@@ -25,43 +25,31 @@ struct WMMAOp{M, N, K} end
25
25
@inline fragtype_b (:: Type{WMMAOp{16, 16, 16}} , :: Type{Layout.AlignedColMajor{Float16}} ) = WMMA. Fragment{16 , 16 , 16 , 16 , Float16, WMMA. ColMajor, WMMA. MatrixB}
26
26
@inline fragtype_accum (:: Type{WMMAOp{16, 16, 16}} , :: Type{Layout.AlignedColMajor{Float32}} ) = WMMA. Fragment{16 , 16 , 16 , 8 , Float32, WMMA. Unspecified, WMMA. Accumulator}
27
27
28
- @inline function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
28
+ function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
29
29
conf = WMMA. Config{M, N, K, Float32}
30
-
31
- linear_base = linearise (tile. base, size (workspace))
32
- linear_offset = linearise (tile. offset, size (workspace))
33
-
34
- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float16)
30
+ linear_index = linearise (tile. index, size (workspace))
31
+ ptr = pointer (workspace, linear_index)
35
32
return WMMA. load_a (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
36
33
end
37
34
38
- @inline function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
35
+ function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
39
36
conf = WMMA. Config{M, N, K, Float32}
40
-
41
- linear_base = linearise (tile. base, size (workspace))
42
- linear_offset = linearise (tile. offset, size (workspace))
43
-
44
- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float16)
37
+ linear_index = linearise (tile. index, size (workspace))
38
+ ptr = pointer (workspace, linear_index)
45
39
return WMMA. load_b (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
46
40
end
47
41
48
- @inline function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
42
+ function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
49
43
conf = WMMA. Config{M, N, K, Float32}
50
-
51
- linear_base = linearise (tile. base, size (workspace))
52
- linear_offset = linearise (tile. offset, size (workspace))
53
-
54
- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float32)
44
+ linear_index = linearise (tile. index, size (workspace))
45
+ ptr = pointer (workspace, linear_index)
55
46
return WMMA. load_c (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
56
47
end
57
48
58
- @inline function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
49
+ function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
59
50
conf = WMMA. Config{M, N, K, Float32}
60
-
61
- linear_base = linearise (tile. base, size (workspace))
62
- linear_offset = linearise (tile. offset, size (workspace))
63
-
64
- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float32)
51
+ linear_index = linearise (tile. index, size (workspace))
52
+ ptr = pointer (workspace, linear_index)
65
53
WMMA. store_d (ptr, frag, size (workspace, 1 ), WMMA. ColMajor, conf)
66
54
end
67
55
0 commit comments