Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 8a9bf2e

Browse files
committed
Only run WMMA tests on 1.5.
1 parent 1481594 commit 8a9bf2e

File tree

3 files changed

+37
-58
lines changed

3 files changed

+37
-58
lines changed

src/device/cuda/wmma.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ using CUDAnative: AS, DevicePtr
77
# CONSTANTS
88
################################################################################
99

10-
# Determines whether or not to Core.AddrSpacePtr is available
11-
const addrspaceptr_available = (VERSION >= v"1.5.0-DEV.324")
12-
1310
# Maps PTX types to Julia array types
1411
const map_ptx_to_jl_array = Dict(
1512
"f16" => Float16,
@@ -52,7 +49,6 @@ get_frag_info(matrix, ptx_el_type) = (
5249

5350
get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space])
5451

55-
if addrspaceptr_available
5652
@generated function Base.cconvert(::Type{Core.AddrSpacePtr{T, as}}, x::DevicePtr{T, AS}) where {T, as, AS}
5753
ir = "%ptr = inttoptr i64 %0 to i8 addrspace($as)*
5854
ret i8 addrspace($as)* %ptr"
@@ -61,7 +57,6 @@ if addrspaceptr_available
6157
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
6258
end
6359
end
64-
end
6560

6661
# Fix for https://github.com/JuliaGPU/CUDAnative.jl/issues/587.
6762
# Instead of ccall'ing the intrinsics with NTuple{N, T} (which gets lowered to
@@ -133,7 +128,7 @@ for mat in ["a", "b", "c"],
133128

134129
ccall_name = "extern $llvm_intr"
135130

136-
ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}
131+
ptr_ty = Core.AddrSpacePtr{arr_ty, addr_space_int}
137132
struct_ty = Symbol("LLVMStruct$sz")
138133

139134
@eval $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride))
@@ -188,7 +183,7 @@ for mat in ["d"],
188183
frag_types = ntuple(i -> frag_ty, sz)
189184
frag_vars = ntuple(i -> :(data[$i]), sz)
190185

191-
ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}
186+
ptr_ty = Core.AddrSpacePtr{arr_ty, addr_space_int}
192187

193188
@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
194189
@eval export $func_name

test/device/wmma.jl

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# Need https://github.com/JuliaLang/julia/pull/33970
22
# and https://github.com/JuliaLang/julia/pull/34043
3-
if VERSION >= v"1.4.0-DEV.666" && capability(device()) >= v"7.0"
3+
if VERSION >= v"1.5.0-DEV.437" && capability(device()) >= v"7.0"
44

55
using CUDAnative.WMMA
66

7-
is_debug = ccall(:jl_is_debugbuild, Cint, ()) != 0
8-
(is_debug && VERSION < v"1.5.0-DEV.437") ? @warn("Skipping WMMA tests due to incompatible Julia") : @testset "WMMA" begin
7+
@testset "WMMA" begin
98

109
################################################################################
1110

@@ -231,20 +230,18 @@ is_debug = ccall(:jl_is_debugbuild, Cint, ()) != 0
231230
return
232231
end
233232

234-
@test_broken_if VERSION >= v"1.5.0-DEV.393" begin
235-
@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta)
236-
d = Array(d_dev)
233+
@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta)
234+
d = Array(d_dev)
237235

238-
new_a = (a_layout == ColMajor) ? a : transpose(a)
239-
new_b = (b_layout == ColMajor) ? b : transpose(b)
240-
new_c = (c_layout == ColMajor) ? c : transpose(c)
241-
new_d = (d_layout == ColMajor) ? d : transpose(d)
236+
new_a = (a_layout == ColMajor) ? a : transpose(a)
237+
new_b = (b_layout == ColMajor) ? b : transpose(b)
238+
new_c = (c_layout == ColMajor) ? c : transpose(c)
239+
new_d = (d_layout == ColMajor) ? d : transpose(d)
242240

243-
if do_mac
244-
all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16))))
245-
else
246-
all(isapprox.(alpha * new_a * new_b, new_d; rtol=sqrt(eps(Float16))))
247-
end
241+
if do_mac
242+
@test_broken all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16))))
243+
else
244+
@test_broken all(isapprox.(alpha * new_a * new_b, new_d; rtol=sqrt(eps(Float16))))
248245
end
249246
end
250247

@@ -254,40 +251,38 @@ is_debug = ccall(:jl_is_debugbuild, Cint, ()) != 0
254251

255252
# Need https://github.com/JuliaLang/julia/pull/34760
256253
# See https://github.com/JuliaGPU/CUDAnative.jl/issues/548
257-
if VERSION >= v"1.5.0-DEV.324"
258-
@testset "Codegen addressing" begin
259-
@testset "Global" begin
260-
function kernel(d)
261-
conf = WMMA.Config{16, 16, 16, Float32}
262-
263-
d_frag = WMMA.fill_c(Float32(0), conf)
264-
WMMA.store_d(pointer(d), d_frag, 16, WMMA.ColMajor, conf)
265-
266-
return
267-
end
254+
@testset "Codegen addressing" begin
255+
@testset "Global" begin
256+
function kernel(d)
257+
conf = WMMA.Config{16, 16, 16, Float32}
268258

269-
ptx = sprint(io -> CUDAnative.code_ptx(io, kernel, (CuDeviceArray{Float32,1,CUDAnative.AS.Global},)))
259+
d_frag = WMMA.fill_c(Float32(0), conf)
260+
WMMA.store_d(pointer(d), d_frag, 16, WMMA.ColMajor, conf)
270261

271-
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
272-
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.global.f32", ptx)
262+
return
273263
end
274264

275-
@testset "Shared" begin
276-
function kernel()
277-
shmem = @cuStaticSharedMem(Float32, (16, 16))
278-
conf = WMMA.Config{16, 16, 16, Float32}
265+
ptx = sprint(io -> CUDAnative.code_ptx(io, kernel, (CuDeviceArray{Float32,1,CUDAnative.AS.Global},)))
279266

280-
d_frag = WMMA.fill_c(Float32(0), conf)
281-
WMMA.store_d(pointer(shmem), d_frag, 16, WMMA.ColMajor, conf)
267+
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
268+
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.global.f32", ptx)
269+
end
282270

283-
return
284-
end
271+
@testset "Shared" begin
272+
function kernel()
273+
shmem = @cuStaticSharedMem(Float32, (16, 16))
274+
conf = WMMA.Config{16, 16, 16, Float32}
285275

286-
ptx = sprint(io -> CUDAnative.code_ptx(io, kernel, ()))
276+
d_frag = WMMA.fill_c(Float32(0), conf)
277+
WMMA.store_d(pointer(shmem), d_frag, 16, WMMA.ColMajor, conf)
287278

288-
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
289-
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32", ptx)
279+
return
290280
end
281+
282+
ptx = sprint(io -> CUDAnative.code_ptx(io, kernel, ()))
283+
284+
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
285+
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32", ptx)
291286
end
292287
end
293288

test/util.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,3 @@ function julia_script(code, args=``)
9191
wait(proc)
9292
proc.exitcode, read(out, String), read(err, String)
9393
end
94-
95-
# tests that are conditionall broken
96-
macro test_broken_if(cond, ex...)
97-
quote
98-
if $(esc(cond))
99-
@test_broken $(map(esc, ex)...)
100-
else
101-
@test $(map(esc, ex)...)
102-
end
103-
end
104-
end

0 commit comments

Comments
 (0)