Skip to content

Commit 322d8c4

Browse files
committed
Use contextual dispatch for device functions.
1 parent fe62319 commit 322d8c4

17 files changed

+290
-516
lines changed

Manifest.toml

+21-2
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,21 @@ version = "6.2.0"
8383

8484
[[GPUCompiler]]
8585
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
86-
git-tree-sha1 = "ef2839b063e158672583b9c09d2cf4876a8d3d55"
86+
git-tree-sha1 = "6e74ec73289b4db63d0dfff67cbfbd06042e331f"
87+
repo-rev = "tb/cache_spoof"
88+
repo-url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"
8789
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
88-
version = "0.10.0"
90+
version = "0.10.1"
8991

9092
[[InteractiveUtils]]
9193
deps = ["Markdown"]
9294
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
9395

96+
[[JLLWrappers]]
97+
git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0"
98+
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
99+
version = "1.2.0"
100+
94101
[[LLVM]]
95102
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
96103
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
@@ -162,6 +169,12 @@ version = "0.7.14"
162169
[[NetworkOptions]]
163170
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
164171

172+
[[OpenSpecFun_jll]]
173+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
174+
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
175+
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
176+
version = "0.5.3+4"
177+
165178
[[OrderedCollections]]
166179
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
167180
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
@@ -217,6 +230,12 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
217230
deps = ["LinearAlgebra", "Random"]
218231
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
219232

233+
[[SpecialFunctions]]
234+
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
235+
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
236+
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
237+
version = "1.3.0"
238+
220239
[[Statistics]]
221240
deps = ["LinearAlgebra", "SparseArrays"]
222241
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2525
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2626
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2727
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
28+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2829
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2930
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3031

src/CUDA.jl

+13
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@ using BFloat16s
1818

1919
using Memoize
2020

21+
using ExprTools
22+
23+
24+
##
25+
26+
const ci_cache = GPUCompiler.CodeCache()
27+
28+
@static if VERSION >= v"1.7-"
29+
Base.Experimental.@MethodTable(method_table)
30+
else
31+
const method_table = nothing
32+
end
33+
2134

2235
## source code includes
2336

src/accumulate.jl

-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ function scan!(f::Function, output::AnyCuArray{T}, input::AnyCuArray;
134134
dims > ndims(input) && return copyto!(output, input)
135135
isempty(inds_t[dims]) && return output
136136

137-
f = cufunc(f)
138-
139137
# iteration domain across the main dimension
140138
Rdim = CartesianIndices((size(input, dims),))
141139

src/broadcast.jl

+3-96
Original file line numberDiff line numberDiff line change
@@ -14,99 +14,6 @@ Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}) where {N,T} =
1414
Base.similar(bc::Broadcasted{CuArrayStyle{N}}, ::Type{T}, dims) where {N,T} =
1515
CuArray{T}(undef, dims)
1616

17-
18-
## replace base functions with libdevice alternatives
19-
20-
cufunc(f) = f
21-
cufunc(::Type{T}) where T = (x...) -> T(x...) # broadcasting type ctors isn't GPU compatible
22-
23-
Broadcast.broadcasted(::CuArrayStyle{N}, f, args...) where {N} =
24-
Broadcasted{CuArrayStyle{N}}(cufunc(f), args, nothing)
25-
26-
const device_intrinsics = :[
27-
cos, cospi, sin, sinpi, tan, acos, asin, atan,
28-
cosh, sinh, tanh, acosh, asinh, atanh, angle,
29-
log, log10, log1p, log2, logb, ilogb,
30-
exp, exp2, exp10, expm1, ldexp,
31-
erf, erfinv, erfc, erfcinv, erfcx,
32-
brev, clz, ffs, byte_perm, popc,
33-
isfinite, isinf, isnan, nearbyint,
34-
nextafter, signbit, copysign, abs,
35-
sqrt, rsqrt, cbrt, rcbrt, pow,
36-
ceil, floor, saturate,
37-
lgamma, tgamma,
38-
j0, j1, jn, y0, y1, yn,
39-
normcdf, normcdfinv, hypot,
40-
fma, sad, dim, mul24, mul64hi, hadd, rhadd, scalbn].args
41-
42-
for f in device_intrinsics
43-
isdefined(Base, f) || continue
44-
@eval cufunc(::typeof(Base.$f)) = $f
45-
end
46-
47-
# broadcast ^
48-
49-
culiteral_pow(::typeof(^), x::T, ::Val{0}) where {T<:Real} = one(x)
50-
culiteral_pow(::typeof(^), x::T, ::Val{1}) where {T<:Real} = x
51-
culiteral_pow(::typeof(^), x::T, ::Val{2}) where {T<:Real} = x * x
52-
culiteral_pow(::typeof(^), x::T, ::Val{3}) where {T<:Real} = x * x * x
53-
culiteral_pow(::typeof(^), x::T, ::Val{p}) where {T<:Real,p} = pow(x, Int32(p))
54-
55-
cufunc(::typeof(Base.literal_pow)) = culiteral_pow
56-
cufunc(::typeof(Base.:(^))) = pow
57-
58-
using MacroTools
59-
60-
const _cufuncs = [copy(device_intrinsics); :^]
61-
cufuncs() = (global _cufuncs; _cufuncs)
62-
63-
_cuint(x::Int) = Int32(x)
64-
_cuint(x::Expr) = x.head == :call && x.args[1] == :Int32 && x.args[2] isa Int ? Int32(x.args[2]) : x
65-
_cuint(x) = x
66-
67-
function _cupowliteral(x::Expr)
68-
if x.head == :call && x.args[1] == :(CUDA.cufunc(^)) && x.args[3] isa Int32
69-
num = x.args[3]
70-
if 0 <= num <= 3
71-
sym = gensym(:x)
72-
new_x = Expr(:block, :($sym = $(x.args[2])))
73-
74-
if iszero(num)
75-
push!(new_x.args, :(one($sym)))
76-
else
77-
unroll = Expr(:call, :*)
78-
for x = one(num):num
79-
push!(unroll.args, sym)
80-
end
81-
push!(new_x.args, unroll)
82-
end
83-
84-
x = new_x
85-
end
86-
end
87-
x
88-
end
89-
_cupowliteral(x) = x
90-
91-
function replace_device(ex)
92-
global _cufuncs
93-
MacroTools.postwalk(ex) do x
94-
x = x in _cufuncs ? :(CUDA.cufunc($x)) : x
95-
x = _cuint(x)
96-
x = _cupowliteral(x)
97-
x
98-
end
99-
end
100-
101-
macro cufunc(ex)
102-
global _cufuncs
103-
def = MacroTools.splitdef(ex)
104-
f = def[:name]
105-
def[:name] = Symbol(:cu, f)
106-
def[:body] = replace_device(def[:body])
107-
push!(_cufuncs, f)
108-
quote
109-
$(esc(MacroTools.combinedef(def)))
110-
CUDA.cufunc(::typeof($(esc(f)))) = $(esc(def[:name]))
111-
end
112-
end
17+
# broadcasting type ctors isn't GPU compatible
18+
Broadcast.broadcasted(::CuArrayStyle{N}, f::Type{T}, args...) where {N, T} =
19+
Broadcasted{CuArrayStyle{N}}((x...) -> T(x...), args, nothing)

src/compiler/gpucompiler.jl

+4
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,7 @@ function GPUCompiler.link_libraries!(job::CUDACompilerJob, mod::LLVM.Module,
3939
job, mod, undefined_fns)
4040
link_libdevice!(mod, job.target.cap, undefined_fns)
4141
end
42+
43+
GPUCompiler.ci_cache(::CUDACompilerJob) = ci_cache
44+
45+
GPUCompiler.method_table(::CUDACompilerJob) = method_table

src/device/intrinsics.jl

+29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
# wrappers for functionality provided by the CUDA toolkit
22

3+
const overrides = quote end
4+
5+
macro device_override(ex)
6+
code = quote
7+
GPUCompiler.@override($method_table, $ex)
8+
end
9+
if VERSION >= v"1.7-"
10+
return esc(code)
11+
else
12+
push!(overrides.args, code)
13+
return
14+
end
15+
end
16+
17+
macro device_function(ex)
18+
ex = macroexpand(__module__, ex)
19+
def = splitdef(ex)
20+
21+
# generate a function that errors
22+
def[:body] = quote
23+
error("This function is not intended for use on the CPU")
24+
end
25+
26+
esc(quote
27+
$(combinedef(def))
28+
@device_override $ex
29+
end)
30+
end
31+
332
# extensions to the C language
433
include("intrinsics/memory_shared.jl")
534
include("intrinsics/indexing.jl")

0 commit comments

Comments
 (0)