Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit b480342

Browse files
authored
Merge pull request #642 from JuliaGPU/tb/rm_addrspacecast
Avoid address space casts.
2 parents d97a5db + b5cce68 commit b480342

File tree

7 files changed

+30
-39
lines changed

7 files changed

+30
-39
lines changed

src/device/cuda.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ include("cuda/assertion.jl")
1111
include("cuda/memory_dynamic.jl")
1212
include("cuda/atomics.jl")
1313
include("cuda/misc.jl")
14+
if VERSION >= v"1.4.1"
1415
include("cuda/wmma.jl")
16+
end
1517

1618
# functionality from libdevice
1719
#

src/device/cuda/memory_shared.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ end
6161
end
6262

6363
T_ptr = convert(LLVMType, DevicePtr{T,AS.Shared})
64-
T_actual_ptr = LLVM.PointerType(eltyp)
6564

6665
# create a function
6766
llvm_f, _ = create_function(T_ptr)
@@ -92,10 +91,9 @@ end
9291
entry = BasicBlock(llvm_f, "entry", JuliaContext())
9392
position!(builder, entry)
9493

95-
ptr_with_as = gep!(builder, gv, [ConstantInt(0, JuliaContext()),
96-
ConstantInt(0, JuliaContext())])
94+
ptr = gep!(builder, gv, [ConstantInt(0, JuliaContext()),
95+
ConstantInt(0, JuliaContext())])
9796

98-
ptr = addrspacecast!(builder, ptr_with_as, T_actual_ptr)
9997
val = ptrtoint!(builder, ptr, T_ptr)
10098
ret!(builder, val)
10199
end

src/device/cuda/wmma.jl

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ using CUDAnative: AS, DevicePtr
77
# CONSTANTS
88
################################################################################
99

10-
# Determines whether or not to Core.AddrSpacePtr is available
11-
const addrspaceptr_available = (VERSION >= v"1.5.0-DEV.324")
12-
1310
# Maps PTX types to Julia array types
1411
const map_ptx_to_jl_array = Dict(
1512
"f16" => Float16,
@@ -52,24 +49,14 @@ get_frag_info(matrix, ptx_el_type) = (
5249

5350
get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space])
5451

55-
if addrspaceptr_available
5652
@generated function Base.cconvert(::Type{Core.AddrSpacePtr{T, as}}, x::DevicePtr{T, AS}) where {T, as, AS}
57-
# Addrspacecast from i8* to i8* is invalid in LLVM
58-
if as == 0
59-
return quote
60-
return Base.bitcast(Core.AddrSpacePtr{T, as}, x)
61-
end
62-
else
63-
ir = "%p = inttoptr i64 %0 to i8*
64-
%ptr = addrspacecast i8* %p to i8 addrspace($as)*
65-
ret i8 addrspace($as)* %ptr"
53+
ir = "%ptr = inttoptr i64 %0 to i8 addrspace($as)*
54+
ret i8 addrspace($as)* %ptr"
6655

67-
return quote
68-
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
69-
end
56+
return quote
57+
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
7058
end
7159
end
72-
end
7360

7461
# Fix for https://github.com/JuliaGPU/CUDAnative.jl/issues/587.
7562
# Instead of ccall'ing the intrinsics with NTuple{N, T} (which gets lowered to
@@ -141,7 +128,7 @@ for mat in ["a", "b", "c"],
141128

142129
ccall_name = "extern $llvm_intr"
143130

144-
ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}
131+
ptr_ty = Core.AddrSpacePtr{arr_ty, addr_space_int}
145132
struct_ty = Symbol("LLVMStruct$sz")
146133

147134
@eval $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride))
@@ -196,7 +183,7 @@ for mat in ["d"],
196183
frag_types = ntuple(i -> frag_ty, sz)
197184
frag_vars = ntuple(i -> :(data[$i]), sz)
198185

199-
ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}
186+
ptr_ty = Core.AddrSpacePtr{arr_ty, addr_space_int}
200187

201188
@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
202189
@eval export $func_name

src/device/pointer.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ Base.:(+)(x::Integer, y::DevicePtr) = y + x
118118
T_int = convert(LLVMType, Int)
119119
T_ptr = convert(LLVMType, DevicePtr{T,A})
120120

121-
T_actual_ptr = LLVM.PointerType(eltyp)
121+
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, A))
122122

123123
# create a function
124124
param_types = [T_ptr, T_int]
@@ -130,10 +130,8 @@ Base.:(+)(x::Integer, y::DevicePtr) = y + x
130130
position!(builder, entry)
131131

132132
ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)
133-
134133
ptr = gep!(builder, ptr, [parameters(llvm_f)[2]])
135-
ptr_with_as = addrspacecast!(builder, ptr, LLVM.PointerType(eltyp, convert(Int, A)))
136-
ld = load!(builder, ptr_with_as)
134+
ld = load!(builder, ptr)
137135

138136
if A != AS.Generic
139137
metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(A)
@@ -153,7 +151,7 @@ end
153151
T_int = convert(LLVMType, Int)
154152
T_ptr = convert(LLVMType, DevicePtr{T,A})
155153

156-
T_actual_ptr = LLVM.PointerType(eltyp)
154+
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, A))
157155

158156
# create a function
159157
param_types = [T_ptr, eltyp, T_int]
@@ -165,11 +163,9 @@ end
165163
position!(builder, entry)
166164

167165
ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)
168-
169166
ptr = gep!(builder, ptr, [parameters(llvm_f)[3]])
170-
ptr_with_as = addrspacecast!(builder, ptr, LLVM.PointerType(eltyp, convert(Int, A)))
171167
val = parameters(llvm_f)[2]
172-
st = store!(builder, val, ptr_with_as)
168+
st = store!(builder, val, ptr)
173169

174170
if A != AS.Generic
175171
metadata(st)[LLVM.MD_tbaa] = tbaa_addrspace(A)
@@ -201,8 +197,7 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
201197
T_int32 = LLVM.Int32Type(JuliaContext())
202198
T_ptr = convert(LLVMType, DevicePtr{T,AS.Global})
203199

204-
T_actual_ptr = LLVM.PointerType(eltyp)
205-
T_actual_ptr_as = LLVM.PointerType(eltyp, convert(Int, AS.Global))
200+
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, AS.Global))
206201

207202
# create a function
208203
param_types = [T_ptr, T_int]
@@ -222,7 +217,7 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
222217
"llvm.nvvm.ldg.global.$class.$typ.p1$typ"
223218
end
224219
mod = LLVM.parent(llvm_f)
225-
intrinsic_typ = LLVM.FunctionType(eltyp, [T_actual_ptr_as, T_int32])
220+
intrinsic_typ = LLVM.FunctionType(eltyp, [T_actual_ptr, T_int32])
226221
intrinsic = LLVM.Function(mod, intrinsic_name, intrinsic_typ)
227222

228223
# generate IR
@@ -231,11 +226,9 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
231226
position!(builder, entry)
232227

233228
ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)
234-
235229
ptr = gep!(builder, ptr, [parameters(llvm_f)[2]])
236-
ptr_with_as = addrspacecast!(builder, ptr, T_actual_ptr_as)
237230
ld = call!(builder, intrinsic,
238-
[ptr_with_as, ConstantInt(Int32(align), JuliaContext())])
231+
[ptr, ConstantInt(Int32(align), JuliaContext())])
239232

240233
metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(AS.Global)
241234

test/device/cuda.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,17 @@ end
11331133
end
11341134
end
11351135

1136+
@testset "shared memory" begin
1137+
function kernel()
1138+
shared = @cuStaticSharedMem(Float32, 1)
1139+
@atomic shared[threadIdx().x] += 0f0
1140+
return
1141+
end
1142+
1143+
@cuda kernel()
1144+
synchronize()
1145+
end
1146+
11361147
end
11371148

11381149
end

test/device/wmma.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Need https://github.com/JuliaLang/julia/pull/33970
22
# and https://github.com/JuliaLang/julia/pull/34043
3-
if VERSION >= v"1.4.0-DEV.666" && capability(device()) >= v"7.0"
43

54
using CUDAnative.WMMA
65

@@ -294,4 +293,3 @@ end
294293
################################################################################
295294

296295
end
297-
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ include("device/execution.jl")
3333
include("device/pointer.jl")
3434
include("device/array.jl")
3535
include("device/cuda.jl")
36+
if VERSION >= v"1.4.1" && capability(device()) >= v"7.0"
3637
include("device/wmma.jl")
38+
end
3739

3840
include("nvtx.jl")
3941

0 commit comments

Comments
 (0)