Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 1481594

Browse files
committed
Avoid address space casts.
1 parent d97a5db commit 1481594

File tree

4 files changed

+24
-30
lines changed

4 files changed

+24
-30
lines changed

src/device/cuda/memory_shared.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ end
6161
end
6262

6363
T_ptr = convert(LLVMType, DevicePtr{T,AS.Shared})
64-
T_actual_ptr = LLVM.PointerType(eltyp)
6564

6665
# create a function
6766
llvm_f, _ = create_function(T_ptr)
@@ -92,10 +91,9 @@ end
9291
entry = BasicBlock(llvm_f, "entry", JuliaContext())
9392
position!(builder, entry)
9493

95-
ptr_with_as = gep!(builder, gv, [ConstantInt(0, JuliaContext()),
96-
ConstantInt(0, JuliaContext())])
94+
ptr = gep!(builder, gv, [ConstantInt(0, JuliaContext()),
95+
ConstantInt(0, JuliaContext())])
9796

98-
ptr = addrspacecast!(builder, ptr_with_as, T_actual_ptr)
9997
val = ptrtoint!(builder, ptr, T_ptr)
10098
ret!(builder, val)
10199
end

src/device/cuda/wmma.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,11 @@ get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space])
5454

5555
if addrspaceptr_available
5656
@generated function Base.cconvert(::Type{Core.AddrSpacePtr{T, as}}, x::DevicePtr{T, AS}) where {T, as, AS}
57-
# Addrspacecast from i8* to i8* is invalid in LLVM
58-
if as == 0
59-
return quote
60-
return Base.bitcast(Core.AddrSpacePtr{T, as}, x)
61-
end
62-
else
63-
ir = "%p = inttoptr i64 %0 to i8*
64-
%ptr = addrspacecast i8* %p to i8 addrspace($as)*
65-
ret i8 addrspace($as)* %ptr"
57+
ir = "%ptr = inttoptr i64 %0 to i8 addrspace($as)*
58+
ret i8 addrspace($as)* %ptr"
6659

67-
return quote
68-
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
69-
end
60+
return quote
61+
return Base.llvmcall($ir, Core.AddrSpacePtr{T, as}, Tuple{Int64}, Base.bitcast(Int64, x))
7062
end
7163
end
7264
end

src/device/pointer.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ Base.:(+)(x::Integer, y::DevicePtr) = y + x
118118
T_int = convert(LLVMType, Int)
119119
T_ptr = convert(LLVMType, DevicePtr{T,A})
120120

121-
T_actual_ptr = LLVM.PointerType(eltyp)
121+
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, A))
122122

123123
# create a function
124124
param_types = [T_ptr, T_int]
@@ -130,10 +130,8 @@ Base.:(+)(x::Integer, y::DevicePtr) = y + x
130130
position!(builder, entry)
131131

132132
ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)
133-
134133
ptr = gep!(builder, ptr, [parameters(llvm_f)[2]])
135-
ptr_with_as = addrspacecast!(builder, ptr, LLVM.PointerType(eltyp, convert(Int, A)))
136-
ld = load!(builder, ptr_with_as)
134+
ld = load!(builder, ptr)
137135

138136
if A != AS.Generic
139137
metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(A)
@@ -153,7 +151,7 @@ end
153151
T_int = convert(LLVMType, Int)
154152
T_ptr = convert(LLVMType, DevicePtr{T,A})
155153

156-
T_actual_ptr = LLVM.PointerType(eltyp)
154+
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, A))
157155

158156
# create a function
159157
param_types = [T_ptr, eltyp, T_int]
@@ -165,11 +163,9 @@ end
165163
position!(builder, entry)
166164

167165
ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)
168-
169166
ptr = gep!(builder, ptr, [parameters(llvm_f)[3]])
170-
ptr_with_as = addrspacecast!(builder, ptr, LLVM.PointerType(eltyp, convert(Int, A)))
171167
val = parameters(llvm_f)[2]
172-
st = store!(builder, val, ptr_with_as)
168+
st = store!(builder, val, ptr)
173169

174170
if A != AS.Generic
175171
metadata(st)[LLVM.MD_tbaa] = tbaa_addrspace(A)
@@ -201,8 +197,7 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
201197
T_int32 = LLVM.Int32Type(JuliaContext())
202198
T_ptr = convert(LLVMType, DevicePtr{T,AS.Global})
203199

204-
T_actual_ptr = LLVM.PointerType(eltyp)
205-
T_actual_ptr_as = LLVM.PointerType(eltyp, convert(Int, AS.Global))
200+
T_actual_ptr = LLVM.PointerType(eltyp, convert(Int, AS.Global))
206201

207202
# create a function
208203
param_types = [T_ptr, T_int]
@@ -222,7 +217,7 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
222217
"llvm.nvvm.ldg.global.$class.$typ.p1$typ"
223218
end
224219
mod = LLVM.parent(llvm_f)
225-
intrinsic_typ = LLVM.FunctionType(eltyp, [T_actual_ptr_as, T_int32])
220+
intrinsic_typ = LLVM.FunctionType(eltyp, [T_actual_ptr, T_int32])
226221
intrinsic = LLVM.Function(mod, intrinsic_name, intrinsic_typ)
227222

228223
# generate IR
@@ -231,11 +226,9 @@ const LDGTypes = Union{UInt8, UInt16, UInt32, UInt64,
231226
position!(builder, entry)
232227

233228
ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr)
234-
235229
ptr = gep!(builder, ptr, [parameters(llvm_f)[2]])
236-
ptr_with_as = addrspacecast!(builder, ptr, T_actual_ptr_as)
237230
ld = call!(builder, intrinsic,
238-
[ptr_with_as, ConstantInt(Int32(align), JuliaContext())])
231+
[ptr, ConstantInt(Int32(align), JuliaContext())])
239232

240233
metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(AS.Global)
241234

test/device/cuda.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,17 @@ end
11331133
end
11341134
end
11351135

1136+
@testset "shared memory" begin
1137+
function kernel()
1138+
shared = @cuStaticSharedMem(Float32, 1)
1139+
@atomic shared[threadIdx().x] += 0f0
1140+
return
1141+
end
1142+
1143+
@cuda kernel()
1144+
synchronize()
1145+
end
1146+
11361147
end
11371148

11381149
end

0 commit comments

Comments
 (0)