@@ -5,8 +5,8 @@ SciMLBase.allows_arbitrary_number_types(alg::Union{OrdinaryDiffEqAlgorithm,DAEAl
5
5
SciMLBase. allowscomplex (alg:: Union{OrdinaryDiffEqAlgorithm,DAEAlgorithm,FunctionMap} ) = true
6
6
SciMLBase. isdiscrete (alg:: FunctionMap ) = true
7
7
SciMLBase. forwarddiffs_model (alg:: Union {OrdinaryDiffEqAdaptiveImplicitAlgorithm,
8
- DAEAlgorithm,OrdinaryDiffEqImplicitAlgorithm,
9
- ExponentialAlgorithm}) = alg_autodiff (alg)
8
+ DAEAlgorithm,OrdinaryDiffEqImplicitAlgorithm,
9
+ ExponentialAlgorithm}) = alg_autodiff (alg)
10
10
SciMLBase. forwarddiffs_model_time (alg:: RosenbrockAlgorithm ) = true
11
11
12
12
# isadaptive is defined below.
@@ -162,21 +162,30 @@ get_chunksize(alg::OrdinaryDiffEqAlgorithm) = error("This algorithm does not hav
162
162
get_chunksize (alg:: OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD} ) where {CS,AD} = Val (CS)
163
163
get_chunksize (alg:: OrdinaryDiffEqImplicitAlgorithm{CS,AD} ) where {CS,AD} = Val (CS)
164
164
get_chunksize (alg:: DAEAlgorithm{CS,AD} ) where {CS,AD} = Val (CS)
165
- get_chunksize (alg:: ExponentialAlgorithm ) = Val (alg. chunksize)
165
+ function get_chunksize (alg:: Union {OrdinaryDiffEqExponentialAlgorithm{CS,AD},
166
+ OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD}}) where {CS,AD}
167
+ Val (CS)
168
+ end
166
169
167
170
get_chunksize_int (alg:: OrdinaryDiffEqAlgorithm ) = error (" This algorithm does not have a chunk size defined." )
168
171
get_chunksize_int (alg:: OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD} ) where {CS,AD} = CS
169
172
get_chunksize_int (alg:: OrdinaryDiffEqImplicitAlgorithm{CS,AD} ) where {CS,AD} = CS
170
173
get_chunksize_int (alg:: DAEAlgorithm{CS,AD} ) where {CS,AD} = CS
171
- get_chunksize_int (alg:: ExponentialAlgorithm ) = alg. chunksize
174
+ function get_chunksize_int (alg:: Union {OrdinaryDiffEqExponentialAlgorithm{CS,AD},
175
+ OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD}}) where {CS,AD}
176
+ CS
177
+ end
172
178
# get_chunksize(alg::CompositeAlgorithm) = get_chunksize(alg.algs[alg.current_alg])
173
179
174
180
function DiffEqBase. prepare_alg (alg:: Union {OrdinaryDiffEqAdaptiveImplicitAlgorithm{0 ,AD,FDT},
175
181
OrdinaryDiffEqImplicitAlgorithm{0 ,AD,FDT},
176
- DAEAlgorithm{0 ,AD,FDT}}, u0:: AbstractArray{T} , p, prob) where {AD,FDT,T}
182
+ DAEAlgorithm{0 ,AD,FDT},OrdinaryDiffEqExponentialAlgorithm{0 ,AD,FDT}}, u0:: AbstractArray{T} ,
183
+ p, prob) where {AD,FDT,T}
177
184
alg isa OrdinaryDiffEqImplicitExtrapolationAlgorithm && return alg # remake fails, should get fixed
178
185
179
- if alg. linsolve === nothing
186
+ if alg isa OrdinaryDiffEqExponentialAlgorithm
187
+ linsolve = nothing
188
+ elseif alg. linsolve === nothing
180
189
if (prob. f isa ODEFunction && prob. f. f isa SciMLBase. AbstractDiffEqOperator)
181
190
linsolve = LinearSolve. defaultalg (prob. f. f, u0)
182
191
elseif (prob. f isa SplitFunction && prob. f. f1. f isa SciMLBase. AbstractDiffEqOperator)
@@ -187,8 +196,8 @@ function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorit
187
196
linsolve = KrylovJL ()
188
197
end
189
198
elseif (prob isa ODEProblem || prob isa DDEProblem) && (prob. f. mass_matrix === nothing ||
190
- (prob. f. mass_matrix != = nothing &&
191
- ! (typeof (prob. f. jac_prototype) <: SciMLBase.AbstractDiffEqOperator )))
199
+ (prob. f. mass_matrix != = nothing &&
200
+ ! (typeof (prob. f. jac_prototype) <: SciMLBase.AbstractDiffEqOperator )))
192
201
linsolve = LinearSolve. defaultalg (prob. f. jac_prototype, u0)
193
202
else
194
203
# If mm is a sparse matrix and A is a DiffEqArrayOperator, then let linear
@@ -202,8 +211,12 @@ function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorit
202
211
# If norecompile mode or very large bitsize, like a dual number u0 already, then
203
212
# don't use a large chunksize as it will either error or not be beneficial
204
213
if (isbitstype (T) && sizeof (T) > 24 ) || (prob. f isa ODEFunction && prob. f. f isa
205
- FunctionWrappersWrappers. FunctionWrappersWrapper)
206
- return remake (alg, chunk_size= Val {1} (), linsolve= linsolve)
214
+ FunctionWrappersWrappers. FunctionWrappersWrapper)
215
+ if alg isa OrdinaryDiffEqExponentialAlgorithm
216
+ return remake (alg, chunk_size= Val {1} ())
217
+ else
218
+ return remake (alg, chunk_size= Val {1} (), linsolve= linsolve)
219
+ end
207
220
end
208
221
209
222
L = ArrayInterface. known_length (typeof (u0))
@@ -218,13 +231,29 @@ function DiffEqBase.prepare_alg(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorit
218
231
end
219
232
220
233
cs = ForwardDiff. pickchunksize (x)
221
- remake (alg, chunk_size= Val {cs} (), linsolve= linsolve)
234
+
235
+ if alg isa OrdinaryDiffEqExponentialAlgorithm
236
+ return remake (alg, chunk_size= Val {cs} ())
237
+ else
238
+ return remake (alg, chunk_size= Val {cs} (), linsolve= linsolve)
239
+ end
222
240
else # statically sized
223
241
cs = pick_static_chunksize (Val {L} ())
224
- remake (alg, chunk_size= cs, linsolve= linsolve)
242
+ if alg isa OrdinaryDiffEqExponentialAlgorithm
243
+ return remake (alg, chunk_size= cs)
244
+ else
245
+ return remake (alg, chunk_size= cs, linsolve= linsolve)
246
+ end
225
247
end
226
248
end
227
249
250
+ # Linear Exponential doesn't have any of the AD stuff
251
+ function DiffEqBase. prepare_alg (alg:: Union {ETD2,SplitEuler,LinearExponential,
252
+ OrdinaryDiffEqLinearExponentialAlgorithm}, u0:: AbstractArray ,
253
+ p, prob)
254
+ alg
255
+ end
256
+
228
257
@generated function pick_static_chunksize (:: Val{chunksize} ) where {chunksize}
229
258
x = ForwardDiff. pickchunksize (chunksize)
230
259
:(Val {$x} ())
@@ -239,24 +268,31 @@ alg_autodiff(alg::OrdinaryDiffEqAlgorithm) = error("This algorithm does not have
239
268
alg_autodiff (alg:: OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD} ) where {CS,AD} = AD
240
269
alg_autodiff (alg:: DAEAlgorithm{CS,AD} ) where {CS,AD} = AD
241
270
alg_autodiff (alg:: OrdinaryDiffEqImplicitAlgorithm{CS,AD} ) where {CS,AD} = AD
242
- alg_autodiff (alg:: ExponentialAlgorithm ) = alg. autodiff
271
+ function alg_autodiff (alg:: Union {OrdinaryDiffEqExponentialAlgorithm{CS,AD},
272
+ OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD}}) where {CS,AD}
273
+ AD
274
+ end
275
+
243
276
# alg_autodiff(alg::CompositeAlgorithm) = alg_autodiff(alg.algs[alg.current_alg])
244
277
get_current_alg_autodiff (alg, cache) = alg_autodiff (alg)
245
278
get_current_alg_autodiff (alg:: CompositeAlgorithm , cache) = alg_autodiff (alg. algs[cache. current])
246
279
247
280
alg_difftype (alg:: Union {OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD,FDT,ST,CJ},
248
281
OrdinaryDiffEqImplicitAlgorithm{CS,AD,FDT,ST,CJ},
249
- OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ},
282
+ OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ},
283
+ OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD,FDT,ST,CJ},
250
284
DAEAlgorithm{CS,AD,FDT,ST,CJ}}) where {CS,AD,FDT,ST,CJ} = FDT
251
285
252
286
standardtag (alg:: Union {OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD,FDT,ST,CJ},
253
287
OrdinaryDiffEqImplicitAlgorithm{CS,AD,FDT,ST,CJ},
254
- OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ},
288
+ OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ},
289
+ OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD,FDT,ST,CJ},
255
290
DAEAlgorithm{CS,AD,FDT,ST,CJ}}) where {CS,AD,FDT,ST,CJ} = ST
256
291
257
292
concrete_jac (alg:: Union {OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS,AD,FDT,ST,CJ},
258
293
OrdinaryDiffEqImplicitAlgorithm{CS,AD,FDT,ST,CJ},
259
- OrdinaryDiffEqExponentialAlgorithm{FDT,ST,CJ},
294
+ OrdinaryDiffEqExponentialAlgorithm{CS,AD,FDT,ST,CJ},
295
+ OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS,AD,FDT,ST,CJ},
260
296
DAEAlgorithm{CS,AD,FDT,ST,CJ}}) where {CS,AD,FDT,ST,CJ} = CJ
261
297
262
298
alg_extrapolates (alg:: Union{OrdinaryDiffEqAlgorithm,DAEAlgorithm} ) = false
0 commit comments