Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Avoid address space casts. #642

Merged
merged 2 commits into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/device/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ include("cuda/assertion.jl")
include("cuda/memory_dynamic.jl")
include("cuda/atomics.jl")
include("cuda/misc.jl")
if VERSION >= v"1.4.1"
include("cuda/wmma.jl")
end

# functionality from libdevice
#
Expand Down
6 changes: 2 additions & 4 deletions src/device/cuda/memory_shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ end
end

T_ptr = convert(LLVMType, DevicePtr{T,AS.Shared})
T_actual_ptr = LLVM.PointerType(eltyp)

# create a function
llvm_f, _ = create_function(T_ptr)
Expand Down Expand Up @@ -92,10 +91,9 @@ end
entry = BasicBlock(llvm_f, "entry", JuliaContext())
position!(builder, entry)

ptr_with_as = gep!(builder, gv, [ConstantInt(0, JuliaContext()),
ConstantInt(0, JuliaContext())])
ptr = gep!(builder, gv, [ConstantInt(0, JuliaContext()),
ConstantInt(0, JuliaContext())])

ptr = addrspacecast!(builder, ptr_with_as, T_actual_ptr)
val = ptrtoint!(builder, ptr, T_ptr)
ret!(builder, val)
end
Expand Down
25 changes: 6 additions & 19 deletions src/device/cuda/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ using CUDAnative: AS, DevicePtr
# CONSTANTS
################################################################################

# Determines whether or not to Core.AddrSpacePtr is available
const addrspaceptr_available = (VERSION >= v"1.5.0-DEV.324")

# Maps PTX types to Julia array types
const map_ptx_to_jl_array = Dict(
"f16" => Float16,
Expand Down Expand Up @@ -52,24 +49,14 @@ get_frag_info(matrix, ptx_el_type) = (

get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space])

if addrspaceptr_available
@generated function Base.cconvert(::Type{Core.AddrSpacePtr{T, as}}, x::DevicePtr{T, AS}) where {T, as, AS}
# Addrspacecast from i8* to i8* is invalid in LLVM
if as == 0
return quote
return Base.bitcast(Core.AddrSpacePtr{T, as}, x)
end
else
ir = "%p = inttoptr i64 %0 to i8*
%ptr = addrspacecast i8* %p to i8 addrspace($as)*
ret i8 addrspace($as)* %ptr"
ir = "%ptr = inttoptr i64 %0 to i8 addrspace($as)*
ret i8 addrspace($as)* %ptr"

return quote
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
end
return quote
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
end
end
end

# Fix for https://github.com/JuliaGPU/CUDAnative.jl/issues/587.
# Instead of ccall'ing the intrinsics with NTuple{N, T} (which gets lowered to
Expand Down Expand Up @@ -141,7 +128,7 @@ for mat in ["a", "b", "c"],

ccall_name = "extern $llvm_intr"

ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}
ptr_ty = Core.AddrSpacePtr{arr_ty, addr_space_int}
struct_ty = Symbol("LLVMStruct$sz")

@eval $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride))
Expand Down Expand Up @@ -196,7 +183,7 @@ for mat in ["d"],
frag_types = ntuple(i -> frag_ty, sz)
frag_vars = ntuple(i -> :(data[$i]), sz)

ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}
ptr_ty = Core.AddrSpacePtr{arr_ty, addr_space_int}

@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval export $func_name
Expand Down
21 changes: 7 additions & 14 deletions src/device/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Base.:(+)(x::Integer, y::DevicePtr) = y + x
T_int = convert(LLVMType, Int)
T_ptr = convert(LLVMType, DevicePtr{T,A})

T_actual_ptr = LLVM.PointerType(eltyp)
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, A))

# create a function
param_types = [T_ptr, T_int]
Expand All @@ -130,10 +130,8 @@ Base.:(+)(x::Integer, y::DevicePtr) = y + x
position!(builder, entry)

ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)

ptr = gep!(builder, ptr, [parameters(llvm_f)[2]])
ptr_with_as = addrspacecast!(builder, ptr, LLVM.PointerType(eltyp, convert(Int, A)))
ld = load!(builder, ptr_with_as)
ld = load!(builder, ptr)

if A != AS.Generic
metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(A)
Expand All @@ -153,7 +151,7 @@ end
T_int = convert(LLVMType, Int)
T_ptr = convert(LLVMType, DevicePtr{T,A})

T_actual_ptr = LLVM.PointerType(eltyp)
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, A))

# create a function
param_types = [T_ptr, eltyp, T_int]
Expand All @@ -165,11 +163,9 @@ end
position!(builder, entry)

ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)

ptr = gep!(builder, ptr, [parameters(llvm_f)[3]])
ptr_with_as = addrspacecast!(builder, ptr, LLVM.PointerType(eltyp, convert(Int, A)))
val = parameters(llvm_f)[2]
st = store!(builder, val, ptr_with_as)
st = store!(builder, val, ptr)

if A != AS.Generic
metadata(st)[LLVM.MD_tbaa] = tbaa_addrspace(A)
Expand Down Expand Up @@ -201,8 +197,7 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
T_int32 = LLVM.Int32Type(JuliaContext())
T_ptr = convert(LLVMType, DevicePtr{T,AS.Global})

T_actual_ptr = LLVM.PointerType(eltyp)
T_actual_ptr_as = LLVM.PointerType(eltyp, convert(Int, AS.Global))
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, AS.Global))

# create a function
param_types = [T_ptr, T_int]
Expand All @@ -222,7 +217,7 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
"llvm.nvvm.ldg.global.$class.$typ.p1$typ"
end
mod = LLVM.parent(llvm_f)
intrinsic_typ = LLVM.FunctionType(eltyp, [T_actual_ptr_as, T_int32])
intrinsic_typ = LLVM.FunctionType(eltyp, [T_actual_ptr, T_int32])
intrinsic = LLVM.Function(mod, intrinsic_name, intrinsic_typ)

# generate IR
Expand All @@ -231,11 +226,9 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
position!(builder, entry)

ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)

ptr = gep!(builder, ptr, [parameters(llvm_f)[2]])
ptr_with_as = addrspacecast!(builder, ptr, T_actual_ptr_as)
ld = call!(builder, intrinsic,
[ptr_with_as, ConstantInt(Int32(align), JuliaContext())])
[ptr, ConstantInt(Int32(align), JuliaContext())])

metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(AS.Global)

Expand Down
11 changes: 11 additions & 0 deletions test/device/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,17 @@ end
end
end

@testset "shared memory" begin
function kernel()
shared = @cuStaticSharedMem(Float32, 1)
@atomic shared[threadIdx().x] += 0f0
return
end

@cuda kernel()
synchronize()
end

end

end
Expand Down
2 changes: 0 additions & 2 deletions test/device/wmma.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Need https://github.com/JuliaLang/julia/pull/33970
# and https://github.com/JuliaLang/julia/pull/34043
if VERSION >= v"1.4.0-DEV.666" && capability(device()) >= v"7.0"

using CUDAnative.WMMA

Expand Down Expand Up @@ -294,4 +293,3 @@ end
################################################################################

end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ include("device/execution.jl")
include("device/pointer.jl")
include("device/array.jl")
include("device/cuda.jl")
if VERSION >= v"1.4.1" && capability(device()) >= v"7.0"
include("device/wmma.jl")
end

include("nvtx.jl")

Expand Down