Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit e687697

Browse files
bors[bot]maleadt
andcommitted
Merge #337
337: WIP: Adapt to the new CUDAdrv.CuPtr pointer type. r=maleadt a=maleadt JuliaGPU/CUDAdrv.jl#125 Co-authored-by: Tim Besard <[email protected]>
2 parents 95fbf93 + a0cf604 commit e687697

File tree

12 files changed

+134
-100
lines changed

12 files changed

+134
-100
lines changed

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
julia 1.0
2-
CUDAdrv 1.0
2+
CUDAdrv 1.1
33
LLVM 0.9.14
44
CUDAapi 0.4.0
55
Adapt 0.4

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
using Documenter, CUDAnative
1+
using Documenter
22

33
using Pkg
44
if haskey(ENV, "GITLAB_CI")
55
Pkg.add([PackageSpec(name = x; rev = "master") for x in ["CUDAdrv", "LLVM"]])
66
end
77

8+
using CUDAnative
9+
810
makedocs(
911
modules = [CUDAnative],
1012
format = Documenter.HTML(prettyurls = get(ENV, "CI", nothing) == "true"),

src/device/cuda_intrinsics/memory_shared.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ macro cuStaticSharedMem(T, dims)
2323
quote
2424
len = prod($(esc(dims)))
2525
ptr = _shmem(Val($id), $(esc(T)), Val(len))
26-
CuDeviceArray($(esc(dims)), DevicePtr{$(esc(T)), AS.Shared}(ptr))
26+
CuDeviceArray($(esc(dims)), ptr)
2727
end
2828
end
2929

@@ -49,15 +49,15 @@ macro cuDynamicSharedMem(T, dims, offset=0)
4949
quote
5050
len = prod($(esc(dims)))
5151
ptr = _shmem(Val($id), $(esc(T))) + $(esc(offset))
52-
CuDeviceArray($(esc(dims)), DevicePtr{$(esc(T)), AS.Shared}(ptr))
52+
CuDeviceArray($(esc(dims)), ptr)
5353
end
5454
end
5555

5656
# get a pointer to shared memory, with known (static) or zero length (dynamic shared memory)
5757
@generated function _shmem(::Val{id}, ::Type{T}, ::Val{len}=Val(0)) where {id,T,len}
5858
eltyp = convert(LLVMType, T)
5959

60-
T_ptr = convert(LLVMType, Ptr{T})
60+
T_ptr = convert(LLVMType, DevicePtr{T,AS.Shared})
6161
T_actual_ptr = LLVM.PointerType(eltyp)
6262

6363
# create a function
@@ -96,5 +96,5 @@ end
9696
ret!(builder, val)
9797
end
9898

99-
call_function(llvm_f, Ptr{T})
99+
call_function(llvm_f, DevicePtr{T,AS.Shared})
100100
end

src/device/pointer.jl

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ abstract type AddressSpace end
1010

1111
module AS
1212

13-
using CUDAnative
14-
import CUDAnative: AddressSpace
13+
import ..AddressSpace
1514

1615
struct Generic <: AddressSpace end
1716
struct Global <: AddressSpace end
@@ -26,27 +25,30 @@ end
2625
# Device pointer
2726
#
2827

29-
struct DevicePtr{T,A}
30-
ptr::Ptr{T}
28+
"""
29+
DevicePtr{T,A}
3130
32-
# inner constructors, fully parameterized
33-
DevicePtr{T,A}(ptr::Ptr{T}) where {T,A<:AddressSpace} = new(ptr)
34-
end
35-
36-
# outer constructors, partially parameterized
37-
DevicePtr{T}(ptr::Ptr{T}) where {T} = DevicePtr{T,AS.Generic}(ptr)
31+
A memory address that refers to data of type `T` that is accessible from the GPU. It is the
32+
on-device counterpart of `CUDAdrv.CuPtr`, additionally keeping track of the address space
33+
`A` where the data resides (shared, global, constant, etc). This information is used to
34+
provide optimized implementations of operations such as `unsafe_load` and `unsafe_store!.`
35+
"""
36+
DevicePtr
3837

39-
# outer constructors, non-parameterized
40-
DevicePtr(ptr::Ptr{T}) where {T} = DevicePtr{T,AS.Generic}(ptr)
38+
if sizeof(Ptr{Cvoid}) == 8
39+
primitive type DevicePtr{T,A} 64 end
40+
else
41+
primitive type DevicePtr{T,A} 32 end
42+
end
4143

42-
Base.show(io::IO, dp::DevicePtr{T,AS}) where {T,AS} =
43-
print(io, AS.name.name, " Device", pointer(dp))
44+
# constructors
45+
DevicePtr{T,A}(x::Union{Int,UInt,CuPtr,DevicePtr}) where {T,A<:AddressSpace} = Base.bitcast(DevicePtr{T,A}, x)
46+
DevicePtr{T}(ptr::CuPtr{T}) where {T} = DevicePtr{T,AS.Generic}(ptr)
47+
DevicePtr(ptr::CuPtr{T}) where {T} = DevicePtr{T,AS.Generic}(ptr)
4448

4549

4650
## getters
4751

48-
Base.pointer(p::DevicePtr) = p.ptr
49-
5052
Base.eltype(::Type{<:DevicePtr{T}}) where {T} = T
5153

5254
addrspace(x::DevicePtr) = addrspace(typeof(x))
@@ -55,20 +57,23 @@ addrspace(::Type{DevicePtr{T,A}}) where {T,A} = A
5557

5658
## conversions
5759

58-
# between regular and device pointers
59-
## simple conversions disallowed
60-
Base.convert(::Type{Ptr{T}}, p::DevicePtr{T}) where {T} = throw(InexactError(:convert, Ptr{T}, p))
61-
Base.convert(::Type{<:DevicePtr{T}}, p::Ptr{T}) where {T} = throw(InexactError(:convert, DevicePtr{T}, p))
62-
## unsafe ones are allowed
63-
Base.unsafe_convert(::Type{Ptr{T}}, p::DevicePtr{T}) where {T} = pointer(p)
60+
# to and from integers
61+
## pointer to integer
62+
Base.convert(::Type{T}, x::DevicePtr) where {T<:Integer} = T(UInt(x))
63+
## integer to pointer
64+
Base.convert(::Type{DevicePtr{T,A}}, x::Union{Int,UInt}) where {T,A<:AddressSpace} = DevicePtr{T,A}(x)
65+
Int(x::DevicePtr) = Base.bitcast(Int, x)
66+
UInt(x::DevicePtr) = Base.bitcast(UInt, x)
6467

65-
# defer conversions to DevicePtr to unsafe_convert
66-
Base.cconvert(::Type{<:DevicePtr}, x) = x
68+
# between host and device pointers
69+
Base.convert(::Type{CuPtr{T}}, p::DevicePtr) where {T} = Base.bitcast(CuPtr{T}, p)
70+
Base.convert(::Type{DevicePtr{T,A}}, p::CuPtr) where {T,A<:AddressSpace} = Base.bitcast(DevicePtr{T,A}, p)
71+
Base.convert(::Type{DevicePtr{T}}, p::CuPtr) where {T} = Base.bitcast(DevicePtr{T,AS.Generic}, p)
6772

6873
# between device pointers
69-
Base.convert(::Type{<:DevicePtr}, p::DevicePtr) = throw(InexactError(:convert, DevicePtr, p))
74+
Base.convert(::Type{<:DevicePtr}, p::DevicePtr) = throw(ArgumentError("cannot convert between incompatible device pointer types"))
7075
Base.convert(::Type{DevicePtr{T,A}}, p::DevicePtr{T,A}) where {T,A} = p
71-
Base.unsafe_convert(::Type{DevicePtr{T,A}}, p::DevicePtr) where {T,A} = DevicePtr{T,A}(reinterpret(Ptr{T}, pointer(p)))
76+
Base.unsafe_convert(::Type{DevicePtr{T,A}}, p::DevicePtr) where {T,A} = Base.bitcast(DevicePtr{T,A}, p)
7277
## identical addrspaces
7378
Base.convert(::Type{DevicePtr{T,A}}, p::DevicePtr{U,A}) where {T,U,A} = Base.unsafe_convert(DevicePtr{T,A}, p)
7479
## convert to & from generic
@@ -78,19 +83,25 @@ Base.convert(::Type{DevicePtr{T,AS.Generic}}, p::DevicePtr{T,AS.Generic}) where
7883
## unspecified, preserve source addrspace
7984
Base.convert(::Type{DevicePtr{T}}, p::DevicePtr{U,A}) where {T,U,A} = Base.unsafe_convert(DevicePtr{T,A}, p)
8085

86+
# defer conversions to DevicePtr to unsafe_convert
87+
Base.cconvert(::Type{<:DevicePtr}, x) = x
88+
8189

8290
## limited pointer arithmetic & comparison
8391

84-
Base.:(==)(a::DevicePtr, b::DevicePtr) = pointer(a) == pointer(b) && addrspace(a) == addrspace(b)
92+
isequal(x::DevicePtr, y::DevicePtr) = (x === y) && addrspace(x) == addrspace(y)
93+
isless(x::DevicePtr{T,A}, y::DevicePtr{T,A}) where {T,A<:AddressSpace} = x < y
8594

86-
Base.isless(x::DevicePtr, y::DevicePtr) = Base.isless(pointer(x), pointer(y))
87-
Base.:(-)(x::DevicePtr, y::DevicePtr) = pointer(x) - pointer(y)
95+
Base.:(==)(x::DevicePtr, y::DevicePtr) = UInt(x) == UInt(y) && addrspace(x) == addrspace(y)
96+
Base.:(<)(x::DevicePtr, y::DevicePtr) = UInt(x) < UInt(y)
97+
Base.:(-)(x::DevicePtr, y::DevicePtr) = UInt(x) - UInt(y)
8898

89-
Base.:(+)(x::DevicePtr{T,A}, y::Integer) where {T,A} = DevicePtr{T,A}(pointer(x) + y)
90-
Base.:(-)(x::DevicePtr{T,A}, y::Integer) where {T,A} = DevicePtr{T,A}(pointer(x) - y)
99+
Base.:(+)(x::DevicePtr, y::Integer) = oftype(x, Base.add_ptr(UInt(x), (y % UInt) % UInt))
100+
Base.:(-)(x::DevicePtr, y::Integer) = oftype(x, Base.sub_ptr(UInt(x), (y % UInt) % UInt))
91101
Base.:(+)(x::Integer, y::DevicePtr) = y + x
92102

93103

104+
94105
## memory operations
95106

96107
Base.convert(::Type{Int}, ::Type{AS.Generic}) = 0
@@ -121,7 +132,7 @@ tbaa_addrspace(as::Type{<:AddressSpace}) = tbaa_make_child(lowercase(String(as.n
121132
eltyp = convert(LLVMType, T)
122133

123134
T_int = convert(LLVMType, Int)
124-
T_ptr = convert(LLVMType, Ptr{T})
135+
T_ptr = convert(LLVMType, DevicePtr{T,A})
125136

126137
T_actual_ptr = LLVM.PointerType(eltyp)
127138

@@ -148,15 +159,15 @@ tbaa_addrspace(as::Type{<:AddressSpace}) = tbaa_make_child(lowercase(String(as.n
148159
ret!(builder, ld)
149160
end
150161

151-
call_function(llvm_f, T, Tuple{Ptr{T}, Int}, :((pointer(p), Int(i-one(i)))))
162+
call_function(llvm_f, T, Tuple{DevicePtr{T,A}, Int}, :((p, Int(i-one(i)))))
152163
end
153164

154165
@generated function Base.unsafe_store!(p::DevicePtr{T,A}, x, i::Integer=1,
155166
::Val{align}=Val(1)) where {T,A,align}
156167
eltyp = convert(LLVMType, T)
157168

158169
T_int = convert(LLVMType, Int)
159-
T_ptr = convert(LLVMType, Ptr{T})
170+
T_ptr = convert(LLVMType, DevicePtr{T,A})
160171

161172
T_actual_ptr = LLVM.PointerType(eltyp)
162173

@@ -184,7 +195,8 @@ end
184195
ret!(builder)
185196
end
186197

187-
call_function(llvm_f, Cvoid, Tuple{Ptr{T}, T, Int}, :((pointer(p), convert(T,x), Int(i-one(i)))))
198+
call_function(llvm_f, Cvoid, Tuple{DevicePtr{T,A}, T, Int},
199+
:((p, convert(T,x), Int(i-one(i)))))
188200
end
189201

190202
## loading through the texture cache
@@ -215,7 +227,7 @@ const CachedLoadPointers = Union{Tuple(DevicePtr{T,AS.Global}
215227

216228
T_int = convert(LLVMType, Int)
217229
T_int32 = LLVM.Int32Type(JuliaContext())
218-
T_ptr = convert(LLVMType, Ptr{T})
230+
T_ptr = convert(LLVMType, DevicePtr{T,AS.Global})
219231

220232
T_actual_ptr = LLVM.PointerType(eltyp)
221233
T_actual_ptr_as = LLVM.PointerType(eltyp, convert(Int, AS.Global))
@@ -258,7 +270,7 @@ const CachedLoadPointers = Union{Tuple(DevicePtr{T,AS.Global}
258270
ret!(builder, ld)
259271
end
260272

261-
call_function(llvm_f, T, Tuple{Ptr{T}, Int}, :((pointer(p), Int(i-one(i)))))
273+
call_function(llvm_f, T, Tuple{DevicePtr{T,AS.Global}, Int}, :((p, Int(i-one(i)))))
262274
end
263275

264276
@inline unsafe_cached_load(p::DevicePtr{T,AS.Global}, i::Integer=1, args...) where {T} =

src/execution.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ end
128128

129129
struct Adaptor end
130130

131+
# convert CUDAdrv pointers to CUDAnative pointers
132+
Adapt.adapt_storage(to::Adaptor, p::CuPtr{T}) where {T} = DevicePtr{T,AS.Generic}(p)
133+
131134
# Base.RefValue isn't GPU compatible, so provide a compatible alternative
132135
struct CuRefValue{T} <: Ref{T}
133136
x::T

test/device/array.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
@testset "constructors" begin
44
# inner constructors
55
let
6-
p = Ptr{Int}(C_NULL)
7-
dp = CUDAnative.DevicePtr(p)
6+
dp = CUDAnative.DevicePtr{Int,AS.Generic}(0)
87
CuDeviceArray{Int,1,AS.Generic}((1,), dp)
98
end
109

@@ -13,8 +12,7 @@
1312
a = I(1)
1413
b = I(2)
1514

16-
p = Ptr{I}(C_NULL)
17-
dp = CUDAnative.DevicePtr(p)
15+
dp = CUDAnative.DevicePtr{I,AS.Generic}(0)
1816

1917
# not parameterized
2018
CuDeviceArray(b, dp)
@@ -138,7 +136,7 @@ end
138136

139137
a = [1]
140138
p = pointer(a)
141-
dp = CUDAnative.DevicePtr(p)
139+
dp = Base.bitcast(CUDAnative.DevicePtr{eltype(p), AS.Generic}, p)
142140
da = CUDAnative.CuDeviceArray(1, dp)
143141
load_index(da)
144142
end

test/device/codegen.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
end
1313

1414
buf = Mem.alloc(Float64)
15-
@cuda kernel(convert(Ptr{Float64}, buf.ptr), (1., 2., ))
15+
ptr = Base.unsafe_convert(CuPtr{Float64}, buf)
16+
17+
@cuda kernel(ptr, (1., 2., ))
1618
@test Mem.download(Float64, buf) == [1.]
1719
end
1820

@@ -32,7 +34,9 @@ end
3234
return
3335
end
3436

35-
@cuda threads=2 kernel(CuDeviceArray((2,1), CUDAnative.DevicePtr(convert(Ptr{Int}, out_buf.ptr))), a, b)
37+
ptr = Base.unsafe_convert(CuPtr{Int}, out_buf)
38+
39+
@cuda threads=2 kernel(CuDeviceArray((2,1), CUDAnative.DevicePtr(ptr)), a, b)
3640
@test Mem.download(Int, out_buf, 2) == (_a .+ 1)[1:2]
3741
end
3842

@@ -62,7 +66,9 @@ end
6266
function gpu(input)
6367
output = Mem.alloc(Int, 2)
6468

65-
@cuda threads=2 kernel(input, convert(Ptr{eltype(input)}, output.ptr), 99)
69+
ptr = Base.unsafe_convert(CuPtr{eltype(input)}, output)
70+
71+
@cuda threads=2 kernel(input, ptr, 99)
6672

6773
return Mem.download(Int, output, 2)
6874
end

0 commit comments

Comments
 (0)