Skip to content

Commit 1d50a2b

Browse files
Merge pull request #1733 from SciML/exponential
Fix interface breaks with Exponential integrators
2 parents 92d4329 + 0aef92f commit 1d50a2b

File tree

3 files changed

+81
-46
lines changed

3 files changed

+81
-46
lines changed

src/alg_utils.jl

+52-16
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ SciMLBase.allows_arbitrary_number_types(alg::Union{OrdinaryDiffEqAlgorithm,DAEAl
55
SciMLBase.allowscomplex(alg::Union{OrdinaryDiffEqAlgorithm,DAEAlgorithm,FunctionMap}) = true
66
SciMLBase.isdiscrete(alg::FunctionMap) = true
77
SciMLBase.forwarddiffs_model(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm,
8-
DAEAlgorithm,OrdinaryDiffEqImplicitAlgorithm,
9-
ExponentialAlgorithm}) = alg_autodiff(alg)
8+
DAEAlgorithm,OrdinaryDiffEqImplicitAlgorithm,
9+
ExponentialAlgorithm}) = alg_autodiff(alg)
1010
SciMLBase.forwarddiffs_model_time(alg::RosenbrockAlgorithm) = true
1111

1212
# isadaptive is defined below.
@@ -162,21 +162,30 @@ get_chunksize(alg::OrdinaryDiffEqAlgorithm) = error("This algorithm does not hav
162162
get_chunksize(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD}) where {CS,AD} = Val(CS)
163163
get_chunksize(alg::OrdinaryDiffEqImplicitAlgorithm{CS,AD}) where {CS,AD} = Val(CS)
164164
get_chunksize(alg::DAEAlgorithm{CS,AD}) where {CS,AD} = Val(CS)
165-
get_chunksize(alg::ExponentialAlgorithm) = Val(alg.chunksize)
165+
function get_chunksize(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS,AD},
166+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD}}) where {CS,AD}
167+
Val(CS)
168+
end
166169

167170
get_chunksize_int(alg::OrdinaryDiffEqAlgorithm) = error("This algorithm does not have a chunk size defined.")
168171
get_chunksize_int(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD}) where {CS,AD} = CS
169172
get_chunksize_int(alg::OrdinaryDiffEqImplicitAlgorithm{CS,AD}) where {CS,AD} = CS
170173
get_chunksize_int(alg::DAEAlgorithm{CS,AD}) where {CS,AD} = CS
171-
get_chunksize_int(alg::ExponentialAlgorithm) = alg.chunksize
174+
function get_chunksize_int(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS,AD},
175+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD}}) where {CS,AD}
176+
CS
177+
end
172178
# get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg])
173179

174180
function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm{0,AD,FDT},
175181
OrdinaryDiffEqImplicitAlgorithm{0,AD,FDT},
176-
DAEAlgorithm{0,AD,FDT}}, u0::AbstractArray{T}, p, prob) where {AD,FDT,T}
182+
DAEAlgorithm{0,AD,FDT},OrdinaryDiffEqExponentialAlgorithm{0,AD,FDT}}, u0::AbstractArray{T},
183+
p, prob) where {AD,FDT,T}
177184
alg isa OrdinaryDiffEqImplicitExtrapolationAlgorithm && return alg # remake fails, should get fixed
178185

179-
if alg.linsolve === nothing
186+
if alg isa OrdinaryDiffEqExponentialAlgorithm
187+
linsolve = nothing
188+
elseif alg.linsolve === nothing
180189
if (prob.f isa ODEFunction && prob.f.f isa SciMLBase.AbstractDiffEqOperator)
181190
linsolve = LinearSolve.defaultalg(prob.f.f, u0)
182191
elseif (prob.f isa SplitFunction && prob.f.f1.f isa SciMLBase.AbstractDiffEqOperator)
@@ -187,8 +196,8 @@ function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorit
187196
linsolve = KrylovJL()
188197
end
189198
elseif (prob isa ODEProblem || prob isa DDEProblem) && (prob.f.mass_matrix === nothing ||
190-
(prob.f.mass_matrix !== nothing &&
191-
!(typeof(prob.f.jac_prototype) <: SciMLBase.AbstractDiffEqOperator)))
199+
(prob.f.mass_matrix !== nothing &&
200+
!(typeof(prob.f.jac_prototype) <: SciMLBase.AbstractDiffEqOperator)))
192201
linsolve = LinearSolve.defaultalg(prob.f.jac_prototype, u0)
193202
else
194203
# If mm is a sparse matrix and A is a DiffEqArrayOperator, then let linear
@@ -202,8 +211,12 @@ function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorit
202211
# If norecompile mode or very large bitsize, like a dual number u0 already, then
203212
# don't use a large chunksize as it will either error or not be beneficial
204213
if (isbitstype(T) && sizeof(T) > 24) || (prob.f isa ODEFunction && prob.f.f isa
205-
FunctionWrappersWrappers.FunctionWrappersWrapper)
206-
return remake(alg, chunk_size=Val{1}(), linsolve=linsolve)
214+
FunctionWrappersWrappers.FunctionWrappersWrapper)
215+
if alg isa OrdinaryDiffEqExponentialAlgorithm
216+
return remake(alg, chunk_size=Val{1}())
217+
else
218+
return remake(alg, chunk_size=Val{1}(), linsolve=linsolve)
219+
end
207220
end
208221

209222
L = ArrayInterface.known_length(typeof(u0))
@@ -218,13 +231,29 @@ function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorit
218231
end
219232

220233
cs = ForwardDiff.pickchunksize(x)
221-
remake(alg, chunk_size=Val{cs}(), linsolve=linsolve)
234+
235+
if alg isa OrdinaryDiffEqExponentialAlgorithm
236+
return remake(alg, chunk_size=Val{cs}())
237+
else
238+
return remake(alg, chunk_size=Val{cs}(), linsolve=linsolve)
239+
end
222240
else # statically sized
223241
cs = pick_static_chunksize(Val{L}())
224-
remake(alg, chunk_size=cs, linsolve=linsolve)
242+
if alg isa OrdinaryDiffEqExponentialAlgorithm
243+
return remake(alg, chunk_size=cs)
244+
else
245+
return remake(alg, chunk_size=cs, linsolve=linsolve)
246+
end
225247
end
226248
end
227249

250+
# Linear Exponential doesn't have any of the AD stuff
251+
function DiffEqBase.prepare_alg(alg::Union{ETD2,SplitEuler,LinearExponential,
252+
OrdinaryDiffEqLinearExponentialAlgorithm}, u0::AbstractArray,
253+
p, prob)
254+
alg
255+
end
256+
228257
@generated function pick_static_chunksize(::Val{chunksize}) where {chunksize}
229258
x = ForwardDiff.pickchunksize(chunksize)
230259
:(Val{$x}())
@@ -239,24 +268,31 @@ alg_autodiff(alg::OrdinaryDiffEqAlgorithm) = error("This algorithm does not have
239268
alg_autodiff(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD}) where {CS,AD} = AD
240269
alg_autodiff(alg::DAEAlgorithm{CS,AD}) where {CS,AD} = AD
241270
alg_autodiff(alg::OrdinaryDiffEqImplicitAlgorithm{CS,AD}) where {CS,AD} = AD
242-
alg_autodiff(alg::ExponentialAlgorithm) = alg.autodiff
271+
function alg_autodiff(alg::Union{OrdinaryDiffEqExponentialAlgorithm{CS,AD},
272+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD}}) where {CS,AD}
273+
AD
274+
end
275+
243276
# alg_autodiff(alg::CompositeAlgorithm) = alg_autodiff(alg.algs[alg.current_alg])
244277
get_current_alg_autodiff(alg, cache) = alg_autodiff(alg)
245278
get_current_alg_autodiff(alg::CompositeAlgorithm, cache) = alg_autodiff(alg.algs[cache.current])
246279

247280
alg_difftype(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD,FDT,ST,CJ},
248281
OrdinaryDiffEqImplicitAlgorithm{CS,AD,FDT,ST,CJ},
249-
OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ},
282+
OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ},
283+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD,FDT,ST,CJ},
250284
DAEAlgorithm{CS,AD,FDT,ST,CJ}}) where {CS,AD,FDT,ST,CJ} = FDT
251285

252286
standardtag(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD,FDT,ST,CJ},
253287
OrdinaryDiffEqImplicitAlgorithm{CS,AD,FDT,ST,CJ},
254-
OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ},
288+
OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ},
289+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD,FDT,ST,CJ},
255290
DAEAlgorithm{CS,AD,FDT,ST,CJ}}) where {CS,AD,FDT,ST,CJ} = ST
256291

257292
concrete_jac(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD,FDT,ST,CJ},
258293
OrdinaryDiffEqImplicitAlgorithm{CS,AD,FDT,ST,CJ},
259-
OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ},
294+
OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ},
295+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD,FDT,ST,CJ},
260296
DAEAlgorithm{CS,AD,FDT,ST,CJ}}) where {CS,AD,FDT,ST,CJ} = CJ
261297

262298
alg_extrapolates(alg::Union{OrdinaryDiffEqAlgorithm,DAEAlgorithm}) = false

src/algorithms.jl

+27-30
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ abstract type OrdinaryDiffEqRosenbrockAlgorithm{CS,AD,FDT,ST,CJ} <: OrdinaryDif
1212
const NewtonAlgorithm = Union{OrdinaryDiffEqNewtonAlgorithm,OrdinaryDiffEqNewtonAdaptiveAlgorithm}
1313
const RosenbrockAlgorithm = Union{OrdinaryDiffEqRosenbrockAlgorithm,OrdinaryDiffEqRosenbrockAdaptiveAlgorithm}
1414

15-
abstract type OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ} <: OrdinaryDiffEqAlgorithm end
16-
abstract type OrdinaryDiffEqAdaptiveExponentialAlgorithm{FDT,ST,CJ} <: OrdinaryDiffEqAdaptiveAlgorithm end
17-
abstract type OrdinaryDiffEqLinearExponentialAlgorithm <: OrdinaryDiffEqExponentialAlgorithm{Val{:forward},Val{true},nothing} end
15+
abstract type OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ} <: OrdinaryDiffEqAlgorithm end
16+
abstract type OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD,FDT,ST,CJ} <: OrdinaryDiffEqAdaptiveAlgorithm end
17+
abstract type OrdinaryDiffEqLinearExponentialAlgorithm <: OrdinaryDiffEqExponentialAlgorithm{0,false,Val{:forward},Val{true},nothing} end
1818
const ExponentialAlgorithm = Union{OrdinaryDiffEqExponentialAlgorithm,OrdinaryDiffEqAdaptiveExponentialAlgorithm}
1919

2020
abstract type OrdinaryDiffEqAdamsVarOrderVarStepAlgorithm <: OrdinaryDiffEqAdaptiveAlgorithm end
@@ -3130,7 +3130,7 @@ end
31303130

31313131
struct MagnusAdapt4 <: OrdinaryDiffEqAdaptiveAlgorithm end
31323132

3133-
struct LinearExponential <: OrdinaryDiffEqExponentialAlgorithm{Val{:forward},Val{true},nothing}
3133+
struct LinearExponential <: OrdinaryDiffEqExponentialAlgorithm{1,false,Val{:forward},Val{true},nothing}
31343134
krylov::Symbol
31353135
m::Int
31363136
iop::Int
@@ -3856,50 +3856,47 @@ RosenbrockW6S4OS(;chunk_size=Val{0}(),autodiff=true, standardtag = Val{true}(),
38563856

38573857
for Alg in [:LawsonEuler, :NorsettEuler, :ETDRK2, :ETDRK3, :ETDRK4, :HochOst4]
38583858

3859-
"""
3860-
Hochbruck, Marlis, and Alexander Ostermann. “Exponential Integrators.” Acta
3861-
Numerica 19 (2010): 209–86. doi:10.1017/S0962492910000048.
3862-
"""
3863-
@eval struct $Alg{FDT,ST,CJ} <: OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ}
3864-
krylov::Bool
3865-
m::Int
3866-
iop::Int
3867-
autodiff::Bool
3868-
chunksize::Int
3869-
end
3870-
@eval $Alg(;krylov=false, m=30, iop=0, autodiff=true, standardtag = Val{true}(), concrete_jac = nothing, chunksize=0,
3871-
diff_type = Val{:forward}) = $Alg{diff_type,_unwrap_val(standardtag),_unwrap_val(concrete_jac)}(krylov, m, iop, _unwrap_val(autodiff),
3872-
chunksize)
3859+
"""
3860+
Hochbruck, Marlis, and Alexander Ostermann. “Exponential Integrators.” Acta
3861+
Numerica 19 (2010): 209–86. doi:10.1017/S0962492910000048.
3862+
"""
3863+
@eval struct $Alg{CS,AD,FDT,ST,CJ} <: OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ}
3864+
krylov::Bool
3865+
m::Int
3866+
iop::Int
3867+
end
3868+
@eval $Alg(; krylov=false, m=30, iop=0, autodiff=true, standardtag=Val{true}(), concrete_jac=nothing, chunk_size=Val{0}(),
3869+
diff_type=Val{:forward}) = $Alg{_unwrap_val(chunk_size),_unwrap_val(autodiff),
3870+
diff_type,_unwrap_val(standardtag),_unwrap_val(concrete_jac)}(krylov, m, iop)
38733871
end
38743872
const ETD1 = NorsettEuler # alias
38753873
for Alg in [:Exprb32, :Exprb43]
3876-
@eval struct $Alg{FDT,ST,CJ} <: OrdinaryDiffEqAdaptiveExponentialAlgorithm{FDT,ST,CJ}
3874+
@eval struct $Alg{CS,AD,FDT,ST,CJ} <: OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD,FDT,ST,CJ}
38773875
m::Int
38783876
iop::Int
3879-
autodiff::Bool
3880-
chunksize::Int
38813877
end
3882-
@eval $Alg(;m=30, iop=0, autodiff=true, standardtag = Val{true}(), concrete_jac = nothing, chunksize=0,
3883-
diff_type = Val{:forward}) = $Alg{diff_type,_unwrap_val(standardtag),_unwrap_val(concrete_jac)}(m, iop, _unwrap_val(autodiff), chunksize)
3878+
@eval $Alg(;m=30, iop=0, autodiff=true, standardtag = Val{true}(), concrete_jac = nothing, chunk_size=Val{0}(),
3879+
diff_type = Val{:forward}) = $Alg{_unwrap_val(chunk_size),_unwrap_val(autodiff),
3880+
diff_type,_unwrap_val(standardtag),
3881+
_unwrap_val(concrete_jac)}(m, iop)
38843882
end
38853883
for Alg in [:Exp4, :EPIRK4s3A, :EPIRK4s3B, :EPIRK5s3, :EXPRB53s3, :EPIRK5P1, :EPIRK5P2]
3886-
@eval struct $Alg{FDT,ST,CJ} <: OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ}
3884+
@eval struct $Alg{CS,AD,FDT,ST,CJ} <: OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ}
38873885
adaptive_krylov::Bool
38883886
m::Int
38893887
iop::Int
3890-
autodiff::Bool
3891-
chunksize::Int
38923888
end
38933889
@eval $Alg(;adaptive_krylov=true, m=30, iop=0, autodiff=true, standardtag = Val{true}(), concrete_jac = nothing,
3894-
chunksize=0, diff_type = Val{:forward}) =
3895-
$Alg{diff_type,_unwrap_val(standardtag),_unwrap_val(concrete_jac)}(adaptive_krylov, m, iop, _unwrap_val(autodiff), chunksize)
3890+
chunk_size=Val{0}(), diff_type = Val{:forward}) =
3891+
$Alg{_unwrap_val(chunk_size),_unwrap_val(autodiff),diff_type,
3892+
_unwrap_val(standardtag),_unwrap_val(concrete_jac)}(adaptive_krylov, m, iop)
38963893
end
3897-
struct SplitEuler <: OrdinaryDiffEqExponentialAlgorithm{Val{:forward},Val{true},nothing} end
3894+
struct SplitEuler <: OrdinaryDiffEqExponentialAlgorithm{0,false,Val{:forward},Val{true},nothing} end
38983895
"""
38993896
ETD2: Exponential Runge-Kutta Method
39003897
Second order Exponential Time Differencing method (in development).
39013898
"""
3902-
struct ETD2 <: OrdinaryDiffEqExponentialAlgorithm{Val{:forward},Val{true},nothing} end
3899+
struct ETD2 <: OrdinaryDiffEqExponentialAlgorithm{0,false,Val{:forward},Val{true},nothing} end
39033900

39043901
#########################################
39053902

test/interface/norecompile.jl

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ if VERSION >= v"1.8"
2828
@test t2 < t4
2929
end
3030

31+
solve(prob, EPIRK4s3A(), dt=1e-1)
32+
3133
function f_oop(u, p, t)
3234
[0.2u[1], 0.4u[2]]
3335
end

0 commit comments

Comments
 (0)