Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 3f64767

Browse files
Split translate function
1 parent 11edd5e commit 3f64767

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

src/device/matmul_kernels/kernel.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function matmul_impl(a, b, c, d,
4545

4646
@unroll for i = 1 : NUM_FRAGMENTS_M
4747
@unroll for j = 1 : NUM_FRAGMENTS_N
48-
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
48+
tile = translate_const(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
4949
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(OPERATOR, SHARED_C_LAYOUT, shmem_c, tile), tile)
5050
end
5151
end
@@ -84,15 +84,15 @@ function matmul_impl(a, b, c, d,
8484
a_frags = MArray{Tuple{NUM_FRAGMENTS_M}, Operator.fragtype_a(OPERATOR, SHARED_A_LAYOUT)}(undef)
8585

8686
@unroll for i = 1 : NUM_FRAGMENTS_M
87-
a_tile = translate(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
87+
a_tile = translate_const(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
8888
@inbounds a_frags[i] = transf_sh2rf_a(Operator.load_a(OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile), a_tile)
8989
end
9090

9191
# (3.3.2) Load a COMPUTE_WARP.K x COMPUTE_WARP.N tile of B from shared memory into registers
9292
b_frags = MArray{Tuple{NUM_FRAGMENTS_N}, Operator.fragtype_b(OPERATOR, SHARED_B_LAYOUT)}(undef)
9393

9494
@unroll for j = 1 : NUM_FRAGMENTS_N
95-
b_tile = translate(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
95+
b_tile = translate_const(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
9696
@inbounds b_frags[j] = transf_sh2rf_b(Operator.load_b(OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile), b_tile)
9797
end
9898

@@ -114,7 +114,7 @@ function matmul_impl(a, b, c, d,
114114

115115
@unroll for i = 1 : NUM_FRAGMENTS_M
116116
@unroll for j = 1 : NUM_FRAGMENTS_N
117-
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
117+
tile = translate_const(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
118118
Operator.store_d(OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
119119
end
120120
end

src/device/tiling.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ export translate
119119
"""
120120
translate(tile::Tile{names, T}, offset::NamedTuple{names, T})
121121
122-
Translate (i.e. move) a [`Tile`](@ref) by a constant `offset`.
122+
Translate (i.e. move) a [`Tile`](@ref) by an `offset`.
123123
124124
# Arguments
125125
- `tile`: The [`Tile`](@ref) to translate.
@@ -132,6 +132,24 @@ end
132132

133133
@inline translate(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate(tile, NamedTuple{names}(offset))
134134

135+
export translate_const
136+
137+
"""
138+
translate_const(tile::Tile{names, T}, offset::NamedTuple{names, T})
139+
140+
Translate (i.e. move) a [`Tile`](@ref) by a constant `offset`.
141+
142+
# Arguments
143+
- `tile`: The [`Tile`](@ref) to translate.
144+
- `offset`: The `offset` in each dimension.
145+
"""
146+
@inline function translate_const(tile::Tile{size, names, T}, offset::NamedTuple{names, T}) where {names, T, size}
147+
offset = map(+, tile.offset, offset)
148+
return Tile{size, names, T}(tile.base, offset)
149+
end
150+
151+
@inline translate_const(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate_const(tile, NamedTuple{names}(offset))
152+
135153
# -------------
136154
# TileIterators
137155
# -------------

0 commit comments

Comments
 (0)