Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit be21dfa

Browse files
Reintroduce workspace_size
This ensures that the size of the array in global memory is known statically.
1 parent 016c86b commit be21dfa

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

src/device/matmul_kernels/epilogue.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ struct Default end
2525
# Cooperatively store a BLOCK_SHAPE.M x BLOCK_SHAPE.N tile of D from shared to global memory within one threadblock
2626
@unroll for warp_tile = parallellise(block_tile.MN, Tile(MEM_CD_WARP), warpId, WARPS_PER_BLOCK)
2727
@unroll for thread_tile = parallellise(warp_tile, Tile(MEM_CD_THREAD), laneId, 32)
28-
x = Layout.load(SHARED_D_LAYOUT, shmem_d, thread_tile)
28+
x = Layout.load(SHARED_D_LAYOUT, shmem_d, thread_tile, block_tile.MN.size)
2929
x = transform(x, thread_tile)
30-
Layout.store!(GLOBAL_D_LAYOUT, d, x, translate(thread_tile, (M = block_i, N = block_j)))
30+
Layout.store!(GLOBAL_D_LAYOUT, d, x, translate(thread_tile, (M = block_i, N = block_j)), gemm_sz.MN.size)
3131
end
3232
end
3333
end

src/device/matmul_kernels/kernel.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ function matmul_impl(a, b, c, d,
3030

3131
@unroll for warp_tile = parallellise(block_tile.MN, Tile(MEM_CD_WARP), warpId, WARPS_PER_BLOCK)
3232
@unroll for thread_tile = parallellise(warp_tile, Tile(MEM_CD_THREAD), laneId, 32)
33-
x = Layout.load(GLOBAL_C_LAYOUT, c, translate(thread_tile, (M = block_i, N = block_j)))
33+
x = Layout.load(GLOBAL_C_LAYOUT, c, translate(thread_tile, (M = block_i, N = block_j)), gemm_sz.MN.size)
3434
x = transf_gl2sh_c(x, thread_tile)
35-
Layout.store!(SHARED_C_LAYOUT, shmem_c, x, thread_tile)
35+
Layout.store!(SHARED_C_LAYOUT, shmem_c, x, thread_tile, block_tile.MN.size)
3636
end
3737
end
3838

@@ -61,18 +61,18 @@ function matmul_impl(a, b, c, d,
6161
# (3.1) Cooperatively load a BLOCK_SHAPE.M x BLOCK_SHAPE.K tile of A from global to shared memory within one threadblock
6262
@unroll for warp_tile = parallellise(block_tile.MK, Tile(MEM_A_WARP), warpId, WARPS_PER_BLOCK)
6363
@unroll for thread_tile = parallellise(warp_tile, Tile(MEM_A_THREAD), laneId, 32)
64-
x = Layout.load(GLOBAL_A_LAYOUT, a, translate(thread_tile, (M = block_i, K = block_k)))
64+
x = Layout.load(GLOBAL_A_LAYOUT, a, translate(thread_tile, (M = block_i, K = block_k)), gemm_sz.MK.size)
6565
x = transf_gl2sh_a(x, thread_tile)
66-
Layout.store!(SHARED_A_LAYOUT, shmem_a, x, thread_tile)
66+
Layout.store!(SHARED_A_LAYOUT, shmem_a, x, thread_tile, block_tile.MK.size)
6767
end
6868
end
6969

7070
# (3.2) Cooperatively load a BLOCK_SHAPE.K x BLOCK_SHAPE.N tile of B from global to shared memory within one threadblock
7171
@unroll for warp_tile = parallellise(block_tile.KN, Tile(MEM_B_WARP), warpId, WARPS_PER_BLOCK)
7272
@unroll for thread_tile = parallellise(warp_tile, Tile(MEM_B_THREAD), laneId, 32)
73-
x = Layout.load(GLOBAL_B_LAYOUT, b, translate(thread_tile, (K = block_k, N = block_j)))
73+
x = Layout.load(GLOBAL_B_LAYOUT, b, translate(thread_tile, (K = block_k, N = block_j)), gemm_sz.KN.size)
7474
x = transf_gl2sh_b(x, thread_tile)
75-
Layout.store!(SHARED_B_LAYOUT, shmem_b, x, thread_tile)
75+
Layout.store!(SHARED_B_LAYOUT, shmem_b, x, thread_tile, block_tile.KN.size)
7676
end
7777
end
7878

src/device/matmul_kernels/layout.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ end
2828

2929
@inline eltype(::Type{Padded{L, P}}) where {L, P} = eltype(L)
3030
@inline size(::Type{Padded{L, P}}, logical_size::NamedTuple) where {L, P} = size(L, pad_logical_coord(Padded{L, P}, logical_size))
31-
@inline load(::Type{Padded{L, P}}, workspace, tile::Tile, logical_size::NamedTuple) where {L, P} = load(L, workspace, tile)
32-
@inline store!(::Type{Padded{L, P}}, workspace, value, tile::Tile) where {L, P} = store!(L, workspace, value, tile::Tile)
31+
@inline load(::Type{Padded{L, P}}, workspace, tile::Tile, workspace_size::NamedTuple) where {L, P} = load(L, workspace, tile, pad_logical_coord(Padded{L, P}, workspace_size))
32+
@inline store!(::Type{Padded{L, P}}, workspace, value, tile::Tile, workspace_size::NamedTuple) where {L, P} = store!(L, workspace, value, tile::Tile, pad_logical_coord(Padded{L, P}, workspace_size))
3333

3434
# ---------------
3535
# AlignedColMajor
@@ -38,7 +38,7 @@ end
3838
struct AlignedColMajor{T} <: LayoutBase{T} end
3939

4040
# TODO: cleanup vectorisation
41-
@inline function load(::Type{AlignedColMajor{T}}, workspace, tile::Tile{size}) where {T, size}
41+
@inline function load(::Type{AlignedColMajor{T}}, workspace, tile::Tile{size}, workspace_size::NamedTuple) where {T, size}
4242
vec_len = 16 ÷ sizeof(T)
4343
N = (sizeof(T) * vec_len) ÷ sizeof(Float32)
4444
res = MArray{Tuple{size[1] ÷ vec_len, size[2]}, NTuple{N, VecElement{Float32}}}(undef)
@@ -47,8 +47,8 @@ struct AlignedColMajor{T} <: LayoutBase{T} end
4747
@unroll for i = 1 : vec_len : size[1]
4848
t = translate(tile, (i - 1, j - 1))
4949

50-
linear_base = linearise(t.base, Base.size(workspace))
51-
linear_offset = linearise(t.offset, Base.size(workspace))
50+
linear_base = linearise(t.base, workspace_size)
51+
linear_offset = linearise(t.offset, workspace_size)
5252

5353
@inbounds res[i, j] = vloada(Vec{vec_len, T}, pointer(workspace, linear_base), linear_offset)
5454
end
@@ -57,15 +57,15 @@ struct AlignedColMajor{T} <: LayoutBase{T} end
5757
return res
5858
end
5959

60-
@inline function store!(::Type{AlignedColMajor{T}}, workspace, value, tile::Tile{size}) where {T, size}
60+
@inline function store!(::Type{AlignedColMajor{T}}, workspace, value, tile::Tile{size}, workspace_size::NamedTuple) where {T, size}
6161
vec_len = 16 ÷ sizeof(T)
6262

6363
@unroll for j = 1 : size[2]
6464
@unroll for i = 1 : vec_len : size[1]
6565
t = translate(tile, (i - 1, j - 1))
6666

67-
linear_base = linearise(t.base, Base.size(workspace))
68-
linear_offset = linearise(t.offset, Base.size(workspace))
67+
linear_base = linearise(t.base, workspace_size)
68+
linear_offset = linearise(t.offset, workspace_size)
6969

7070
vstorea!(Vec{vec_len, T}, pointer(workspace, linear_base), value[i, j], linear_offset)
7171
end

0 commit comments

Comments
 (0)