Skip to content

Commit 87750c4

Browse files
committed
Dont dispatch on init and solve!
1 parent ca15839 commit 87750c4

17 files changed

+175
-111
lines changed

src/abstract_types.jl

+30-27
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
function __internal_init end
2+
function __internal_solve! end
3+
14
"""
25
AbstractDescentAlgorithm
36
@@ -10,15 +13,15 @@ in which case we use the normal form equations ``JᵀJ δu = Jᵀ fu``. Note tha
1013
factorization is often the faster choice, but it is not as numerically stable as the least
1114
squares solver.
1215
13-
### `SciMLBase.init` specification
16+
### `__internal_init` specification
1417
1518
```julia
16-
SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractDescentAlgorithm, J, fu, u;
19+
__internal_init(prob::NonlinearProblem{uType, iip}, alg::AbstractDescentAlgorithm, J, fu, u;
1720
pre_inverted::Val{INV} = Val(false), linsolve_kwargs = (;), abstol = nothing,
1821
reltol = nothing, alias_J::Bool = true, shared::Val{N} = Val(1),
1922
kwargs...) where {INV, N, uType, iip} --> AbstractDescentCache
2023
21-
SciMLBase.init(prob::NonlinearLeastSquaresProblem{uType, iip},
24+
__internal_init(prob::NonlinearLeastSquaresProblem{uType, iip},
2225
alg::AbstractDescentAlgorithm, J, fu, u; pre_inverted::Val{INV} = Val(false),
2326
linsolve_kwargs = (;), abstol = nothing, reltol = nothing, alias_J::Bool = true,
2427
shared::Val{N} = Val(1), kwargs...) where {INV, N, uType, iip} --> AbstractDescentCache
@@ -59,10 +62,10 @@ get_linear_solver(alg::AbstractDescentAlgorithm) = __getproperty(alg, Val(:linso
5962
6063
Abstract Type for all Descent Caches.
6164
62-
### `SciMLBase.solve!` specification
65+
### `__internal_solve!` specification
6366
6467
```julia
65-
δu, success, intermediates = SciMLBase.solve!(cache::AbstractDescentCache, J, fu, u,
68+
δu, success, intermediates = __internal_solve!(cache::AbstractDescentCache, J, fu, u,
6669
idx::Val; skip_solve::Bool = false, kwargs...)
6770
```
6871
@@ -112,10 +115,10 @@ end
112115
113116
Abstract Type for all Line Search Algorithms used in NonlinearSolve.jl.
114117
115-
### `SciMLBase.init` specification
118+
### `__internal_init` specification
116119
117120
```julia
118-
SciMLBase.init(prob::AbstractNonlinearProblem,
121+
__internal_init(prob::AbstractNonlinearProblem,
119122
alg::AbstractNonlinearSolveLineSearchAlgorithm, f::F, fu, u, p, args...;
120123
internalnorm::IN = DEFAULT_NORM,
121124
kwargs...) where {F, IN} --> AbstractNonlinearSolveLineSearchCache
@@ -128,10 +131,10 @@ abstract type AbstractNonlinearSolveLineSearchAlgorithm end
128131
129132
Abstract Type for all Line Search Caches used in NonlinearSolve.jl.
130133
131-
### `SciMLBase.solve!` specification
134+
### `__internal_solve!` specification
132135
133136
```julia
134-
SciMLBase.solve!(cache::AbstractNonlinearSolveLineSearchCache, u, du; kwargs...)
137+
__internal_solve!(cache::AbstractNonlinearSolveLineSearchCache, u, du; kwargs...)
135138
```
136139
137140
Returns 2 values:
@@ -226,10 +229,10 @@ abstract type AbstractLinearSolverCache <: Function end
226229
227230
Abstract Type for Damping Functions in DampedNewton.
228231
229-
### `SciMLBase.init` specification
232+
### `__internal_init` specification
230233
231234
```julia
232-
SciMLBase.init(prob::AbstractNonlinearProblem, f::AbstractDampingFunction, initial_damping,
235+
__internal_init(prob::AbstractNonlinearProblem, f::AbstractDampingFunction, initial_damping,
233236
J, fu, u, args...; internal_norm = DEFAULT_NORM,
234237
kwargs...) --> AbstractDampingFunctionCache
235238
```
@@ -254,10 +257,10 @@ Abstract Type for the Caches created by AbstractDampingFunctions
254257
- `(cache::AbstractDampingFunctionCache)(::Nothing)`: returns the damping factor. The type
255258
of the damping factor returned from `solve!` is guaranteed to be the same as this.
256259
257-
### `SciMLBase.solve!` specification
260+
### `__internal_solve!` specification
258261
259262
```julia
260-
SciMLBase.solve!(cache::AbstractDampingFunctionCache, J, fu, args...; kwargs...)
263+
__internal_solve!(cache::AbstractDampingFunctionCache, J, fu, args...; kwargs...)
261264
```
262265
263266
Returns the damping factor.
@@ -310,10 +313,10 @@ Abstract Type for all Jacobian Initialization Algorithms used in NonlinearSolve.
310313
- `jacobian_initialized_preinverted(alg)`: whether or not the Jacobian is initialized
311314
preinverted. Defaults to `false`.
312315
313-
### `SciMLBase.init` specification
316+
### `__internal_init` specification
314317
315318
```julia
316-
SciMLBase.init(prob::AbstractNonlinearProblem, alg::AbstractJacobianInitialization,
319+
__internal_init(prob::AbstractNonlinearProblem, alg::AbstractJacobianInitialization,
317320
solver, f::F, fu, u, p; linsolve = missing, internalnorm::IN = DEFAULT_NORM,
318321
kwargs...)
319322
```
@@ -345,10 +348,10 @@ Abstract Type for all Approximate Jacobian Update Rules used in NonlinearSolve.j
345348
346349
- `store_inverse_jacobian(alg)`: Return `INV`
347350
348-
### `SciMLBase.init` specification
351+
### `__internal_init` specification
349352
350353
```julia
351-
SciMLBase.init(prob::AbstractNonlinearProblem,
354+
__internal_init(prob::AbstractNonlinearProblem,
352355
alg::AbstractApproximateJacobianUpdateRule, J, fu, u, du, args...;
353356
internalnorm::F = DEFAULT_NORM,
354357
kwargs...) where {F} --> AbstractApproximateJacobianUpdateRuleCache{INV}
@@ -367,10 +370,10 @@ Abstract Type for all Approximate Jacobian Update Rule Caches used in NonlinearS
367370
368371
- `store_inverse_jacobian(alg)`: Return `INV`
369372
370-
### `SciMLBase.solve!` specification
373+
### `__internal_solve!` specification
371374
372375
```julia
373-
SciMLBase.solve!(cache::AbstractApproximateJacobianUpdateRuleCache, J, fu, u, du;
376+
__internal_solve!(cache::AbstractApproximateJacobianUpdateRuleCache, J, fu, u, du;
374377
kwargs...) --> J / J⁻¹
375378
```
376379
"""
@@ -383,17 +386,17 @@ store_inverse_jacobian(::AbstractApproximateJacobianUpdateRuleCache{INV}) where
383386
384387
Condition for resetting the Jacobian in Quasi-Newton's methods.
385388
386-
### `SciMLBase.init` specification
389+
### `__internal_init` specification
387390
388391
```julia
389-
SciMLBase.init(alg::AbstractResetCondition, J, fu, u, du, args...;
392+
__internal_init(alg::AbstractResetCondition, J, fu, u, du, args...;
390393
kwargs...) --> ResetCache
391394
```
392395
393-
### `SciMLBase.solve!` specification
396+
### `__internal_solve!` specification
394397
395398
```julia
396-
SciMLBase.solve!(cache::ResetCache, J, fu, u, du) --> Bool
399+
__internal_solve!(cache::ResetCache, J, fu, u, du) --> Bool
397400
```
398401
"""
399402
abstract type AbstractResetCondition end
@@ -403,10 +406,10 @@ abstract type AbstractResetCondition end
403406
404407
Abstract Type for all Trust Region Methods used in NonlinearSolve.jl.
405408
406-
### `SciMLBase.init` specification
409+
### `__internal_init` specification
407410
408411
```julia
409-
SciMLBase.init(prob::AbstractNonlinearProblem, alg::AbstractTrustRegionMethod,
412+
__internal_init(prob::AbstractNonlinearProblem, alg::AbstractTrustRegionMethod,
410413
f::F, fu, u, p, args...; internalnorm::IF = DEFAULT_NORM,
411414
kwargs...) where {F, IF} --> AbstractTrustRegionMethodCache
412415
```
@@ -423,10 +426,10 @@ Abstract Type for all Trust Region Method Caches used in NonlinearSolve.jl.
423426
- `last_step_accepted(cache)`: whether or not the last step was accepted. Defaults to
424427
`cache.last_step_accepted`. Should if overloaded if the field is not present.
425428
426-
### `SciMLBase.solve!` specification
429+
### `__internal_solve!` specification
427430
428431
```julia
429-
SciMLBase.solve!(cache::AbstractTrustRegionMethodCache, J, fu, u, δu, descent_stats)
432+
__internal_solve!(cache::AbstractTrustRegionMethodCache, J, fu, u, δu, descent_stats)
430433
```
431434
432435
Returns `last_step_accepted`, updated `u_cache` and `fu_cache`. If the last step was

src/algorithms/broyden.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function reinit_cache!(cache::NoChangeInStateResetCache, args...; kwargs...)
9898
cache.steps_since_change_dfu = 0
9999
end
100100

101-
function SciMLBase.init(alg::NoChangeInStateReset, J, fu, u, du, args...; kwargs...)
101+
function __internal_init(alg::NoChangeInStateReset, J, fu, u, du, args...; kwargs...)
102102
if alg.check_dfu
103103
@bb dfu = copy(fu)
104104
else
@@ -110,7 +110,7 @@ function SciMLBase.init(alg::NoChangeInStateReset, J, fu, u, du, args...; kwargs
110110
0)
111111
end
112112

113-
function SciMLBase.solve!(cache::NoChangeInStateResetCache, J, fu, u, du)
113+
function __internal_solve!(cache::NoChangeInStateResetCache, J, fu, u, du)
114114
reset_tolerance = cache.reset_tolerance
115115
if cache.check_du
116116
if any(@closure(x->abs(x) reset_tolerance), du)
@@ -168,7 +168,7 @@ Broyden Update Rule corresponding to "good broyden's method" [broyden1965class](
168168
internalnorm
169169
end
170170

171-
function SciMLBase.init(prob::AbstractNonlinearProblem,
171+
function __internal_init(prob::AbstractNonlinearProblem,
172172
alg::Union{GoodBroydenUpdateRule, BadBroydenUpdateRule}, J, fu, u, du, args...;
173173
internalnorm::F = DEFAULT_NORM, kwargs...) where {F}
174174
@bb J⁻¹dfu = similar(u)
@@ -187,7 +187,7 @@ function SciMLBase.init(prob::AbstractNonlinearProblem,
187187
return BroydenUpdateRuleCache{mode}(J⁻¹dfu, dfu, u_cache, du_cache, internalnorm)
188188
end
189189

190-
function SciMLBase.solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹, fu, u, du) where {mode}
190+
function __internal_solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹, fu, u, du) where {mode}
191191
T = eltype(u)
192192
@bb @. cache.dfu = fu - cache.dfu
193193
@bb cache.J⁻¹dfu = J⁻¹ × vec(cache.dfu)
@@ -205,7 +205,7 @@ function SciMLBase.solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹, fu, u, du
205205
return J⁻¹
206206
end
207207

208-
function SciMLBase.solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹::Diagonal, fu, u,
208+
function __internal_solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹::Diagonal, fu, u,
209209
du) where {mode}
210210
T = eltype(u)
211211
@bb @. cache.dfu = fu - cache.dfu

src/algorithms/klement.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct IllConditionedJacobianReset <: AbstractResetCondition end
6464
condition_number_threshold
6565
end
6666

67-
function SciMLBase.init(alg::IllConditionedJacobianReset, J, fu, u, du, args...; kwargs...)
67+
function __internal_init(alg::IllConditionedJacobianReset, J, fu, u, du, args...; kwargs...)
6868
condition_number_threshold = if J isa AbstractMatrix
6969
inv(eps(real(eltype(J)))^(1 // 2))
7070
else
@@ -73,7 +73,7 @@ function SciMLBase.init(alg::IllConditionedJacobianReset, J, fu, u, du, args...;
7373
return IllConditionedJacobianResetCache(condition_number_threshold)
7474
end
7575

76-
function SciMLBase.solve!(cache::IllConditionedJacobianResetCache, J, fu, u, du)
76+
function __internal_solve!(cache::IllConditionedJacobianResetCache, J, fu, u, du)
7777
J isa Number && return iszero(J)
7878
J isa Diagonal && return any(iszero, diag(J))
7979
J isa AbstractMatrix && return cond(J) cache.condition_number_threshold
@@ -98,7 +98,7 @@ Update rule for [`Klement`](@ref).
9898
fu_cache
9999
end
100100

101-
function SciMLBase.init(prob::AbstractNonlinearProblem, alg::KlementUpdateRule, J, fu, u,
101+
function __internal_init(prob::AbstractNonlinearProblem, alg::KlementUpdateRule, J, fu, u,
102102
du, args...; kwargs...)
103103
@bb Jdu = similar(fu)
104104
if J isa Diagonal || J isa Number
@@ -112,14 +112,14 @@ function SciMLBase.init(prob::AbstractNonlinearProblem, alg::KlementUpdateRule,
112112
return KlementUpdateRuleCache(Jdu, J_cache, J_cache_2, Jdu_cache, fu_cache)
113113
end
114114

115-
function SciMLBase.solve!(cache::KlementUpdateRuleCache, J::Number, fu, u, du)
115+
function __internal_solve!(cache::KlementUpdateRuleCache, J::Number, fu, u, du)
116116
Jdu = J^2 * du^2
117117
J = J + ((fu - cache.fu_cache - J * du) / ifelse(iszero(Jdu), 1e-5, Jdu)) * du * J^2
118118
cache.fu_cache = fu
119119
return J
120120
end
121121

122-
function SciMLBase.solve!(cache::KlementUpdateRuleCache, J_::Diagonal, fu, u, du)
122+
function __internal_solve!(cache::KlementUpdateRuleCache, J_::Diagonal, fu, u, du)
123123
T = eltype(u)
124124
J = _restructure(u, diag(J_))
125125
@bb @. cache.Jdu = (J^2) * (du^2)
@@ -129,7 +129,7 @@ function SciMLBase.solve!(cache::KlementUpdateRuleCache, J_::Diagonal, fu, u, du
129129
return Diagonal(vec(J))
130130
end
131131

132-
function SciMLBase.solve!(cache::KlementUpdateRuleCache, J::AbstractMatrix, fu, u, du)
132+
function __internal_solve!(cache::KlementUpdateRuleCache, J::AbstractMatrix, fu, u, du)
133133
T = eltype(u)
134134
@bb @. cache.J_cache = J'^2
135135
@bb @. cache.Jdu = du^2

src/algorithms/lbroyden.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ end
3838

3939
jacobian_initialized_preinverted(::BroydenLowRankInitialization) = true
4040

41-
function SciMLBase.init(prob::AbstractNonlinearProblem,
41+
function __internal_init(prob::AbstractNonlinearProblem,
4242
alg::BroydenLowRankInitialization{T}, solver, f::F, fu, u, p; maxiters = 1000,
4343
internalnorm::IN = DEFAULT_NORM, kwargs...) where {T, F, IN}
4444
if u isa Number # Use the standard broyden
45-
return init(prob, IdentityInitialization(true, FullStructure()), solver, f, fu, u,
45+
return __internal_init(prob, IdentityInitialization(true, FullStructure()), solver,
46+
f, fu, u,
4647
p; maxiters, kwargs...)
4748
end
4849
# Pay to cost of slightly more allocations to prevent type-instability for StaticArrays

src/algorithms/levenberg_marquardt.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function returns_norm_form_damping(::Union{LevenbergMarquardtDampingFunction,
9797
return true
9898
end
9999

100-
function SciMLBase.init(prob::AbstractNonlinearProblem,
100+
function __internal_init(prob::AbstractNonlinearProblem,
101101
f::LevenbergMarquardtDampingFunction, initial_damping, J, fu, u, ::Val{NF};
102102
internalnorm::F = DEFAULT_NORM, kwargs...) where {F, NF}
103103
T = promote_type(eltype(u), eltype(fu))
@@ -115,7 +115,7 @@ end
115115

116116
(damping::LevenbergMarquardtDampingCache)(::Nothing) = damping.J_damped
117117

118-
function SciMLBase.solve!(damping::LevenbergMarquardtDampingCache, J, fu, ::Val{false};
118+
function __internal_solve!(damping::LevenbergMarquardtDampingCache, J, fu, ::Val{false};
119119
kwargs...)
120120
if __can_setindex(damping.J_diag_cache)
121121
sum!(abs2, _vec(damping.J_diag_cache), J')
@@ -129,7 +129,7 @@ function SciMLBase.solve!(damping::LevenbergMarquardtDampingCache, J, fu, ::Val{
129129
return damping.J_damped
130130
end
131131

132-
function SciMLBase.solve!(damping::LevenbergMarquardtDampingCache, JᵀJ, fu, ::Val{true};
132+
function __internal_solve!(damping::LevenbergMarquardtDampingCache, JᵀJ, fu, ::Val{true};
133133
kwargs...)
134134
damping.DᵀD = __update_LM_diagonal!!(damping.DᵀD, JᵀJ)
135135
@bb @. damping.J_damped = damping.λ * damping.DᵀD

src/algorithms/pseudo_transient.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function requires_normal_form_rhs(cache::Union{SwitchedEvolutionRelaxation,
5252
return false
5353
end
5454

55-
function SciMLBase.init(prob::AbstractNonlinearProblem, f::SwitchedEvolutionRelaxation,
55+
function __internal_init(prob::AbstractNonlinearProblem, f::SwitchedEvolutionRelaxation,
5656
initial_damping, J, fu, u, args...; internalnorm::F = DEFAULT_NORM,
5757
kwargs...) where {F}
5858
T = promote_type(eltype(u), eltype(fu))
@@ -62,7 +62,7 @@ end
6262

6363
(damping::SwitchedEvolutionRelaxationCache)(::Nothing) = damping.α⁻¹
6464

65-
function SciMLBase.solve!(damping::SwitchedEvolutionRelaxationCache, J, fu, args...;
65+
function __internal_solve!(damping::SwitchedEvolutionRelaxationCache, J, fu, args...;
6666
kwargs...)
6767
res_norm = damping.internalnorm(fu)
6868
damping.α⁻¹ *= res_norm / damping.res_norm

0 commit comments

Comments
 (0)