@@ -10,8 +10,7 @@ abstract type AddressSpace end
10
10
11
11
module AS
12
12
13
- using CUDAnative
14
- import CUDAnative: AddressSpace
13
+ import .. AddressSpace
15
14
16
15
struct Generic <: AddressSpace end
17
16
struct Global <: AddressSpace end
26
25
# Device pointer
27
26
#
28
27
29
- struct DevicePtr{T,A}
30
- ptr :: Ptr{T }
28
+ """
29
+ DevicePtr{T,A }
31
30
32
- # inner constructors, fully parameterized
33
- DevicePtr {T,A} (ptr :: Ptr{T} ) where {T,A <: AddressSpace } = new (ptr)
34
- end
35
-
36
- # outer constructors, partially parameterized
37
- DevicePtr {T} (ptr :: Ptr{T} ) where {T} = DevicePtr {T,AS.Generic} (ptr)
31
+ A memory address that refers to data of type `T` that is accessible from the GPU. It is the
32
+ on-device counterpart of `CUDAdrv.CuPtr`, additionally keeping track of the address space
33
+ `A` where the data resides (shared, global, constant, etc). This information is used to
34
+ provide optimized implementations of operations such as `unsafe_load` and `unsafe_store!.`
35
+ """
36
+ DevicePtr
38
37
39
- # outer constructors, non-parameterized
40
- DevicePtr (ptr:: Ptr{T} ) where {T} = DevicePtr {T,AS.Generic} (ptr)
38
+ if sizeof (Ptr{Cvoid}) == 8
39
+ primitive type DevicePtr{T,A} 64 end
40
+ else
41
+ primitive type DevicePtr{T,A} 32 end
42
+ end
41
43
42
- Base. show (io:: IO , dp:: DevicePtr{T,AS} ) where {T,AS} =
43
- print (io, AS. name. name, " Device" , pointer (dp))
44
+ # constructors
45
+ DevicePtr {T,A} (x:: Union{Int,UInt,CuPtr,DevicePtr} ) where {T,A<: AddressSpace } = Base. bitcast (DevicePtr{T,A}, x)
46
+ DevicePtr {T} (ptr:: CuPtr{T} ) where {T} = DevicePtr {T,AS.Generic} (ptr)
47
+ DevicePtr (ptr:: CuPtr{T} ) where {T} = DevicePtr {T,AS.Generic} (ptr)
44
48
45
49
46
50
# # getters
47
51
48
- Base. pointer (p:: DevicePtr ) = p. ptr
49
-
50
52
Base. eltype (:: Type{<:DevicePtr{T}} ) where {T} = T
51
53
52
54
addrspace (x:: DevicePtr ) = addrspace (typeof (x))
@@ -55,20 +57,23 @@ addrspace(::Type{DevicePtr{T,A}}) where {T,A} = A
55
57
56
58
# # conversions
57
59
58
- # between regular and device pointers
59
- # # simple conversions disallowed
60
- Base. convert (:: Type{Ptr{T}} , p:: DevicePtr{T} ) where {T} = throw (InexactError (:convert , Ptr{T}, p))
61
- Base. convert (:: Type{<:DevicePtr{T}} , p:: Ptr{T} ) where {T} = throw (InexactError (:convert , DevicePtr{T}, p))
62
- # # unsafe ones are allowed
63
- Base. unsafe_convert (:: Type{Ptr{T}} , p:: DevicePtr{T} ) where {T} = pointer (p)
60
+ # to and from integers
61
+ # # pointer to integer
62
+ Base. convert (:: Type{T} , x:: DevicePtr ) where {T<: Integer } = T (UInt (x))
63
+ # # integer to pointer
64
+ Base. convert (:: Type{DevicePtr{T,A}} , x:: Union{Int,UInt} ) where {T,A<: AddressSpace } = DevicePtr {T,A} (x)
65
+ Int (x:: DevicePtr ) = Base. bitcast (Int, x)
66
+ UInt (x:: DevicePtr ) = Base. bitcast (UInt, x)
64
67
65
- # defer conversions to DevicePtr to unsafe_convert
66
- Base. cconvert (:: Type{<:DevicePtr} , x) = x
68
+ # between host and device pointers
69
+ Base. convert (:: Type{CuPtr{T}} , p:: DevicePtr ) where {T} = Base. bitcast (CuPtr{T}, p)
70
+ Base. convert (:: Type{DevicePtr{T,A}} , p:: CuPtr ) where {T,A<: AddressSpace } = Base. bitcast (DevicePtr{T,A}, p)
71
+ Base. convert (:: Type{DevicePtr{T}} , p:: CuPtr ) where {T} = Base. bitcast (DevicePtr{T,AS. Generic}, p)
67
72
68
73
# between device pointers
69
- Base. convert (:: Type{<:DevicePtr} , p:: DevicePtr ) = throw (InexactError ( : convert, DevicePtr, p ))
74
+ Base. convert (:: Type{<:DevicePtr} , p:: DevicePtr ) = throw (ArgumentError ( " cannot convert between incompatible device pointer types " ))
70
75
Base. convert (:: Type{DevicePtr{T,A}} , p:: DevicePtr{T,A} ) where {T,A} = p
71
- Base. unsafe_convert (:: Type{DevicePtr{T,A}} , p:: DevicePtr ) where {T,A} = DevicePtr {T,A} ( reinterpret (Ptr{T}, pointer (p)) )
76
+ Base. unsafe_convert (:: Type{DevicePtr{T,A}} , p:: DevicePtr ) where {T,A} = Base . bitcast ( DevicePtr{T,A}, p )
72
77
# # identical addrspaces
73
78
Base. convert (:: Type{DevicePtr{T,A}} , p:: DevicePtr{U,A} ) where {T,U,A} = Base. unsafe_convert (DevicePtr{T,A}, p)
74
79
# # convert to & from generic
@@ -78,19 +83,25 @@ Base.convert(::Type{DevicePtr{T,AS.Generic}}, p::DevicePtr{T,AS.Generic}) where
78
83
# # unspecified, preserve source addrspace
79
84
Base. convert (:: Type{DevicePtr{T}} , p:: DevicePtr{U,A} ) where {T,U,A} = Base. unsafe_convert (DevicePtr{T,A}, p)
80
85
86
+ # defer conversions to DevicePtr to unsafe_convert
87
+ Base. cconvert (:: Type{<:DevicePtr} , x) = x
88
+
81
89
82
90
# # limited pointer arithmetic & comparison
83
91
84
- Base.:(== )(a:: DevicePtr , b:: DevicePtr ) = pointer (a) == pointer (b) && addrspace (a) == addrspace (b)
92
+ isequal (x:: DevicePtr , y:: DevicePtr ) = (x === y) && addrspace (x) == addrspace (y)
93
+ isless (x:: DevicePtr{T,A} , y:: DevicePtr{T,A} ) where {T,A<: AddressSpace } = x < y
85
94
86
- Base. isless (x:: DevicePtr , y:: DevicePtr ) = Base. isless (pointer (x), pointer (y))
87
- Base.:(- )(x:: DevicePtr , y:: DevicePtr ) = pointer (x) - pointer (y)
95
+ Base.:(== )(x:: DevicePtr , y:: DevicePtr ) = UInt (x) == UInt (y) && addrspace (x) == addrspace (y)
96
+ Base.:(< )(x:: DevicePtr , y:: DevicePtr ) = UInt (x) < UInt (y)
97
+ Base.:(- )(x:: DevicePtr , y:: DevicePtr ) = UInt (x) - UInt (y)
88
98
89
- Base.:(+ )(x:: DevicePtr{T,A} , y:: Integer ) where {T,A} = DevicePtr {T,A} ( pointer (x) + y )
90
- Base.:(- )(x:: DevicePtr{T,A} , y:: Integer ) where {T,A} = DevicePtr {T,A} ( pointer (x) - y )
99
+ Base.:(+ )(x:: DevicePtr , y:: Integer ) = oftype (x, Base . add_ptr ( UInt (x), (y % UInt) % UInt) )
100
+ Base.:(- )(x:: DevicePtr , y:: Integer ) = oftype (x, Base . sub_ptr ( UInt (x), (y % UInt) % UInt) )
91
101
Base.:(+ )(x:: Integer , y:: DevicePtr ) = y + x
92
102
93
103
104
+
94
105
# # memory operations
95
106
96
107
Base. convert (:: Type{Int} , :: Type{AS.Generic} ) = 0
@@ -121,7 +132,7 @@ tbaa_addrspace(as::Type{<:AddressSpace}) = tbaa_make_child(lowercase(String(as.n
121
132
eltyp = convert (LLVMType, T)
122
133
123
134
T_int = convert (LLVMType, Int)
124
- T_ptr = convert (LLVMType, Ptr{T })
135
+ T_ptr = convert (LLVMType, DevicePtr{T,A })
125
136
126
137
T_actual_ptr = LLVM. PointerType (eltyp)
127
138
@@ -148,15 +159,15 @@ tbaa_addrspace(as::Type{<:AddressSpace}) = tbaa_make_child(lowercase(String(as.n
148
159
ret! (builder, ld)
149
160
end
150
161
151
- call_function (llvm_f, T, Tuple{Ptr{T }, Int}, :((pointer (p) , Int (i- one (i)))))
162
+ call_function (llvm_f, T, Tuple{DevicePtr{T,A }, Int}, :((p , Int (i- one (i)))))
152
163
end
153
164
154
165
@generated function Base. unsafe_store! (p:: DevicePtr{T,A} , x, i:: Integer = 1 ,
155
166
:: Val{align} = Val (1 )) where {T,A,align}
156
167
eltyp = convert (LLVMType, T)
157
168
158
169
T_int = convert (LLVMType, Int)
159
- T_ptr = convert (LLVMType, Ptr{T })
170
+ T_ptr = convert (LLVMType, DevicePtr{T,A })
160
171
161
172
T_actual_ptr = LLVM. PointerType (eltyp)
162
173
184
195
ret! (builder)
185
196
end
186
197
187
- call_function (llvm_f, Cvoid, Tuple{Ptr{T}, T, Int}, :((pointer (p), convert (T,x), Int (i- one (i)))))
198
+ call_function (llvm_f, Cvoid, Tuple{DevicePtr{T,A}, T, Int},
199
+ :((p, convert (T,x), Int (i- one (i)))))
188
200
end
189
201
190
202
# # loading through the texture cache
@@ -215,7 +227,7 @@ const CachedLoadPointers = Union{Tuple(DevicePtr{T,AS.Global}
215
227
216
228
T_int = convert (LLVMType, Int)
217
229
T_int32 = LLVM. Int32Type (JuliaContext ())
218
- T_ptr = convert (LLVMType, Ptr{T })
230
+ T_ptr = convert (LLVMType, DevicePtr{T,AS . Global })
219
231
220
232
T_actual_ptr = LLVM. PointerType (eltyp)
221
233
T_actual_ptr_as = LLVM. PointerType (eltyp, convert (Int, AS. Global))
@@ -258,7 +270,7 @@ const CachedLoadPointers = Union{Tuple(DevicePtr{T,AS.Global}
258
270
ret! (builder, ld)
259
271
end
260
272
261
- call_function (llvm_f, T, Tuple{Ptr{T }, Int}, :((pointer (p) , Int (i- one (i)))))
273
+ call_function (llvm_f, T, Tuple{DevicePtr{T,AS . Global }, Int}, :((p , Int (i- one (i)))))
262
274
end
263
275
264
276
@inline unsafe_cached_load (p:: DevicePtr{T,AS.Global} , i:: Integer = 1 , args... ) where {T} =
0 commit comments