Skip to content

Commit af65b30

Browse files
committed
fix: offload all pathces to Optimisers
1 parent 2e6d90b commit af65b30

File tree

4 files changed

+16
-46
lines changed

4 files changed

+16
-46
lines changed

ext/LuxReactantExt/LuxReactantExt.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module LuxReactantExt
22

33
using Enzyme: Enzyme, Const
44
using Optimisers: Optimisers
5-
using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber
5+
using Reactant:
6+
Reactant, @compile, @code_hlo, @jit, AnyTracedRArray, TracedRArray, TracedRNumber
67
using ReactantCore: ReactantCore, @trace
78
using Setfield: @set!
89
using Static: True, False

ext/LuxReactantExt/patches.jl

+3-16
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,7 @@ Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(ve
33
# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
44
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g
55

6-
# Optimisers AccumGrad
7-
function Optimisers.apply!(opt::Lux.ReactantCompatibleOptimisers.AccumGrad, state, x, dx)
8-
accum_dx, counter = state
9-
@trace if counter == 1
10-
@. accum_dx = dx / opt.n
11-
else
12-
@. accum_dx += dx / opt.n
13-
end
14-
@trace if counter == opt.n
15-
dx_final = dx
16-
counter = 1
17-
else
18-
dx_final = zero.(dx)
19-
counter += 1
20-
end
21-
return (accum_dx, counter), dx_final
6+
# Optimisers setup
7+
function Lux.Training.optimisers_setup_with_jit(opt, ps)
8+
return @jit Optimisers.setup(opt, ps)
229
end

src/helpers/optimizers.jl

+5-23
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# We can remove this once https://github.com/FluxML/Optimisers.jl/issues/205 is resolved.
33
module ReactantCompatibleOptimisers
44

5-
using ConcreteStructs: @concrete
6-
using Functors: fmap
75
using Optimisers: Optimisers, AbstractRule
86

97
using ..Lux: Lux, Utils
@@ -21,32 +19,16 @@ function make_reactant_compatible(leaf::Optimisers.Leaf{<:AbstractRule})
2119
return Optimisers.Leaf(rule, state, leaf.frozen)
2220
end
2321

24-
function make_reactant_compatible(opt::Optimisers.OptimiserChain, state)
25-
res = make_reactant_compatible.(opt.opts, state)
26-
new_opts = first.(res)
27-
new_state = last.(res)
28-
return Optimisers.OptimiserChain(new_opts...), new_state
22+
function make_reactant_compatible(opt::Optimisers.OptimiserChain)
23+
return Optimisers.OptimiserChain(make_reactant_compatible.(opt.opts)...)
2924
end
3025

31-
function make_reactant_compatible(opt::Optimisers.AbstractRule, state)
32-
return (
33-
Utils.to_rarray(opt; track_numbers = AbstractFloat),
34-
Utils.to_rarray(state; track_numbers = AbstractFloat),
35-
)
26+
function make_reactant_compatible(opt::Optimisers.AbstractRule)
27+
return Utils.to_rarray(opt; track_numbers = AbstractFloat)
3628
end
3729

3830
function make_reactant_compatible(opt::Optimisers.AccumGrad, state)
39-
return (
40-
AccumGrad(Utils.to_rarray(opt.n; track_numbers = Integer)),
41-
Utils.to_rarray(state; track_numbers = Integer),
42-
)
31+
return Utils.to_rarray(opt.n; track_numbers = Integer)
4332
end
4433

45-
@concrete struct AccumGrad <: AbstractRule
46-
n
47-
end
48-
49-
# XXX: the counter needs to match the client / device?
50-
Optimisers.init(::AccumGrad, x) = zero(x), Utils.to_rarray(1; track_numbers = Integer)
51-
5234
end

src/helpers/training.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,17 @@ Constructor for [`TrainState`](@ref).
6363
[`TrainState`](@ref) object.
6464
"""
6565
function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
66-
st_opt = Optimisers.setup(optimizer, ps)
6766
if get_device_type(ps) <: ReactantDevice
68-
st_opt = fmap(
69-
ReactantCompatibleOptimisers.make_reactant_compatible,
70-
st_opt;
71-
exclude = Base.Fix2(isa, Optimisers.Leaf)
72-
)
67+
optimizer = ReactantCompatibleOptimisers.make_reactant_compatible(optimizer)
68+
st_opt = optimisers_setup_with_jit(optimizer, ps)
69+
else
70+
st_opt = Optimisers.setup(optimizer, ps)
7371
end
7472
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
7573
end
7674

75+
function optimisers_setup_with_jit end
76+
7777
@concrete struct TrainingBackendCache
7878
backend
7979
first_try <: StaticBool

0 commit comments

Comments
 (0)