Skip to content

Commit 75f1874

Browse files
committed
Use macro for shared caches
1 parent dccc1dd commit 75f1874

File tree

9 files changed

+72
-78
lines changed

9 files changed

+72
-78
lines changed

src/algorithms/multistep.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,
22
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing,
33
vjp_autodiff = nothing, linesearch = NoLineSearch())
4-
scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff))
4+
forward_ad = ifelse(autodiff isa ADTypes.AbstractForwardMode, autodiff, nothing)
5+
scheme_concrete = apply_patch(
6+
scheme, (; autodiff, vjp_autodiff, jvp_autodiff = forward_ad))
57
descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs)
68
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme),
7-
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff)
9+
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff, forward_ad)
810
end

src/descent/damped_newton.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,7 @@ function __internal_init(
5858
shared::Val{N} = Val(1), kwargs...) where {INV, N}
5959
length(fu) != length(u) &&
6060
@assert !INV "Precomputed Inverse for Non-Square Jacobian doesn't make sense."
61-
@bb δu = similar(u)
62-
δus = N 1 ? nothing : map(2:N) do i
63-
@bb δu_ = similar(u)
64-
end
65-
61+
δu, δus = @shared_caches N (@bb δu = similar(u))
6662
normal_form_damping = returns_norm_form_damping(alg.damping_fn)
6763
normal_form_linsolve = __needs_square_A(alg.linsolve, u)
6864
if u isa Number

src/descent/dogleg.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u;
5656
linsolve_kwargs, abstol, reltol, shared, kwargs...)
5757
cauchy_cache = __internal_init(prob, alg.steepest_descent, J, fu, u; pre_inverted,
5858
linsolve_kwargs, abstol, reltol, shared, kwargs...)
59-
@bb δu = similar(u)
60-
δus = N 1 ? nothing : map(2:N) do i
61-
@bb δu_ = similar(u)
62-
end
59+
δu, δus = @shared_caches N (@bb δu = similar(u))
6360
@bb δu_cache_1 = similar(u)
6461
@bb δu_cache_2 = similar(u)
6562
@bb δu_cache_mul = similar(u)

src/descent/geodesic_acceleration.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::GeodesicAccelerati
8989
abstol = nothing, reltol = nothing, internalnorm::F = DEFAULT_NORM,
9090
kwargs...) where {INV, N, F}
9191
T = promote_type(eltype(u), eltype(fu))
92-
@bb δu = similar(u)
93-
δus = N 1 ? nothing : map(2:N) do i
94-
@bb δu_ = similar(u)
95-
end
92+
δu, δus = @shared_caches N (@bb δu = similar(u))
9693
descent_cache = __internal_init(prob, alg.descent, J, fu, u; shared = Val(N * 2),
9794
pre_inverted, linsolve_kwargs, abstol, reltol, kwargs...)
9895
@bb Jv = similar(fu)

src/descent/multistep.jl

+24-32
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,36 @@ function Base.show(io::IO, mss::AbstractMultiStepScheme)
1515
print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])")
1616
end
1717

18-
alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T())
18+
newton_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = newton_steps(T())
1919

2020
struct __PotraPtak3 <: AbstractMultiStepScheme end
2121
const PotraPtak3 = __PotraPtak3()
2222

23-
alg_steps(::__PotraPtak3) = 2
23+
newton_steps(::__PotraPtak3) = 2
2424
nintermediates(::__PotraPtak3) = 1
2525

2626
@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
2727
jvp_autodiff = nothing
2828
end
2929
const SinghSharma4 = __SinghSharma4()
3030

31-
alg_steps(::__SinghSharma4) = 3
31+
newton_steps(::__SinghSharma4) = 4
32+
nintermediates(::__SinghSharma4) = 2
3233

3334
@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
3435
jvp_autodiff = nothing
3536
end
3637
const SinghSharma5 = __SinghSharma5()
3738

38-
alg_steps(::__SinghSharma5) = 3
39+
newton_steps(::__SinghSharma5) = 4
40+
nintermediates(::__SinghSharma5) = 2
3941

4042
@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
4143
jvp_autodiff = nothing
4244
end
4345
const SinghSharma7 = __SinghSharma7()
4446

45-
alg_steps(::__SinghSharma7) = 4
47+
newton_steps(::__SinghSharma7) = 6
4648

4749
@generated function display_name(alg::T) where {T <: AbstractMultiStepScheme}
4850
res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end])
@@ -75,6 +77,8 @@ supports_trust_region(::GenericMultiStepDescent) = false
7577
fus
7678
internal_cache
7779
internal_caches
80+
extra
81+
extras
7882
scheme::S
7983
timer
8084
nf::Int
@@ -91,49 +95,37 @@ function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = ca
9195
end
9296

9397
function __internal_multistep_caches(
94-
scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent,
95-
prob, args...; shared::Val{N} = Val(1), kwargs...) where {N}
98+
scheme::Union{MSS.__PotraPtak3, MSS.__SinghSharma4, MSS.__SinghSharma5},
99+
alg::GenericMultiStepDescent, prob, args...;
100+
shared::Val{N} = Val(1), kwargs...) where {N}
96101
internal_descent = NewtonDescent(; alg.linsolve, alg.precs)
97-
internal_cache = __internal_init(
102+
return @shared_caches N __internal_init(
98103
prob, internal_descent, args...; kwargs..., shared = Val(2))
99-
internal_caches = N 1 ? nothing :
100-
map(2:N) do i
101-
__internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2))
102-
end
103-
return internal_cache, internal_caches
104104
end
105105

106+
__extras_cache(::MSS.AbstractMultiStepScheme, args...; kwargs...) = nothing, nothing
107+
106108
function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
107109
alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1),
108110
pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
109111
abstol = nothing, reltol = nothing, timer = get_timer_output(),
110112
kwargs...) where {INV, N}
111-
@bb δu = similar(u)
112-
δus = N 1 ? nothing : map(2:N) do i
113-
@bb δu_ = similar(u)
114-
end
115-
fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
113+
δu, δus = @shared_caches N (@bb δu = similar(u))
114+
fu_cache, fus_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i
116115
@bb xx = similar(fu)
117-
end
118-
fus_cache = N 1 ? nothing : map(2:N) do i
119-
ntuple(MSS.nintermediates(alg.scheme)) do j
120-
@bb xx = similar(fu)
121-
end
122-
end
123-
u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
116+
end)
117+
u_cache, us_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i
124118
@bb xx = similar(u)
125-
end
126-
us_cache = N 1 ? nothing : map(2:N) do i
127-
ntuple(MSS.nintermediates(alg.scheme)) do j
128-
@bb xx = similar(u)
129-
end
130-
end
119+
end)
131120
internal_cache, internal_caches = __internal_multistep_caches(
132121
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
133122
abstol, reltol, timer, kwargs...)
123+
extra, extras = __extras_cache(
124+
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
125+
abstol, reltol, timer, kwargs...)
134126
return GenericMultiStepDescentCache(
135127
prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
136-
internal_cache, internal_caches, alg.scheme, timer, 0)
128+
internal_cache, internal_caches, extra, extras, alg.scheme, timer, 0)
137129
end
138130

139131
function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J,

src/descent/newton.jl

+2-8
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,7 @@ function __internal_init(prob::NonlinearProblem, alg::NewtonDescent, J, fu, u;
3636
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
3737
abstol = nothing, reltol = nothing, timer = get_timer_output(),
3838
kwargs...) where {INV, N}
39-
@bb δu = similar(u)
40-
δus = N 1 ? nothing : map(2:N) do i
41-
@bb δu_ = similar(u)
42-
end
39+
δu, δus = @shared_caches N (@bb δu = similar(u))
4340
INV && return NewtonDescentCache{true, false}(δu, δus, nothing, nothing, nothing, timer)
4441
lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol,
4542
linsolve_kwargs...)
@@ -64,10 +61,7 @@ function __internal_init(prob::NonlinearLeastSquaresProblem, alg::NewtonDescent,
6461
end
6562
lincache = LinearSolverCache(alg, alg.linsolve, A, b, _vec(u); abstol, reltol,
6663
linsolve_kwargs...)
67-
@bb δu = similar(u)
68-
δus = N 1 ? nothing : map(2:N) do i
69-
@bb δu_ = similar(u)
70-
end
64+
δu, δus = @shared_caches N (@bb δu = similar(u))
7165
return NewtonDescentCache{false, normal_form}(δu, δus, lincache, JᵀJ, Jᵀfu, timer)
7266
end
7367

src/descent/steepest.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ end
3434
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
3535
timer = get_timer_output(), kwargs...) where {INV, N}
3636
INV && @assert length(fu)==length(u) "Non-Square Jacobian Inverse doesn't make sense."
37-
@bb δu = similar(u)
38-
δus = N 1 ? nothing : map(2:N) do i
39-
@bb δu_ = similar(u)
40-
end
37+
δu, δus = @shared_caches N (@bb δu = similar(u))
4138
if INV
4239
lincache = LinearSolverCache(alg, alg.linsolve, transpose(J), _vec(fu), _vec(u);
4340
abstol, reltol, linsolve_kwargs...)

src/internal/helpers.jl

+38
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,41 @@ function __internal_caches(__source__, __module__, cType, internal_cache_names::
268268
end
269269
end)
270270
end
271+
272+
"""
273+
apply_patch(scheme, patch::NamedTuple{names})
274+
275+
Applies the patch to the scheme, returning the new scheme. If some of the `names` are not,
276+
present in the scheme, they are ignored.
277+
"""
278+
@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names}
279+
exprs = []
280+
for name in names
281+
hasfield(scheme, name) || continue
282+
push!(exprs, quote
283+
lens = PropertyLens{$(Meta.quot(name))}()
284+
return set(scheme, lens, getfield(patch, $(Meta.quot(name))))
285+
end)
286+
end
287+
push!(exprs, :(return scheme))
288+
return Expr(:block, exprs...)
289+
end
290+
291+
"""
292+
@shared_caches N expr
293+
294+
Create a shared cache and a vector of caches. If `N` is 1, then the vector of caches is
295+
`nothing`.
296+
"""
297+
macro shared_caches(N, expr)
298+
@gensym cache caches
299+
return esc(quote
300+
begin
301+
$(cache) = $(expr)
302+
$(caches) = $(N) 1 ? nothing : map(2:($(N))) do i
303+
$(expr)
304+
end
305+
($cache, $caches)
306+
end
307+
end)
308+
end

src/utils.jl

-19
Original file line numberDiff line numberDiff line change
@@ -158,22 +158,3 @@ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the i
158158
"""
159159
@inline pickchunksize(x) = pickchunksize(length(x))
160160
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)
161-
162-
"""
163-
apply_patch(scheme, patch::NamedTuple{names})
164-
165-
Applies the patch to the scheme, returning the new scheme. If some of the `names` are not,
166-
present in the scheme, they are ignored.
167-
"""
168-
@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names}
169-
exprs = []
170-
for name in names
171-
hasfield(scheme, name) || continue
172-
push!(exprs, quote
173-
lens = PropertyLens{$(Meta.quot(name))}()
174-
return set(scheme, lens, getfield(patch, $(Meta.quot(name))))
175-
end)
176-
end
177-
push!(exprs, :(return scheme))
178-
return Expr(:block, exprs...)
179-
end

0 commit comments

Comments
 (0)