Skip to content

Commit e59ab6f

Browse files
wsmosesvchuravy
andauthored
Enzyme support older versions (#537)
Co-authored-by: Valentin Churavy <[email protected]>
1 parent dd3044a commit e59ab6f

File tree

6 files changed

+671
-326
lines changed

6 files changed

+671
-326
lines changed

Diff for: .github/workflows/ci.yml

+1-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
show-versioninfo: true
4949
- uses: julia-actions/cache@v2
5050
- run: |
51-
julia -e '@static if VERSION >= v"1.10"
51+
julia -e '
5252
using Pkg
5353
withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do
5454
Pkg.activate("test")
@@ -61,11 +61,8 @@ jobs:
6161
try
6262
Pkg.develop([PackageSpec("Enzyme"), PackageSpec("EnzymeCore")])
6363
catch err
64-
@error "Could not install Enzyme" exception=(err,catch_backtrace())
65-
exit(3)
6664
end
6765
end
68-
end
6966
'
7067
- uses: julia-actions/julia-buildpkg@v1
7168
- uses: julia-actions/julia-runtest@v1

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
2121
[compat]
2222
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
2323
Atomix = "0.1"
24-
EnzymeCore = "0.8.1"
24+
EnzymeCore = "0.7, 0.8.1"
2525
InteractiveUtils = "1.6"
2626
LinearAlgebra = "1.6"
2727
MacroTools = "0.5"

Diff for: ext/EnzymeCore07Ext.jl

+342
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
# https://github.com/EnzymeAD/Enzyme.jl/issues/1516
2+
# On the CPU `autodiff_deferred` can deadlock.
3+
# Hence a specialized CPU version
4+
function cpu_fwd(ctx, f, args...)
5+
EnzymeCore.autodiff(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
6+
return nothing
7+
end
8+
9+
function gpu_fwd(ctx, f, args...)
10+
EnzymeCore.autodiff_deferred(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
11+
return nothing
12+
end
13+
14+
function EnzymeRules.forward(
15+
func::Const{<:Kernel{CPU}},
16+
::Type{Const{Nothing}},
17+
args...;
18+
ndrange = nothing,
19+
workgroupsize = nothing,
20+
)
21+
kernel = func.val
22+
f = kernel.f
23+
fwd_kernel = similar(kernel, cpu_fwd)
24+
25+
fwd_kernel(f, args...; ndrange, workgroupsize)
26+
end
27+
28+
function EnzymeRules.forward(
29+
func::Const{<:Kernel{<:GPU}},
30+
::Type{Const{Nothing}},
31+
args...;
32+
ndrange = nothing,
33+
workgroupsize = nothing,
34+
)
35+
kernel = func.val
36+
f = kernel.f
37+
fwd_kernel = similar(kernel, gpu_fwd)
38+
39+
fwd_kernel(f, args...; ndrange, workgroupsize)
40+
end
41+
42+
_enzyme_mkcontext(kernel::Kernel{CPU}, ndrange, iterspace, dynamic) =
43+
mkcontext(kernel, first(blocks(iterspace)), ndrange, iterspace, dynamic)
44+
_enzyme_mkcontext(kernel::Kernel{<:GPU}, ndrange, iterspace, dynamic) =
45+
mkcontext(kernel, ndrange, iterspace)
46+
47+
_augmented_return(::Kernel{CPU}, subtape, arg_refs, tape_type) =
48+
AugmentedReturn{Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}}(
49+
nothing,
50+
nothing,
51+
(subtape, arg_refs, tape_type),
52+
)
53+
_augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =
54+
AugmentedReturn{Nothing, Nothing, Any}(nothing, nothing, (subtape, arg_refs, tape_type))
55+
56+
function _create_tape_kernel(
57+
kernel::Kernel{CPU},
58+
ModifiedBetween,
59+
FT,
60+
ctxTy,
61+
ndrange,
62+
iterspace,
63+
args2...,
64+
)
65+
TapeType = EnzymeCore.tape_type(
66+
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
67+
FT,
68+
Const{Nothing},
69+
Const{ctxTy},
70+
map(Core.Typeof, args2)...,
71+
)
72+
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
73+
aug_kernel = similar(kernel, cpu_aug_fwd)
74+
return TapeType, subtape, aug_kernel
75+
end
76+
77+
function _create_tape_kernel(
78+
kernel::Kernel{<:GPU},
79+
ModifiedBetween,
80+
FT,
81+
ctxTy,
82+
ndrange,
83+
iterspace,
84+
args2...,
85+
)
86+
# For peeking at the TapeType we need to first construct a correct compilation job
87+
# this requires the use of the device side representation of arguments.
88+
# So we convert the arguments here, this is a bit wasteful since the `aug_kernel` call
89+
# will later do the same.
90+
dev_args2 = ((argconvert(kernel, a) for a in args2)...,)
91+
dev_TT = map(Core.Typeof, dev_args2)
92+
93+
job =
94+
EnzymeCore.compiler_job_from_backend(backend(kernel), typeof(() -> return), Tuple{})
95+
TapeType = EnzymeCore.tape_type(
96+
job,
97+
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
98+
FT,
99+
Const{Nothing},
100+
Const{ctxTy},
101+
dev_TT...,
102+
)
103+
104+
# Allocate per thread
105+
subtape = allocate(backend(kernel), TapeType, prod(ndrange))
106+
107+
aug_kernel = similar(kernel, gpu_aug_fwd)
108+
return TapeType, subtape, aug_kernel
109+
end
110+
111+
_create_rev_kernel(kernel::Kernel{CPU}) = similar(kernel, cpu_rev)
112+
_create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)
113+
114+
function cpu_aug_fwd(
115+
ctx,
116+
f::FT,
117+
::Val{ModifiedBetween},
118+
subtape,
119+
::Val{TapeType},
120+
args...,
121+
) where {ModifiedBetween, FT, TapeType}
122+
# A2 = Const{Nothing} -- since f->Nothing
123+
forward, _ = EnzymeCore.autodiff_thunk(
124+
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
125+
Const{Core.Typeof(f)},
126+
Const{Nothing},
127+
Const{Core.Typeof(ctx)},
128+
map(Core.Typeof, args)...,
129+
)
130+
131+
# On the CPU: F is a per block function
132+
# On the CPU: subtape::Vector{Vector}
133+
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
134+
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
135+
return nothing
136+
end
137+
138+
function cpu_rev(
139+
ctx,
140+
f::FT,
141+
::Val{ModifiedBetween},
142+
subtape,
143+
::Val{TapeType},
144+
args...,
145+
) where {ModifiedBetween, FT, TapeType}
146+
_, reverse = EnzymeCore.autodiff_thunk(
147+
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
148+
Const{Core.Typeof(f)},
149+
Const{Nothing},
150+
Const{Core.Typeof(ctx)},
151+
map(Core.Typeof, args)...,
152+
)
153+
I = __index_Group_Cartesian(ctx, CartesianIndex(1, 1)) #=fake=#
154+
tp = subtape[I]
155+
reverse(Const(f), Const(ctx), args..., tp)
156+
return nothing
157+
end
158+
159+
# GPU support
160+
function gpu_aug_fwd(
161+
ctx,
162+
f::FT,
163+
::Val{ModifiedBetween},
164+
subtape,
165+
::Val{TapeType},
166+
args...,
167+
) where {ModifiedBetween, FT, TapeType}
168+
# A2 = Const{Nothing} -- since f->Nothing
169+
forward, _ = EnzymeCore.autodiff_deferred_thunk(
170+
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
171+
TapeType,
172+
Const{Core.Typeof(f)},
173+
Const{Nothing},
174+
Const{Core.Typeof(ctx)},
175+
map(Core.Typeof, args)...,
176+
)
177+
178+
# On the GPU: F is a per thread function
179+
# On the GPU: subtape::Vector
180+
if __validindex(ctx)
181+
I = __index_Global_Linear(ctx)
182+
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
183+
end
184+
return nothing
185+
end
186+
187+
function gpu_rev(
188+
ctx,
189+
f::FT,
190+
::Val{ModifiedBetween},
191+
subtape,
192+
::Val{TapeType},
193+
args...,
194+
) where {ModifiedBetween, FT, TapeType}
195+
# XXX: TapeType and A2 as args to autodiff_deferred_thunk
196+
_, reverse = EnzymeCore.autodiff_deferred_thunk(
197+
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
198+
TapeType,
199+
Const{Core.Typeof(f)},
200+
Const{Nothing},
201+
Const{Core.Typeof(ctx)},
202+
map(Core.Typeof, args)...,
203+
)
204+
if __validindex(ctx)
205+
I = __index_Global_Linear(ctx)
206+
tp = subtape[I]
207+
reverse(Const(f), Const(ctx), args..., tp)
208+
end
209+
return nothing
210+
end
211+
212+
function EnzymeRules.augmented_primal(
213+
config::Config,
214+
func::Const{<:Kernel},
215+
::Type{Const{Nothing}},
216+
args::Vararg{Any, N};
217+
ndrange = nothing,
218+
workgroupsize = nothing,
219+
) where {N}
220+
kernel = func.val
221+
f = kernel.f
222+
223+
ndrange, workgroupsize, iterspace, dynamic =
224+
launch_config(kernel, ndrange, workgroupsize)
225+
ctx = _enzyme_mkcontext(kernel, ndrange, iterspace, dynamic)
226+
ctxTy = Core.Typeof(ctx) # CompilerMetadata{ndrange(kernel), Core.Typeof(dynamic)}
227+
# TODO autodiff_deferred on the func.val
228+
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
229+
230+
FT = Const{Core.Typeof(f)}
231+
232+
arg_refs = ntuple(Val(N)) do i
233+
Base.@_inline_meta
234+
if args[i] isa Active
235+
if func.val isa Kernel{<:GPU}
236+
error("Active kernel arguments not supported on GPU")
237+
else
238+
Ref(EnzymeCore.make_zero(args[i].val))
239+
end
240+
else
241+
nothing
242+
end
243+
end
244+
args2 = ntuple(Val(N)) do i
245+
Base.@_inline_meta
246+
if args[i] isa Active
247+
MixedDuplicated(args[i].val, arg_refs[i])
248+
else
249+
args[i]
250+
end
251+
end
252+
253+
TapeType, subtape, aug_kernel = _create_tape_kernel(
254+
kernel,
255+
ModifiedBetween,
256+
FT,
257+
ctxTy,
258+
ndrange,
259+
iterspace,
260+
args2...,
261+
)
262+
aug_kernel(f, ModifiedBetween, subtape, Val(TapeType), args2...; ndrange, workgroupsize)
263+
264+
# TODO the fact that ctxTy is type unstable means this is all type unstable.
265+
# Since custom rules require a fixed return type, explicitly cast to Any, rather
266+
# than returning a AugmentedReturn{Nothing, Nothing, T} where T.
267+
return _augmented_return(kernel, subtape, arg_refs, TapeType)
268+
end
269+
270+
function EnzymeRules.reverse(
271+
config::Config,
272+
func::Const{<:Kernel},
273+
::Type{<:EnzymeCore.Annotation},
274+
tape,
275+
args::Vararg{Any, N};
276+
ndrange = nothing,
277+
workgroupsize = nothing,
278+
) where {N}
279+
subtape, arg_refs, tape_type = tape
280+
281+
args2 = ntuple(Val(N)) do i
282+
Base.@_inline_meta
283+
if args[i] isa Active
284+
MixedDuplicated(args[i].val, arg_refs[i])
285+
else
286+
args[i]
287+
end
288+
end
289+
290+
kernel = func.val
291+
f = kernel.f
292+
293+
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
294+
295+
rev_kernel = _create_rev_kernel(kernel)
296+
rev_kernel(
297+
f,
298+
ModifiedBetween,
299+
subtape,
300+
Val(tape_type),
301+
args2...;
302+
ndrange,
303+
workgroupsize,
304+
)
305+
res = ntuple(Val(N)) do i
306+
Base.@_inline_meta
307+
if args[i] isa Active
308+
arg_refs[i][]
309+
else
310+
nothing
311+
end
312+
end
313+
# Reverse synchronization right after the kernel launch
314+
synchronize(backend(kernel))
315+
return res
316+
end
317+
318+
# Synchronize rules
319+
# TODO: Right now we do the synchronization as part of the kernel launch in the augmented primal
320+
# and reverse rules. This is not ideal, as we would want to launch the kernel in the reverse
321+
# synchronize rule and then synchronize where the launch was. However, with the current
322+
# kernel semantics this ensures correctness for now.
323+
function EnzymeRules.augmented_primal(
324+
config::Config,
325+
func::Const{typeof(synchronize)},
326+
::Type{Const{Nothing}},
327+
backend::T,
328+
) where {T <: EnzymeCore.Annotation}
329+
synchronize(backend.val)
330+
return AugmentedReturn(nothing, nothing, nothing)
331+
end
332+
333+
function EnzymeRules.reverse(
334+
config::Config,
335+
func::Const{typeof(synchronize)},
336+
::Type{Const{Nothing}},
337+
tape,
338+
backend,
339+
)
340+
# noop for now
341+
return (nothing,)
342+
end

0 commit comments

Comments
 (0)