Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 39c4b7a

Browse files
committed
Use Cassette for contextual dispatch.
1 parent a5eebb4 commit 39c4b7a

File tree

6 files changed

+90
-6
lines changed

6 files changed

+90
-6
lines changed

Manifest.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ git-tree-sha1 = "1fce616fa0806c67c133eb1d2f68f0f1a7504665"
2626
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
2727
version = "5.0.1"
2828

29+
[[Cassette]]
30+
git-tree-sha1 = "36bd4e0088652b0b2d25a03e531f0d04258feb78"
31+
uuid = "7057c7e9-c182-5462-911a-8362d720325c"
32+
version = "0.3.0"
33+
2934
[[DataStructures]]
3035
deps = ["InteractiveUtils", "OrderedCollections"]
3136
git-tree-sha1 = "b7720de347734f4716d1815b00ce5664ed6bbfd4"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
88
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
99
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
10+
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
1011
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1112
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1213
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"

src/CUDAnative.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ const ptxas = Ref{String}()
3434
include("utils.jl")
3535

3636
# needs to be loaded _before_ the compiler infrastructure, because of generated functions
37+
isdevice() = false
3738
include("device/tools.jl")
3839
include("device/pointer.jl")
3940
include("device/array.jl")
@@ -44,6 +45,7 @@ include("device/runtime.jl")
4445
include("init.jl")
4546

4647
include("compiler.jl")
48+
include("context.jl")
4749
include("execution.jl")
4850
include("exceptions.jl")
4951
include("reflection.jl")

src/compiler/driver.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,12 @@ function codegen(target::Symbol, job::CompilerJob;
6262
@timeit_debug to "validation" check_method(job)
6363

6464
@timeit_debug to "Julia front-end" begin
65+
f = contextualize(job.f)
6566

6667
# get the method instance
6768
world = typemax(UInt)
68-
meth = which(job.f, job.tt)
69-
sig = Base.signature_type(job.f, job.tt)::Type
69+
meth = which(f, job.tt)
70+
sig = Base.signature_type(f, job.tt)::Type
7071
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
7172
(Any, Any), sig, meth.sig)::Core.SimpleVector
7273
if VERSION >= v"1.2.0-DEV.320"

src/context.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
##
2+
# Implements contextual dispatch through Cassette.jl
3+
# Goals:
4+
# - Rewrite common CPU functions to appropriate GPU intrinsics
5+
#
6+
# TODO:
7+
# - error (erf, ...)
8+
# - pow
9+
# - min, max
10+
# - mod, rem
11+
# - gamma
12+
# - bessel
13+
# - distributions
14+
# - unsorted
15+
16+
using Cassette
17+
18+
function transform(ctx, ref)
19+
CI = ref.code_info
20+
noinline = any(@nospecialize(x) ->
21+
Core.Compiler.isexpr(x, :meta) &&
22+
x.args[1] == :noinline,
23+
CI.code)
24+
CI.inlineable = !noinline
25+
26+
CI.ssavaluetypes = length(CI.code)
27+
# Core.Compiler.validate_code(CI)
28+
return CI
29+
end
30+
31+
const InlinePass = Cassette.@pass transform
32+
33+
Cassette.@context CUDACtx
34+
const cudactx = Cassette.disablehooks(CUDACtx(pass = InlinePass))
35+
36+
###
37+
# Cassette fixes
38+
###
39+
40+
# kwfunc fix
41+
Cassette.overdub(::CUDACtx, ::typeof(Core.kwfunc), f) = return Core.kwfunc(f)
42+
43+
# the functions below are marked `@pure` and by rewritting them we hide that from
44+
# inference so we leave them alone (see https://github.com/jrevels/Cassette.jl/issues/108).
45+
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isimmutable), x) = return Base.isimmutable(x)
46+
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isstructtype), t) = return Base.isstructtype(t)
47+
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isprimitivetype), t) = return Base.isprimitivetype(t)
48+
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isbitstype), t) = return Base.isbitstype(t)
49+
@inline Cassette.overdub(::CUDACtx, ::typeof(Base.isbits), x) = return Base.isbits(x)
50+
51+
@inline Cassette.overdub(::CUDACtx, ::typeof(datatype_align), ::Type{T}) where {T} = datatype_align(T)
52+
53+
###
54+
# Rewrite functions
55+
###
56+
Cassette.overdub(ctx::CUDACtx, ::typeof(isdevice)) = true
57+
58+
# libdevice.jl
59+
for f in (:cos, :cospi, :sin, :sinpi, :tan,
60+
:acos, :asin, :atan,
61+
:cosh, :sinh, :tanh,
62+
:acosh, :asinh, :atanh,
63+
:log, :log10, :log1p, :log2,
64+
:exp, :exp2, :exp10, :expm1, :ldexp,
65+
:isfinite, :isinf, :isnan,
66+
:signbit, :abs,
67+
:sqrt, :cbrt,
68+
:ceil, :floor,)
69+
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
70+
@Base._inline_meta
71+
return CUDAnative.$f(x)
72+
end
73+
end
74+
75+
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)

test/device/execution.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ end
463463
val_dev = CuArray(val)
464464
cuda_ptr = pointer(val_dev)
465465
ptr = CUDAnative.DevicePtr{Int}(cuda_ptr)
466-
for i in (1, 10, 20, 35)
466+
for i in (1, 10, 20, 32)
467467
variables = ('a':'z'..., 'A':'Z'...)
468468
params = [Symbol(variables[j]) for j in 1:i]
469469
# generate a kernel
@@ -553,11 +553,11 @@ let (code, out, err) = julia_script(script, `-g2`)
553553
@test occursin("ERROR: KernelException: exception thrown during kernel execution on device", err)
554554
@test occursin("ERROR: a exception was thrown during kernel execution", out)
555555
if VERSION < v"1.3.0-DEV.270"
556-
@test occursin("[1] Type at float.jl", out)
556+
@test occursin(r"\[.\] Type at float.jl", out)
557557
else
558-
@test occursin("[1] Int64 at float.jl", out)
558+
@test occursin(r"\[.\] Int64 at float.jl", out)
559559
end
560-
@test occursin("[2] kernel at none:2", out)
560+
@test occursin(r"\[.\] kernel at none:2", out)
561561
end
562562

563563
end

0 commit comments

Comments
 (0)