Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 01d968e

Browse files
committed
Try #334:
2 parents e687697 + 94c4b7c commit 01d968e

File tree

5 files changed

+63
-3
lines changed

5 files changed

+63
-3
lines changed

Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
55
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
66
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
77
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
8+
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
89
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
910
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1011
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -20,3 +21,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2021

2122
[targets]
2223
test = ["Test", "BenchmarkTools", "SpecialFunctions"]
24+
25+
[compat]
26+
julia = ">= 1.1"

REQUIRE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
julia 1.0
1+
julia 1.1
22
CUDAdrv 1.1
33
LLVM 0.9.14
44
CUDAapi 0.4.0

src/CUDAnative.jl

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ end
2323
include("utils.jl")
2424

2525
# needs to be loaded _before_ the compiler infrastructure, because of generated functions
26+
isdevice() = false
2627
include(joinpath("device", "tools.jl"))
2728
include(joinpath("device", "pointer.jl"))
2829
include(joinpath("device", "array.jl"))
@@ -31,6 +32,7 @@ include(joinpath("device", "cuda_intrinsics.jl"))
3132
include(joinpath("device", "runtime_intrinsics.jl"))
3233

3334
include("compiler.jl")
35+
include("context.jl")
3436
include("execution.jl")
3537
include("reflection.jl")
3638

src/context.jl

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
##
2+
# Implements contextual dispatch through Cassette.jl
3+
# Goals:
4+
# - Rewrite common CPU functions to appropriate GPU intrinsics
5+
#
6+
# TODO:
7+
# - error (erf, ...)
8+
# - pow
9+
# - min, max
10+
# - mod, rem
11+
# - gamma
12+
# - bessel
13+
# - distributions
14+
# - unsorted
15+
16+
using Cassette
17+
18+
function transform(ctx, ref)
19+
ci = ref.code_info
20+
noinline = any(@nospecialize(x) -> Core.Compiler.isexpr(x, :meta) && x.args[1] == :noinline, ci.code)
21+
if !noinline
22+
ci.inlineable = true
23+
end
24+
return ci
25+
end
26+
const InlinePass = Cassette.@pass transform
27+
28+
Cassette.@context CUDACtx
29+
const cudactx = Cassette.disablehooks(CUDACtx(pass = InlinePass))
30+
31+
Cassette.overdub(::CUDACtx, ::typeof(datatype_align), ::Type{T}) where {T} = datatype_align(T)
32+
Cassette.overdub(ctx::CUDACtx, ::typeof(isdevice)) = true
33+
34+
# libdevice.jl
35+
for f in (:cos, :cospi, :sin, :sinpi, :tan,
36+
:acos, :asin, :atan,
37+
:cosh, :sinh, :tanh,
38+
:acosh, :asinh, :atanh,
39+
:log, :log10, :log1p, :log2,
40+
:exp, :exp2, :exp10, :expm1, :ldexp,
41+
:isfinite, :isinf, :isnan,
42+
:signbit, :abs,
43+
:sqrt, :cbrt,
44+
:ceil, :floor,)
45+
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
46+
@Base._inline_meta
47+
return CUDAnative.$f(x)
48+
end
49+
end
50+
51+
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)
52+

src/execution.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ kernel to determine the launch configuration:
175175
GC.@preserve args begin
176176
kernel_args = cudaconvert.(args)
177177
kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
178-
kernel = cufunction(f, kernel_tt; compilation_kwargs)
178+
kernel_f = contextualize(f)
179+
kernel = cufunction(kernel_f, kernel_tt; compilation_kwargs)
179180
kernel(kernel_args...; launch_kwargs)
180181
end
181182
"""
@@ -205,7 +206,8 @@ macro cuda(ex...)
205206
GC.@preserve $(vars...) begin
206207
local kernel_args = cudaconvert.(($(var_exprs...),))
207208
local kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
208-
local kernel = cufunction($(esc(f)), kernel_tt; $(map(esc, compiler_kwargs)...))
209+
local kernel_f = contextualize($(esc(f)))
210+
local kernel = cufunction(kernel_f, kernel_tt; $(map(esc, compiler_kwargs)...))
209211
kernel(kernel_args...; $(map(esc, call_kwargs)...))
210212
end
211213
end)

0 commit comments

Comments
 (0)