Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 1ca56c4

Browse files
Revert "Add translate variant for offset"
This reverts commit 340e791.
1 parent 340e791 commit 1ca56c4

File tree

3 files changed

+16
-37
lines changed

3 files changed

+16
-37
lines changed

src/device/matmul_kernels/kernel.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function matmul_impl(a, b, c, d,
4545

4646
@unroll for i = 1 : NUM_FRAGMENTS_M
4747
@unroll for j = 1 : NUM_FRAGMENTS_N
48-
tile = translate_offset(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
48+
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
4949
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(OPERATOR, SHARED_C_LAYOUT, shmem_c, tile), tile)
5050
end
5151
end
@@ -84,15 +84,15 @@ function matmul_impl(a, b, c, d,
8484
a_frags = MArray{Tuple{NUM_FRAGMENTS_M}, Operator.fragtype_a(OPERATOR, SHARED_A_LAYOUT)}(undef)
8585

8686
@unroll for i = 1 : NUM_FRAGMENTS_M
87-
a_tile = translate_offset(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
87+
a_tile = translate(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
8888
@inbounds a_frags[i] = transf_sh2rf_a(Operator.load_a(OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile), a_tile)
8989
end
9090

9191
# (3.3.2) Load a COMPUTE_WARP.K x COMPUTE_WARP.N tile of B from shared memory into registers
9292
b_frags = MArray{Tuple{NUM_FRAGMENTS_N}, Operator.fragtype_b(OPERATOR, SHARED_B_LAYOUT)}(undef)
9393

9494
@unroll for j = 1 : NUM_FRAGMENTS_N
95-
b_tile = translate_offset(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
95+
b_tile = translate(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
9696
@inbounds b_frags[j] = transf_sh2rf_b(Operator.load_b(OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile), b_tile)
9797
end
9898

@@ -114,7 +114,7 @@ function matmul_impl(a, b, c, d,
114114

115115
@unroll for i = 1 : NUM_FRAGMENTS_M
116116
@unroll for j = 1 : NUM_FRAGMENTS_N
117-
tile = translate_offset(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
117+
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
118118
Operator.store_d(OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
119119
end
120120
end

src/device/matmul_kernels/operator.jl

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,31 @@ struct WMMAOp{M, N, K} end
2525
@inline fragtype_b(::Type{WMMAOp{16, 16, 16}}, ::Type{Layout.AlignedColMajor{Float16}}) = WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixB}
2626
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16}}, ::Type{Layout.AlignedColMajor{Float32}}) = WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}
2727

28-
@inline function load_a(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
28+
function load_a(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
2929
conf = WMMA.Config{M, N, K, Float32}
30-
31-
linear_base = linearise(tile.base, size(workspace))
32-
linear_offset = linearise(tile.offset, size(workspace))
33-
34-
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float16)
30+
linear_index = linearise(tile.index, size(workspace))
31+
ptr = pointer(workspace, linear_index)
3532
return WMMA.load_a(ptr, size(workspace, 1), WMMA.ColMajor, conf)
3633
end
3734

38-
@inline function load_b(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
35+
function load_b(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
3936
conf = WMMA.Config{M, N, K, Float32}
40-
41-
linear_base = linearise(tile.base, size(workspace))
42-
linear_offset = linearise(tile.offset, size(workspace))
43-
44-
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float16)
37+
linear_index = linearise(tile.index, size(workspace))
38+
ptr = pointer(workspace, linear_index)
4539
return WMMA.load_b(ptr, size(workspace, 1), WMMA.ColMajor, conf)
4640
end
4741

48-
@inline function load_c(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, tile::Tile) where {M, N, K}
42+
function load_c(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, tile::Tile) where {M, N, K}
4943
conf = WMMA.Config{M, N, K, Float32}
50-
51-
linear_base = linearise(tile.base, size(workspace))
52-
linear_offset = linearise(tile.offset, size(workspace))
53-
54-
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float32)
44+
linear_index = linearise(tile.index, size(workspace))
45+
ptr = pointer(workspace, linear_index)
5546
return WMMA.load_c(ptr, size(workspace, 1), WMMA.ColMajor, conf)
5647
end
5748

58-
@inline function store_d(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
49+
function store_d(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
5950
conf = WMMA.Config{M, N, K, Float32}
60-
61-
linear_base = linearise(tile.base, size(workspace))
62-
linear_offset = linearise(tile.offset, size(workspace))
63-
64-
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float32)
51+
linear_index = linearise(tile.index, size(workspace))
52+
ptr = pointer(workspace, linear_index)
6553
WMMA.store_d(ptr, frag, size(workspace, 1), WMMA.ColMajor, conf)
6654
end
6755

src/device/tiling.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,6 @@ end
132132

133133
@inline translate(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate(tile, NamedTuple{names}(offset))
134134

135-
export translate_offset
136-
137-
@inline function translate_offset(tile::Tile{size, names, T}, offset::NamedTuple{names, T}) where {names, T, size}
138-
new_offset = map(+, tile.offset, offset)
139-
return Tile{size, names, T}(tile.base, new_offset)
140-
end
141-
142-
@inline translate_offset(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate_offset(tile, NamedTuple{names}(offset))
143-
144135
# -------------
145136
# TileIterators
146137
# -------------

0 commit comments

Comments
 (0)